Skip to content

Commit

Permalink
make Randomizer thread aware, which makes it deterministic even when …
Browse files Browse the repository at this point in the history
…running with threads. #1141
  • Loading branch information
rbouckaert committed Mar 19, 2024
1 parent 6ff6d57 commit f070286
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/beast/base/util/MersenneTwisterFast.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public class MersenneTwisterFast implements Serializable {
/**
* Constructor using the time of day as default seed.
*/
private MersenneTwisterFast() {
MersenneTwisterFast() {
this(System.currentTimeMillis() + seedAdditive_);
seedAdditive_ += nextInt();
}
Expand Down
81 changes: 55 additions & 26 deletions src/beast/base/util/Randomizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@

package beast.base.util;

import java.util.HashMap;
import java.util.Map;

/**
* Handy utility functions which have some Mathematical relevance.
Expand Down Expand Up @@ -82,7 +84,7 @@ private Randomizer() {
*/
public static int randomChoice(double[] cf) {

double U = random.nextDouble();
double U = random().nextDouble();

int s;
if (U <= cf[0]) {
Expand All @@ -98,15 +100,38 @@ public static int randomChoice(double[] cf) {
return s;
}

/**

private static Map<String, MersenneTwisterFast> randomMap = new HashMap<>();

private static MersenneTwisterFast random() {
final String threadName = Thread.currentThread().getName();
if (threadName == null) {
// no thread name used, so use default MersenneTwisterFast
return random;
}

// retrieve MersenneTwisterFast from the randomMap
// creating a new one if it is not already in the randomMap
MersenneTwisterFast r = randomMap.get(threadName);
if (r != null) {
return r;
}

r = new MersenneTwisterFast();
r.setSeed(random.getSeed() + randomMap.size());
randomMap.put(threadName, r);
return r;
}

/**
* Binary search to sample an integer given a cumulative probability distribution.
* Modified from {@link java.util.Arrays#binarySearch(double[], double)}.
* @param cpd normalized cumulative probability distribution.
* @return a sample (index of <code>cpd[]</code>) according to CPD.
* Negative integer if something is wrong.
*/
public static int binarySearchSampling(double[] cpd) {
double U = random.nextDouble();
double U = random().nextDouble();

if (U <= cpd[0])
return 0;
Expand Down Expand Up @@ -140,7 +165,7 @@ else if (midVal > U) {
*/
public static int randomChoicePDF(double[] pdf) {

double U = random.nextDouble() * getTotal(pdf);
double U = random().nextDouble() * getTotal(pdf);
for (int i = 0; i < pdf.length; i++) {

U -= pdf[i];
Expand Down Expand Up @@ -202,7 +227,7 @@ public static double getTotal(double[] array) {
*/
public static long getSeed() {
synchronized (random) {
return random.getSeed();
return random().getSeed();
}
}

Expand All @@ -211,7 +236,11 @@ public static long getSeed() {
*/
public static void setSeed(long seed) {
synchronized (random) {
random.setSeed(seed);
random.setSeed(seed == 1 ? 4357 : seed-1);
int i = 0;
for (MersenneTwisterFast r : randomMap.values()) {
r.setSeed(seed + i++);
}
}
}

Expand All @@ -220,7 +249,7 @@ public static void setSeed(long seed) {
*/
public static byte nextByte() {
synchronized (random) {
return random.nextByte();
return random().nextByte();
}
}

Expand All @@ -229,7 +258,7 @@ public static byte nextByte() {
*/
public static boolean nextBoolean() {
synchronized (random) {
return random.nextBoolean();
return random().nextBoolean();
}
}

Expand All @@ -238,7 +267,7 @@ public static boolean nextBoolean() {
*/
public static void nextBytes(byte[] bs) {
synchronized (random) {
random.nextBytes(bs);
random().nextBytes(bs);
}
}

Expand All @@ -247,7 +276,7 @@ public static void nextBytes(byte[] bs) {
*/
public static char nextChar() {
synchronized (random) {
return random.nextChar();
return random().nextChar();
}
}

Expand All @@ -259,7 +288,7 @@ public static char nextChar() {
*/
public static double nextGaussian() {
synchronized (random) {
return random.nextGaussian();
return random().nextGaussian();
}
}

Expand All @@ -274,7 +303,7 @@ public static double nextGaussian() {
*/
public static double nextGamma(double alpha, double lambda) {
synchronized (random) {
return random.nextGamma(alpha, lambda);
return random().nextGamma(alpha, lambda);
}
}

Expand All @@ -287,7 +316,7 @@ public static double nextGamma(double alpha, double lambda) {
*/
public static long nextPoisson(double lambda) {
synchronized (random) {
return random.nextPoisson(lambda);
return random().nextPoisson(lambda);
}
}

Expand All @@ -298,7 +327,7 @@ public static long nextPoisson(double lambda) {
*/
public static double nextDouble() {
synchronized (random) {
return random.nextDouble();
return random().nextDouble();
}
}

Expand All @@ -318,7 +347,7 @@ public static double randomLogDouble() {
*/
public static double nextExponential(double lambda) {
synchronized (random) {
return -1.0 * Math.log(1 - random.nextDouble()) / lambda;
return -1.0 * Math.log(1 - random().nextDouble()) / lambda;
}
}

Expand Down Expand Up @@ -347,7 +376,7 @@ public static long nextGeometric(double p) {
*/
public static float nextFloat() {
synchronized (random) {
return random.nextFloat();
return random().nextFloat();
}
}

Expand All @@ -360,7 +389,7 @@ public static float nextFloat() {
*/
public static long nextLong() {
synchronized (random) {
return random.nextLong();
return random().nextLong();
}
}

Expand All @@ -374,7 +403,7 @@ public static long nextLong() {
*/
public static short nextShort() {
synchronized (random) {
return random.nextShort();
return random().nextShort();
}
}

Expand All @@ -387,7 +416,7 @@ public static short nextShort() {
*/
public static int nextInt() {
synchronized (random) {
return random.nextInt();
return random().nextInt();
}
}

Expand All @@ -400,7 +429,7 @@ public static int nextInt() {
*/
public static int nextInt(int n) {
synchronized (random) {
return random.nextInt(n);
return random().nextInt(n);
}
}

Expand Down Expand Up @@ -436,7 +465,7 @@ public static double uniform(double low, double high) {
*/
public static void shuffle(int[] array) {
synchronized (random) {
random.shuffle(array);
random().shuffle(array);
}
}

Expand All @@ -447,7 +476,7 @@ public static void shuffle(int[] array) {
*/
public static void shuffle(int[] array, int numberOfShuffles) {
synchronized (random) {
random.shuffle(array, numberOfShuffles);
random().shuffle(array, numberOfShuffles);
}
}

Expand All @@ -459,7 +488,7 @@ public static void shuffle(int[] array, int numberOfShuffles) {
*/
public static int[] shuffled(int l) {
synchronized (random) {
return random.shuffled(l);
return random().shuffled(l);
}
}

Expand All @@ -475,7 +504,7 @@ public static int[] sampleIndicesWithReplacement(int l) {
synchronized (random) {
int[] result = new int[l];
for (int i = 0; i < l; i++)
result[i] = random.nextInt(l);
result[i] = random().nextInt(l);
return result;
}
}
Expand All @@ -487,7 +516,7 @@ public static int[] sampleIndicesWithReplacement(int l) {
*/
public static void permute(int[] array) {
synchronized (random) {
random.permute(array);
random().permute(array);
}
}

Expand All @@ -499,7 +528,7 @@ public static void permute(int[] array) {
*/
public static int[] permuted(int l) {
synchronized (random) {
return random.permuted(l);
return random().permuted(l);
}
}

Expand Down

0 comments on commit f070286

Please sign in to comment.