-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Copy Imputer #72
base: develop
Are you sure you want to change the base?
Conversation
Can you please provide some more details on the description about where this imputer may be useful and some example code/config of it in action? |
@HCookie Done ✔️ Is it clearer now? |
default: "none" | ||
x: | ||
- y | ||
- q | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments on what these characters represent will help here
@@ -303,3 +408,50 @@ def __init__( | |||
"You are using a dynamic Imputer: NaN values will not be present in the model predictions. \ | |||
The model will be trained to predict imputed values. This might deteriorate performances." | |||
) | |||
|
|||
|
|||
class DynamicCopyImputer(CopyImputer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A little more detail on the exact differences, and uses would be helpful to other users.
self.loss_mask_training = torch.ones( | ||
(x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device | ||
) | ||
|
||
# Choose correct index based on number of variables | ||
if x.shape[-1] == self.num_training_input_vars: | ||
index = self.index_training_input | ||
elif x.shape[-1] == self.num_inference_input_vars: | ||
index = self.index_inference_input | ||
else: | ||
raise ValueError( | ||
f"Input tensor ({x.shape[-1]}) does not match the training " | ||
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", | ||
) | ||
|
||
# Replace values | ||
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): | ||
if idx_dst is not None: | ||
assert not torch.isnan( | ||
x[..., self.data_indices.data.input.name_to_index[value]][nan_locations[..., idx_src]] | ||
).any(), f"NaNs found in {value}." | ||
x[..., idx_dst][nan_locations[..., idx_src]] = x[ | ||
..., self.data_indices.data.input.name_to_index[value] | ||
][nan_locations[..., idx_src]] | ||
|
||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the similarities between this and the parent class implementation, would it make sense to factor it into another function. Or a generic parent class with two children?
Add an imputer that copies missing information from another field.
When a value is missing at a certain Pressure Level it would be useful to copy the value of the Pressure Level above.
This technique ensures constant imputing in areas with missing Pressure Levels and will not cause a big delta in values with respect to vertical changes.
Solves #71