Skip to content

Commit

Permalink
Merge pull request #32 from theosanderson/horseshoe
Browse files Browse the repository at this point in the history
Horseshoe fix
  • Loading branch information
theosanderson authored Nov 1, 2023
2 parents 3461244 + e01d4af commit fb0c3e2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
24 changes: 21 additions & 3 deletions src/chronumental/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,28 @@ def get_parser():
"Scale factor for date distribution. Essentially a measure of how uncertain we think the measured dates are."
)

parser.add_argument(
'--initial_tau',
default=3.2,
type=float,
help="Initial value for the tau parameter in the model. Only applies to Horseshoe.")

parser.add_argument(
'--hs_scale',
default=86917549.587,
type=float,
help="hs scale parameter in the model. Only applies to Horseshoe.")




parser.add_argument('--steps',
default=1000,
default=20000,
type=int,
help="Number of steps to use for the SVI. Increasing this number will make runtime increase, but yield more accurate results.")

parser.add_argument('--lr',
default=0.01,
default=0.03,
type=float,
help="Adam learning rate")

Expand Down Expand Up @@ -317,7 +332,10 @@ def main():
"expected_min_between_transmissions":
args.expected_min_between_transmissions,
"enforce_exact_clock": args.enforce_exact_clock,
"variance_on_clock_rate": args.variance_on_clock_rate
"variance_on_clock_rate": args.variance_on_clock_rate,
"initial_tau": args.initial_tau,
"hs_scale": args.hs_scale,
"fixed_tau": True
}

my_model = models.models[args.model](
Expand Down
12 changes: 9 additions & 3 deletions src/chronumental/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(self, **kwargs):
'variance_on_clock_rate']
self.expected_min_between_transmissions = kwargs[
'model_configuration']['expected_min_between_transmissions']



super().__init__(**kwargs)

Expand Down Expand Up @@ -163,6 +165,10 @@ def __init__(self, **kwargs):
'variance_on_clock_rate']
self.expected_min_between_transmissions = kwargs[
'model_configuration']['expected_min_between_transmissions']

self.initial_tau = kwargs['model_configuration']['initial_tau']
self.fixed_tau = kwargs['model_configuration']['fixed_tau']
self.hs_scale = kwargs['model_configuration']['hs_scale']

super().__init__(**kwargs)

Expand Down Expand Up @@ -216,7 +222,7 @@ def model(self):

calced_dates = self.calc_dates(branch_times, root_date)

hs_scale = 1
hs_scale = self.hs_scale

tau = numpyro.sample("tau",
dist.HalfCauchy(hs_scale))
Expand Down Expand Up @@ -254,10 +260,10 @@ def guide(self):
constraint=dist.constraints.positive)

tau_param = numpyro.param("tau_param",
1,
self.initial_tau,
constraint=dist.constraints.positive)
tau = numpyro.sample("tau",
dist.Delta(tau_param))
dist.Delta(tau_param if not self.fixed_tau else self.initial_tau))

sample_variances = numpyro.sample("lambda",
dist.Delta(variances))
Expand Down

0 comments on commit fb0c3e2

Please sign in to comment.