Skip to content

Commit

Permalink
Improved nomenclature, added comments, new tests that include joint p…
Browse files Browse the repository at this point in the history
…riors
  • Loading branch information
JasperMartins committed Jan 21, 2025
1 parent 27f4ef6 commit 60d7be1
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 26 deletions.
49 changes: 30 additions & 19 deletions bilby/core/prior/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ def __init__(self, names, bounds=None):
self.requested_parameters = dict()
self.reset_request()

# a dictionary of the rescale(d) parameters
self._rescale_parameters = dict()
self._rescaled_parameters = dict()
# a dictionary that stores the unit-cube values of parameters for later rescaling
self._current_unit_cube_parameter_values = dict()
# a dictionary of arrays that are used as intermediate return values of JointPrior.rescale()
# and updated in-place once all parameters have been requested
self._current_rescaled_parameter_values = dict()
self.reset_rescale()

# a list of sampled parameters
Expand Down Expand Up @@ -95,24 +97,24 @@ def filled_rescale(self):
Check if all the rescaled parameters have been filled.
"""

return not np.any([val is None for val in self._rescale_parameters.values()])
return not np.any([val is None for val in self._current_unit_cube_parameter_values.values()])

def set_rescale(self, key, values):
values = np.array(values)
self._rescale_parameters[key] = values
self._rescaled_parameters[key] = np.atleast_1d(np.ones_like(values)) * np.nan
self._current_unit_cube_parameter_values[key] = np.array(values)
self._current_rescaled_parameter_values[key] = np.full_like(values, np.nan, dtype=float)

def reset_rescale(self):
"""
Reset the rescaled parameters to None.
"""

for name in self.names:
self._rescale_parameters[name] = None
self._rescaled_parameters[name] = None
self._current_unit_cube_parameter_values[name] = None
self._current_rescaled_parameter_values[name] = None

def get_rescaled(self, key):
return self._rescaled_parameters[key]
"""Return an array that will be updated in-place once the rescale-operation
has been performed."""
return self._current_rescaled_parameter_values[key]

def get_instantiation_dict(self):
subclass_args = infer_args_from_method(self.__init__)
Expand Down Expand Up @@ -317,7 +319,7 @@ def rescale(self, value, **kwargs):
If given, a 1d vector sample (one for each parameter) drawn from a uniform
distribution between 0 and 1, or a 2d NxM array of samples where
N is the number of samples and M is the number of parameters.
If None, values previously set using BaseJointPriorDist.set_rescale() are used.
If None, the values previously set using BaseJointPriorDist.set_rescale() are used.
kwargs: dict
All keyword args that need to be passed to _rescale method, these keyword
args are called in the JointPrior rescale methods for each parameter
Expand All @@ -329,9 +331,11 @@ def rescale(self, value, **kwargs):
distribution.
"""
if value is None:
samp = np.array(list(self._rescale_parameters.values())).T
samp = np.array(list(self._current_unit_cube_parameter_values.values())).T
else:
samp = np.array(value)
for key, val in zip(self.names, value):
self.set_rescale(key, val)
samp = np.asarray(value)

if len(samp.shape) == 1:
samp = samp.reshape(1, self.num_vars)
Expand All @@ -342,11 +346,12 @@ def rescale(self, value, **kwargs):
raise ValueError("Array is the wrong shape")

samp = self._rescale(samp, **kwargs)
if value is None:
for i, key in enumerate(self.names):
output = self.get_rescaled(key)
# update in-place for proper handling in PriorDict-instances
output[:] = samp[:, i]
for i, key in enumerate(self.names):
# get the numpy array used for indermediate outputs
# prior to a full rescale-operation
output = self.get_rescaled(key)
# update the array in-place
output[...] = samp[:, i]
return np.squeeze(samp)

def _rescale(self, samp, **kwargs):
Expand Down Expand Up @@ -819,10 +824,16 @@ def rescale(self, val, **kwargs):
self.dist.set_rescale(self.name, val)

if self.dist.filled_rescale():
# If all names have been filled, perform rescale operation
self.dist.rescale(value=None, **kwargs)
# get the rescaled values for the requested parameter
output = self.dist.get_rescaled(self.name)
# reset the rescale operation
self.dist.reset_rescale()
else:
# If not all names have been filled, return a *numpy array*
# filled only with `np.nan`. Once all names have been requested,
# this array is updated *in-place* with the rescaled values.
output = self.dist.get_rescaled(self.name)

# have to return raw output to conserve in-place modifications
Expand Down
8 changes: 5 additions & 3 deletions test/core/prior/conditional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def test_rescale_with_joint_prior(self):

# set multivariate Gaussian distribution
names = ["mvgvar_0", "mvgvar_1"]
mu = [[0.79, -0.83]]
mu = [[1, 1]]
cov = [[[0.03, 0.], [0., 0.04]]]
mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov)

Expand All @@ -349,7 +349,7 @@ def test_rescale_with_joint_prior(self):
)
)

ref_variables = list(self.test_sample.values()) + [0.4, 0.1]
ref_variables = list(self.test_sample.values()) + [0.5, 0.5]
keys = list(self.test_sample.keys()) + names
res = priordict.rescale(keys=keys, theta=ref_variables)

