Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conserve rescale shape #863

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

JasperMartins
Copy link
Contributor

This PR updates the rescale functions of PriorDict and ConditionalPriorDict to preserve the correct shape of the rescaled samples.

The PR also updates the behavior of JointPrior.rescale and BaseJointPriorDist.

JointPrior.rescale now returns a mutable numpy array. If not all required keys have been requested for rescaling, the arrays is filled with np.nan. BaseJointPriorDist keeps track of this array. Once all required keys are requested, the rescale operation is performed, and the returned arrays are populated with the rescaled values.
This change enables out-of-order sets of keys in rescale - previously priordict.rescale(keys=["a", "JointPrior_b", "c", "JointPrior_d"], theta=...) where "JointPrior_b" and "JointPrior_d" share the same dist would have resulted in an order [*samples_a, *samples_c, *samples_b, *samples_d].

Relates to the discussion in #850

samples.append(samps)
for i, samps in enumerate(samples):
# turns 0d-arrays into scalars
samples[i] = np.squeeze(samps).tolist()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would not make more sense to preserve numpy arrays to make it easier to define conversion_functions and so on without needing to cast to array again

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is would be worth returning arrays, I think it's worth accepting that this is going to not be fully backward compatible.

Copy link
Collaborator

@ColmTalbot ColmTalbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for opening this, I think this will be a good change, and I'm in favour of also not pushing everything through lists.

I don't really follow the logic in the joint prior, so some additional docstrings/comments would probably help there.

samples.append(samps)
for i, samps in enumerate(samples):
# turns 0d-arrays into scalars
samples[i] = np.squeeze(samps).tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is would be worth returning arrays, I think it's worth accepting that this is going to not be fully backward compatible.


Returns
=======
list: List of floats containing the rescaled sample
list:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth returning the same type as the input, i.e., return a dict if a dict is passed.

Another option would be to add a specific method to handle dicts, I've thought about doing this a few times before, e.g., PriorDict.rescale_dict that just wraps PriorDict.rescale.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also apply to the conservation of numpy arrays, ie return a dict, list or array depending on the input.

def set_rescale(self, key, values):
values = np.array(values)
self._rescale_parameters[key] = values
self._rescaled_parameters[key] = np.atleast_1d(np.ones_like(values)) * np.nan
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't see where these values are set to a non-trivial value, what is the distinction between these two?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values are updated in-place in the rescale function:

if value is None:
for i, key in enumerate(self.names):
output = self.get_rescaled(key)
# update in-place for proper handling in PriorDict-instances
output[:] = samp[:, i]
return np.squeeze(samp)

self.dist.set_rescale(self.name, val)
if self.dist.filled_rescale():
self.dist.rescale(value=None, **kwargs)
output = self.dist.get_rescaled(self.name)
self.dist.reset_rescale()
else:
output = self.dist.get_rescaled(self.name)
# have to return raw output to conserve in-place modifications
return output

So, in self._rescaled_parameters, the values of the rescaled parameters are stored once the rescaling operation is performed. Prior to this, the array contains only np.NaN to clarify that the rescaling has not yet happened. The per-key numpy arrays stored in self._rescaled_parameters are returned as references in JointPrior.rescale. Once all keys have been requested, the values in the arrays are updated in place, meaning that they will also be updated in the array returned by JointPrior.rescale.

I agree that this should be documented better, particularly to ensure that PriorDict.rescale is not updated in a way that loses track of the reference to the arrays stored in BaseJointPriorDict._rescaled_parameters in the future.

@JasperMartins JasperMartins force-pushed the Conserve-rescale-shape branch from 94623a9 to 60d7be1 Compare January 21, 2025 15:35
@JasperMartins
Copy link
Contributor Author

JasperMartins commented Jan 21, 2025

I have updated the changes to joint priors with more comments / docstrings and changed variable names to be more reflective of what happens. I have also noticed the implementation attempts to solve a similar issue as #848 when rescaling joint priors.

For now, I have left the return-values of PriorDict.rescale() and ConditionalPriorDict.rescale() as lists. Changing to arrays should, I think, be accompanied by a larger-scale switch to numpy arrays for the return values of methods of Prior() that I don't have time for right at the moment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants