Skip to content

Commit

Permalink
add source-wise mus_interpolator
Browse files Browse the repository at this point in the history
  • Loading branch information
hammannr committed Oct 10, 2024
1 parent 0fdb5b2 commit 5baa97a
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions blueice/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, pdf_base_config, likelihood_config=None, **kwargs):
likelihood_config = {}
self.config = likelihood_config
self.config.setdefault('morpher', 'GridInterpolator')
self.source_wise_interpolation = self.config.get('source_wise_interpolation', False)
self.source_wise_interpolation = self.pdf_base_config.get('source_wise_interpolation', False)

# Base model: no variations of any settings
self.base_model = Model(self.pdf_base_config)
Expand Down Expand Up @@ -129,8 +129,7 @@ def _get_model_anchor(self, anchor, source_name):
"""Return the shape anchors of the full model, given the shape anchors of a signle source.
All values of shape parameters not used by this source will be set to None.
"""
shape_keys = self.source_shape_parameters[source_name].keys()
shape_indices = self.get_shape_indices(shape_keys)
shape_indices = self._get_shape_indices(source_name)
model_anchor = [None] * len(self.shape_parameters)
for i, idx in enumerate(shape_indices):
model_anchor[idx] = anchor[i]
Expand Down Expand Up @@ -197,15 +196,36 @@ def prepare(self, n_cores=1, ipp_client=None):
models = [Model(c) for c in tqdm(configs, desc="Loading computed models")]

if self.source_wise_interpolation:
for i,(source_name, morpher) in enumerate(source_morphers.items()):
print("USING SOURCE-WISE INTERPOLATION")
for i,(source_name, morpher) in enumerate(self.source_morphers.items()):
anchors = morpher.get_anchor_points(bounds=None)
self.anchor_sources[source_name] = OrderedDict()
for anchor in anchors:
model_anchor = self._get_model_anchor(anchor, source_name)
model_index = zs_list.index(model_anchor)
self.anchor_sources[source_name][anchor] = models[model_index].sources[i]
# TODO: Implement self.mus_interpolator for source-wise interpolation

mus_interpolators = OrderedDict()
for sn, base_source in zip(self.source_name_list, self.base_model.sources):
if sn in self.source_morphers:
mus_interpolators[sn] = self.source_morphers[sn].make_interpolator(
f=lambda s: s.expected_events,
extra_dims=[1],
anchor_models=self.anchor_sources[sn])
else:
mus_interpolators[sn] = base_source.expected_events
def mus_interpolator(*args):
# take zs, convert to values for each source's interpolator call the respective interpolator
mus = []
for sn in self.source_name_list:
if sn in self.source_shape_parameters:
shape_indices = self._get_shape_indices(sn)
these_args = [args[0][i] for i in shape_indices]
mus.append(mus_interpolators[sn](np.asarray(these_args))[0])
else:
mus.append(mus_interpolators[sn])
return np.array(mus)
self.mus_interpolator = mus_interpolator

else:
# Add the new models to the anchor_models dict
for zs, model in zip(zs_list, models):
Expand Down Expand Up @@ -513,11 +533,11 @@ def ps_interpolator(*args):
for sn in self.source_name_list:
if sn in self.source_shape_parameters:
shape_indices = self._get_shape_indices(sn)
these_args = [args[i] for i in shape_indices]
these_args = [args[0][i] for i in shape_indices]
ps.append(ps_interpolators[sn](np.asarray(these_args)))
else:
ps.append(ps_interpolators[sn])
return ps
return np.array(ps)
self.ps_interpolator = ps_interpolator
else:
self.ps_interpolator = self.morpher.make_interpolator(f=lambda m: m.score_events(d),
Expand Down

0 comments on commit 5baa97a

Please sign in to comment.