-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor MACE subclasses - reduce code duplication & clearer logic #97
base: develop
Are you sure you want to change the base?
Conversation
- extracted the layer calculations - "ScaleShift" applied to non-duplicated layer calculation code
- DipolesMACE & EnergyDipolesMACE inherit from MACE - EnergyDipolesMACE accepting dict for forward pass + closer match to MACE class - DipoleOnly versions of blocks added explicitly (to be refactored) - JIT for DipolesMACE & EnergyDipolesMACE
@davkovacs there's a more general refactor in https://github.com/stenczelt/mace/tree/ENH/refactor-model-v3 where I have separated out the backbone of the MACE model (agnostic of what you're calculating) and the quantity-specific readout blocks and gradient calculations. |
I would say let’s have one refactor and test that thoroughly |
Roger that @davkovacs ! |
interaction_cls: Type[InteractionBlock], | ||
interaction_cls_first: Type[InteractionBlock], | ||
num_interactions: int, | ||
num_elements: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number of elements is redundant here. I propose removing it. ok?
hidden irreducible representations, basically the size of the layer features | ||
and hence direct control on the size of the model | ||
MLP_irreps | ||
avg_num_neighbors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These docstrings are duplicated across the classes exposed to the user as well. I think this is fine, so whichever one you are using you see the docstring directly. If you disagree feel free to remove the redundant ones and perhaps keep them for the main one
atomic_numbers: List[int], | ||
correlation: int, | ||
gate: Optional[Callable], | ||
radial_MLP: Optional[List[int]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not changed this for the sake of consistency, method parameters are conventionally lower case, so I propose changing these MLP
parameters to lowercase
|
||
|
||
@compile_mode("script") | ||
class AtomicDipolesMACE(MaceCoreModel, DipoleModelMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this called "AtomicDipolesMACE"? Should it not just be DipolesMACE
|
||
def forward( | ||
self, | ||
data: Dict[str, torch.Tensor], | ||
training: bool = False, | ||
compute_force: bool = True, | ||
compute_force: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this class cannot calculate forces, why don't we remove the associated parameters? Are they intended for compatibility with wrappers?
self, | ||
data: Dict[str, torch.Tensor], | ||
training: bool = False, | ||
compute_force: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, parameters as traps only. Can understand if it is compatibility and perhaps will be implemented later on.
return output | ||
|
||
@compile_mode("script") | ||
class ScaleShiftEnergyDipoleMACE(EnergyDipolesMACE, ScaleShiftEnergyModelMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is just an extra, which was easy to add due to the separated logic of dipoles & energy
assert key in output | ||
|
||
|
||
def test_scaled_and_shifted(dipole_model_config, data_batch_1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While the tests were enormously useful while refactoring, there is more to test on all these classes. Should be somewhat easier now.
@davkovacs @ilyes319 I have added my comments relating to what I have done, you might be able to answer these or make the appropriate judgement calls. Please let me know if you need clarification on anything or want changes which I should make |
Supersedes #65, implements the JIT part of #95 as well for free
Refactored subclasses of
MACE
:ScaleShiftMACE
: Only scales & shifts MACEAtomicDipolesMACE
: Only defined dipole calculation capabilitiesEnergyDipolesMACE
: Energy & dipole calculationMain changes:
__init__
is basically the same, there is only a little difference, which is handled with