diff --git a/src/chronumental/__main__.py b/src/chronumental/__main__.py index 6d071d3..c18016d 100644 --- a/src/chronumental/__main__.py +++ b/src/chronumental/__main__.py @@ -90,13 +90,28 @@ def get_parser(): "Scale factor for date distribution. Essentially a measure of how uncertain we think the measured dates are." ) + parser.add_argument( + '--initial_tau', + default=3.2, + type=float, + help="Initial value for the tau parameter in the model. Only applies to Horseshoe.") + + parser.add_argument( + '--hs_scale', + default=86917549.587, + type=float, + help="hs scale parameter in the model. Only applies to Horseshoe.") + + + + parser.add_argument('--steps', - default=1000, + default=20000, type=int, help="Number of steps to use for the SVI. Increasing this number will make runtime increase, but yield more accurate results.") parser.add_argument('--lr', - default=0.01, + default=0.03, type=float, help="Adam learning rate") @@ -317,7 +332,10 @@ def main(): "expected_min_between_transmissions": args.expected_min_between_transmissions, "enforce_exact_clock": args.enforce_exact_clock, - "variance_on_clock_rate": args.variance_on_clock_rate + "variance_on_clock_rate": args.variance_on_clock_rate, + "initial_tau": args.initial_tau, + "hs_scale": args.hs_scale, + "fixed_tau": True } my_model = models.models[args.model]( diff --git a/src/chronumental/models.py b/src/chronumental/models.py index 8915fba..9feb40b 100644 --- a/src/chronumental/models.py +++ b/src/chronumental/models.py @@ -58,6 +58,8 @@ def __init__(self, **kwargs): 'variance_on_clock_rate'] self.expected_min_between_transmissions = kwargs[ 'model_configuration']['expected_min_between_transmissions'] + + super().__init__(**kwargs) @@ -163,6 +165,10 @@ def __init__(self, **kwargs): 'variance_on_clock_rate'] self.expected_min_between_transmissions = kwargs[ 'model_configuration']['expected_min_between_transmissions'] + + self.initial_tau = kwargs['model_configuration']['initial_tau'] + self.fixed_tau = kwargs['model_configuration']['fixed_tau'] + self.hs_scale = kwargs['model_configuration']['hs_scale'] super().__init__(**kwargs) @@ -216,7 +222,7 @@ def model(self): calced_dates = self.calc_dates(branch_times, root_date) - hs_scale = 1 + hs_scale = self.hs_scale tau = numpyro.sample("tau", dist.HalfCauchy(hs_scale)) @@ -254,10 +260,10 @@ def guide(self): constraint=dist.constraints.positive) tau_param = numpyro.param("tau_param", - 1, + self.initial_tau, constraint=dist.constraints.positive) tau = numpyro.sample("tau", - dist.Delta(tau_param)) + dist.Delta(tau_param if not self.fixed_tau else self.initial_tau)) sample_variances = numpyro.sample("lambda", dist.Delta(variances))