From 130e3a24e63a60679c945bc089fcd6fc37edfa8c Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:15:39 -0500 Subject: [PATCH 01/52] partial forces --- nequip/data/_keys.py | 1 + nequip/model/__init__.py | 15 ++++---- nequip/model/_grads.py | 18 ++++++++++ nequip/nn/__init__.py | 2 +- nequip/nn/_grad_output.py | 55 ++++++++++++++++++++++++++++++ tests/unit/model/test_eng_force.py | 28 +++++++++++++++ 6 files changed, 111 insertions(+), 8 deletions(-) diff --git a/nequip/data/_keys.py b/nequip/data/_keys.py index 58918657..d87f52a7 100644 --- a/nequip/data/_keys.py +++ b/nequip/data/_keys.py @@ -38,6 +38,7 @@ PER_ATOM_ENERGY_KEY: Final[str] = "atomic_energy" TOTAL_ENERGY_KEY: Final[str] = "total_energy" FORCE_KEY: Final[str] = "forces" +PARTIAL_FORCE_KEY: Final[str] = "partial_forces" BATCH_KEY: Final[str] = "batch" diff --git a/nequip/model/__init__.py b/nequip/model/__init__.py index b849efed..eea9a716 100644 --- a/nequip/model/__init__.py +++ b/nequip/model/__init__.py @@ -1,15 +1,16 @@ from ._eng import EnergyModel -from ._grads import ForceOutput +from ._grads import ForceOutput, PartialForceOutput from ._scaling import RescaleEnergyEtc, PerSpeciesRescale from ._weight_init import uniform_initialize_FCs from ._build import model_from_config __all__ = [ - "EnergyModel", - "ForceOutput", - "RescaleEnergyEtc", - "PerSpeciesRescale", - "uniform_initialize_FCs", - "model_from_config", + EnergyModel, + ForceOutput, + PartialForceOutput, + RescaleEnergyEtc, + PerSpeciesRescale, + uniform_initialize_FCs, + model_from_config, ] diff --git a/nequip/model/_grads.py b/nequip/model/_grads.py index dcf85f9b..12c275c6 100644 --- a/nequip/model/_grads.py +++ b/nequip/model/_grads.py @@ -1,4 +1,5 @@ from nequip.nn import GraphModuleMixin, GradientOutput +from nequip.nn import PartialForceOutput as PartialForceOutputModule from nequip.data import AtomicDataDict @@ -20,3 +21,20 @@ def ForceOutput(model: GraphModuleMixin) -> GradientOutput: out_field=AtomicDataDict.FORCE_KEY, sign=-1, # force is the negative gradient ) + + +def PartialForceOutput(model: GraphModuleMixin) -> GradientOutput: + r"""Add forces and partial forces to a model that predicts energy. + + Args: + energy_model: the model to wrap. Must have ``AtomicDataDict.TOTAL_ENERGY_KEY`` as an output. + + Returns: + A ``GradientOutput`` wrapping ``energy_model``. + """ + if ( + AtomicDataDict.FORCE_KEY in model.irreps_out + or AtomicDataDict.PARTIAL_FORCE_KEY in model.irreps_out + ): + raise ValueError("This model already has force outputs.") + return PartialForceOutputModule(func=model) diff --git a/nequip/nn/__init__.py b/nequip/nn/__init__.py index 383f619b..12ab085b 100644 --- a/nequip/nn/__init__.py +++ b/nequip/nn/__init__.py @@ -6,7 +6,7 @@ PerSpeciesScaleShift, ) # noqa: F401 from ._interaction_block import InteractionBlock # noqa: F401 -from ._grad_output import GradientOutput # noqa: F401 +from ._grad_output import GradientOutput, PartialForceOutput # noqa: F401 from ._rescale import RescaleOutput # noqa: F401 from ._convnetlayer import ConvNetLayer # noqa: F401 from ._util import SaveForOutput # noqa: F401 diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index b5ee9efc..cae0731f 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -97,3 +97,58 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data[k].requires_grad_(req_grad) return data + + +@compile_mode("unsupported") +class PartialForceOutput(GraphModuleMixin, torch.nn.Module): + r"""Wrap a model and include as an output its gradient.""" + vectorize: bool + + def __init__( + self, + func: GraphModuleMixin, + vectorize: bool = False, + ): + super().__init__() + # TODO wrap: + self.func = func + self.vectorize = vectorize + if vectorize: + # See https://pytorch.org/docs/stable/generated/torch.autograd.functional.jacobian.html + torch._C._debug_only_display_vmap_fallback_warnings(True) + + # check and init irreps + self._init_irreps( + irreps_in=func.irreps_in, + my_irreps_in={AtomicDataDict.PER_ATOM_ENERGY_KEY: Irreps("0e")}, + irreps_out=func.irreps_out, + ) + self.irreps_out[AtomicDataDict.PARTIAL_FORCE_KEY] = Irreps("1o") + self.irreps_out[AtomicDataDict.FORCE_KEY] = Irreps("1o") + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + data = data.copy() + out_data = {} + + def wrapper(pos: torch.Tensor) -> torch.Tensor: + """Wrapper from pos to atomic energy""" + nonlocal data, out_data + data[AtomicDataDict.POSITIONS_KEY] = pos + out_data = self.func(data) + return out_data[AtomicDataDict.PER_ATOM_ENERGY_KEY].squeeze(-1) + + pos = data[AtomicDataDict.POSITIONS_KEY] + + partial_forces = torch.autograd.functional.jacobian( + func=wrapper, + inputs=pos, + create_graph=self.training, # needed to allow gradients of this output during training + vectorize=self.vectorize, + ) + partial_forces = partial_forces.negative() + # output is [n_at, n_at, 3] + + out_data[AtomicDataDict.PARTIAL_FORCE_KEY] = partial_forces + out_data[AtomicDataDict.FORCE_KEY] = partial_forces.sum(dim=0) + + return out_data diff --git a/tests/unit/model/test_eng_force.py b/tests/unit/model/test_eng_force.py index 036b690e..3f041d16 100644 --- a/tests/unit/model/test_eng_force.py +++ b/tests/unit/model/test_eng_force.py @@ -223,6 +223,34 @@ def test_numeric_gradient(self, config, atomic_batch, device, float_tolerance): numeric, analytical, rtol=5e-2 ) + def test_partial_forces(self, atomic_batch, device): + config = minimal_config1.copy() + config["model_builders"] = [ + "EnergyModel", + "ForceOutput", + ] + partial_config = config.copy() + partial_config["model_builders"] = [ + "EnergyModel", + "PartialForceOutput", + ] + model = model_from_config(config=config, initialize=True) + partial_model = model_from_config(config=partial_config, initialize=True) + model.to(device) + partial_model.to(device) + partial_model.load_state_dict(model.state_dict()) + data = atomic_batch.to(device) + output = model(AtomicData.to_AtomicDataDict(data)) + output_partial = partial_model(AtomicData.to_AtomicDataDict(data)) + assert torch.allclose( + output[AtomicDataDict.FORCE_KEY], + output_partial[AtomicDataDict.FORCE_KEY], + atol=1e-6, + ) + n_at = data[AtomicDataDict.POSITIONS_KEY].shape[0] + assert output_partial[AtomicDataDict.PARTIAL_FORCE_KEY].shape == (n_at, n_at, 3) + # TODO check sparsity + class TestAutoGradient: def test_cross_frame_grad(self, config, nequip_dataset): From 794d6cc0075d5a5f5b6c1f28350e503889c7270d Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:19:10 -0500 Subject: [PATCH 02/52] test all keys --- tests/unit/model/test_eng_force.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/unit/model/test_eng_force.py b/tests/unit/model/test_eng_force.py index 3f041d16..d215c991 100644 --- a/tests/unit/model/test_eng_force.py +++ b/tests/unit/model/test_eng_force.py @@ -242,14 +242,23 @@ def test_partial_forces(self, atomic_batch, device): data = atomic_batch.to(device) output = model(AtomicData.to_AtomicDataDict(data)) output_partial = partial_model(AtomicData.to_AtomicDataDict(data)) - assert torch.allclose( - output[AtomicDataDict.FORCE_KEY], - output_partial[AtomicDataDict.FORCE_KEY], - atol=1e-6, - ) + # everything should be the same + # including the + for k in output: + assert k != AtomicDataDict.PARTIAL_FORCE_KEY + assert k in output_partial + if output[k].is_floating_point(): + assert torch.allclose( + output[k], + output_partial[k], + atol=1e-6 if k == AtomicDataDict.FORCE_KEY else 1e-8, + ) + else: + assert torch.equal(output[k], output_partial[k]) n_at = data[AtomicDataDict.POSITIONS_KEY].shape[0] - assert output_partial[AtomicDataDict.PARTIAL_FORCE_KEY].shape == (n_at, n_at, 3) - # TODO check sparsity + partial_forces = output_partial[AtomicDataDict.PARTIAL_FORCE_KEY] + assert partial_forces.shape == (n_at, n_at, 3) + # TODO check sparsity? class TestAutoGradient: From f108374b06efe9acc5a15a5863ef7300444b115a Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:28:54 -0500 Subject: [PATCH 03/52] warnings --- nequip/nn/_grad_output.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index cae0731f..afe965ae 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -101,19 +101,26 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: @compile_mode("unsupported") class PartialForceOutput(GraphModuleMixin, torch.nn.Module): - r"""Wrap a model and include as an output its gradient.""" + r"""Generate partial and total forces from an energy model. + + Args: + func: the energy model + vectorize: the vectorize option to ``torch.autograd.functional.jacobian``, + false by default since it doesn't work well. + """ vectorize: bool def __init__( self, func: GraphModuleMixin, vectorize: bool = False, + vectorize_warnings: bool = False, ): super().__init__() # TODO wrap: self.func = func self.vectorize = vectorize - if vectorize: + if vectorize_warnings: # See https://pytorch.org/docs/stable/generated/torch.autograd.functional.jacobian.html torch._C._debug_only_display_vmap_fallback_warnings(True) From c0c17e37ca0e95ae61d967de8fe8d97431e91860 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 17 Nov 2021 23:31:25 -0500 Subject: [PATCH 04/52] initial --- configs/minimal.yaml | 1 + nequip/model/_build.py | 8 ++++++-- nequip/model/_eng.py | 40 ++++++++++++++++++++++++++++++++++------ 3 files changed, 41 insertions(+), 8 deletions(-) diff --git a/configs/minimal.yaml b/configs/minimal.yaml index fa05bbbb..a3af4532 100644 --- a/configs/minimal.yaml +++ b/configs/minimal.yaml @@ -35,6 +35,7 @@ chemical_symbol_to_type: # 0: my_type # 1: atom # 2: thing +avg_num_neighbors: auto # logging wandb: false diff --git a/nequip/model/_build.py b/nequip/model/_build.py index 4f1ae7dd..4247cf48 100644 --- a/nequip/model/_build.py +++ b/nequip/model/_build.py @@ -75,9 +75,13 @@ def model_from_config( if "dataset" in pnames: if "initialize" not in pnames: raise ValueError("Cannot request dataset without requesting initialize") - if initialize and dataset is None: + if ( + initialize + and pnames["dataset"].default != inspect.Parameter.empty + and dataset is None + ): raise RuntimeError( - f"Builder {builder.__name__} asked for the dataset, initialize is true, but no dataset was provided to `model_from_config`." + f"Builder {builder.__name__} requires the dataset, initialize is true, but no dataset was provided to `model_from_config`." ) params["dataset"] = dataset if "model" in pnames: diff --git a/nequip/model/_eng.py b/nequip/model/_eng.py index 314bef32..49ca2e95 100644 --- a/nequip/model/_eng.py +++ b/nequip/model/_eng.py @@ -1,6 +1,10 @@ +from typing import Optional import logging -from nequip.data import AtomicDataDict +import torch +from nequip import data + +from nequip.data import AtomicDataDict, AtomicDataset from nequip.nn import ( SequentialGraphNetwork, AtomwiseLinear, @@ -14,13 +18,40 @@ ) -def EnergyModel(config) -> SequentialGraphNetwork: +def EnergyModel( + config, initialize: bool, dataset: Optional[AtomicDataset] = None +) -> SequentialGraphNetwork: """Base default energy model archetecture. For minimal and full configuration option listings, see ``minimal.yaml`` and ``example.yaml``. """ logging.debug("Start building the network model") + # Compute avg_num_neighbors + annkey: str = "avg_num_neighbors" + if config.get(annkey, None) == "auto" and initialize: + if dataset is None: + raise ValueError( + "When avg_num_neighbors = auto, the dataset is required to build+initialize a model" + ) + config[annkey] = dataset.statistics( + fields=[ + lambda data: ( + torch.unique( + data[AtomicDataDict.EDGE_INDEX_KEY][0], return_counts=True + )[1], + "node", + ) + ], + modes=["mean_std"], + stride=config.dataset_statistics_stride, + )[0][0].item() + else: + # make sure its valid + ann = config.get(annkey, None) + if ann is not None: + assert isinstance(ann, float) or isinstance(ann, int) + num_layers = config.get("num_layers", 3) layers = { @@ -59,7 +90,4 @@ def EnergyModel(config) -> SequentialGraphNetwork: ), ) - return SequentialGraphNetwork.from_parameters( - shared_params=config, - layers=layers, - ) + return SequentialGraphNetwork.from_parameters(shared_params=config, layers=layers,) From 69c071758378aadf2ba1152afe2b98837012310f Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 22 Nov 2021 21:56:26 -0500 Subject: [PATCH 05/52] refactor --- nequip/model/__init__.py | 16 ++++++++------- nequip/model/_eng.py | 37 +++++++++-------------------------- nequip/model/builder_utils.py | 36 ++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 35 deletions(-) create mode 100644 nequip/model/builder_utils.py diff --git a/nequip/model/__init__.py b/nequip/model/__init__.py index b849efed..7f0e523d 100644 --- a/nequip/model/__init__.py +++ b/nequip/model/__init__.py @@ -2,14 +2,16 @@ from ._grads import ForceOutput from ._scaling import RescaleEnergyEtc, PerSpeciesRescale from ._weight_init import uniform_initialize_FCs - from ._build import model_from_config +from . import builder_utils + __all__ = [ - "EnergyModel", - "ForceOutput", - "RescaleEnergyEtc", - "PerSpeciesRescale", - "uniform_initialize_FCs", - "model_from_config", + EnergyModel, + ForceOutput, + RescaleEnergyEtc, + PerSpeciesRescale, + uniform_initialize_FCs, + model_from_config, + builder_utils, ] diff --git a/nequip/model/_eng.py b/nequip/model/_eng.py index 49ca2e95..30051328 100644 --- a/nequip/model/_eng.py +++ b/nequip/model/_eng.py @@ -1,9 +1,6 @@ from typing import Optional import logging -import torch -from nequip import data - from nequip.data import AtomicDataDict, AtomicDataset from nequip.nn import ( SequentialGraphNetwork, @@ -17,6 +14,8 @@ SphericalHarmonicEdgeAttrs, ) +from . import builder_utils + def EnergyModel( config, initialize: bool, dataset: Optional[AtomicDataset] = None @@ -27,30 +26,9 @@ def EnergyModel( """ logging.debug("Start building the network model") - # Compute avg_num_neighbors - annkey: str = "avg_num_neighbors" - if config.get(annkey, None) == "auto" and initialize: - if dataset is None: - raise ValueError( - "When avg_num_neighbors = auto, the dataset is required to build+initialize a model" - ) - config[annkey] = dataset.statistics( - fields=[ - lambda data: ( - torch.unique( - data[AtomicDataDict.EDGE_INDEX_KEY][0], return_counts=True - )[1], - "node", - ) - ], - modes=["mean_std"], - stride=config.dataset_statistics_stride, - )[0][0].item() - else: - # make sure its valid - ann = config.get(annkey, None) - if ann is not None: - assert isinstance(ann, float) or isinstance(ann, int) + builder_utils.add_avg_num_neighbors( + config=config, initialize=initialize, dataset=dataset + ) num_layers = config.get("num_layers", 3) @@ -90,4 +68,7 @@ def EnergyModel( ), ) - return SequentialGraphNetwork.from_parameters(shared_params=config, layers=layers,) + return SequentialGraphNetwork.from_parameters( + shared_params=config, + layers=layers, + ) diff --git a/nequip/model/builder_utils.py b/nequip/model/builder_utils.py new file mode 100644 index 00000000..52a856b8 --- /dev/null +++ b/nequip/model/builder_utils.py @@ -0,0 +1,36 @@ +from typing import Optional + +import torch + +from nequip.utils import Config +from nequip.data import AtomicDataset, AtomicDataDict + + +def add_avg_num_neighbors( + config: Config, initialize: bool, dataset: Optional[AtomicDataset] = None +) -> Optional[float]: + # Compute avg_num_neighbors + annkey: str = "avg_num_neighbors" + if config.get(annkey, None) == "auto" and initialize: + if dataset is None: + raise ValueError( + "When avg_num_neighbors = auto, the dataset is required to build+initialize a model" + ) + config[annkey] = dataset.statistics( + fields=[ + lambda data: ( + torch.unique( + data[AtomicDataDict.EDGE_INDEX_KEY][0], return_counts=True + )[1], + "node", + ) + ], + modes=["mean_std"], + stride=config.dataset_statistics_stride, + )[0][0].item() + + # make sure its valid + ann = config.get(annkey, None) + if ann is not None: + config[annkey] = float(config[annkey]) + return config[annkey] From 2cb8f752ec85e11a9f1e9dd29aefbc568b411fec Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 22 Nov 2021 22:13:37 -0500 Subject: [PATCH 06/52] tests --- nequip/model/builder_utils.py | 20 ++++++++------ tests/unit/model/test_builder_utils.py | 38 ++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) create mode 100644 tests/unit/model/test_builder_utils.py diff --git a/nequip/model/builder_utils.py b/nequip/model/builder_utils.py index 52a856b8..2f93c51f 100644 --- a/nequip/model/builder_utils.py +++ b/nequip/model/builder_utils.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -7,16 +7,20 @@ def add_avg_num_neighbors( - config: Config, initialize: bool, dataset: Optional[AtomicDataset] = None + config: Config, + initialize: bool, + dataset: Optional[AtomicDataset] = None, + default: Optional[Union[str, float]] = "auto", ) -> Optional[float]: # Compute avg_num_neighbors annkey: str = "avg_num_neighbors" - if config.get(annkey, None) == "auto" and initialize: + ann = config.get(annkey, default) + if ann == "auto" and initialize: if dataset is None: raise ValueError( "When avg_num_neighbors = auto, the dataset is required to build+initialize a model" ) - config[annkey] = dataset.statistics( + ann = dataset.statistics( fields=[ lambda data: ( torch.unique( @@ -26,11 +30,11 @@ def add_avg_num_neighbors( ) ], modes=["mean_std"], - stride=config.dataset_statistics_stride, + stride=config.get("dataset_statistics_stride", 1), )[0][0].item() # make sure its valid - ann = config.get(annkey, None) if ann is not None: - config[annkey] = float(config[annkey]) - return config[annkey] + ann = float(ann) + config[annkey] = ann + return ann diff --git a/tests/unit/model/test_builder_utils.py b/tests/unit/model/test_builder_utils.py new file mode 100644 index 00000000..dcb45d4d --- /dev/null +++ b/tests/unit/model/test_builder_utils.py @@ -0,0 +1,38 @@ +import pytest + +import torch + +from nequip.data import AtomicDataDict + +from nequip.model.builder_utils import add_avg_num_neighbors + + +def test_avg_num_neighbors(nequip_dataset): + # test basic options + annkey = "avg_num_neighbors" + config = {annkey: 3} + add_avg_num_neighbors(config, initialize=False, dataset=None) + assert config[annkey] == 3.0 # nothing should happen + assert isinstance(config[annkey], float) + + config = {annkey: 3} + # dont need dataset if config isn't auto + add_avg_num_neighbors(config, initialize=False, dataset=None) + with pytest.raises(ValueError): + # need if it is + config = {annkey: "auto"} + add_avg_num_neighbors(config, initialize=True, dataset=None) + + # compute dumb truth + num_neigh = [] + for i in range(len(nequip_dataset)): + frame = nequip_dataset[i] + num_neigh.append( + torch.bincount(frame[AtomicDataDict.EDGE_INDEX_KEY][0]).float() + ) + avg_num_neighbor_truth = torch.mean(torch.cat(num_neigh, dim=0)) + + # compare + config = {annkey: "auto"} + add_avg_num_neighbors(config, initialize=True, dataset=nequip_dataset) + assert config[annkey] == avg_num_neighbor_truth From 84912dd8f755da0915536dfb661e72e2eec7ce8e Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 22 Nov 2021 22:14:25 -0500 Subject: [PATCH 07/52] update default --- configs/example.yaml | 2 +- configs/full.yaml | 2 +- configs/minimal.yaml | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/configs/example.yaml b/configs/example.yaml index 70dfdd98..765623c2 100644 --- a/configs/example.yaml +++ b/configs/example.yaml @@ -39,7 +39,7 @@ PolynomialCutoff_p: 6 # radial network invariant_layers: 2 # number of radial layers, usually 1-3 works best, smaller is faster invariant_neurons: 64 # number of hidden neurons in radial function, smaller is faster -avg_num_neighbors: null # number of neighbors to divide by, null => no normalization. +avg_num_neighbors: auto # number of neighbors to divide by, null => no normalization. use_sc: true # use self-connection or not, usually gives big improvement # data set diff --git a/configs/full.yaml b/configs/full.yaml index f19e6f5c..6b369293 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -45,7 +45,7 @@ PolynomialCutoff_p: 6 # radial network invariant_layers: 2 # number of radial layers, usually 1-3 works best, smaller is faster invariant_neurons: 64 # number of hidden neurons in radial function, smaller is faster -avg_num_neighbors: null # number of neighbors to divide by, null => no normalization. +avg_num_neighbors: auto # number of neighbors to divide by, null => no normalization. use_sc: true # use self-connection or not, usually gives big improvement compile_model: false # whether to compile the constructed model to TorchScript diff --git a/configs/minimal.yaml b/configs/minimal.yaml index a3af4532..fa05bbbb 100644 --- a/configs/minimal.yaml +++ b/configs/minimal.yaml @@ -35,7 +35,6 @@ chemical_symbol_to_type: # 0: my_type # 1: atom # 2: thing -avg_num_neighbors: auto # logging wandb: false From 7752e5c52fc57087da1bab9d26f9fa6fc511521a Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 22 Nov 2021 22:20:06 -0500 Subject: [PATCH 08/52] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 133982db..41e84a1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. ## [Unreleased] - 0.5.0 +### Added +- Added `avg_num_neighbors: auto` option + ### Changed - Allow e3nn 0.4.*, which changes the default normalization of `TensorProduct`s; this change _should_ not affect typical NequIP networks - Deployed are now frozen on load, rather than compile From 26a78fbb4bfc080176cf405564278b9ed720627b Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 22 Nov 2021 22:29:22 -0500 Subject: [PATCH 09/52] make tests pass --- tests/unit/model/test_eng_force.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/model/test_eng_force.py b/tests/unit/model/test_eng_force.py index 05fca69a..4b9feb3b 100644 --- a/tests/unit/model/test_eng_force.py +++ b/tests/unit/model/test_eng_force.py @@ -21,6 +21,7 @@ COMMON_CONFIG = { "num_types": 3, "types_names": ["H", "C", "O"], + "avg_num_neighbors": None, } r_max = 3 minimal_config1 = dict( From 0a93d99a81a911beb9547814a692f0cfc8edefe0 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 22 Nov 2021 22:39:37 -0500 Subject: [PATCH 10/52] fix check --- nequip/model/_build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/model/_build.py b/nequip/model/_build.py index 4247cf48..c618bad3 100644 --- a/nequip/model/_build.py +++ b/nequip/model/_build.py @@ -77,7 +77,7 @@ def model_from_config( raise ValueError("Cannot request dataset without requesting initialize") if ( initialize - and pnames["dataset"].default != inspect.Parameter.empty + and pnames["dataset"].default == inspect.Parameter.empty and dataset is None ): raise RuntimeError( From a46e6ca563f5b5c65872122b5d0bfc8a92500262 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Sun, 28 Nov 2021 13:40:32 -0500 Subject: [PATCH 11/52] fix translation testing --- CHANGELOG.md | 2 ++ nequip/utils/test.py | 26 ++++++++++++++++++-------- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7040c1b..d368a0c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. ## [Unreleased] +### Fixed +- Equivariance testing no longer unintentionally skips translation ## [0.5.0] - 2021-11-24 ### Changed diff --git a/nequip/utils/test.py b/nequip/utils/test.py index da2fe4ef..359bde01 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -151,14 +151,24 @@ def assert_AtomicData_equivariant( # == Test rotation, parity, and translation using e3nn == irreps_in = {k: None for k in AtomicDataDict.ALLOWED_KEYS} - irreps_in.update( - { - AtomicDataDict.POSITIONS_KEY: "cartesian_points", - AtomicDataDict.CELL_KEY: "3x1o", - } - ) irreps_in.update(func.irreps_in) irreps_in = {k: v for k, v in irreps_in.items() if k in data_in} + irreps_out = func.irreps_out.copy() + # for certain things, we don't care what the given irreps are... + # make sure that we test correctly for equivariance: + for irps in (irreps_in, irreps_out): + if AtomicDataDict.POSITIONS_KEY in irps: + # it should always have been 1o vectors + # since that's actually a valid Irreps + assert o3.Irreps(irps[AtomicDataDict.POSITIONS_KEY]) == o3.Irreps("1o") + irps[AtomicDataDict.POSITIONS_KEY] = "cartesian_points" + if AtomicDataDict.CELL_KEY in irps: + prev_cell_irps = irps[AtomicDataDict.CELL_KEY] + assert prev_cell_irps is None or o3.Irreps(prev_cell_irps) == o3.Irreps( + "3x1o" + ) + # must be this to actually rotate it + irps[AtomicDataDict.CELL_KEY] = "3x1o" def wrapper(*args): arg_dict = {k: v for k, v in zip(irreps_in, args)} @@ -175,7 +185,7 @@ def wrapper(*args): cell = arg_dict[AtomicDataDict.CELL_KEY] assert cell.shape[-2:] == (3, 3) arg_dict[AtomicDataDict.CELL_KEY] = cell.reshape(cell.shape[:-2] + (9,)) - return [output[k] for k in func.irreps_out] + return [output[k] for k in irreps_out] data_in = AtomicData.to_AtomicDataDict(data_in) # cell is a special case @@ -191,7 +201,7 @@ def wrapper(*args): wrapper, args_in=args_in, irreps_in=list(irreps_in.values()), - irreps_out=list(func.irreps_out.values()), + irreps_out=list(irreps_out.values()), **kwargs, ) From c9ec76fd1d31f67ef42f416f3d1f352aa5458cc7 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 3 Dec 2021 00:19:03 -0500 Subject: [PATCH 12/52] fix per graph field cat dim --- CHANGELOG.md | 1 + nequip/data/AtomicData.py | 18 ++++++------------ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d368a0c9..0b2cb106 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Most recent change on the bottom. ## [Unreleased] ### Fixed - Equivariance testing no longer unintentionally skips translation +- Correct cat dim for all registered per-graph fields ## [0.5.0] - 2021-11-24 ### Changed diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 81a2978d..256d7d53 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -504,15 +504,13 @@ def irreps(self): return self.__irreps__ def __cat_dim__(self, key, value): - if key in ( - AtomicDataDict.CELL_KEY, - AtomicDataDict.PBC_KEY, - AtomicDataDict.TOTAL_ENERGY_KEY, - ): - # the cell and PBC are graph-level properties and so need a new batch dimension + if key == AtomicDataDict.EDGE_INDEX_KEY: + return 1 # always cat in the edge dimension + elif key in _GRAPH_FIELDS: + # graph-level properties and so need a new batch dimension return None else: - return super().__cat_dim__(key, value) + return 0 # cat along node/edge dimension def without_nodes(self, which_nodes): """Return a copy of ``self`` with ``which_nodes`` removed. @@ -665,9 +663,5 @@ def neighbor_list_and_relative_vec( (torch.LongTensor(first_idex), torch.LongTensor(second_idex)) ).to(device=out_device) - shifts = torch.as_tensor( - shifts, - dtype=out_dtype, - device=out_device, - ) + shifts = torch.as_tensor(shifts, dtype=out_dtype, device=out_device,) return edge_index, shifts, cell_tensor From 7000f7bd0d8fbadf0368e7f77046ab7f73aff87b Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 3 Dec 2021 01:09:36 -0500 Subject: [PATCH 13/52] add `chemical_symbols` --- CHANGELOG.md | 3 +++ configs/example.yaml | 8 ++++---- configs/full.yaml | 23 ++++++++++++++--------- configs/minimal.yaml | 15 +++++---------- configs/minimal_eng.yaml | 15 +++++---------- docs/options/dataset.rst | 5 +++++ nequip/data/dataset.py | 18 ++++++------------ nequip/data/transforms.py | 20 ++++++++++++++++++-- tests/unit/model/test_eng_force.py | 16 ++++------------ 9 files changed, 64 insertions(+), 59 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b2cb106..2bdd51b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. ## [Unreleased] +### Added +- The types may now be specified with a simpler `chemical_symbols` option + ### Fixed - Equivariance testing no longer unintentionally skips translation - Correct cat dim for all registered per-graph fields diff --git a/configs/example.yaml b/configs/example.yaml index 70dfdd98..d8214462 100644 --- a/configs/example.yaml +++ b/configs/example.yaml @@ -58,10 +58,10 @@ key_mapping: npz_fixed_field_keys: # fields that are repeated across different examples - atomic_numbers -# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. -chemical_symbol_to_type: - H: 0 - C: 1 +# A list of atomic types to be found in the data. The NequIP types will be named with the chemical symbols, and inputs with the correct atomic numbers will be mapped to the corresponding types. +chemical_symbols: + - H + - C # logging wandb: true # we recommend using wandb for logging, we'll turn it off here as it's optional diff --git a/configs/full.yaml b/configs/full.yaml index 57d99094..7dfa442e 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -88,16 +88,21 @@ npz_fixed_field_keys: # key_mapping: # free_energy: total_energy -# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. -chemical_symbol_to_type: - H: 0 - C: 1 - -# Alternatively, if the dataset has type indicess, the total number of types is all that is required: +# A list of atomic types to be found in the data. The NequIP types will be named with the chemical symbols, and inputs with the correct atomic numbers will be mapped to the corresponding types. +chemical_symbols: + - H + - C +# Alternatively, you may explicitly specify which chemical species maps to which type in NequIP (type index; the name is still taken from the chemical symbol) +# chemical_symbol_to_type: +# H: 0 +# C: 1 + +# Alternatively, if the dataset has type indices, you may give the names for the types in order: +# (this also sets the number of types) # type_names: -# 0: my_type -# 1: atom -# 2: thing +# - my_type +# - atom +# - thing # As an alternative option to npz, you can also pass data ase ASE Atoms-objects # This can often be easier to work with, simply make sure the ASE Atoms object diff --git a/configs/minimal.yaml b/configs/minimal.yaml index fa05bbbb..b5afb0f0 100644 --- a/configs/minimal.yaml +++ b/configs/minimal.yaml @@ -25,16 +25,11 @@ key_mapping: R: pos # raw atomic positions npz_fixed_field_keys: # fields that are repeated across different examples - atomic_numbers -# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. -chemical_symbol_to_type: - H: 0 - C: 1 - O: 2 -# Alternatively, if the dataset has type indexes, the total number of types is all that is required: -# type_names: -# 0: my_type -# 1: atom -# 2: thing + +chemical_symbols: + - H + - O + - C # logging wandb: false diff --git a/configs/minimal_eng.yaml b/configs/minimal_eng.yaml index 0b4ad0d3..fe002fbc 100644 --- a/configs/minimal_eng.yaml +++ b/configs/minimal_eng.yaml @@ -29,16 +29,11 @@ key_mapping: R: pos # raw atomic positions npz_fixed_field_keys: # fields that are repeated across different examples - atomic_numbers -# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. -chemical_symbol_to_type: - H: 0 - C: 1 - O: 2 -# Alternatively, if the dataset has type indexes, the total number of types is all that is required: -# type_names: -# 0: my_type -# 1: atom -# 2: thing + +chemical_symbols: + - H + - O + - C # logging wandb: false diff --git a/docs/options/dataset.rst b/docs/options/dataset.rst index 22cd701e..54b39fc9 100644 --- a/docs/options/dataset.rst +++ b/docs/options/dataset.rst @@ -13,6 +13,11 @@ type_names | Type: NoneType | Default: ``None`` +chemical_symbols +^^^^^^^^^^^^^^^^ + | Type: NoneType + | Default: ``None`` + chemical_symbol_to_type ^^^^^^^^^^^^^^^^^^^^^^^ | Type: NoneType diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index a83d65c9..f650c6cc 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -526,10 +526,7 @@ def statistics( @staticmethod def _per_atom_statistics( - ana_mode: str, - arr: torch.Tensor, - batch: torch.Tensor, - unbiased: bool = True, + ana_mode: str, arr: torch.Tensor, batch: torch.Tensor, unbiased: bool = True, ): """Compute "per-atom" statistics that are normalized by the number of atoms in the system. @@ -743,8 +740,8 @@ class ASEDataset(AtomicInMemoryDataset): - user_label key_mapping: user_label: label0 - chemical_symbol_to_type: - H: 0 + chemical_symbols: + - H ``` for VASP parser, the yaml input should be @@ -755,8 +752,8 @@ class ASEDataset(AtomicInMemoryDataset): format: vasp-out key_mapping: free_energy: total_energy - chemical_symbol_to_type: - H: 0 + chemical_symbols: + - H ``` """ @@ -846,10 +843,7 @@ def get_data(self): atoms_list = self.get_atoms() # skip the None arguments - kwargs = dict( - include_keys=self.include_keys, - key_mapping=self.key_mapping, - ) + kwargs = dict(include_keys=self.include_keys, key_mapping=self.key_mapping,) kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs.update(self.extra_fixed_fields) diff --git a/nequip/data/transforms.py b/nequip/data/transforms.py index 7ac9d724..0c6735d1 100644 --- a/nequip/data/transforms.py +++ b/nequip/data/transforms.py @@ -20,7 +20,23 @@ def __init__( self, type_names: Optional[List[str]] = None, chemical_symbol_to_type: Optional[Dict[str, int]] = None, + chemical_symbols: Optional[List[str]] = None, ): + if chemical_symbols is not None: + if chemical_symbol_to_type is not None: + raise ValueError( + "Cannot provide both `chemical_symbols` and `chemical_symbol_to_type`" + ) + # repro old, sane NequIP behaviour + # checks also for validity of keys + atomic_nums = [ase.data.atomic_numbers[sym] for sym in chemical_symbols] + # https://stackoverflow.com/questions/29876580/how-to-sort-a-list-according-to-another-list-python + chemical_symbols = [ + e[1] for e in sorted(zip(atomic_nums, chemical_symbols)) + ] + chemical_symbol_to_type = {k: i for i, k in enumerate(chemical_symbols)} + del chemical_symbols + # Build from chem->type mapping, if provided self.chemical_symbol_to_type = chemical_symbol_to_type if self.chemical_symbol_to_type is not None: @@ -57,7 +73,7 @@ def __init__( # check if type_names is None: raise ValueError( - "Neither chemical_symbol_to_type nor type_names was provided; one or the other is required" + "None of chemical_symbols, chemical_symbol_to_type, nor type_names was provided; exactly one is required" ) # validate type names assert all( @@ -79,7 +95,7 @@ def __call__( elif AtomicDataDict.ATOMIC_NUMBERS_KEY in data: assert ( self.chemical_symbol_to_type is not None - ), "Atomic numbers provided but there is no chemical_symbol_to_type mapping!" + ), "Atomic numbers provided but there is no chemical_symbols/chemical_symbol_to_type mapping!" atomic_numbers = data[AtomicDataDict.ATOMIC_NUMBERS_KEY] del data[AtomicDataDict.ATOMIC_NUMBERS_KEY] diff --git a/tests/unit/model/test_eng_force.py b/tests/unit/model/test_eng_force.py index 05fca69a..c0f30d36 100644 --- a/tests/unit/model/test_eng_force.py +++ b/tests/unit/model/test_eng_force.py @@ -79,14 +79,8 @@ def config(request): @pytest.fixture( params=[ - ( - ["EnergyModel", "ForceOutput"], - AtomicDataDict.FORCE_KEY, - ), - ( - ["EnergyModel"], - AtomicDataDict.TOTAL_ENERGY_KEY, - ), + (["EnergyModel", "ForceOutput"], AtomicDataDict.FORCE_KEY,), + (["EnergyModel"], AtomicDataDict.TOTAL_ENERGY_KEY,), ] ) def model(request, config): @@ -138,9 +132,7 @@ def test_jit(self, model, atomic_batch, device): model_script = script(instance) assert torch.allclose( - instance(data)[out_field], - model_script(data)[out_field], - atol=1e-6, + instance(data)[out_field], model_script(data)[out_field], atol=1e-6, ) # - Try saving, loading in another process, and running - @@ -273,7 +265,7 @@ def test_large_separation(self, model, config, molecules): atoms2.positions += 40.0 + np.random.randn(3) atoms_both = atoms1.copy() atoms_both.extend(atoms2) - tm = TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}) + tm = TypeMapper(chemical_symbols=["H", "C", "O"]) data1 = tm(AtomicData.from_ase(atoms1, r_max=r_max)) data2 = tm(AtomicData.from_ase(atoms2, r_max=r_max)) data_both = tm(AtomicData.from_ase(atoms_both, r_max=r_max)) From 5413a91ae3bfb401bd05782e86df8f683fafdd63 Mon Sep 17 00:00:00 2001 From: Lixin Sun Date: Fri, 3 Dec 2021 12:30:32 -0800 Subject: [PATCH 14/52] customize data fields for scalars, long and others. (#108) * add stride argument * update unit tests * format * update kwargs * add scalar field and per atom energies to ase one * add long field * change abbreviation * flakes * Update nequip/train/_key.py * handle shape difference * update graph_field * remove unsqueeze for graph_field --- nequip/data/AtomicData.py | 50 +++++++++++++++++++++++------- nequip/data/_build.py | 6 +--- nequip/model/__init__.py | 3 +- nequip/train/_key.py | 1 + tests/unit/data/test_AtomicData.py | 5 +-- 5 files changed, 46 insertions(+), 19 deletions(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 256d7d53..f0e0a72f 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -24,6 +24,17 @@ # A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case) PBC = Union[bool, Tuple[bool, bool, bool]] +_DEFAULT_SCALAR_FIELDS: Set[str] = { + AtomicDataDict.ATOMIC_NUMBERS_KEY, + AtomicDataDict.ATOM_TYPE_KEY, + AtomicDataDict.BATCH_KEY, +} +_DEFAULT_LONG_FIELDS: Set[str] = { + AtomicDataDict.EDGE_INDEX_KEY, + AtomicDataDict.ATOMIC_NUMBERS_KEY, + AtomicDataDict.ATOM_TYPE_KEY, + AtomicDataDict.BATCH_KEY, +} _DEFAULT_NODE_FIELDS: Set[str] = { AtomicDataDict.POSITIONS_KEY, AtomicDataDict.WEIGHTS_KEY, @@ -44,16 +55,22 @@ } _DEFAULT_GRAPH_FIELDS: Set[str] = { AtomicDataDict.TOTAL_ENERGY_KEY, + AtomicDataDict.PBC_KEY, + AtomicDataDict.CELL_KEY, } _NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS) _EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS) _GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS) +_SCALAR_FIELDS: Set[str] = set(_DEFAULT_SCALAR_FIELDS) +_LONG_FIELDS: Set[str] = set(_DEFAULT_LONG_FIELDS) def register_fields( node_fields: Sequence[str] = [], edge_fields: Sequence[str] = [], graph_fields: Sequence[str] = [], + scalar_fields: Sequence[str] = [], + long_fields: Sequence[str] = [], ) -> None: r"""Register fields as being per-atom, per-edge, or per-frame. @@ -64,11 +81,14 @@ def register_fields( node_fields: set = set(node_fields) edge_fields: set = set(edge_fields) graph_fields: set = set(graph_fields) + scalar_fields: set = set(scalar_fields) allfields = node_fields.union(edge_fields, graph_fields) assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields) _NODE_FIELDS.update(node_fields) _EDGE_FIELDS.update(edge_fields) _GRAPH_FIELDS.update(graph_fields) + _SCALAR_FIELDS.update(scalar_fields) + _LONG_FIELDS.update(long_fields) if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < ( len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS) ): @@ -135,12 +155,7 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): AtomicDataDict.validate_keys(kwargs) # Deal with _some_ dtype issues for k, v in kwargs.items(): - if ( - k == AtomicDataDict.EDGE_INDEX_KEY - or k == AtomicDataDict.ATOMIC_NUMBERS_KEY - or k == AtomicDataDict.ATOM_TYPE_KEY - or k == AtomicDataDict.BATCH_KEY - ): + if k in _LONG_FIELDS: # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems) # int32 would pass later checks, but is actually disallowed by torch kwargs[k] = torch.as_tensor(v, dtype=torch.long) @@ -165,7 +180,15 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): for k, v in kwargs.items(): - if len(kwargs[k].shape) == 0: + if len(v.shape) == 0: + kwargs[k] = v.unsqueeze(-1) + v = kwargs[k] + + if ( + k in set.union(_NODE_FIELDS, _EDGE_FIELDS) + and k not in _SCALAR_FIELDS + and len(v.shape) == 1 + ): kwargs[k] = v.unsqueeze(-1) v = kwargs[k] @@ -184,9 +207,7 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): f"{k} is a edge field but has the wrong dimension {v.shape}" ) elif k in _GRAPH_FIELDS: - if num_frames == 1 and v.shape[0] != 1: - kwargs[k] = v.unsqueeze(0) - elif v.shape[0] != num_frames: + if num_frames > 1 and v.shape[0] != num_frames: raise ValueError(f"Wrong shape for graph property {k}") super().__init__(num_nodes=len(kwargs["pos"]), **kwargs) @@ -425,6 +446,7 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: cell = getattr(self, AtomicDataDict.CELL_KEY, None) batch = getattr(self, AtomicDataDict.BATCH_KEY, None) energy = getattr(self, AtomicDataDict.TOTAL_ENERGY_KEY, None) + energies = getattr(self, AtomicDataDict.PER_ATOM_ENERGY_KEY, None) force = getattr(self, AtomicDataDict.FORCE_KEY, None) do_calc = energy is not None or force is not None @@ -456,6 +478,8 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: if do_calc: fields = {} + if energies is not None: + fields["energies"] = energies[mask].cpu().numpy() if energy is not None: fields["energy"] = energy[batch_idx].cpu().numpy() if force is not None: @@ -663,5 +687,9 @@ def neighbor_list_and_relative_vec( (torch.LongTensor(first_idex), torch.LongTensor(second_idex)) ).to(device=out_device) - shifts = torch.as_tensor(shifts, dtype=out_dtype, device=out_device,) + shifts = torch.as_tensor( + shifts, + dtype=out_dtype, + device=out_device, + ) return edge_index, shifts, cell_tensor diff --git a/nequip/data/_build.py b/nequip/data/_build.py index 670b6156..7645bcb7 100644 --- a/nequip/data/_build.py +++ b/nequip/data/_build.py @@ -72,11 +72,7 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config) # Register fields: - register_fields( - node_fields=config.get("node_fields", []), - edge_fields=config.get("edge_fields", []), - graph_fields=config.get("graph_fields", []), - ) + instantiate(register_fields, all_args=config) instance, _ = instantiate( class_name, diff --git a/nequip/model/__init__.py b/nequip/model/__init__.py index b849efed..670004dc 100644 --- a/nequip/model/__init__.py +++ b/nequip/model/__init__.py @@ -1,7 +1,7 @@ from ._eng import EnergyModel from ._grads import ForceOutput from ._scaling import RescaleEnergyEtc, PerSpeciesRescale -from ._weight_init import uniform_initialize_FCs +from ._weight_init import uniform_initialize_FCs, initialize_from_state from ._build import model_from_config @@ -11,5 +11,6 @@ "RescaleEnergyEtc", "PerSpeciesRescale", "uniform_initialize_FCs", + "initialize_from_state", "model_from_config", ] diff --git a/nequip/train/_key.py b/nequip/train/_key.py index f3582ebd..17057b8b 100644 --- a/nequip/train/_key.py +++ b/nequip/train/_key.py @@ -12,6 +12,7 @@ ABBREV = { AtomicDataDict.TOTAL_ENERGY_KEY: "e", + AtomicDataDict.PER_ATOM_ENERGY_KEY: "Ei", AtomicDataDict.FORCE_KEY: "f", LOSS_KEY: "loss", VALIDATION: "val", diff --git a/tests/unit/data/test_AtomicData.py b/tests/unit/data/test_AtomicData.py index f5d6dc27..5a1631db 100644 --- a/tests/unit/data/test_AtomicData.py +++ b/tests/unit/data/test_AtomicData.py @@ -41,8 +41,9 @@ def test_to_ase_batches(atomic_batch): assert np.array_equal( atoms.get_atomic_numbers(), atomic_data[AtomicDataDict.ATOM_TYPE_KEY][mask] ) - assert np.array_equal(atoms.get_cell(), atomic_data.cell[batch_idx]) - assert np.array_equal(atoms.get_pbc(), atomic_data.pbc[batch_idx]) + + assert np.max(np.abs(atoms.get_cell()[:]-atomic_data.cell[batch_idx].numpy())) == 0 + assert not np.logical_xor(atoms.get_pbc(), atomic_data.pbc[batch_idx].numpy()).all() def test_ase_roundtrip(CuFcc): From 61b8b9d00c80cba71f96f25ce2b51834647e5c91 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 3 Dec 2021 16:08:05 -0500 Subject: [PATCH 15/52] fix for consistant field shapes --- nequip/data/AtomicData.py | 18 +++--------------- nequip/data/dataset.py | 2 +- nequip/nn/embedding/_one_hot.py | 2 +- tests/unit/data/test_AtomicData.py | 12 +++++++++--- tests/unit/data/test_dataset.py | 4 +++- 5 files changed, 17 insertions(+), 21 deletions(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index f0e0a72f..e8bb6596 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -24,11 +24,7 @@ # A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case) PBC = Union[bool, Tuple[bool, bool, bool]] -_DEFAULT_SCALAR_FIELDS: Set[str] = { - AtomicDataDict.ATOMIC_NUMBERS_KEY, - AtomicDataDict.ATOM_TYPE_KEY, - AtomicDataDict.BATCH_KEY, -} + _DEFAULT_LONG_FIELDS: Set[str] = { AtomicDataDict.EDGE_INDEX_KEY, AtomicDataDict.ATOMIC_NUMBERS_KEY, @@ -61,7 +57,6 @@ _NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS) _EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS) _GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS) -_SCALAR_FIELDS: Set[str] = set(_DEFAULT_SCALAR_FIELDS) _LONG_FIELDS: Set[str] = set(_DEFAULT_LONG_FIELDS) @@ -69,7 +64,6 @@ def register_fields( node_fields: Sequence[str] = [], edge_fields: Sequence[str] = [], graph_fields: Sequence[str] = [], - scalar_fields: Sequence[str] = [], long_fields: Sequence[str] = [], ) -> None: r"""Register fields as being per-atom, per-edge, or per-frame. @@ -81,13 +75,11 @@ def register_fields( node_fields: set = set(node_fields) edge_fields: set = set(edge_fields) graph_fields: set = set(graph_fields) - scalar_fields: set = set(scalar_fields) allfields = node_fields.union(edge_fields, graph_fields) assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields) _NODE_FIELDS.update(node_fields) _EDGE_FIELDS.update(edge_fields) _GRAPH_FIELDS.update(graph_fields) - _SCALAR_FIELDS.update(scalar_fields) _LONG_FIELDS.update(long_fields) if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < ( len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS) @@ -184,11 +176,7 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): kwargs[k] = v.unsqueeze(-1) v = kwargs[k] - if ( - k in set.union(_NODE_FIELDS, _EDGE_FIELDS) - and k not in _SCALAR_FIELDS - and len(v.shape) == 1 - ): + if k in set.union(_NODE_FIELDS, _EDGE_FIELDS) and len(v.shape) == 1: kwargs[k] = v.unsqueeze(-1) v = kwargs[k] @@ -470,7 +458,7 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: mask = slice(None) mol = ase.Atoms( - numbers=atomic_nums[mask], + numbers=atomic_nums[mask].view(-1), # must be flat for ASE positions=positions[mask], cell=cell[batch_idx] if cell is not None else None, pbc=pbc[batch_idx] if pbc is not None else None, diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index a83d65c9..7481b7d4 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -566,7 +566,7 @@ def _per_species_statistics( For a per-node quantity, computes the expected statistic but for each type instead of over all nodes. """ - N = bincount(atom_types, batch) + N = bincount(atom_types.squeeze(-1), batch) N = N[(N > 0).any(dim=1)] # deal with non-contiguous batch indexes if arr_is_per == "graph": diff --git a/nequip/nn/embedding/_one_hot.py b/nequip/nn/embedding/_one_hot.py index 373243d2..f7228400 100644 --- a/nequip/nn/embedding/_one_hot.py +++ b/nequip/nn/embedding/_one_hot.py @@ -34,7 +34,7 @@ def __init__( self._init_irreps(irreps_in=irreps_in, irreps_out=irreps_out) def forward(self, data: AtomicDataDict.Type): - type_numbers = data[AtomicDataDict.ATOM_TYPE_KEY] + type_numbers = data[AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1) one_hot = torch.nn.functional.one_hot( type_numbers, num_classes=self.num_types ).to(device=type_numbers.device, dtype=data[AtomicDataDict.POSITIONS_KEY].dtype) diff --git a/tests/unit/data/test_AtomicData.py b/tests/unit/data/test_AtomicData.py index 5a1631db..2bc50743 100644 --- a/tests/unit/data/test_AtomicData.py +++ b/tests/unit/data/test_AtomicData.py @@ -39,11 +39,17 @@ def test_to_ase_batches(atomic_batch): assert np.allclose(atoms.get_positions(), atomic_data.pos[mask]) assert atoms.get_atomic_numbers().shape == (len(atoms),) assert np.array_equal( - atoms.get_atomic_numbers(), atomic_data[AtomicDataDict.ATOM_TYPE_KEY][mask] + atoms.get_atomic_numbers(), + atomic_data[AtomicDataDict.ATOM_TYPE_KEY][mask].view(-1), ) - assert np.max(np.abs(atoms.get_cell()[:]-atomic_data.cell[batch_idx].numpy())) == 0 - assert not np.logical_xor(atoms.get_pbc(), atomic_data.pbc[batch_idx].numpy()).all() + assert ( + np.max(np.abs(atoms.get_cell()[:] - atomic_data.cell[batch_idx].numpy())) + == 0 + ) + assert not np.logical_xor( + atoms.get_pbc(), atomic_data.pbc[batch_idx].numpy() + ).all() def test_ase_roundtrip(CuFcc): diff --git a/tests/unit/data/test_dataset.py b/tests/unit/data/test_dataset.py index 98aa635c..3c3501f8 100644 --- a/tests/unit/data/test_dataset.py +++ b/tests/unit/data/test_dataset.py @@ -210,7 +210,9 @@ def test_per_graph_field( # get species count per graph Ns = [] for i in range(npz_dataset.len()): - Ns.append(torch.bincount(npz_dataset[i][AtomicDataDict.ATOM_TYPE_KEY])) + Ns.append( + torch.bincount(npz_dataset[i][AtomicDataDict.ATOM_TYPE_KEY].view(-1)) + ) n_spec = max(len(e) for e in Ns) N = torch.zeros(len(Ns), n_spec) for i in range(len(Ns)): From f51ad8773d233d6e8b9e2f4c73cadeecff407195 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 3 Dec 2021 16:08:37 -0500 Subject: [PATCH 16/52] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b2cb106..e95dd123 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. ## [Unreleased] +### Changed +- All fields now have consistant [N, dim] shaping + ### Fixed - Equivariance testing no longer unintentionally skips translation - Correct cat dim for all registered per-graph fields From 775db4617fff67c57bd22c91e8b96a4619daeb2f Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 3 Dec 2021 16:17:14 -0500 Subject: [PATCH 17/52] add dataset_seed (#117) * add dataset_seed * fix seed defaults --- CHANGELOG.md | 4 ++++ configs/example.yaml | 3 ++- configs/full.yaml | 3 ++- configs/minimal.yaml | 3 ++- nequip/train/trainer.py | 47 ++++++++++++++++++++++++----------------- 5 files changed, 38 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e95dd123..fc5a5c9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. ## [Unreleased] +### Added +- `dataset_seed` to separately control randomness used to select training data (and their order) + ### Changed - All fields now have consistant [N, dim] shaping +- Changed default `seed` and `dataset_seed` in example YAMLs ### Fixed - Equivariance testing no longer unintentionally skips translation diff --git a/configs/example.yaml b/configs/example.yaml index 70dfdd98..ff0577c9 100644 --- a/configs/example.yaml +++ b/configs/example.yaml @@ -6,7 +6,8 @@ # if 'root'/'run_name' exists, 'root'/'run_name'_'year'-'month'-'day'-'hour'-'min'-'s' will be used instead. root: results/toluene run_name: example-run-toluene -seed: 0 # random number seed for numpy and torch +seed: 123 +dataset_seed: 456 # random number seed for numpy and torch append: true # set true if a restarted run should append to the previous log file default_dtype: float32 # type of float to use, e.g. float32 and float64 diff --git a/configs/full.yaml b/configs/full.yaml index 57d99094..e9e66d29 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -10,7 +10,8 @@ # if 'root'/'run_name' exists, 'root'/'run_name'_'year'-'month'-'day'-'hour'-'min'-'s' will be used instead. root: results/toluene run_name: example-run-toluene -seed: 0 # random number seed for numpy and torch +seed: 123 +dataset_seed: 456 # random number seed for numpy and torch append: true # set true if a restarted run should append to the previous log file default_dtype: float32 # type of float to use, e.g. float32 and float64 allow_tf32: false # whether to use TensorFloat32 if it is available diff --git a/configs/minimal.yaml b/configs/minimal.yaml index fa05bbbb..a07084d6 100644 --- a/configs/minimal.yaml +++ b/configs/minimal.yaml @@ -1,7 +1,8 @@ # general root: results/aspirin run_name: minimal -seed: 0 +seed: 123 +dataset_seed: 456 # network num_basis: 8 diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index e543d116..072953ac 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -131,7 +131,8 @@ class Trainer: Args: model: neural network model - seed (int): random see number + seed (int): random seed number + dataset_seed (int): random seed for dataset operations loss_coeffs (dict): dictionary to store coefficient and loss functions @@ -214,6 +215,7 @@ def __init__( model_builders: Optional[list] = [], device: str = "cuda" if torch.cuda.is_available() else "cpu", seed: Optional[int] = None, + dataset_seed: Optional[int] = None, loss_coeffs: Union[dict, str] = AtomicDataDict.TOTAL_ENERGY_KEY, train_on_keys: Optional[List[str]] = None, metrics_components: Optional[Union[dict, str]] = None, @@ -293,6 +295,10 @@ def __init__( torch.manual_seed(seed) np.random.seed(seed) + self.dataset_rng = torch.Generator() + if dataset_seed is not None: + self.dataset_rng.manual_seed(dataset_seed) + self.logger.info(f"Torch device: {self.device}") self.torch_device = torch.device(self.device) @@ -453,6 +459,7 @@ def as_dict( if item is not None: dictionary["state_dict"][key] = item.state_dict() dictionary["state_dict"]["rng_state"] = torch.get_rng_state() + dictionary["state_dict"]["dataset_rng_state"] = self.dataset_rng.get_state() if torch.cuda.is_available(): dictionary["state_dict"]["cuda_rng_state"] = torch.cuda.get_rng_state( device=self.torch_device @@ -604,6 +611,7 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): trainer._initialized = True torch.set_rng_state(state_dict["rng_state"]) + trainer.dataset_rng.set_state(state_dict["dataset_rng_state"]) if torch.cuda.is_available(): torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) @@ -638,10 +646,7 @@ def load_model_from_training_session( if config.get("compile_model", False): model = torch.jit.load(traindir + "/" + model_name, map_location=device) else: - model = model_from_config( - config=config, - initialize=False, - ) + model = model_from_config(config=config, initialize=False,) if model is not None: # TODO: this is not exactly equivalent to building with # this set as default dtype... does it matter? @@ -847,8 +852,7 @@ def epoch_step(self): self.n_batches = len(dataset) for self.ibatch, batch in enumerate(dataset): self.batch_step( - data=batch, - validation=(category == VALIDATION), + data=batch, validation=(category == VALIDATION), ) self.end_of_batch_log(batch_type=category) for callback in self.end_of_batch_callbacks: @@ -996,11 +1000,7 @@ def end_of_epoch_log(self): lr = self.optim.param_groups[0]["lr"] wall = perf_counter() - self.wall - self.mae_dict = dict( - LR=lr, - epoch=self.iepoch, - wall=wall, - ) + self.mae_dict = dict(LR=lr, epoch=self.iepoch, wall=wall,) header = "epoch, wall, LR" @@ -1100,7 +1100,7 @@ def set_dataset( ) if self.train_val_split == "random": - idcs = torch.randperm(total_n) + idcs = torch.randperm(total_n, generator=self.dataset_rng) elif self.train_val_split == "sequential": idcs = torch.arange(total_n) else: @@ -1116,10 +1116,12 @@ def set_dataset( if self.n_val > len(validation_dataset): raise ValueError("Not enough data in dataset for requested n_train") if self.train_val_split == "random": - self.train_idcs = torch.randperm(len(dataset))[: self.n_train] - self.val_idcs = torch.randperm(len(validation_dataset))[ - : self.n_val - ] + self.train_idcs = torch.randperm( + len(dataset), generator=self.dataset_rng + )[: self.n_train] + self.val_idcs = torch.randperm( + len(validation_dataset), generator=self.dataset_rng + )[: self.n_val] elif self.train_val_split == "sequential": self.train_idcs = torch.arange(self.n_train) self.val_idcs = torch.arange(self.n_val) @@ -1139,7 +1141,6 @@ def set_dataset( # https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#enable-async-data-loading-and-augmentation dl_kwargs = dict( batch_size=self.batch_size, - shuffle=self.shuffle, exclude_keys=self.exclude_keys, num_workers=self.dataloader_num_workers, # keep stuff around in memory @@ -1150,6 +1151,14 @@ def set_dataset( pin_memory=(self.torch_device != torch.device("cpu")), # avoid getting stuck timeout=(10 if self.dataloader_num_workers > 0 else 0), + # use the right randomness + generator=self.dataset_rng, + ) + self.dl_train = DataLoader( + dataset=self.dataset_train, + shuffle=self.shuffle, # training should shuffle + **dl_kwargs, ) - self.dl_train = DataLoader(dataset=self.dataset_train, **dl_kwargs) + # validation, on the other hand, shouldn't shuffle + # we still pass the generator just to be safe self.dl_val = DataLoader(dataset=self.dataset_val, **dl_kwargs) From cff68ff6bdf5ac2de7c50a4d23eb7d184b5a9cdd Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 3 Dec 2021 16:21:03 -0500 Subject: [PATCH 18/52] Asynchronous IO (#109) * Initial async IO * fix error handling * Note * allow multiple filenames * join queue * lint * initial saving groups * saving groups * enable with env var * fix default * comments * more cleanup * lint * delete right file * Revert "lint" This reverts commit 752e045bee35066b19653d88ce84c7a5bf45db11. * remove redundant save_config * reverse to default blocking behavior for save * add Config.from_dict; use internal dict when building models in trainer.from_dict method Co-authored-by: nw13slx --- CHANGELOG.md | 3 +- nequip/data/dataset.py | 9 +- nequip/train/trainer.py | 110 +++++++++++-------- nequip/utils/__init__.py | 10 +- nequip/utils/config.py | 4 + nequip/utils/savenload.py | 225 ++++++++++++++++++++++++++++++++------ 6 files changed, 274 insertions(+), 87 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc5a5c9d..ab271c4d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. -## [Unreleased] +## [Unreleased] - 0.5.1 ### Added +- Asynchronous IO: during training, models are written asynchronously. - `dataset_seed` to separately control randomness used to select training data (and their order) ### Changed diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index 7481b7d4..db5e6406 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -301,11 +301,10 @@ def process(self): # it doesn't matter if they overwrite each others cached' # datasets. It only matters that they don't simultaneously try # to write the _same_ file, corrupting it. - with atomic_write(self.processed_paths[0]) as tmppth: - torch.save((data, fixed_fields, self.include_frames), tmppth) - with atomic_write(self.processed_paths[1]) as tmppth: - with open(tmppth, "w") as f: - yaml.dump(self._get_parameters(), f) + with atomic_write(self.processed_paths[0], binary=True) as f: + torch.save((data, fixed_fields, self.include_frames), f) + with atomic_write(self.processed_paths[1], binary=False) as f: + yaml.dump(self._get_parameters(), f) logging.info("Cached processed data to disk") diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 072953ac..17923f00 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -37,6 +37,8 @@ save_file, load_file, atomic_write, + finish_all_writes, + atomic_write_group, dtype_from_name, ) from nequip.utils.git import get_commit @@ -488,15 +490,16 @@ def as_dict( return dictionary - def save_config(self) -> None: + def save_config(self, blocking: bool = True) -> None: save_file( item=self.as_dict(state_dict=False, training_progress=False), supported_formats=dict(yaml=["yaml"]), filename=self.config_path, enforced_format=None, + blocking=blocking, ) - def save(self, filename: Optional[str] = None, format=None): + def save(self, filename: Optional[str] = None, format=None, blocking: bool = True): """save the file as filename Args: @@ -523,11 +526,11 @@ def save(self, filename: Optional[str] = None, format=None): supported_formats=dict(torch=["pth", "pt"], yaml=["yaml"], json=["json"]), filename=filename, enforced_format=format, + blocking=blocking, ) logger.debug(f"Saved trainer to {filename}") - self.save_config() - self.save_model(self.last_model_path) + self.save_model(self.last_model_path, blocking=blocking) logger.debug(f"Saved last model to to {self.last_model_path}") return filename @@ -561,10 +564,10 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): append (bool): if True, append the old model files and append the same logfile """ - d = deepcopy(dictionary) + dictionary = deepcopy(dictionary) for code in [e3nn, nequip, torch]: - version = d.get(f"{code.__name__}_version", None) + version = dictionary.get(f"{code.__name__}_version", None) if version is not None and version != code.__version__: logging.warning( "Loading a pickled model created with different library version(s) may cause issues." @@ -574,14 +577,14 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): # update the restart and append option if append is not None: - d["append"] = append + dictionary["append"] = append model = None iepoch = -1 - if "model" in d: - model = d.pop("model") - elif "progress" in d: - progress = d["progress"] + if "model" in dictionary: + model = dictionary.pop("model") + elif "progress" in dictionary: + progress = dictionary["progress"] # load the model from file iepoch = progress["iepoch"] @@ -592,15 +595,17 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): raise AttributeError("model weights & bias are not saved") model, _ = Trainer.load_model_from_training_session( - traindir=load_path.parent, model_name=load_path.name + traindir=load_path.parent, + model_name=load_path.name, + config_dictionary=dictionary, ) logging.debug(f"Reload the model from {load_path}") - d.pop("progress") + dictionary.pop("progress") - state_dict = d.pop("state_dict", None) + state_dict = dictionary.pop("state_dict", None) - trainer = cls(model=model, **d) + trainer = cls(model=model, **dictionary) if state_dict is not None and trainer.model is not None: logging.debug("Reload optimizer and scheduler states") @@ -615,7 +620,7 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): if torch.cuda.is_available(): torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) - if "progress" in d: + if "progress" in dictionary: trainer.best_metrics = progress["best_metrics"] trainer.best_epoch = progress["best_epoch"] stop_arg = progress.pop("stop_arg", None) @@ -636,12 +641,18 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): @staticmethod def load_model_from_training_session( - traindir, model_name="best_model.pth", device="cpu" + traindir, + model_name="best_model.pth", + device="cpu", + config_dictionary: Optional[dict] = None, ) -> Tuple[torch.nn.Module, Config]: traindir = str(traindir) model_name = str(model_name) - config = Config.from_file(traindir + "/config.yaml") + if config_dictionary is not None: + config = Config.from_dict(config_dictionary) + else: + config = Config.from_file(traindir + "/config.yaml") if config.get("compile_model", False): model = torch.jit.load(traindir + "/" + model_name, map_location=device) @@ -714,8 +725,11 @@ def train(self): self.init_log() self.wall = perf_counter() - if self.iepoch == -1: - self.save() + with atomic_write_group(): + if self.iepoch == -1: + self.save() + if self.iepoch in [-1, 0]: + self.save_config() self.init_metrics() @@ -730,6 +744,7 @@ def train(self): self.final_log() self.save() + finish_all_writes() def batch_step(self, data, validation=False): # no need to have gradients from old steps taking up memory @@ -930,36 +945,36 @@ def end_of_epoch_save(self): """ save model and trainer details """ + with atomic_write_group(): + current_metrics = self.mae_dict[self.metrics_key] + if current_metrics < self.best_metrics: + self.best_metrics = current_metrics + self.best_epoch = self.iepoch - current_metrics = self.mae_dict[self.metrics_key] - if current_metrics < self.best_metrics: - self.best_metrics = current_metrics - self.best_epoch = self.iepoch - - self.save_ema_model(self.best_model_path) + self.save_ema_model(self.best_model_path, blocking=False) - self.logger.info( - f"! Best model {self.best_epoch:8d} {self.best_metrics:8.3f}" - ) + self.logger.info( + f"! Best model {self.best_epoch:8d} {self.best_metrics:8.3f}" + ) - if (self.iepoch + 1) % self.log_epoch_freq == 0: - self.save() + if (self.iepoch + 1) % self.log_epoch_freq == 0: + self.save(blocking=False) - if ( - self.save_checkpoint_freq > 0 - and (self.iepoch + 1) % self.save_checkpoint_freq == 0 - ): - ckpt_path = self.output.generate_file(f"ckpt{self.iepoch+1}.pth") - self.save_model(ckpt_path) + if ( + self.save_checkpoint_freq > 0 + and (self.iepoch + 1) % self.save_checkpoint_freq == 0 + ): + ckpt_path = self.output.generate_file(f"ckpt{self.iepoch+1}.pth") + self.save_model(ckpt_path, blocking=False) - if ( - self.save_ema_checkpoint_freq > 0 - and (self.iepoch + 1) % self.save_ema_checkpoint_freq == 0 - ): - ckpt_path = self.output.generate_file(f"ckpt_ema_{self.iepoch+1}.pth") - self.save_ema_model(ckpt_path) + if ( + self.save_ema_checkpoint_freq > 0 + and (self.iepoch + 1) % self.save_ema_checkpoint_freq == 0 + ): + ckpt_path = self.output.generate_file(f"ckpt_ema_{self.iepoch+1}.pth") + self.save_ema_model(ckpt_path, blocking=False) - def save_ema_model(self, path): + def save_ema_model(self, path, blocking: bool = True): if self.use_ema: # If using EMA, store the EMA validation model @@ -971,11 +986,10 @@ def save_ema_model(self, path): cm = contextlib.nullcontext() with cm: - self.save_model(path) + self.save_model(path, blocking=blocking) - def save_model(self, path): - self.save_config() - with atomic_write(path) as write_to: + def save_model(self, path, blocking: bool = True): + with atomic_write(path, blocking=blocking, binary=True) as write_to: if isinstance(self.model, torch.jit.ScriptModule): torch.jit.save(self.model, write_to) else: diff --git a/nequip/utils/__init__.py b/nequip/utils/__init__.py index 16ad1ee6..79e778a0 100644 --- a/nequip/utils/__init__.py +++ b/nequip/utils/__init__.py @@ -3,7 +3,13 @@ instantiate, get_w_prefix, ) -from .savenload import save_file, load_file, atomic_write +from .savenload import ( + save_file, + load_file, + atomic_write, + finish_all_writes, + atomic_write_group, +) from .config import Config from .output import Output from .modules import find_first_of_type @@ -16,6 +22,8 @@ save_file, load_file, atomic_write, + finish_all_writes, + atomic_write_group, Config, Output, find_first_of_type, diff --git a/nequip/utils/config.py b/nequip/utils/config.py index 72b896a1..d13e0546 100644 --- a/nequip/utils/config.py +++ b/nequip/utils/config.py @@ -262,6 +262,10 @@ def from_file(filename: str, format: Optional[str] = None, defaults: dict = {}): filename=filename, enforced_format=format, ) + return Config.from_dict(dictionary, defaults) + + @staticmethod + def from_dict(dictionary: dict, defaults: dict = {}): c = Config(defaults) c.update(dictionary) return c diff --git a/nequip/utils/savenload.py b/nequip/utils/savenload.py index 0980fef2..0302d221 100644 --- a/nequip/utils/savenload.py +++ b/nequip/utils/savenload.py @@ -1,49 +1,203 @@ """ utilities that involve file searching and operations (i.e. save/load) """ -from typing import Union +from typing import Union, List, Tuple import sys import logging import contextlib +import contextvars +import tempfile from pathlib import Path -from os import makedirs -from os.path import isfile, isdir, dirname, realpath +import shutil +import os -@contextlib.contextmanager -def atomic_write(filename: Union[Path, str]): - filename = Path(filename) - tmp_path = filename.parent / (f".tmp-{filename.name}~") - # Create the temp file - open(tmp_path, "w").close() +# accumulate writes to group for renaming +_MOVE_SET = contextvars.ContextVar("_move_set", default=None) + + +def _delete_files_if_exist(paths): + # clean up + # better for python 3.8 > + if sys.version_info[1] >= 8: + for f in paths: + f.unlink(missing_ok=True) + else: + # race condition? + for f in paths: + if f.exists(): + f.unlink() + + +def _process_moves(moves: List[Tuple[bool, Path, Path]]): + """blocking to copy (possibly across filesystems) to temp name; then atomic rename to final name""" try: - # do the IO - yield tmp_path - # move the temp file to the final output path, which also removes the temp file - tmp_path.rename(filename) + for _, from_name, to_name in moves: + # blocking copy to temp file in same filesystem + tmp_path = to_name.parent / (f".tmp-{to_name.name}~") + shutil.move(from_name, tmp_path) + # then atomic rename to overwrite + tmp_path.rename(to_name) finally: - # clean up - # better for python 3.8 > - if sys.version_info[1] >= 8: - tmp_path.unlink(missing_ok=True) + _delete_files_if_exist([m[1] for m in moves]) + + +# allow user to enable/disable depending on their filesystem +_ASYNC_ENABLED = os.environ.get("NEQUIP_ASYNC_IO", "true").lower() +assert _ASYNC_ENABLED in ("true", "false") +_ASYNC_ENABLED = _ASYNC_ENABLED == "true" + +if _ASYNC_ENABLED: + import threading + from queue import Queue + + _MOVE_QUEUE = Queue() + _MOVE_THREAD = None + + # Because we use a queue, later writes will always (correctly) + # overwrite earlier writes + def _moving_thread(queue): + while True: + moves = queue.get() + _process_moves(moves) + # logging is thread safe: https://stackoverflow.com/questions/2973900/is-pythons-logging-module-thread-safe + logging.debug(f"Finished writing {', '.join(m[2].name for m in moves)}") + queue.task_done() + + def _submit_move(from_name, to_name, blocking: bool): + global _MOVE_QUEUE + global _MOVE_THREAD + global _MOVE_SET + + # launch thread if its not running + if _MOVE_THREAD is None: + _MOVE_THREAD = threading.Thread( + target=_moving_thread, args=(_MOVE_QUEUE,), daemon=True + ) + _MOVE_THREAD.start() + + # check on health of copier thread + if not _MOVE_THREAD.is_alive(): + _MOVE_THREAD.join() # will raise exception + raise RuntimeError("Writer thread failed.") + + # submit this move + obj = (blocking, from_name, to_name) + if _MOVE_SET.get() is None: + # no current group + _MOVE_QUEUE.put([obj]) + # if it should be blocking, wait for it to be processed + if blocking: + _MOVE_QUEUE.join() + else: + # add and let the group submit and block (or not) + _MOVE_SET.get().append(obj) + + @contextlib.contextmanager + def atomic_write_group(): + global _MOVE_SET + if _MOVE_SET.get() is not None: + # nesting is a no-op + # submit along with outermost context manager + yield + return + token = _MOVE_SET.set(list()) + # run the saves + yield + _MOVE_QUEUE.put(_MOVE_SET.get()) # send it off + # if anyone is blocking, block the whole group: + if any(m[0] for m in _MOVE_SET.get()): + # someone is blocking + _MOVE_QUEUE.join() + # exit context + _MOVE_SET.reset(token) + + def finish_all_writes(): + global _MOVE_QUEUE + _MOVE_QUEUE.join() + # ^ wait for all remaining moves to be processed + + +else: + + def _submit_move(from_name, to_name, blocking: bool): + global _MOVE_SET + obj = (blocking, from_name, to_name) + if _MOVE_SET.get() is None: + # no current group just do it + _process_moves([obj]) else: - # race condition? - if tmp_path.exists(): - tmp_path.unlink() + # add and let the group do it + _MOVE_SET.get().append(obj) + + @contextlib.contextmanager + def atomic_write_group(): + global _MOVE_SET + if _MOVE_SET.get() is not None: + # don't nest them + yield + return + token = _MOVE_SET.set(list()) + yield + _process_moves(_MOVE_SET.get()) # do it + _MOVE_SET.reset(token) + + def finish_all_writes(): + pass # nothing to do since all writes blocked + + +@contextlib.contextmanager +def atomic_write( + filename: Union[Path, str, List[Union[Path, str]]], + blocking: bool = True, + binary: bool = False, +): + aslist: bool = True + if not isinstance(filename, list): + aslist = False + filename = [filename] + filename = [Path(f) for f in filename] + + with contextlib.ExitStack() as stack: + files = [ + stack.enter_context( + tempfile.NamedTemporaryFile( + mode="w" + ("b" if binary else ""), delete=False + ) + ) + for _ in filename + ] + try: + if not aslist: + yield files[0] + else: + yield files + except: # noqa + # ^ noqa cause we want to delete them no matter what if there was a failure + # only remove them if there was an error + _delete_files_if_exist([Path(f.name) for f in files]) + raise + + for tp, fname in zip(files, filename): + _submit_move(Path(tp.name), Path(fname), blocking=blocking) def save_file( - item, supported_formats: dict, filename: str, enforced_format: str = None + item, + supported_formats: dict, + filename: str, + enforced_format: str = None, + blocking: bool = True, ): """ Save file. It can take yaml, json, pickle, json, npz and torch save """ # check whether folder exist - path = dirname(realpath(filename)) - if not isdir(path): + path = os.path.dirname(os.path.realpath(filename)) + if not os.path.isdir(path): logging.debug(f"save_file make dirs {path}") - makedirs(path, exist_ok=True) + os.makedirs(path, exist_ok=True) format, filename = adjust_format_name( supported_formats=supported_formats, @@ -51,17 +205,25 @@ def save_file( enforced_format=enforced_format, ) - with atomic_write(filename) as write_to: + with atomic_write( + filename, + blocking=blocking, + binary={ + "json": False, + "yaml": False, + "pickle": True, + "torch": True, + "npz": True, + }[format], + ) as write_to: if format == "json": import json - with open(write_to, "w+") as fout: - json.dump(item, fout) + json.dump(item, write_to) elif format == "yaml": import yaml - with open(write_to, "w+") as fout: - yaml.dump(item, fout) + yaml.dump(item, write_to) elif format == "torch": import torch @@ -69,8 +231,7 @@ def save_file( elif format == "pickle": import pickle - with open(write_to, "wb") as fout: - pickle.save(item, fout) + pickle.dump(item, write_to) elif format == "npz": import numpy as np @@ -93,7 +254,7 @@ def load_file(supported_formats: dict, filename: str, enforced_format: str = Non else: format = enforced_format - if not isfile(filename): + if not os.path.isfile(filename): abs_path = str(Path(filename).resolve()) raise OSError(f"file {filename} at {abs_path} is not found") From 13525fc7bbccd4ce3731928133a6aae7d561e0da Mon Sep 17 00:00:00 2001 From: nw13slx Date: Fri, 3 Dec 2021 16:28:53 -0500 Subject: [PATCH 19/52] fix batch_key shape --- nequip/data/AtomicData.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index e8bb6596..bd70f323 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -224,7 +224,7 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): ): assert self.atomic_numbers.dtype in _TORCH_INTEGER_DTYPES if "batch" in self and self.batch is not None: - assert self.batch.dim() == 1 and self.batch.shape[0] == self.num_nodes + assert self.batch.dim() == 2 and self.batch.shape[0] == self.num_nodes # Check that there are the right number of cells if "cell" in self and self.cell is not None: cell = self.cell.view(-1, 3, 3) @@ -454,6 +454,7 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: for batch_idx in range(n_batches): if batch is not None: mask = batch == batch_idx + mask = mask.view(-1) else: mask = slice(None) From 905bfb308e41ccecf0f04bdd9afa2fe54609c34f Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 3 Dec 2021 16:45:50 -0500 Subject: [PATCH 20/52] remove weird default --- nequip/model/builder_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nequip/model/builder_utils.py b/nequip/model/builder_utils.py index 2f93c51f..0db2c7ff 100644 --- a/nequip/model/builder_utils.py +++ b/nequip/model/builder_utils.py @@ -10,11 +10,10 @@ def add_avg_num_neighbors( config: Config, initialize: bool, dataset: Optional[AtomicDataset] = None, - default: Optional[Union[str, float]] = "auto", ) -> Optional[float]: # Compute avg_num_neighbors annkey: str = "avg_num_neighbors" - ann = config.get(annkey, default) + ann = config.get(annkey, None) if ann == "auto" and initialize: if dataset is None: raise ValueError( From a5f204b96c8cb22245711ca16c0e148ddb948466 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 3 Dec 2021 16:47:42 -0500 Subject: [PATCH 21/52] more helpful error in initialize =False --- nequip/model/builder_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nequip/model/builder_utils.py b/nequip/model/builder_utils.py index 0db2c7ff..30438d10 100644 --- a/nequip/model/builder_utils.py +++ b/nequip/model/builder_utils.py @@ -14,7 +14,9 @@ def add_avg_num_neighbors( # Compute avg_num_neighbors annkey: str = "avg_num_neighbors" ann = config.get(annkey, None) - if ann == "auto" and initialize: + if ann == "auto": + if not initialize: + raise ValueError("avg_num_neighbors = auto but initialize is False") if dataset is None: raise ValueError( "When avg_num_neighbors = auto, the dataset is required to build+initialize a model" From b2ae635669dfec3a42188cf22a187552a75b1450 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 3 Dec 2021 16:48:10 -0500 Subject: [PATCH 22/52] lint --- nequip/model/builder_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/model/builder_utils.py b/nequip/model/builder_utils.py index 30438d10..f2a402f9 100644 --- a/nequip/model/builder_utils.py +++ b/nequip/model/builder_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional import torch From 1eb4a3b85a9ef37d996dbfe52720768473273ddd Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 9 Dec 2021 11:51:40 -0500 Subject: [PATCH 23/52] bump --- nequip/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/_version.py b/nequip/_version.py index edbb9d1b..41f4a9f7 100644 --- a/nequip/_version.py +++ b/nequip/_version.py @@ -2,4 +2,4 @@ # See Python packaging guide # https://packaging.python.org/guides/single-sourcing-package-version/ -__version__ = "0.5.0" +__version__ = "0.5.1" From e0fecb8edfa051a7684795618e4f707a7b0a12dc Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 9 Dec 2021 11:52:56 -0500 Subject: [PATCH 24/52] disable async by default --- CHANGELOG.md | 2 +- nequip/utils/savenload.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7454d723..a4bb3b05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ Most recent change on the bottom. ### Added - Added `avg_num_neighbors: auto` option - Asynchronous IO: during training, models are written asynchronously. -- `dataset_seed` to separately control randomness used to select training data (and their order) +- `dataset_seed` to separately control randomness used to select training data (and their order). Enable this with environment variable `NEQUIP_ASYNC_IO=true`. - The types may now be specified with a simpler `chemical_symbols` option ### Changed diff --git a/nequip/utils/savenload.py b/nequip/utils/savenload.py index 0302d221..2f26c4c0 100644 --- a/nequip/utils/savenload.py +++ b/nequip/utils/savenload.py @@ -43,7 +43,7 @@ def _process_moves(moves: List[Tuple[bool, Path, Path]]): # allow user to enable/disable depending on their filesystem -_ASYNC_ENABLED = os.environ.get("NEQUIP_ASYNC_IO", "true").lower() +_ASYNC_ENABLED = os.environ.get("NEQUIP_ASYNC_IO", "false").lower() assert _ASYNC_ENABLED in ("true", "false") _ASYNC_ENABLED = _ASYNC_ENABLED == "true" From e1684466daa6c28680e52c8abf000efca6f7c076 Mon Sep 17 00:00:00 2001 From: Simon Batzner Date: Thu, 9 Dec 2021 21:23:05 -0500 Subject: [PATCH 25/52] default device cuda --- configs/full.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/full.yaml b/configs/full.yaml index f3a99c20..632474b3 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -145,12 +145,12 @@ report_init_validation: false # early stopping based on metrics values. # LR, wall and any keys printed in the log file can be used. -# The key can start with Training or Validation. If not defined, the validation value will be used. +# The key can start with Training or validation. If not defined, the validation value will be used. early_stopping_patiences: # stop early if a metric value stopped decreasing for n epochs - Validation_loss: 50 + validation_loss: 50 early_stopping_delta: # If delta is defined, a decrease smaller than delta will not be considered as a decrease - Validation_loss: 0.005 + validation_loss: 0.005 early_stopping_cumulative_delta: false # If True, the minimum value recorded will not be updated when the decrease is smaller than delta From 61170c791fc2331351b097e7f6d924647377d8ca Mon Sep 17 00:00:00 2001 From: Simon Batzner Date: Thu, 9 Dec 2021 21:25:55 -0500 Subject: [PATCH 26/52] comment out device: cuda --- configs/full.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/full.yaml b/configs/full.yaml index 632474b3..a04e0fd4 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -15,7 +15,7 @@ dataset_seed: 456 append: true # set true if a restarted run should append to the previous log file default_dtype: float32 # type of float to use, e.g. float32 and float64 allow_tf32: false # whether to use TensorFloat32 if it is available -device: cuda # which device to use. Default: automatically detected cuda or "cpu" +# device: cuda # which device to use. Default: automatically detected cuda or "cpu" # network r_max: 4.0 # cutoff radius in length units, here Angstrom, this is an important hyperparamter to scan From 4107eb21174e366fe57af1f99955237697c57eb1 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 9 Dec 2021 23:25:21 -0500 Subject: [PATCH 27/52] pre-download data for tests --- .github/workflows/tests.yml | 3 +++ .github/workflows/tests_develop.yml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 41e389be..b2bd1b17 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,6 +42,9 @@ jobs: run: | pip install pytest pip install pytest-xdist[psutil] + - name: Download test data + run: | + cd benchmark_data; wget "http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip"; cd .. - name: Test with pytest run: | # See https://github.com/pytest-dev/pytest/issues/1075 diff --git a/.github/workflows/tests_develop.yml b/.github/workflows/tests_develop.yml index a69a728d..954d2e3e 100644 --- a/.github/workflows/tests_develop.yml +++ b/.github/workflows/tests_develop.yml @@ -42,6 +42,9 @@ jobs: run: | pip install pytest pip install pytest-xdist[psutil] + - name: Download test data + run: | + cd benchmark_data; wget "http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip"; cd .. - name: Test with pytest run: | # See https://github.com/pytest-dev/pytest/issues/1075 From 74e003a68b9cff4d06e532c7743701076585660a Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 9 Dec 2021 23:37:52 -0500 Subject: [PATCH 28/52] correction --- .github/workflows/tests.yml | 1 + .github/workflows/tests_develop.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b2bd1b17..96bbac82 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,6 +44,7 @@ jobs: pip install pytest-xdist[psutil] - name: Download test data run: | + mkdir benchmark_data cd benchmark_data; wget "http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip"; cd .. - name: Test with pytest run: | diff --git a/.github/workflows/tests_develop.yml b/.github/workflows/tests_develop.yml index 954d2e3e..d9728d14 100644 --- a/.github/workflows/tests_develop.yml +++ b/.github/workflows/tests_develop.yml @@ -44,6 +44,7 @@ jobs: pip install pytest-xdist[psutil] - name: Download test data run: | + mkdir benchmark_data cd benchmark_data; wget "http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip"; cd .. - name: Test with pytest run: | From d724922278d5a3cf960072d05cc5a2bd1525ba36 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Sun, 12 Dec 2021 01:25:10 -0500 Subject: [PATCH 29/52] show full commit --- nequip/utils/git.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/utils/git.py b/nequip/utils/git.py index 168f8921..d58923f3 100644 --- a/nequip/utils/git.py +++ b/nequip/utils/git.py @@ -9,7 +9,7 @@ def get_commit(module: str): path = str(Path(module.__file__).parents[0] / "..") retcode = subprocess.run( - "git show --oneline -s".split(), + "git show --oneline --abbrev=40 -s".split(), cwd=path, stdout=subprocess.PIPE, stderr=subprocess.PIPE, From 8a2d85671b580c23b15c8ace13aa2c4593b983a7 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Sun, 12 Dec 2021 01:25:21 -0500 Subject: [PATCH 30/52] record git commit for builders --- nequip/train/trainer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 17923f00..f790ee51 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -485,8 +485,16 @@ def as_dict( for code in [e3nn, nequip, torch]: dictionary[f"{code.__name__}_version"] = code.__version__ - for code in ["e3nn", "nequip"]: - dictionary[f"{code}_commit"] = get_commit(code) + + codes_for_git = {"e3nn", "nequip"} + for builder in self.model_builders: + builder = builder.split(".") + if len(builder) > 1: + # it's not a single name which is from nequip + codes_for_git.add(builder[0]) + dictionary[f"code_versions"] = { + code: get_commit(code) for code in codes_for_git + } return dictionary From 3d9901c6b961ff5cf2067c5d2d645c41e39812d6 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Sun, 12 Dec 2021 14:26:00 -0500 Subject: [PATCH 31/52] lint --- nequip/train/trainer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index f790ee51..025adbed 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -492,9 +492,7 @@ def as_dict( if len(builder) > 1: # it's not a single name which is from nequip codes_for_git.add(builder[0]) - dictionary[f"code_versions"] = { - code: get_commit(code) for code in codes_for_git - } + dictionary["code_versions"] = {code: get_commit(code) for code in codes_for_git} return dictionary @@ -665,7 +663,10 @@ def load_model_from_training_session( if config.get("compile_model", False): model = torch.jit.load(traindir + "/" + model_name, map_location=device) else: - model = model_from_config(config=config, initialize=False,) + model = model_from_config( + config=config, + initialize=False, + ) if model is not None: # TODO: this is not exactly equivalent to building with # this set as default dtype... does it matter? @@ -875,7 +876,8 @@ def epoch_step(self): self.n_batches = len(dataset) for self.ibatch, batch in enumerate(dataset): self.batch_step( - data=batch, validation=(category == VALIDATION), + data=batch, + validation=(category == VALIDATION), ) self.end_of_batch_log(batch_type=category) for callback in self.end_of_batch_callbacks: @@ -1022,7 +1024,11 @@ def end_of_epoch_log(self): lr = self.optim.param_groups[0]["lr"] wall = perf_counter() - self.wall - self.mae_dict = dict(LR=lr, epoch=self.iepoch, wall=wall,) + self.mae_dict = dict( + LR=lr, + epoch=self.iepoch, + wall=wall, + ) header = "epoch, wall, LR" From abd49171a2815a9f5810954ae8431b673a718676 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Sun, 12 Dec 2021 14:51:20 -0500 Subject: [PATCH 32/52] report keys in equivariance error --- CHANGELOG.md | 1 + nequip/scripts/train.py | 6 +--- nequip/utils/test.py | 71 ++++++++++++++++++++++++++++++++++------- 3 files changed, 62 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a4bb3b05..f410017f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ Most recent change on the bottom. - Asynchronous IO: during training, models are written asynchronously. - `dataset_seed` to separately control randomness used to select training data (and their order). Enable this with environment variable `NEQUIP_ASYNC_IO=true`. - The types may now be specified with a simpler `chemical_symbols` option +- Equivariance testing reports per-field errors ### Changed - All fields now have consistant [N, dim] shaping diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 175979de..b7ec812b 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -180,11 +180,7 @@ def fresh_start(config): # Equivar test if config.equivariance_test: - from e3nn.util.test import format_equivariance_error - - equivar_err = assert_AtomicData_equivariant(final_model, dataset[0]) - errstr = format_equivariance_error(equivar_err) - del equivar_err + errstr = assert_AtomicData_equivariant(final_model, dataset[0]) logging.info(f"Equivariance test passed; equivariance errors:\n{errstr}") del errstr diff --git a/nequip/utils/test.py b/nequip/utils/test.py index 359bde01..2ccc2b53 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -1,8 +1,8 @@ -from typing import Union +from typing import Union, Optional import torch from e3nn import o3 -from e3nn.util.test import assert_equivariant +from e3nn.util.test import equivariance_error, FLOAT_TOLERANCE from nequip.nn import GraphModuleMixin from nequip.data import ( @@ -24,7 +24,9 @@ def _inverse_permutation(perm): def assert_permutation_equivariant( - func: GraphModuleMixin, data_in: AtomicDataDict.Type + func: GraphModuleMixin, + data_in: AtomicDataDict.Type, + tolerance: Optional[float] = None, ): r"""Test the permutation equivariance of ``func``. @@ -39,7 +41,10 @@ def assert_permutation_equivariant( # Prevent pytest from showing this function in the traceback # __tracebackhide__ = True - atol = PERMUTATION_FLOAT_TOLERANCE[torch.get_default_dtype()] + if tolerance is None: + atol = PERMUTATION_FLOAT_TOLERANCE[torch.get_default_dtype()] + else: + atol = tolerance data_in = data_in.copy() device = data_in[AtomicDataDict.POSITIONS_KEY].device @@ -120,8 +125,10 @@ def assert_permutation_equivariant( def assert_AtomicData_equivariant( func: GraphModuleMixin, data_in: Union[AtomicData, AtomicDataDict.Type], + permutation_tolerance: Optional[float] = None, + o3_tolerance: Optional[float] = None, **kwargs, -): +) -> str: r"""Test the rotation, translation, parity, and permutation equivariance of ``func``. For details on permutation testing, see ``assert_permutation_equivariant``. @@ -135,7 +142,7 @@ def assert_AtomicData_equivariant( **kwargs: passed to ``e3nn.util.test.assert_equivariant`` Returns: - Information on equivariance error from ``e3nn.util.test.assert_equivariant`` + A string description of the errors. """ # Prevent pytest from showing this function in the traceback __tracebackhide__ = True @@ -144,10 +151,7 @@ def assert_AtomicData_equivariant( data_in = AtomicData.to_AtomicDataDict(data_in) # == Test permutation of graph nodes == - assert_permutation_equivariant( - func, - data_in, - ) + assert_permutation_equivariant(func, data_in, tolerance=permutation_tolerance) # == Test rotation, parity, and translation using e3nn == irreps_in = {k: None for k in AtomicDataDict.ALLOWED_KEYS} @@ -197,7 +201,7 @@ def wrapper(*args): args_in = [data_in[k] for k in irreps_in] - return assert_equivariant( + errs = equivariance_error( wrapper, args_in=args_in, irreps_in=list(irreps_in.values()), @@ -205,6 +209,51 @@ def wrapper(*args): **kwargs, ) + if o3_tolerance is None: + o3_tolerance = FLOAT_TOLERANCE[torch.get_default_dtype()] + if isinstance(next(iter(errs.values())), float): + # old e3nn doesn't report which key + problems = {k: v for k, v in errs.items() if v > o3_tolerance} + + def _describe(errors): + return "\n".join( + "(parity_k={:d}, did_translate={}) -> max error={:.3e}".format( + int(k[0]), + bool(k[1]), + float(v), + ) + for k, v in errors.items() + ) + + if len(problems) > 0: + raise AssertionError( + "Equivariance test failed for cases:" + _describe(problems) + ) + + return _describe(errs) + else: + # it's newer and tells us which is which + all_errs = [] + for case, err in errs.items(): + for key, this_err in zip(irreps_out.keys(), err): + all_errs.append(case + (key, this_err)) + problems = [e for e in all_errs if e[-1] > o3_tolerance] + + def _describe(errors): + return "\n".join( + " (parity_k={:1d}, did_translate={:5}, field={:20}) -> max error={:.3e}".format( + int(k[0]), str(bool(k[1])), str(k[2]), float(k[3]) + ) + for k in errors + ) + + if len(problems) > 0: + raise AssertionError( + "Equivariance test failed for cases:\n" + _describe(problems) + ) + + return _describe(all_errs) + _DEBUG_HOOKS = None From c4e54d4aede24349df7c6d51b35325d54f73c5ba Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Sun, 12 Dec 2021 14:51:53 -0500 Subject: [PATCH 33/52] handle function builders --- nequip/train/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 025adbed..7bde37fb 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -488,6 +488,8 @@ def as_dict( codes_for_git = {"e3nn", "nequip"} for builder in self.model_builders: + if not isinstance(builder, str): + continue builder = builder.split(".") if len(builder) > 1: # it's not a single name which is from nequip From 9ead52f7091ba885ba1ad9349908c152c68282cd Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 13 Dec 2021 14:04:54 -0500 Subject: [PATCH 34/52] handle prev e3nn version --- nequip/utils/test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nequip/utils/test.py b/nequip/utils/test.py index 2ccc2b53..14b4b399 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -211,7 +211,8 @@ def wrapper(*args): if o3_tolerance is None: o3_tolerance = FLOAT_TOLERANCE[torch.get_default_dtype()] - if isinstance(next(iter(errs.values())), float): + anerr = next(iter(errs.values())) + if isinstance(anerr, float) or anerr.ndim == 0: # old e3nn doesn't report which key problems = {k: v for k, v in errs.items() if v > o3_tolerance} From 72034b8560b9f7e0ecb36e90e6713a9331f03ce2 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 13 Dec 2021 14:15:58 -0500 Subject: [PATCH 35/52] contributing notes --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- CONTRIBUTING.md | 23 +++++++++++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1f6e8b38..d0ee7a97 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -24,7 +24,7 @@ Resolves: #??? - [ ] My code follows the code style of this project and has been formatted using `black`. -- [ ] All new and existing tests passed. +- [ ] All new and existing tests passed, including on GPU (if relevant). - [ ] I have added tests that cover my changes (if relevant). - [ ] The option documentation (`docs/options`) has been updated with new or changed options. - [ ] I have updated `CHANGELOG.md`. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 65f75f70..5aa0f34a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,8 +1,27 @@ # Contributing to NequIP +Issues and pull requests are welcome! + +**!! If you want to make a major change, or one whose correct location/implementation is not obvious, please reach out to discuss with us first. !!** + +In general: + - Optional additions/alternatives to how the model is built or initialized should be implemented as model builders (see `nequip.model`) + - New model features should be implemented as new modules when possible + - Added options should be documented in the docs and changes in the CHANGELOG.md file + +Unless they fix a significant bug with immediate impact, **all PRs should be onto the `develop` branch!** + ## Code style We use the [`black`](https://black.readthedocs.io/en/stable/index.html) code formatter with default settings and the flake8 linter with settings: ``` ---ignore=E501,W503,E203 -``` \ No newline at end of file +--ignore=E226,E501,E741,E743,C901,W503,E203 --max-line-length=127 +``` + +Please run the formatter before you commit and certainly before you make a PR. The formatter can be easily set up to run automatically on file save in various editors. + +## CUDA support + +All additions should support CUDA/GPU. + +If possible, please test your changes on a GPU— the CI tests on GitHub actions do not have GPU resources available. \ No newline at end of file From a4f17b3cc03046bd86f8e5b148f2517c14d91668 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 13 Dec 2021 14:31:15 -0500 Subject: [PATCH 36/52] equivariance test in real units --- nequip/scripts/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index b7ec812b..6bfa4000 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -180,7 +180,9 @@ def fresh_start(config): # Equivar test if config.equivariance_test: + final_model.eval() errstr = assert_AtomicData_equivariant(final_model, dataset[0]) + final_model.train() logging.info(f"Equivariance test passed; equivariance errors:\n{errstr}") del errstr From d8440f161a02e4b6641819acbb8431d26b0bbe19 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 13 Dec 2021 14:43:24 -0500 Subject: [PATCH 37/52] fix perspecies when scale but no shift --- CHANGELOG.md | 1 + nequip/nn/_atomwise.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f410017f..d0ce6779 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Most recent change on the bottom. ### Fixed - Equivariance testing no longer unintentionally skips translation - Correct cat dim for all registered per-graph fields +- `PerSpeciesScaleShift` now correctly outputs when scales, but not shifts, are enabled— previously it was broken and would only output updated values when both were enabled. ## [0.5.0] - 2021-11-24 ### Changed diff --git a/nequip/nn/_atomwise.py b/nequip/nn/_atomwise.py index 88718fb8..358bdb6c 100644 --- a/nequip/nn/_atomwise.py +++ b/nequip/nn/_atomwise.py @@ -165,7 +165,8 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: if self.has_scales: in_field = self.scales[species_idx].view(-1, 1) * in_field if self.has_shifts: - data[self.out_field] = self.shifts[species_idx].view(-1, 1) + in_field + in_field = self.shifts[species_idx].view(-1, 1) + in_field + data[self.out_field] = in_field return data def update_for_rescale(self, rescale_module): From 0efd72ce02494d622878fbd8cf763ad6fb9b50d9 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 13 Dec 2021 15:57:14 -0500 Subject: [PATCH 38/52] print note --- nequip/scripts/train.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 6bfa4000..a6a91a4d 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -180,10 +180,18 @@ def fresh_start(config): # Equivar test if config.equivariance_test: - final_model.eval() + # final_model.eval() errstr = assert_AtomicData_equivariant(final_model, dataset[0]) final_model.train() - logging.info(f"Equivariance test passed; equivariance errors:\n{errstr}") + logging.info( + "Equivariance test passed; equivariance errors:\n" + " Errors are in real units, where relevant.\n" + " Please note that the large scale of the typical\n" + " shifts to the (atomic) energy can cause\n" + " catastrophic cancellation and give incorrectly\n" + " the equivariance error as zero for those fields.\n" + f"{errstr}" + ) del errstr # Set the trainer From 32cf2159bd6b6b4d95d7e1817e33e57d72b158d0 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 13 Dec 2021 16:27:34 -0500 Subject: [PATCH 39/52] equivariance testing with multiple frames --- CHANGELOG.md | 2 ++ nequip/scripts/train.py | 19 +++++++++----- nequip/utils/test.py | 57 ++++++++++++++++++++++++----------------- 3 files changed, 48 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0ce6779..cddaf93b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,10 +13,12 @@ Most recent change on the bottom. - `dataset_seed` to separately control randomness used to select training data (and their order). Enable this with environment variable `NEQUIP_ASYNC_IO=true`. - The types may now be specified with a simpler `chemical_symbols` option - Equivariance testing reports per-field errors +- `--equivariance-test n` tests equivariance on `n` frames from the training dataset ### Changed - All fields now have consistant [N, dim] shaping - Changed default `seed` and `dataset_seed` in example YAMLs +- Equivariance testing can only use training frames now ### Fixed - Equivariance testing no longer unintentionally skips translation diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index a6a91a4d..e9eed8ca 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -77,8 +77,10 @@ def parse_command_line(args=None): parser.add_argument("config", help="configuration file") parser.add_argument( "--equivariance-test", - help="test the model's equivariance before training", - action="store_true", + help="test the model's equivariance before training on n (default 1) random frames from the dataset", + const=1, + type=int, + nargs="?", ) parser.add_argument( "--model-debug-mode", @@ -179,9 +181,14 @@ def fresh_start(config): logging.info("Successfully compiled model...") # Equivar test - if config.equivariance_test: - # final_model.eval() - errstr = assert_AtomicData_equivariant(final_model, dataset[0]) + if config.equivariance_test > 0: + n_train: int = len(trainer.dataset_train) + assert config.equivariance_test <= n_train + final_model.eval() + indexes = torch.randperm(n_train)[: config.equivariance_test] + errstr = assert_AtomicData_equivariant( + final_model, [trainer.dataset_train[i] for i in indexes] + ) final_model.train() logging.info( "Equivariance test passed; equivariance errors:\n" @@ -192,7 +199,7 @@ def fresh_start(config): " the equivariance error as zero for those fields.\n" f"{errstr}" ) - del errstr + del errstr, indexes, n_train # Set the trainer trainer.model = final_model diff --git a/nequip/utils/test.py b/nequip/utils/test.py index 14b4b399..ecfac943 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -1,4 +1,4 @@ -from typing import Union, Optional +from typing import Union, Optional, List import torch from e3nn import o3 @@ -124,7 +124,9 @@ def assert_permutation_equivariant( def assert_AtomicData_equivariant( func: GraphModuleMixin, - data_in: Union[AtomicData, AtomicDataDict.Type], + data_in: Union[ + AtomicData, AtomicDataDict.Type, List[Union[AtomicData, AtomicDataDict.Type]] + ], permutation_tolerance: Optional[float] = None, o3_tolerance: Optional[float] = None, **kwargs, @@ -138,7 +140,7 @@ def assert_AtomicData_equivariant( Args: func: the module or model to test - data_in: the example input data to test with + data_in: the example input data(s) to test with. Only the first is used for permutation testing. **kwargs: passed to ``e3nn.util.test.assert_equivariant`` Returns: @@ -147,16 +149,18 @@ def assert_AtomicData_equivariant( # Prevent pytest from showing this function in the traceback __tracebackhide__ = True - if not isinstance(data_in, dict): - data_in = AtomicData.to_AtomicDataDict(data_in) + if not isinstance(data_in, list): + data_in = [data_in] + data_in = [AtomicData.to_AtomicDataDict(d) for d in data_in] # == Test permutation of graph nodes == - assert_permutation_equivariant(func, data_in, tolerance=permutation_tolerance) + # TODO: since permutation is distinct, run only on one. + assert_permutation_equivariant(func, data_in[0], tolerance=permutation_tolerance) # == Test rotation, parity, and translation using e3nn == irreps_in = {k: None for k in AtomicDataDict.ALLOWED_KEYS} irreps_in.update(func.irreps_in) - irreps_in = {k: v for k, v in irreps_in.items() if k in data_in} + irreps_in = {k: v for k, v in irreps_in.items() if k in data_in[0]} irreps_out = func.irreps_out.copy() # for certain things, we don't care what the given irreps are... # make sure that we test correctly for equivariance: @@ -191,23 +195,28 @@ def wrapper(*args): arg_dict[AtomicDataDict.CELL_KEY] = cell.reshape(cell.shape[:-2] + (9,)) return [output[k] for k in irreps_out] - data_in = AtomicData.to_AtomicDataDict(data_in) - # cell is a special case - if AtomicDataDict.CELL_KEY in data_in: - # flatten - cell = data_in[AtomicDataDict.CELL_KEY] - assert cell.shape[-2:] == (3, 3) - data_in[AtomicDataDict.CELL_KEY] = cell.reshape(cell.shape[:-2] + (9,)) - - args_in = [data_in[k] for k in irreps_in] - - errs = equivariance_error( - wrapper, - args_in=args_in, - irreps_in=list(irreps_in.values()), - irreps_out=list(irreps_out.values()), - **kwargs, - ) + # prepare input data + for d in data_in: + # cell is a special case + if AtomicDataDict.CELL_KEY in d: + # flatten + cell = d[AtomicDataDict.CELL_KEY] + assert cell.shape[-2:] == (3, 3) + d[AtomicDataDict.CELL_KEY] = cell.reshape(cell.shape[:-2] + (9,)) + + errs = [ + equivariance_error( + wrapper, + args_in=[d[k] for k in irreps_in], + irreps_in=list(irreps_in.values()), + irreps_out=list(irreps_out.values()), + **kwargs, + ) + for d in data_in + ] + + # take max across errors + errs = {k: torch.max(torch.vstack([e[k] for e in errs]), dim=0)[0] for k in errs[0]} if o3_tolerance is None: o3_tolerance = FLOAT_TOLERANCE[torch.get_default_dtype()] From 0e527d0fcdd19711e91d66aba00a590d5a5f1fcd Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 13 Dec 2021 17:04:25 -0500 Subject: [PATCH 40/52] refactor shape processing --- nequip/data/AtomicData.py | 111 ++++++++++++++++++++------------------ nequip/data/dataset.py | 24 ++++----- 2 files changed, 69 insertions(+), 66 deletions(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index bd70f323..95e2a4a4 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -106,6 +106,64 @@ def deregister_fields(*fields: Sequence[str]) -> None: _GRAPH_FIELDS.discard(f) +def _process_dict(kwargs): + """Convert a dict of data into correct dtypes/shapes according to key""" + # Deal with _some_ dtype issues + for k, v in kwargs.items(): + if k in _LONG_FIELDS: + # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems) + # int32 would pass later checks, but is actually disallowed by torch + kwargs[k] = torch.as_tensor(v, dtype=torch.long) + elif isinstance(v, np.ndarray): + if np.issubdtype(v.dtype, np.floating): + kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) + else: + kwargs[k] = torch.as_tensor(v) + elif np.issubdtype(type(v), np.floating): + # Force scalars to be tensors with a data dimension + # This makes them play well with irreps + kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) + elif isinstance(v, torch.Tensor) and len(v.shape) == 0: + # ^ this tensor is a scalar; we need to give it + # a data dimension to play nice with irreps + kwargs[k] = v + + if AtomicDataDict.BATCH_KEY in kwargs: + num_frames = kwargs[AtomicDataDict.BATCH_KEY].max() + 1 + else: + num_frames = 1 + + for k, v in kwargs.items(): + + if len(v.shape) == 0: + kwargs[k] = v.unsqueeze(-1) + v = kwargs[k] + + if k in set.union(_NODE_FIELDS, _EDGE_FIELDS) and len(v.shape) == 1: + kwargs[k] = v.unsqueeze(-1) + v = kwargs[k] + + if ( + k in _NODE_FIELDS + and AtomicDataDict.POSITIONS_KEY in kwargs + and v.shape[0] != kwargs[AtomicDataDict.POSITIONS_KEY].shape[0] + ): + raise ValueError( + f"{k} is a node field but has the wrong dimension {v.shape}" + ) + elif ( + k in _EDGE_FIELDS + and AtomicDataDict.EDGE_INDEX_KEY in kwargs + and v.shape[0] != kwargs[AtomicDataDict.EDGE_INDEX_KEY].shape[1] + ): + raise ValueError( + f"{k} is a edge field but has the wrong dimension {v.shape}" + ) + elif k in _GRAPH_FIELDS: + if num_frames > 1 and v.shape[0] != num_frames: + raise ValueError(f"Wrong shape for graph property {k}") + + class AtomicData(Data): """A neighbor graph for points in (periodic triclinic) real space. @@ -145,58 +203,7 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): # Check the keys AtomicDataDict.validate_keys(kwargs) - # Deal with _some_ dtype issues - for k, v in kwargs.items(): - if k in _LONG_FIELDS: - # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems) - # int32 would pass later checks, but is actually disallowed by torch - kwargs[k] = torch.as_tensor(v, dtype=torch.long) - elif isinstance(v, np.ndarray): - if np.issubdtype(v.dtype, np.floating): - kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) - else: - kwargs[k] = torch.as_tensor(v) - elif np.issubdtype(type(v), np.floating): - # Force scalars to be tensors with a data dimension - # This makes them play well with irreps - kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) - elif isinstance(v, torch.Tensor) and len(v.shape) == 0: - # ^ this tensor is a scalar; we need to give it - # a data dimension to play nice with irreps - kwargs[k] = v - - if AtomicDataDict.BATCH_KEY in kwargs: - num_frames = kwargs[AtomicDataDict.BATCH_KEY].max() + 1 - else: - num_frames = 1 - - for k, v in kwargs.items(): - - if len(v.shape) == 0: - kwargs[k] = v.unsqueeze(-1) - v = kwargs[k] - - if k in set.union(_NODE_FIELDS, _EDGE_FIELDS) and len(v.shape) == 1: - kwargs[k] = v.unsqueeze(-1) - v = kwargs[k] - - if ( - k in _NODE_FIELDS - and v.shape[0] != kwargs[AtomicDataDict.POSITIONS_KEY].shape[0] - ): - raise ValueError( - f"{k} is a node field but has the wrong dimension {v.shape}" - ) - elif ( - k in _EDGE_FIELDS - and v.shape[0] != kwargs[AtomicDataDict.EDGE_INDEX_KEY].shape[1] - ): - raise ValueError( - f"{k} is a edge field but has the wrong dimension {v.shape}" - ) - elif k in _GRAPH_FIELDS: - if num_frames > 1 and v.shape[0] != num_frames: - raise ValueError(f"Wrong shape for graph property {k}") + _process_dict(kwargs) super().__init__(num_nodes=len(kwargs["pos"]), **kwargs) diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index 4ff30ee4..59e7d79d 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -28,6 +28,7 @@ from nequip.utils.regressor import solver from nequip.utils.savenload import atomic_write from .transforms import TypeMapper +from .AtomicData import _process_dict class AtomicDataset(Dataset): @@ -280,18 +281,7 @@ def process(self): del fields # type conversion - for key, value in fixed_fields.items(): - if isinstance(value, np.ndarray): - if np.issubdtype(value.dtype, np.floating): - fixed_fields[key] = torch.as_tensor( - value, dtype=torch.get_default_dtype() - ) - else: - fixed_fields[key] = torch.as_tensor(value) - elif np.issubdtype(type(value), np.floating): - fixed_fields[key] = torch.as_tensor( - value, dtype=torch.get_default_dtype() - ) + _process_dict(fixed_fields) logging.info(f"Loaded data: {data}") @@ -525,7 +515,10 @@ def statistics( @staticmethod def _per_atom_statistics( - ana_mode: str, arr: torch.Tensor, batch: torch.Tensor, unbiased: bool = True, + ana_mode: str, + arr: torch.Tensor, + batch: torch.Tensor, + unbiased: bool = True, ): """Compute "per-atom" statistics that are normalized by the number of atoms in the system. @@ -842,7 +835,10 @@ def get_data(self): atoms_list = self.get_atoms() # skip the None arguments - kwargs = dict(include_keys=self.include_keys, key_mapping=self.key_mapping,) + kwargs = dict( + include_keys=self.include_keys, + key_mapping=self.key_mapping, + ) kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs.update(self.extra_fixed_fields) From 1b446f7b4e1052494754c8fe6e6e01dfe85b4cf7 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 13 Dec 2021 17:06:16 -0500 Subject: [PATCH 41/52] handle atom types extra final dim --- nequip/train/_loss.py | 3 +-- nequip/train/metrics.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/nequip/train/_loss.py b/nequip/train/_loss.py index 4db16274..7514a509 100644 --- a/nequip/train/_loss.py +++ b/nequip/train/_loss.py @@ -126,11 +126,11 @@ def __call__( reduce_dims = tuple(i + 1 for i in range(len(per_atom_loss.shape) - 1)) + spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1) if has_nan: if len(reduce_dims) > 0: per_atom_loss = per_atom_loss.sum(dim=reduce_dims) - spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY] per_species_loss = scatter(per_atom_loss, spe_idx, dim=0) N = scatter(not_nan, spe_idx, dim=0) @@ -146,7 +146,6 @@ def __call__( per_atom_loss = per_atom_loss.mean(dim=reduce_dims) # offset species index by 1 to use 0 for nan - spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY] _, inverse_species_index = torch.unique(spe_idx, return_inverse=True) per_species_loss = scatter_mean(per_atom_loss, inverse_species_index, dim=0) diff --git a/nequip/train/metrics.py b/nequip/train/metrics.py index 820afd85..2f790bd9 100644 --- a/nequip/train/metrics.py +++ b/nequip/train/metrics.py @@ -169,7 +169,9 @@ def __call__(self, pred: dict, ref: dict): params = {} if per_species: # TO DO, this needs OneHot component. will need to be decoupled - params = {"accumulate_by": pred[AtomicDataDict.ATOM_TYPE_KEY]} + params = { + "accumulate_by": pred[AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1) + } if per_atom: if N is None: N = torch.bincount(ref[AtomicDataDict.BATCH_KEY]).unsqueeze(-1) From af0d6c0811fd31e10774810f9fc5492fd72581d8 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 13 Dec 2021 17:22:42 -0500 Subject: [PATCH 42/52] bugfix --- nequip/data/AtomicData.py | 7 ++++++- nequip/data/dataset.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 95e2a4a4..56cf6c31 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -106,10 +106,13 @@ def deregister_fields(*fields: Sequence[str]) -> None: _GRAPH_FIELDS.discard(f) -def _process_dict(kwargs): +def _process_dict(kwargs, ignore_fields=[]): """Convert a dict of data into correct dtypes/shapes according to key""" # Deal with _some_ dtype issues for k, v in kwargs.items(): + if k in ignore_fields: + continue + if k in _LONG_FIELDS: # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems) # int32 would pass later checks, but is actually disallowed by torch @@ -134,6 +137,8 @@ def _process_dict(kwargs): num_frames = 1 for k, v in kwargs.items(): + if k in ignore_fields: + continue if len(v.shape) == 0: kwargs[k] = v.unsqueeze(-1) diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index 59e7d79d..7b0f6201 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -281,7 +281,7 @@ def process(self): del fields # type conversion - _process_dict(fixed_fields) + _process_dict(fixed_fields, ignore_fields=["r_max"]) logging.info(f"Loaded data: {data}") From 4832db7754bf184d1e1b0ef12346715e8465d240 Mon Sep 17 00:00:00 2001 From: Marcel Langer Date: Sun, 19 Dec 2021 14:42:21 +0100 Subject: [PATCH 43/52] Minimum viable patch for vibes compatibility --- nequip/ase/nequip_calculator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nequip/ase/nequip_calculator.py b/nequip/ase/nequip_calculator.py index 39723dfe..11b7f54a 100644 --- a/nequip/ase/nequip_calculator.py +++ b/nequip/ase/nequip_calculator.py @@ -10,6 +10,11 @@ import nequip.scripts.deploy +def nequip_calculator(model, **kwargs): + """Build ASE Calculator directly from deployed model.""" + return NequIPCalculator.from_deployed_model(model, **kwargs) + + class NequIPCalculator(Calculator): """NequIP ASE Calculator. From 1f51aac8af290eb1c2098dc39145f9a6c3f877c7 Mon Sep 17 00:00:00 2001 From: Simon Batzner Date: Sun, 19 Dec 2021 21:51:53 +0100 Subject: [PATCH 44/52] update CHANGELOD for vibes compatibility --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 78adfb6a..03ed6699 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ Most recent change on the bottom. ### Fixed - `PerSpeciesScaleShift` now correctly outputs when scales, but not shifts, are enabled— previously it was broken and would only output updated values when both were enabled. +### Added +- `NequIPCalculator` can now be built via a `nequip_calculator()` function. This adds a minimal compatibility with [vibes](https://gitlab.com/vibes-developers/vibes/) + ## [0.5.0] - 2021-11-24 ### Changed - Allow e3nn 0.4.*, which changes the default normalization of `TensorProduct`s; this change _should_ not affect typical NequIP networks From 5bcd861759008372b736f65ef1b156532dfd0028 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 6 Jan 2022 16:09:24 -0700 Subject: [PATCH 45/52] link CONTRIBUTING.MD --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b3f2d378..dce25648 100644 --- a/README.md +++ b/README.md @@ -145,9 +145,11 @@ NequIP is being developed by: under the guidance of [Boris Kozinsky at Harvard](https://bkoz.seas.harvard.edu/). -## Contact & questions +## Contact, questions, and contributing If you have questions, please don't hesitate to reach out at batzner[at]g[dot]harvard[dot]edu. If you find a bug or have a proposal for a feature, please post it in the [Issues](https://github.com/mir-group/nequip/issues). If you have a question, topic, or issue that isn't obviously one of those, try our [GitHub Disucssions](https://github.com/mir-group/nequip/discussions). + +If you want to contribute to the code, please read [`CONTRIBUTING.md`](CONTRIBUTING.md). From 1c3144418c2469cc6e47bf5759eea5e30b1640d7 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 13 Jan 2022 13:49:02 -0700 Subject: [PATCH 46/52] add `untransform` --- nequip/data/transforms.py | 13 +++++++++++++ tests/unit/data/test_dataset.py | 9 +++++++++ 2 files changed, 22 insertions(+) diff --git a/nequip/data/transforms.py b/nequip/data/transforms.py index 0c6735d1..68ad6bcc 100644 --- a/nequip/data/transforms.py +++ b/nequip/data/transforms.py @@ -69,6 +69,11 @@ def __init__( for sym, type in self.chemical_symbol_to_type.items(): Z_to_index[ase.data.atomic_numbers[sym] - self._min_Z] = type self._Z_to_index = Z_to_index + self._index_to_Z = torch.zeros( + size=(len(self.chemical_symbol_to_type),), dtype=torch.long + ) + for sym, type_idx in self.chemical_symbol_to_type.items(): + self._index_to_Z[type_idx] = ase.data.atomic_numbers[sym] self._valid_set = set(valid_atomic_numbers) # check if type_names is None: @@ -117,3 +122,11 @@ def transform(self, atomic_numbers): ) return self._Z_to_index[atomic_numbers - self._min_Z] + + def untransform(self, atom_types): + """Transform atom types back into atomic numbers""" + return self._index_to_Z[atom_types] + + @property + def has_chemical_symbols(self) -> bool: + return self.chemical_symbol_to_type is not None diff --git a/tests/unit/data/test_dataset.py b/tests/unit/data/test_dataset.py index 3c3501f8..01b89e2d 100644 --- a/tests/unit/data/test_dataset.py +++ b/tests/unit/data/test_dataset.py @@ -66,6 +66,15 @@ def root(): yield path +def test_type_mapper(): + tm = TypeMapper(chemical_symbol_to_type={"C": 1, "H": 0}) + atomic_numbers = torch.as_tensor([1, 1, 6, 1, 6, 6, 6]) + atom_types = tm.transform(atomic_numbers) + assert atom_types[0] == 0 + untransformed = tm.untransform(atom_types) + assert torch.equal(untransformed, atomic_numbers) + + class TestInit: def test_init(self): with pytest.raises(NotImplementedError) as excinfo: From 33a06d666b48189c5f1150181bac5ab3675c49d4 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 13 Jan 2022 13:49:24 -0700 Subject: [PATCH 47/52] typo --- nequip/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index 7b0f6201..73e1fcec 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -88,7 +88,7 @@ def __init__( force_fixed_keys: List[str] = [], extra_fixed_fields: Dict[str, Any] = {}, include_frames: Optional[List[int]] = None, - type_mapper: TypeMapper = None, + type_mapper: Optional[TypeMapper] = None, ): # TO DO, this may be simplified # See if a subclass defines some inputs From 48ea45f543bc4c380834a8d1da6b0a93898e0576 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 13 Jan 2022 13:50:24 -0700 Subject: [PATCH 48/52] correct species in XYZ --- CHANGELOG.md | 2 ++ nequip/data/AtomicData.py | 8 +++++++- nequip/scripts/evaluate.py | 5 ++++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 143c7e57..7076b4a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ Most recent change on the bottom. - Equivariance testing no longer unintentionally skips translation - Correct cat dim for all registered per-graph fields - `PerSpeciesScaleShift` now correctly outputs when scales, but not shifts, are enabled— previously it was broken and would only output updated values when both were enabled. +- `nequip-evaluate` outputs correct species to the `extxyz` file when a chemical symbol <-> type mapping exists for the test dataset + ## [0.5.0] - 2021-11-24 ### Changed - Allow e3nn 0.4.*, which changes the default normalization of `TensorProduct`s; this change _should_ not affect typical NequIP networks diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 56cf6c31..7ce7f12b 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -419,13 +419,17 @@ def from_ase( **add_fields, ) - def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: + def to_ase(self, type_mapper=None) -> Union[List[ase.Atoms], ase.Atoms]: """Build a (list of) ``ase.Atoms`` object(s) from an ``AtomicData`` object. For each unique batch number provided in ``AtomicDataDict.BATCH_KEY``, an ``ase.Atoms`` object is created. If ``AtomicDataDict.BATCH_KEY`` does not exist in self, a single ``ase.Atoms`` object is created. + Args: + type_mapper: if provided, will be used to map ``ATOM_TYPES`` back into + elements, if the configuration of the ``type_mapper`` allows. + Returns: A list of ``ase.Atoms`` objects if ``AtomicDataDict.BATCH_KEY`` is in self and is not None. Otherwise, a single ``ase.Atoms`` object is returned. @@ -437,6 +441,8 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: ) if AtomicDataDict.ATOMIC_NUMBERS_KEY in self: atomic_nums = self.atomic_numbers + elif type_mapper is not None and type_mapper.has_chemical_symbols: + atomic_nums = type_mapper.untransform(self[AtomicDataDict.ATOM_TYPE_KEY]) else: warnings.warn( "AtomicData.to_ase(): self didn't contain atomic numbers... using atom_type as atomic numbers instead, but this means the chemical symbols in ASE (outputs) will be wrong" diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 67221af1..733f62a3 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -321,10 +321,13 @@ def main(args=None, running_as_script: bool = True): with torch.no_grad(): # Write output + # TODO: make sure don't keep appending to existing file if output is not None: ase.io.write( output, - AtomicData.from_AtomicDataDict(out).to(device="cpu").to_ase(), + AtomicData.from_AtomicDataDict(out) + .to(device="cpu") + .to_ase(type_mapper=dataset.type_mapper), format="extxyz", append=True, ) From 0682c3032bb0a2c2c17a32cdc8d87c2956f246aa Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 13 Jan 2022 14:33:23 -0700 Subject: [PATCH 49/52] typo --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7076b4a8..274bac72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,8 @@ Most recent change on the bottom. ### Added - `NequIPCalculator` can now be built via a `nequip_calculator()` function. This adds a minimal compatibility with [vibes](https://gitlab.com/vibes-developers/vibes/) - Added `avg_num_neighbors: auto` option -- Asynchronous IO: during training, models are written asynchronously. -- `dataset_seed` to separately control randomness used to select training data (and their order). Enable this with environment variable `NEQUIP_ASYNC_IO=true`. +- Asynchronous IO: during training, models are written asynchronously. Enable this with environment variable `NEQUIP_ASYNC_IO=true`. +- `dataset_seed` to separately control randomness used to select training data (and their order). - The types may now be specified with a simpler `chemical_symbols` option - Equivariance testing reports per-field errors - `--equivariance-test n` tests equivariance on `n` frames from the training dataset From 03005b7b17b988a20b008454307b982176403106 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 13 Jan 2022 14:34:42 -0700 Subject: [PATCH 50/52] date --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 274bac72..37371245 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. -## [Unreleased] - 0.5.1 +## [Unreleased] + +## [0.5.1] - 2022-01-13 ### Added - `NequIPCalculator` can now be built via a `nequip_calculator()` function. This adds a minimal compatibility with [vibes](https://gitlab.com/vibes-developers/vibes/) - Added `avg_num_neighbors: auto` option From db1891eb7995322611937217866d87e4fc093c4a Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 13 Jan 2022 14:43:27 -0700 Subject: [PATCH 51/52] backport for 3.6 --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index d01cf27e..f206bba9 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ "e3nn>=0.3.5,<0.5.0", "pyyaml", "contextlib2;python_version<'3.7'", # backport of nullcontext + 'contextvars;python_version<"3.7"', # backport of contextvars for savenload "typing_extensions;python_version<'3.8'", # backport of Final "torch-runstats>=0.2.0", "torch-ema>=0.3.0", From d4ccae2627d29d26bced6e3257d85a8f6a493ced Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Thu, 13 Jan 2022 14:44:26 -0700 Subject: [PATCH 52/52] test on 1.10 --- .github/workflows/tests.yml | 2 +- .github/workflows/tests_develop.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 96bbac82..1f48c043 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: python-version: [3.6, 3.9] - torch-version: [1.8.0, 1.9.0] + torch-version: [1.8.0, 1.10.0] steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/tests_develop.yml b/.github/workflows/tests_develop.yml index d9728d14..66a732b1 100644 --- a/.github/workflows/tests_develop.yml +++ b/.github/workflows/tests_develop.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: python-version: [3.9] - torch-version: [1.9.0] + torch-version: [1.10.0] steps: - uses: actions/checkout@v2