diff --git a/phylonco-beast/src/main/java/phylonco/beast/evolution/populationmodel/StochasticVariableSelection.java b/phylonco-beast/src/main/java/phylonco/beast/evolution/populationmodel/StochasticVariableSelection.java index 7355a69..33cb443 100644 --- a/phylonco-beast/src/main/java/phylonco/beast/evolution/populationmodel/StochasticVariableSelection.java +++ b/phylonco-beast/src/main/java/phylonco/beast/evolution/populationmodel/StochasticVariableSelection.java @@ -1,10 +1,11 @@ package phylonco.beast.evolution.populationmodel; +import beast.base.core.BEASTInterface; import beast.base.core.Description; import beast.base.core.Input; import beast.base.core.Loggable; import beast.base.evolution.tree.coalescent.PopulationFunction; -import beast.base.inference.parameter.RealParameter; +import beast.base.inference.parameter.IntegerParameter; import java.io.PrintStream; import java.util.ArrayList; @@ -12,23 +13,33 @@ @Description("Stochastic variable selection for different population growth models.") public class StochasticVariableSelection extends PopulationFunction.Abstract implements Loggable { - public final Input indicatorInput = new Input<>("indicator", + public final Input indicatorInput = new Input<>("indicator", "The indicator for selecting the population model.", Input.Validate.REQUIRED); - public final Input modelsInput = new Input<>("models", - "The array of population models.", new PopulationFunction[0]); + public final Input> modelsInput = new Input<>("models", + "The list of population models.", new ArrayList<>()); private PopulationFunction selectedModel; @Override public void initAndValidate() { + if (indicatorInput.get() != null && indicatorInput.get() instanceof IntegerParameter) { + IntegerParameter IParam = indicatorInput.get(); + IParam.setBounds(Math.max(0, IParam.getLower()), Math.min(3, IParam.getUpper())); + } + int indicator = (int) indicatorInput.get().getArrayValue(); - PopulationFunction[] models = modelsInput.get(); - if (indicator < 0 || indicator >= models.length) { + List modelsList = modelsInput.get(); + + if (indicator < 0 || indicator >= modelsList.size()) { throw new IllegalArgumentException("Invalid indicator value: " + indicator); } - selectedModel = models[indicator]; + if (modelsList.size() != 4) { + throw new IllegalArgumentException("There must be exactly 4 population models."); + } + + selectedModel = modelsList.get(indicator); if (selectedModel == null) { throw new IllegalArgumentException("Selected model is null. Indicator: " + indicator); @@ -53,9 +64,13 @@ public double getInverseIntensity(double x) { @Override public List getParameterIds() { List ids = new ArrayList<>(); - ids.add(indicatorInput.get().getID()); + if (indicatorInput.get() instanceof BEASTInterface) { + ids.add(((BEASTInterface) indicatorInput.get()).getID()); + } for (PopulationFunction model : modelsInput.get()) { - ids.addAll(model.getParameterIds()); + if (model instanceof BEASTInterface) { + ids.add(((BEASTInterface) model).getID()); + } } return ids; } @@ -68,4 +83,4 @@ public void log(long sample, PrintStream out) {} @Override public void close(PrintStream out) {} -} +} \ No newline at end of file