Skip to content

Commit

Permalink
upadates for version 2.0.0
Browse files Browse the repository at this point in the history
includes time varying rates for version 2 pre release, also adds simulator for SIR models
  • Loading branch information
nicfel committed Aug 13, 2024
1 parent d1b8622 commit 11f40c0
Show file tree
Hide file tree
Showing 22 changed files with 3,065 additions and 119 deletions.
6 changes: 4 additions & 2 deletions src/coalre/distribution/CoalescentWithReassortment.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,13 @@ private double reassortment(NetworkEvent event, double lociMRCA) {
// lp+=Math.log(reassortmentRate.getArrayValue())
// + event.segsSortedLeft * Math.log(intervals.getBinomialProb())
// + (event.segsToSort-event.segsSortedLeft)*Math.log(1-intervals.getBinomialProb())
// + Math.log(2.0);
// + Math.log(2.0);\

double binomval = Math.pow(intervals.getBinomialProb(), event.segsSortedLeft)
* Math.pow(1-intervals.getBinomialProb(), event.segsToSort-event.segsSortedLeft)
+ Math.pow(intervals.getBinomialProb(), event.segsToSort-event.segsSortedLeft)
* Math.pow(1-intervals.getBinomialProb(), event.segsSortedLeft);

if (event.time<=(lociMRCA*maxHeightRatioInput.get())) {
if (isTimeVarying)
return Math.log(timeVaryingReassortmentRates.getPopSize(event.time))
Expand Down Expand Up @@ -148,12 +148,14 @@ private double coalesce(NetworkEvent event) {
private double intervalContribution(NetworkEvent prevEvent, NetworkEvent nextEvent, double lociMRCA) {

double result = 0.0;
// System.out.println(networkIntervalsInput.get().networkInput.get());
// System.out.println(prevEvent.time + " " + nextEvent.time);

// if (nextEvent.time<3.3) {
// System.out.println(prevEvent.time + " " + nextEvent.time);
// System.out.println(timeVaryingReassortmentRates.getIntegral(prevEvent.time, nextEvent.time));
// }


if (nextEvent.time<(lociMRCA*maxHeightRatioInput.get())) {
if (isTimeVarying)
Expand Down
84 changes: 84 additions & 0 deletions src/coalre/dynamics/ComputeErrors.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package coalre.dynamics;

import beast.base.core.Description;
import beast.base.core.Function;
import beast.base.core.Input;
import beast.base.core.Input.Validate;
import beast.base.inference.CalculationNode;


@Description("calculates the differences between the entries of a vector")
public class ComputeErrors extends CalculationNode implements Function {
final public Input<Function> functionInput = new Input<>("arg", "argument for which the differences for entries is calculated", Validate.REQUIRED);
final public Input<Function> casesInput = new Input<>("logCases", "log of the cases", Validate.REQUIRED);
final public Input<Function> overallNeScalerInput = new Input<>("overallNeScaler", "argument for which the differences for entries is calculated", Validate.REQUIRED);

enum Mode {integer_mode, double_mode}

Mode mode;

boolean needsRecompute = true;
double[] errorTerm;
double[] storedErrorTerm;

@Override
public void initAndValidate() {
errorTerm = new double[functionInput.get().getDimension()];
storedErrorTerm = new double[functionInput.get().getDimension()];
}

@Override
public int getDimension() {
return errorTerm.length;
}

@Override
public double getArrayValue() {
if (needsRecompute) {
compute();
}
return errorTerm[0];
}

/**
* do the actual work, and reset flag *
*/
void compute() {

for (int i = 0; i < functionInput.get().getDimension(); i++) {
errorTerm[i] = functionInput.get().getArrayValue(i) - casesInput.get().getArrayValue(i) - overallNeScalerInput.get().getArrayValue(i);
}
needsRecompute = false;
}

@Override
public double getArrayValue(int dim) {
if (needsRecompute) {
compute();
}
return errorTerm[dim];
}

/**
* CalculationNode methods *
*/
@Override
public void store() {
System.arraycopy(errorTerm, 0, storedErrorTerm, 0, errorTerm.length);
super.store();
}

@Override
public void restore() {
double [] tmp = storedErrorTerm;
storedErrorTerm = errorTerm;
errorTerm = tmp;
super.restore();
}

@Override
public boolean requiresRecalculation() {
needsRecompute = true;
return true;
}
} // class Sum
106 changes: 106 additions & 0 deletions src/coalre/dynamics/Difference.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package coalre.dynamics;

import beast.base.core.Description;
import beast.base.core.Function;
import beast.base.core.Input;
import beast.base.core.Input.Validate;
import beast.base.inference.CalculationNode;
import beast.base.inference.parameter.RealParameter;


@Description("calculates the differences between the entries of a vector")
public class Difference extends CalculationNode implements Function {
final public Input<Function> functionInput = new Input<>("arg", "argument for which the differences for entries is calculated", Validate.REQUIRED);
final public Input<RealParameter> rateShiftInput = new Input<>("rateShift", "rate shift parameter");
final public Input<Integer> independentAfter = new Input<>("independentAfter", "ignore difference after that index");

enum Mode {integer_mode, double_mode}

Mode mode;

boolean needsRecompute = true;
double[] difference;
double[] storedDifference;

@Override
public void initAndValidate() {
if (independentAfter.get()!=null) {
difference = new double[functionInput.get().getDimension()-1];
storedDifference = new double[functionInput.get().getDimension()-1];
}else {
difference = new double[functionInput.get().getDimension()];
storedDifference = new double[functionInput.get().getDimension()];
}
}

@Override
public int getDimension() {
return difference.length;
}

@Override
public double getArrayValue() {
if (needsRecompute) {
compute();
}
return difference[0];
}

/**
* do the actual work, and reset flag *
*/
void compute() {
int offset = 1;
if (rateShiftInput.get()==null) {
for (int i = 1; i < functionInput.get().getDimension(); i++) {
if (independentAfter.get() != null && i == independentAfter.get()) {
offset++;
}
difference[i-offset] = functionInput.get().getArrayValue(i-1)-functionInput.get().getArrayValue(i);
}
}else {
for (int i = 1; i < functionInput.get().getDimension(); i++) {
if (independentAfter.get() != null && i == independentAfter.get()) {
offset++;
}
difference[i-offset] = (functionInput.get().getArrayValue(i-1)-functionInput.get().getArrayValue(i))
/ (rateShiftInput.get().getArrayValue(i)-rateShiftInput.get().getArrayValue(i-1));
}
}



needsRecompute = false;
}

@Override
public double getArrayValue(int dim) {
if (needsRecompute) {
compute();
}
return difference[dim];
}

/**
* CalculationNode methods *
*/
@Override
public void store() {
System.arraycopy(difference, 0, storedDifference, 0, difference.length);
super.store();
}

@Override
public void restore() {
double [] tmp = storedDifference;
storedDifference = difference;
difference = tmp;
super.restore();
}

@Override
public boolean requiresRecalculation() {
needsRecompute = true;
return true;
}
} // class Sum
83 changes: 83 additions & 0 deletions src/coalre/dynamics/ExpMean.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package coalre.dynamics;

import beast.base.core.Description;
import beast.base.core.Function;
import beast.base.core.Input;
import beast.base.core.Input.Validate;
import beast.base.inference.CalculationNode;
import beast.base.inference.parameter.RealParameter;


@Description("calculates the differences between the entries of a vector")
public class ExpMean extends CalculationNode implements Function {
final public Input<Function> functionInput = new Input<>("arg", "argument for which the differences for entries is calculated", Validate.REQUIRED);

enum Mode {integer_mode, double_mode}

Mode mode;

boolean needsRecompute = true;
double expMean;
double storedExpMean;

@Override
public void initAndValidate() {
}

@Override
public int getDimension() {
return 1;
}

@Override
public double getArrayValue() {
if (needsRecompute) {
compute();
}
return expMean;
}

/**
* do the actual work, and reset flag *
*/
void compute() {
expMean = 0;
for (int i = 1; i < functionInput.get().getDimension(); i++) {
expMean += Math.exp(functionInput.get().getArrayValue(i));
}
expMean /= (functionInput.get().getDimension());
expMean = Math.exp(expMean);
needsRecompute = false;
}

@Override
public double getArrayValue(int dim) {
if (needsRecompute) {
compute();
}
return expMean;
}

/**
* CalculationNode methods *
*/
@Override
public void store() {
storedExpMean = expMean;
super.store();
}

@Override
public void restore() {
double tmp = storedExpMean;
storedExpMean = expMean;
expMean = tmp;
super.restore();
}

@Override
public boolean requiresRecalculation() {
needsRecompute = true;
return true;
}
} // class Sum
83 changes: 83 additions & 0 deletions src/coalre/dynamics/LogDifference.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package coalre.dynamics;

import beast.base.core.Description;
import beast.base.core.Function;
import beast.base.core.Input;
import beast.base.core.Input.Validate;
import beast.base.inference.CalculationNode;
import beast.base.inference.parameter.RealParameter;


@Description("calculates the differences between the entries of a vector")
public class LogDifference extends CalculationNode implements Function {
final public Input<Function> functionInput = new Input<>("arg", "argument for which the differences for entries is calculated", Validate.REQUIRED);

enum Mode {integer_mode, double_mode}

Mode mode;

boolean needsRecompute = true;
double expMean;
double storedExpMean;

@Override
public void initAndValidate() {
}

@Override
public int getDimension() {
return 1;
}

@Override
public double getArrayValue() {
if (needsRecompute) {
compute();
}
return expMean;
}

/**
* do the actual work, and reset flag *
*/
void compute() {
expMean = 0;
for (int i = 1; i < functionInput.get().getDimension(); i++) {
expMean += functionInput.get().getArrayValue(i);
}
expMean /= (functionInput.get().getDimension());
expMean = expMean;
needsRecompute = false;
}

@Override
public double getArrayValue(int dim) {
if (needsRecompute) {
compute();
}
return expMean;
}

/**
* CalculationNode methods *
*/
@Override
public void store() {
storedExpMean = expMean;
super.store();
}

@Override
public void restore() {
double tmp = storedExpMean;
storedExpMean = expMean;
expMean = tmp;
super.restore();
}

@Override
public boolean requiresRecalculation() {
needsRecompute = true;
return true;
}
} // class Sum
Loading

0 comments on commit 11f40c0

Please sign in to comment.