Expand All @@ -359,9 +359,11 @@ def test_rescale_with_joint_prior(self):

# check conditional values are still as expected
expected = [self.test_sample["var_0"]]
self.assertFalse(np.any(np.isnan(res)))
for ii in range(1, 4):
expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
self.assertListEqual(expected, res[0:4])
expected.extend([1, 1])
self.assertListEqual(expected, res)

def test_cdf(self):
"""
Expand Down
35 changes: 31 additions & 4 deletions test/core/prior/dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,17 @@ def setUp(self):
name="b", alpha=3, minimum=1, maximum=2, unit="m/s", boundary=None
)
self.third_prior = bilby.core.prior.DeltaFunction(name="c", peak=42, unit="m")

mvg = bilby.core.prior.MultivariateGaussianDist(
names=["testa", "testb"],
mus=[1, 1],
covs=np.array([[2.0, 0.5], [0.5, 2.0]]),
weights=1.0,
)
self.testa = bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit")
self.testb = bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit")
self.priors = dict(
mass=self.first_prior, speed=self.second_prior, length=self.third_prior
mass=self.first_prior, speed=self.second_prior, length=self.third_prior, testa=self.testa, testb=self.testb
)
self.prior_set_from_dict = bilby.core.prior.PriorDict(dictionary=self.priors)
self.default_prior_file = os.path.join(
Expand Down Expand Up @@ -70,7 +79,7 @@ def test_prior_set_is_dict(self):
self.assertIsInstance(self.prior_set_from_dict, dict)

def test_prior_set_has_correct_length(self):
self.assertEqual(3, len(self.prior_set_from_dict))
self.assertEqual(5, len(self.prior_set_from_dict))

def test_prior_set_has_expected_priors(self):
self.assertDictEqual(self.priors, dict(self.prior_set_from_dict))
Expand Down Expand Up @@ -160,6 +169,12 @@ def test_to_file(self):
"unit='m/s', boundary=None)\n",
"mass = Uniform(minimum=0, maximum=1, name='a', latex_label='a', "
"unit='kg', boundary=None)\n",
"testa_testb_mvg = MultivariateGaussianDist(names=['testa', 'testb'], nmodes=1, mus=[[1, 1]], "
"sigmas=[[1.4142135623730951, 1.4142135623730951]], "
"corrcoefs=[[[0.9999999999999998, 0.24999999999999994], [0.24999999999999994, 0.9999999999999998]]], "
"covs=[[[2.0, 0.5], [0.5, 2.0]]], weights=[1.0], bounds={'testa': (-inf, inf), 'testb': (-inf, inf)})\n",
"testa = MultivariateGaussian(dist=testa_testb_mvg, name='testa', latex_label='testa', unit='unit')\n",
"testb = MultivariateGaussian(dist=testa_testb_mvg, name='testb', latex_label='testb', unit='unit')\n",
]
self.prior_set_from_dict.to_file(outdir="prior_files", label="to_file_test")
with open("prior_files/to_file_test.prior") as f:
Expand All @@ -178,6 +193,13 @@ def test_from_dict_with_string(self):
self.assertDictEqual(self.prior_set_from_dict, from_dict)

def test_convert_floats_to_delta_functions(self):
mvg = bilby.core.prior.MultivariateGaussianDist(
names=["testa", "testb"],
mus=[1, 1],
covs=np.array([[2.0, 0.5], [0.5, 2.0]]),
weights=1.0,
)

self.prior_set_from_dict["d"] = 5
self.prior_set_from_dict["e"] = 7.3
self.prior_set_from_dict["f"] = "unconvertable"
Expand All @@ -190,6 +212,8 @@ def test_convert_floats_to_delta_functions(self):
name="b", alpha=3, minimum=1, maximum=2, unit="m/s", boundary=None
),
length=bilby.core.prior.DeltaFunction(name="c", peak=42, unit="m"),
testa=bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit"),
testb=bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit"),
d=bilby.core.prior.DeltaFunction(peak=5),
e=bilby.core.prior.DeltaFunction(peak=7.3),
f="unconvertable",
Expand Down Expand Up @@ -321,12 +345,15 @@ def test_ln_prob(self):
self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples))

def test_rescale(self):
theta = [0.5, 0.5, 0.5]
theta = [0.5, 0.5, 0.5, 0.5, 0.5]
expected = [
self.first_prior.rescale(0.5),
self.second_prior.rescale(0.5),
self.third_prior.rescale(0.5),
self.testa.rescale(0.5),
self.testb.rescale(0.5)
]
assert not np.any(np.isnan(expected))
self.assertListEqual(
sorted(expected),
sorted(
Expand All @@ -342,7 +369,7 @@ def test_cdf(self):
Note that the format of inputs/outputs is different between the two methods.
"""
sample = self.prior_set_from_dict.sample()
sample = self.prior_set_from_dict.sample_subset(keys=["length", "speed", "mass"])
original = np.array(list(sample.values()))
new = np.array(self.prior_set_from_dict.rescale(
sample.keys(),
Expand Down

0 comments on commit 60d7be1

Please sign in to comment.