diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1f6e8b38..d0ee7a97 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -24,7 +24,7 @@ Resolves: #??? - [ ] My code follows the code style of this project and has been formatted using `black`. -- [ ] All new and existing tests passed. +- [ ] All new and existing tests passed, including on GPU (if relevant). - [ ] I have added tests that cover my changes (if relevant). - [ ] The option documentation (`docs/options`) has been updated with new or changed options. - [ ] I have updated `CHANGELOG.md`. diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 41e389be..1f48c043 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: python-version: [3.6, 3.9] - torch-version: [1.8.0, 1.9.0] + torch-version: [1.8.0, 1.10.0] steps: - uses: actions/checkout@v2 @@ -42,6 +42,10 @@ jobs: run: | pip install pytest pip install pytest-xdist[psutil] + - name: Download test data + run: | + mkdir benchmark_data + cd benchmark_data; wget "http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip"; cd .. - name: Test with pytest run: | # See https://github.com/pytest-dev/pytest/issues/1075 diff --git a/.github/workflows/tests_develop.yml b/.github/workflows/tests_develop.yml index a69a728d..66a732b1 100644 --- a/.github/workflows/tests_develop.yml +++ b/.github/workflows/tests_develop.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: python-version: [3.9] - torch-version: [1.9.0] + torch-version: [1.10.0] steps: - uses: actions/checkout@v2 @@ -42,6 +42,10 @@ jobs: run: | pip install pytest pip install pytest-xdist[psutil] + - name: Download test data + run: | + mkdir benchmark_data + cd benchmark_data; wget "http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip"; cd .. - name: Test with pytest run: | # See https://github.com/pytest-dev/pytest/issues/1075 diff --git a/CHANGELOG.md b/CHANGELOG.md index 78adfb6a..37371245 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,9 +6,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. + ## [Unreleased] + +## [0.5.1] - 2022-01-13 +### Added +- `NequIPCalculator` can now be built via a `nequip_calculator()` function. This adds a minimal compatibility with [vibes](https://gitlab.com/vibes-developers/vibes/) +- Added `avg_num_neighbors: auto` option +- Asynchronous IO: during training, models are written asynchronously. Enable this with environment variable `NEQUIP_ASYNC_IO=true`. +- `dataset_seed` to separately control randomness used to select training data (and their order). +- The types may now be specified with a simpler `chemical_symbols` option +- Equivariance testing reports per-field errors +- `--equivariance-test n` tests equivariance on `n` frames from the training dataset + +### Changed +- All fields now have consistant [N, dim] shaping +- Changed default `seed` and `dataset_seed` in example YAMLs +- Equivariance testing can only use training frames now + ### Fixed +- Equivariance testing no longer unintentionally skips translation +- Correct cat dim for all registered per-graph fields - `PerSpeciesScaleShift` now correctly outputs when scales, but not shifts, are enabled— previously it was broken and would only output updated values when both were enabled. +- `nequip-evaluate` outputs correct species to the `extxyz` file when a chemical symbol <-> type mapping exists for the test dataset ## [0.5.0] - 2021-11-24 ### Changed diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 65f75f70..5aa0f34a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,8 +1,27 @@ # Contributing to NequIP +Issues and pull requests are welcome! + +**!! If you want to make a major change, or one whose correct location/implementation is not obvious, please reach out to discuss with us first. !!** + +In general: + - Optional additions/alternatives to how the model is built or initialized should be implemented as model builders (see `nequip.model`) + - New model features should be implemented as new modules when possible + - Added options should be documented in the docs and changes in the CHANGELOG.md file + +Unless they fix a significant bug with immediate impact, **all PRs should be onto the `develop` branch!** + ## Code style We use the [`black`](https://black.readthedocs.io/en/stable/index.html) code formatter with default settings and the flake8 linter with settings: ``` ---ignore=E501,W503,E203 -``` \ No newline at end of file +--ignore=E226,E501,E741,E743,C901,W503,E203 --max-line-length=127 +``` + +Please run the formatter before you commit and certainly before you make a PR. The formatter can be easily set up to run automatically on file save in various editors. + +## CUDA support + +All additions should support CUDA/GPU. + +If possible, please test your changes on a GPU— the CI tests on GitHub actions do not have GPU resources available. \ No newline at end of file diff --git a/README.md b/README.md index b3f2d378..dce25648 100644 --- a/README.md +++ b/README.md @@ -145,9 +145,11 @@ NequIP is being developed by: under the guidance of [Boris Kozinsky at Harvard](https://bkoz.seas.harvard.edu/). -## Contact & questions +## Contact, questions, and contributing If you have questions, please don't hesitate to reach out at batzner[at]g[dot]harvard[dot]edu. If you find a bug or have a proposal for a feature, please post it in the [Issues](https://github.com/mir-group/nequip/issues). If you have a question, topic, or issue that isn't obviously one of those, try our [GitHub Disucssions](https://github.com/mir-group/nequip/discussions). + +If you want to contribute to the code, please read [`CONTRIBUTING.md`](CONTRIBUTING.md). diff --git a/configs/example.yaml b/configs/example.yaml index 70dfdd98..e916b1e7 100644 --- a/configs/example.yaml +++ b/configs/example.yaml @@ -6,7 +6,8 @@ # if 'root'/'run_name' exists, 'root'/'run_name'_'year'-'month'-'day'-'hour'-'min'-'s' will be used instead. root: results/toluene run_name: example-run-toluene -seed: 0 # random number seed for numpy and torch +seed: 123 +dataset_seed: 456 # random number seed for numpy and torch append: true # set true if a restarted run should append to the previous log file default_dtype: float32 # type of float to use, e.g. float32 and float64 @@ -39,7 +40,7 @@ PolynomialCutoff_p: 6 # radial network invariant_layers: 2 # number of radial layers, usually 1-3 works best, smaller is faster invariant_neurons: 64 # number of hidden neurons in radial function, smaller is faster -avg_num_neighbors: null # number of neighbors to divide by, null => no normalization. +avg_num_neighbors: auto # number of neighbors to divide by, null => no normalization. use_sc: true # use self-connection or not, usually gives big improvement # data set @@ -58,10 +59,10 @@ key_mapping: npz_fixed_field_keys: # fields that are repeated across different examples - atomic_numbers -# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. -chemical_symbol_to_type: - H: 0 - C: 1 +# A list of atomic types to be found in the data. The NequIP types will be named with the chemical symbols, and inputs with the correct atomic numbers will be mapped to the corresponding types. +chemical_symbols: + - H + - C # logging wandb: true # we recommend using wandb for logging, we'll turn it off here as it's optional diff --git a/configs/full.yaml b/configs/full.yaml index b87e11fc..a04e0fd4 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -10,11 +10,12 @@ # if 'root'/'run_name' exists, 'root'/'run_name'_'year'-'month'-'day'-'hour'-'min'-'s' will be used instead. root: results/toluene run_name: example-run-toluene -seed: 0 # random number seed for numpy and torch +seed: 123 +dataset_seed: 456 # random number seed for numpy and torch append: true # set true if a restarted run should append to the previous log file default_dtype: float32 # type of float to use, e.g. float32 and float64 allow_tf32: false # whether to use TensorFloat32 if it is available -# device: cuda # which device to use. Default: automatically detected cuda or "cpu" +# device: cuda # which device to use. Default: automatically detected cuda or "cpu" # network r_max: 4.0 # cutoff radius in length units, here Angstrom, this is an important hyperparamter to scan @@ -45,7 +46,7 @@ PolynomialCutoff_p: 6 # radial network invariant_layers: 2 # number of radial layers, usually 1-3 works best, smaller is faster invariant_neurons: 64 # number of hidden neurons in radial function, smaller is faster -avg_num_neighbors: null # number of neighbors to divide by, null => no normalization. +avg_num_neighbors: auto # number of neighbors to divide by, null => no normalization. use_sc: true # use self-connection or not, usually gives big improvement compile_model: false # whether to compile the constructed model to TorchScript @@ -88,16 +89,21 @@ npz_fixed_field_keys: # key_mapping: # free_energy: total_energy -# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. -chemical_symbol_to_type: - H: 0 - C: 1 - -# Alternatively, if the dataset has type indicess, the total number of types is all that is required: +# A list of atomic types to be found in the data. The NequIP types will be named with the chemical symbols, and inputs with the correct atomic numbers will be mapped to the corresponding types. +chemical_symbols: + - H + - C +# Alternatively, you may explicitly specify which chemical species maps to which type in NequIP (type index; the name is still taken from the chemical symbol) +# chemical_symbol_to_type: +# H: 0 +# C: 1 + +# Alternatively, if the dataset has type indices, you may give the names for the types in order: +# (this also sets the number of types) # type_names: -# 0: my_type -# 1: atom -# 2: thing +# - my_type +# - atom +# - thing # As an alternative option to npz, you can also pass data ase ASE Atoms-objects # This can often be easier to work with, simply make sure the ASE Atoms object diff --git a/configs/minimal.yaml b/configs/minimal.yaml index fa05bbbb..489fc7a6 100644 --- a/configs/minimal.yaml +++ b/configs/minimal.yaml @@ -1,7 +1,8 @@ # general root: results/aspirin run_name: minimal -seed: 0 +seed: 123 +dataset_seed: 456 # network num_basis: 8 @@ -25,16 +26,11 @@ key_mapping: R: pos # raw atomic positions npz_fixed_field_keys: # fields that are repeated across different examples - atomic_numbers -# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. -chemical_symbol_to_type: - H: 0 - C: 1 - O: 2 -# Alternatively, if the dataset has type indexes, the total number of types is all that is required: -# type_names: -# 0: my_type -# 1: atom -# 2: thing + +chemical_symbols: + - H + - O + - C # logging wandb: false diff --git a/configs/minimal_eng.yaml b/configs/minimal_eng.yaml index 0b4ad0d3..fe002fbc 100644 --- a/configs/minimal_eng.yaml +++ b/configs/minimal_eng.yaml @@ -29,16 +29,11 @@ key_mapping: R: pos # raw atomic positions npz_fixed_field_keys: # fields that are repeated across different examples - atomic_numbers -# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. -chemical_symbol_to_type: - H: 0 - C: 1 - O: 2 -# Alternatively, if the dataset has type indexes, the total number of types is all that is required: -# type_names: -# 0: my_type -# 1: atom -# 2: thing + +chemical_symbols: + - H + - O + - C # logging wandb: false diff --git a/docs/options/dataset.rst b/docs/options/dataset.rst index 22cd701e..54b39fc9 100644 --- a/docs/options/dataset.rst +++ b/docs/options/dataset.rst @@ -13,6 +13,11 @@ type_names | Type: NoneType | Default: ``None`` +chemical_symbols +^^^^^^^^^^^^^^^^ + | Type: NoneType + | Default: ``None`` + chemical_symbol_to_type ^^^^^^^^^^^^^^^^^^^^^^^ | Type: NoneType diff --git a/nequip/_version.py b/nequip/_version.py index edbb9d1b..41f4a9f7 100644 --- a/nequip/_version.py +++ b/nequip/_version.py @@ -2,4 +2,4 @@ # See Python packaging guide # https://packaging.python.org/guides/single-sourcing-package-version/ -__version__ = "0.5.0" +__version__ = "0.5.1" diff --git a/nequip/ase/nequip_calculator.py b/nequip/ase/nequip_calculator.py index 39723dfe..11b7f54a 100644 --- a/nequip/ase/nequip_calculator.py +++ b/nequip/ase/nequip_calculator.py @@ -10,6 +10,11 @@ import nequip.scripts.deploy +def nequip_calculator(model, **kwargs): + """Build ASE Calculator directly from deployed model.""" + return NequIPCalculator.from_deployed_model(model, **kwargs) + + class NequIPCalculator(Calculator): """NequIP ASE Calculator. diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 81a2978d..7ce7f12b 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -24,6 +24,13 @@ # A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case) PBC = Union[bool, Tuple[bool, bool, bool]] + +_DEFAULT_LONG_FIELDS: Set[str] = { + AtomicDataDict.EDGE_INDEX_KEY, + AtomicDataDict.ATOMIC_NUMBERS_KEY, + AtomicDataDict.ATOM_TYPE_KEY, + AtomicDataDict.BATCH_KEY, +} _DEFAULT_NODE_FIELDS: Set[str] = { AtomicDataDict.POSITIONS_KEY, AtomicDataDict.WEIGHTS_KEY, @@ -44,16 +51,20 @@ } _DEFAULT_GRAPH_FIELDS: Set[str] = { AtomicDataDict.TOTAL_ENERGY_KEY, + AtomicDataDict.PBC_KEY, + AtomicDataDict.CELL_KEY, } _NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS) _EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS) _GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS) +_LONG_FIELDS: Set[str] = set(_DEFAULT_LONG_FIELDS) def register_fields( node_fields: Sequence[str] = [], edge_fields: Sequence[str] = [], graph_fields: Sequence[str] = [], + long_fields: Sequence[str] = [], ) -> None: r"""Register fields as being per-atom, per-edge, or per-frame. @@ -69,6 +80,7 @@ def register_fields( _NODE_FIELDS.update(node_fields) _EDGE_FIELDS.update(edge_fields) _GRAPH_FIELDS.update(graph_fields) + _LONG_FIELDS.update(long_fields) if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < ( len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS) ): @@ -94,6 +106,69 @@ def deregister_fields(*fields: Sequence[str]) -> None: _GRAPH_FIELDS.discard(f) +def _process_dict(kwargs, ignore_fields=[]): + """Convert a dict of data into correct dtypes/shapes according to key""" + # Deal with _some_ dtype issues + for k, v in kwargs.items(): + if k in ignore_fields: + continue + + if k in _LONG_FIELDS: + # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems) + # int32 would pass later checks, but is actually disallowed by torch + kwargs[k] = torch.as_tensor(v, dtype=torch.long) + elif isinstance(v, np.ndarray): + if np.issubdtype(v.dtype, np.floating): + kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) + else: + kwargs[k] = torch.as_tensor(v) + elif np.issubdtype(type(v), np.floating): + # Force scalars to be tensors with a data dimension + # This makes them play well with irreps + kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) + elif isinstance(v, torch.Tensor) and len(v.shape) == 0: + # ^ this tensor is a scalar; we need to give it + # a data dimension to play nice with irreps + kwargs[k] = v + + if AtomicDataDict.BATCH_KEY in kwargs: + num_frames = kwargs[AtomicDataDict.BATCH_KEY].max() + 1 + else: + num_frames = 1 + + for k, v in kwargs.items(): + if k in ignore_fields: + continue + + if len(v.shape) == 0: + kwargs[k] = v.unsqueeze(-1) + v = kwargs[k] + + if k in set.union(_NODE_FIELDS, _EDGE_FIELDS) and len(v.shape) == 1: + kwargs[k] = v.unsqueeze(-1) + v = kwargs[k] + + if ( + k in _NODE_FIELDS + and AtomicDataDict.POSITIONS_KEY in kwargs + and v.shape[0] != kwargs[AtomicDataDict.POSITIONS_KEY].shape[0] + ): + raise ValueError( + f"{k} is a node field but has the wrong dimension {v.shape}" + ) + elif ( + k in _EDGE_FIELDS + and AtomicDataDict.EDGE_INDEX_KEY in kwargs + and v.shape[0] != kwargs[AtomicDataDict.EDGE_INDEX_KEY].shape[1] + ): + raise ValueError( + f"{k} is a edge field but has the wrong dimension {v.shape}" + ) + elif k in _GRAPH_FIELDS: + if num_frames > 1 and v.shape[0] != num_frames: + raise ValueError(f"Wrong shape for graph property {k}") + + class AtomicData(Data): """A neighbor graph for points in (periodic triclinic) real space. @@ -133,61 +208,7 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): # Check the keys AtomicDataDict.validate_keys(kwargs) - # Deal with _some_ dtype issues - for k, v in kwargs.items(): - if ( - k == AtomicDataDict.EDGE_INDEX_KEY - or k == AtomicDataDict.ATOMIC_NUMBERS_KEY - or k == AtomicDataDict.ATOM_TYPE_KEY - or k == AtomicDataDict.BATCH_KEY - ): - # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems) - # int32 would pass later checks, but is actually disallowed by torch - kwargs[k] = torch.as_tensor(v, dtype=torch.long) - elif isinstance(v, np.ndarray): - if np.issubdtype(v.dtype, np.floating): - kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) - else: - kwargs[k] = torch.as_tensor(v) - elif np.issubdtype(type(v), np.floating): - # Force scalars to be tensors with a data dimension - # This makes them play well with irreps - kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype()) - elif isinstance(v, torch.Tensor) and len(v.shape) == 0: - # ^ this tensor is a scalar; we need to give it - # a data dimension to play nice with irreps - kwargs[k] = v - - if AtomicDataDict.BATCH_KEY in kwargs: - num_frames = kwargs[AtomicDataDict.BATCH_KEY].max() + 1 - else: - num_frames = 1 - - for k, v in kwargs.items(): - - if len(kwargs[k].shape) == 0: - kwargs[k] = v.unsqueeze(-1) - v = kwargs[k] - - if ( - k in _NODE_FIELDS - and v.shape[0] != kwargs[AtomicDataDict.POSITIONS_KEY].shape[0] - ): - raise ValueError( - f"{k} is a node field but has the wrong dimension {v.shape}" - ) - elif ( - k in _EDGE_FIELDS - and v.shape[0] != kwargs[AtomicDataDict.EDGE_INDEX_KEY].shape[1] - ): - raise ValueError( - f"{k} is a edge field but has the wrong dimension {v.shape}" - ) - elif k in _GRAPH_FIELDS: - if num_frames == 1 and v.shape[0] != 1: - kwargs[k] = v.unsqueeze(0) - elif v.shape[0] != num_frames: - raise ValueError(f"Wrong shape for graph property {k}") + _process_dict(kwargs) super().__init__(num_nodes=len(kwargs["pos"]), **kwargs) @@ -215,7 +236,7 @@ def __init__(self, irreps: Dict[str, e3nn.o3.Irreps] = {}, **kwargs): ): assert self.atomic_numbers.dtype in _TORCH_INTEGER_DTYPES if "batch" in self and self.batch is not None: - assert self.batch.dim() == 1 and self.batch.shape[0] == self.num_nodes + assert self.batch.dim() == 2 and self.batch.shape[0] == self.num_nodes # Check that there are the right number of cells if "cell" in self and self.cell is not None: cell = self.cell.view(-1, 3, 3) @@ -398,13 +419,17 @@ def from_ase( **add_fields, ) - def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: + def to_ase(self, type_mapper=None) -> Union[List[ase.Atoms], ase.Atoms]: """Build a (list of) ``ase.Atoms`` object(s) from an ``AtomicData`` object. For each unique batch number provided in ``AtomicDataDict.BATCH_KEY``, an ``ase.Atoms`` object is created. If ``AtomicDataDict.BATCH_KEY`` does not exist in self, a single ``ase.Atoms`` object is created. + Args: + type_mapper: if provided, will be used to map ``ATOM_TYPES`` back into + elements, if the configuration of the ``type_mapper`` allows. + Returns: A list of ``ase.Atoms`` objects if ``AtomicDataDict.BATCH_KEY`` is in self and is not None. Otherwise, a single ``ase.Atoms`` object is returned. @@ -416,6 +441,8 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: ) if AtomicDataDict.ATOMIC_NUMBERS_KEY in self: atomic_nums = self.atomic_numbers + elif type_mapper is not None and type_mapper.has_chemical_symbols: + atomic_nums = type_mapper.untransform(self[AtomicDataDict.ATOM_TYPE_KEY]) else: warnings.warn( "AtomicData.to_ase(): self didn't contain atomic numbers... using atom_type as atomic numbers instead, but this means the chemical symbols in ASE (outputs) will be wrong" @@ -425,6 +452,7 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: cell = getattr(self, AtomicDataDict.CELL_KEY, None) batch = getattr(self, AtomicDataDict.BATCH_KEY, None) energy = getattr(self, AtomicDataDict.TOTAL_ENERGY_KEY, None) + energies = getattr(self, AtomicDataDict.PER_ATOM_ENERGY_KEY, None) force = getattr(self, AtomicDataDict.FORCE_KEY, None) do_calc = energy is not None or force is not None @@ -444,11 +472,12 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: for batch_idx in range(n_batches): if batch is not None: mask = batch == batch_idx + mask = mask.view(-1) else: mask = slice(None) mol = ase.Atoms( - numbers=atomic_nums[mask], + numbers=atomic_nums[mask].view(-1), # must be flat for ASE positions=positions[mask], cell=cell[batch_idx] if cell is not None else None, pbc=pbc[batch_idx] if pbc is not None else None, @@ -456,6 +485,8 @@ def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: if do_calc: fields = {} + if energies is not None: + fields["energies"] = energies[mask].cpu().numpy() if energy is not None: fields["energy"] = energy[batch_idx].cpu().numpy() if force is not None: @@ -504,15 +535,13 @@ def irreps(self): return self.__irreps__ def __cat_dim__(self, key, value): - if key in ( - AtomicDataDict.CELL_KEY, - AtomicDataDict.PBC_KEY, - AtomicDataDict.TOTAL_ENERGY_KEY, - ): - # the cell and PBC are graph-level properties and so need a new batch dimension + if key == AtomicDataDict.EDGE_INDEX_KEY: + return 1 # always cat in the edge dimension + elif key in _GRAPH_FIELDS: + # graph-level properties and so need a new batch dimension return None else: - return super().__cat_dim__(key, value) + return 0 # cat along node/edge dimension def without_nodes(self, which_nodes): """Return a copy of ``self`` with ``which_nodes`` removed. diff --git a/nequip/data/_build.py b/nequip/data/_build.py index 670b6156..7645bcb7 100644 --- a/nequip/data/_build.py +++ b/nequip/data/_build.py @@ -72,11 +72,7 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config) # Register fields: - register_fields( - node_fields=config.get("node_fields", []), - edge_fields=config.get("edge_fields", []), - graph_fields=config.get("graph_fields", []), - ) + instantiate(register_fields, all_args=config) instance, _ = instantiate( class_name, diff --git a/nequip/data/_keys.py b/nequip/data/_keys.py index 58918657..d87f52a7 100644 --- a/nequip/data/_keys.py +++ b/nequip/data/_keys.py @@ -38,6 +38,7 @@ PER_ATOM_ENERGY_KEY: Final[str] = "atomic_energy" TOTAL_ENERGY_KEY: Final[str] = "total_energy" FORCE_KEY: Final[str] = "forces" +PARTIAL_FORCE_KEY: Final[str] = "partial_forces" BATCH_KEY: Final[str] = "batch" diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index a83d65c9..73e1fcec 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -28,6 +28,7 @@ from nequip.utils.regressor import solver from nequip.utils.savenload import atomic_write from .transforms import TypeMapper +from .AtomicData import _process_dict class AtomicDataset(Dataset): @@ -87,7 +88,7 @@ def __init__( force_fixed_keys: List[str] = [], extra_fixed_fields: Dict[str, Any] = {}, include_frames: Optional[List[int]] = None, - type_mapper: TypeMapper = None, + type_mapper: Optional[TypeMapper] = None, ): # TO DO, this may be simplified # See if a subclass defines some inputs @@ -280,18 +281,7 @@ def process(self): del fields # type conversion - for key, value in fixed_fields.items(): - if isinstance(value, np.ndarray): - if np.issubdtype(value.dtype, np.floating): - fixed_fields[key] = torch.as_tensor( - value, dtype=torch.get_default_dtype() - ) - else: - fixed_fields[key] = torch.as_tensor(value) - elif np.issubdtype(type(value), np.floating): - fixed_fields[key] = torch.as_tensor( - value, dtype=torch.get_default_dtype() - ) + _process_dict(fixed_fields, ignore_fields=["r_max"]) logging.info(f"Loaded data: {data}") @@ -301,11 +291,10 @@ def process(self): # it doesn't matter if they overwrite each others cached' # datasets. It only matters that they don't simultaneously try # to write the _same_ file, corrupting it. - with atomic_write(self.processed_paths[0]) as tmppth: - torch.save((data, fixed_fields, self.include_frames), tmppth) - with atomic_write(self.processed_paths[1]) as tmppth: - with open(tmppth, "w") as f: - yaml.dump(self._get_parameters(), f) + with atomic_write(self.processed_paths[0], binary=True) as f: + torch.save((data, fixed_fields, self.include_frames), f) + with atomic_write(self.processed_paths[1], binary=False) as f: + yaml.dump(self._get_parameters(), f) logging.info("Cached processed data to disk") @@ -566,7 +555,7 @@ def _per_species_statistics( For a per-node quantity, computes the expected statistic but for each type instead of over all nodes. """ - N = bincount(atom_types, batch) + N = bincount(atom_types.squeeze(-1), batch) N = N[(N > 0).any(dim=1)] # deal with non-contiguous batch indexes if arr_is_per == "graph": @@ -743,8 +732,8 @@ class ASEDataset(AtomicInMemoryDataset): - user_label key_mapping: user_label: label0 - chemical_symbol_to_type: - H: 0 + chemical_symbols: + - H ``` for VASP parser, the yaml input should be @@ -755,8 +744,8 @@ class ASEDataset(AtomicInMemoryDataset): format: vasp-out key_mapping: free_energy: total_energy - chemical_symbol_to_type: - H: 0 + chemical_symbols: + - H ``` """ diff --git a/nequip/data/transforms.py b/nequip/data/transforms.py index 7ac9d724..68ad6bcc 100644 --- a/nequip/data/transforms.py +++ b/nequip/data/transforms.py @@ -20,7 +20,23 @@ def __init__( self, type_names: Optional[List[str]] = None, chemical_symbol_to_type: Optional[Dict[str, int]] = None, + chemical_symbols: Optional[List[str]] = None, ): + if chemical_symbols is not None: + if chemical_symbol_to_type is not None: + raise ValueError( + "Cannot provide both `chemical_symbols` and `chemical_symbol_to_type`" + ) + # repro old, sane NequIP behaviour + # checks also for validity of keys + atomic_nums = [ase.data.atomic_numbers[sym] for sym in chemical_symbols] + # https://stackoverflow.com/questions/29876580/how-to-sort-a-list-according-to-another-list-python + chemical_symbols = [ + e[1] for e in sorted(zip(atomic_nums, chemical_symbols)) + ] + chemical_symbol_to_type = {k: i for i, k in enumerate(chemical_symbols)} + del chemical_symbols + # Build from chem->type mapping, if provided self.chemical_symbol_to_type = chemical_symbol_to_type if self.chemical_symbol_to_type is not None: @@ -53,11 +69,16 @@ def __init__( for sym, type in self.chemical_symbol_to_type.items(): Z_to_index[ase.data.atomic_numbers[sym] - self._min_Z] = type self._Z_to_index = Z_to_index + self._index_to_Z = torch.zeros( + size=(len(self.chemical_symbol_to_type),), dtype=torch.long + ) + for sym, type_idx in self.chemical_symbol_to_type.items(): + self._index_to_Z[type_idx] = ase.data.atomic_numbers[sym] self._valid_set = set(valid_atomic_numbers) # check if type_names is None: raise ValueError( - "Neither chemical_symbol_to_type nor type_names was provided; one or the other is required" + "None of chemical_symbols, chemical_symbol_to_type, nor type_names was provided; exactly one is required" ) # validate type names assert all( @@ -79,7 +100,7 @@ def __call__( elif AtomicDataDict.ATOMIC_NUMBERS_KEY in data: assert ( self.chemical_symbol_to_type is not None - ), "Atomic numbers provided but there is no chemical_symbol_to_type mapping!" + ), "Atomic numbers provided but there is no chemical_symbols/chemical_symbol_to_type mapping!" atomic_numbers = data[AtomicDataDict.ATOMIC_NUMBERS_KEY] del data[AtomicDataDict.ATOMIC_NUMBERS_KEY] @@ -101,3 +122,11 @@ def transform(self, atomic_numbers): ) return self._Z_to_index[atomic_numbers - self._min_Z] + + def untransform(self, atom_types): + """Transform atom types back into atomic numbers""" + return self._index_to_Z[atom_types] + + @property + def has_chemical_symbols(self) -> bool: + return self.chemical_symbol_to_type is not None diff --git a/nequip/model/__init__.py b/nequip/model/__init__.py index b849efed..ccb92551 100644 --- a/nequip/model/__init__.py +++ b/nequip/model/__init__.py @@ -1,15 +1,20 @@ from ._eng import EnergyModel -from ._grads import ForceOutput +from ._grads import ForceOutput, PartialForceOutput from ._scaling import RescaleEnergyEtc, PerSpeciesRescale -from ._weight_init import uniform_initialize_FCs +from ._weight_init import uniform_initialize_FCs, initialize_from_state from ._build import model_from_config +from . import builder_utils + __all__ = [ - "EnergyModel", - "ForceOutput", - "RescaleEnergyEtc", - "PerSpeciesRescale", - "uniform_initialize_FCs", - "model_from_config", + EnergyModel, + ForceOutput, + PartialForceOutput, + RescaleEnergyEtc, + PerSpeciesRescale, + uniform_initialize_FCs, + initialize_from_state, + model_from_config, + builder_utils, ] diff --git a/nequip/model/_build.py b/nequip/model/_build.py index 4f1ae7dd..c618bad3 100644 --- a/nequip/model/_build.py +++ b/nequip/model/_build.py @@ -75,9 +75,13 @@ def model_from_config( if "dataset" in pnames: if "initialize" not in pnames: raise ValueError("Cannot request dataset without requesting initialize") - if initialize and dataset is None: + if ( + initialize + and pnames["dataset"].default == inspect.Parameter.empty + and dataset is None + ): raise RuntimeError( - f"Builder {builder.__name__} asked for the dataset, initialize is true, but no dataset was provided to `model_from_config`." + f"Builder {builder.__name__} requires the dataset, initialize is true, but no dataset was provided to `model_from_config`." ) params["dataset"] = dataset if "model" in pnames: diff --git a/nequip/model/_eng.py b/nequip/model/_eng.py index 314bef32..30051328 100644 --- a/nequip/model/_eng.py +++ b/nequip/model/_eng.py @@ -1,6 +1,7 @@ +from typing import Optional import logging -from nequip.data import AtomicDataDict +from nequip.data import AtomicDataDict, AtomicDataset from nequip.nn import ( SequentialGraphNetwork, AtomwiseLinear, @@ -13,14 +14,22 @@ SphericalHarmonicEdgeAttrs, ) +from . import builder_utils -def EnergyModel(config) -> SequentialGraphNetwork: + +def EnergyModel( + config, initialize: bool, dataset: Optional[AtomicDataset] = None +) -> SequentialGraphNetwork: """Base default energy model archetecture. For minimal and full configuration option listings, see ``minimal.yaml`` and ``example.yaml``. """ logging.debug("Start building the network model") + builder_utils.add_avg_num_neighbors( + config=config, initialize=initialize, dataset=dataset + ) + num_layers = config.get("num_layers", 3) layers = { diff --git a/nequip/model/_grads.py b/nequip/model/_grads.py index dcf85f9b..12c275c6 100644 --- a/nequip/model/_grads.py +++ b/nequip/model/_grads.py @@ -1,4 +1,5 @@ from nequip.nn import GraphModuleMixin, GradientOutput +from nequip.nn import PartialForceOutput as PartialForceOutputModule from nequip.data import AtomicDataDict @@ -20,3 +21,20 @@ def ForceOutput(model: GraphModuleMixin) -> GradientOutput: out_field=AtomicDataDict.FORCE_KEY, sign=-1, # force is the negative gradient ) + + +def PartialForceOutput(model: GraphModuleMixin) -> GradientOutput: + r"""Add forces and partial forces to a model that predicts energy. + + Args: + energy_model: the model to wrap. Must have ``AtomicDataDict.TOTAL_ENERGY_KEY`` as an output. + + Returns: + A ``GradientOutput`` wrapping ``energy_model``. + """ + if ( + AtomicDataDict.FORCE_KEY in model.irreps_out + or AtomicDataDict.PARTIAL_FORCE_KEY in model.irreps_out + ): + raise ValueError("This model already has force outputs.") + return PartialForceOutputModule(func=model) diff --git a/nequip/model/builder_utils.py b/nequip/model/builder_utils.py new file mode 100644 index 00000000..f2a402f9 --- /dev/null +++ b/nequip/model/builder_utils.py @@ -0,0 +1,41 @@ +from typing import Optional + +import torch + +from nequip.utils import Config +from nequip.data import AtomicDataset, AtomicDataDict + + +def add_avg_num_neighbors( + config: Config, + initialize: bool, + dataset: Optional[AtomicDataset] = None, +) -> Optional[float]: + # Compute avg_num_neighbors + annkey: str = "avg_num_neighbors" + ann = config.get(annkey, None) + if ann == "auto": + if not initialize: + raise ValueError("avg_num_neighbors = auto but initialize is False") + if dataset is None: + raise ValueError( + "When avg_num_neighbors = auto, the dataset is required to build+initialize a model" + ) + ann = dataset.statistics( + fields=[ + lambda data: ( + torch.unique( + data[AtomicDataDict.EDGE_INDEX_KEY][0], return_counts=True + )[1], + "node", + ) + ], + modes=["mean_std"], + stride=config.get("dataset_statistics_stride", 1), + )[0][0].item() + + # make sure its valid + if ann is not None: + ann = float(ann) + config[annkey] = ann + return ann diff --git a/nequip/nn/__init__.py b/nequip/nn/__init__.py index 383f619b..12ab085b 100644 --- a/nequip/nn/__init__.py +++ b/nequip/nn/__init__.py @@ -6,7 +6,7 @@ PerSpeciesScaleShift, ) # noqa: F401 from ._interaction_block import InteractionBlock # noqa: F401 -from ._grad_output import GradientOutput # noqa: F401 +from ._grad_output import GradientOutput, PartialForceOutput # noqa: F401 from ._rescale import RescaleOutput # noqa: F401 from ._convnetlayer import ConvNetLayer # noqa: F401 from ._util import SaveForOutput # noqa: F401 diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index b5ee9efc..afe965ae 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -97,3 +97,65 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data[k].requires_grad_(req_grad) return data + + +@compile_mode("unsupported") +class PartialForceOutput(GraphModuleMixin, torch.nn.Module): + r"""Generate partial and total forces from an energy model. + + Args: + func: the energy model + vectorize: the vectorize option to ``torch.autograd.functional.jacobian``, + false by default since it doesn't work well. + """ + vectorize: bool + + def __init__( + self, + func: GraphModuleMixin, + vectorize: bool = False, + vectorize_warnings: bool = False, + ): + super().__init__() + # TODO wrap: + self.func = func + self.vectorize = vectorize + if vectorize_warnings: + # See https://pytorch.org/docs/stable/generated/torch.autograd.functional.jacobian.html + torch._C._debug_only_display_vmap_fallback_warnings(True) + + # check and init irreps + self._init_irreps( + irreps_in=func.irreps_in, + my_irreps_in={AtomicDataDict.PER_ATOM_ENERGY_KEY: Irreps("0e")}, + irreps_out=func.irreps_out, + ) + self.irreps_out[AtomicDataDict.PARTIAL_FORCE_KEY] = Irreps("1o") + self.irreps_out[AtomicDataDict.FORCE_KEY] = Irreps("1o") + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + data = data.copy() + out_data = {} + + def wrapper(pos: torch.Tensor) -> torch.Tensor: + """Wrapper from pos to atomic energy""" + nonlocal data, out_data + data[AtomicDataDict.POSITIONS_KEY] = pos + out_data = self.func(data) + return out_data[AtomicDataDict.PER_ATOM_ENERGY_KEY].squeeze(-1) + + pos = data[AtomicDataDict.POSITIONS_KEY] + + partial_forces = torch.autograd.functional.jacobian( + func=wrapper, + inputs=pos, + create_graph=self.training, # needed to allow gradients of this output during training + vectorize=self.vectorize, + ) + partial_forces = partial_forces.negative() + # output is [n_at, n_at, 3] + + out_data[AtomicDataDict.PARTIAL_FORCE_KEY] = partial_forces + out_data[AtomicDataDict.FORCE_KEY] = partial_forces.sum(dim=0) + + return out_data diff --git a/nequip/nn/embedding/_one_hot.py b/nequip/nn/embedding/_one_hot.py index 373243d2..f7228400 100644 --- a/nequip/nn/embedding/_one_hot.py +++ b/nequip/nn/embedding/_one_hot.py @@ -34,7 +34,7 @@ def __init__( self._init_irreps(irreps_in=irreps_in, irreps_out=irreps_out) def forward(self, data: AtomicDataDict.Type): - type_numbers = data[AtomicDataDict.ATOM_TYPE_KEY] + type_numbers = data[AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1) one_hot = torch.nn.functional.one_hot( type_numbers, num_classes=self.num_types ).to(device=type_numbers.device, dtype=data[AtomicDataDict.POSITIONS_KEY].dtype) diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 67221af1..733f62a3 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -321,10 +321,13 @@ def main(args=None, running_as_script: bool = True): with torch.no_grad(): # Write output + # TODO: make sure don't keep appending to existing file if output is not None: ase.io.write( output, - AtomicData.from_AtomicDataDict(out).to(device="cpu").to_ase(), + AtomicData.from_AtomicDataDict(out) + .to(device="cpu") + .to_ase(type_mapper=dataset.type_mapper), format="extxyz", append=True, ) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 175979de..e9eed8ca 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -77,8 +77,10 @@ def parse_command_line(args=None): parser.add_argument("config", help="configuration file") parser.add_argument( "--equivariance-test", - help="test the model's equivariance before training", - action="store_true", + help="test the model's equivariance before training on n (default 1) random frames from the dataset", + const=1, + type=int, + nargs="?", ) parser.add_argument( "--model-debug-mode", @@ -179,14 +181,25 @@ def fresh_start(config): logging.info("Successfully compiled model...") # Equivar test - if config.equivariance_test: - from e3nn.util.test import format_equivariance_error - - equivar_err = assert_AtomicData_equivariant(final_model, dataset[0]) - errstr = format_equivariance_error(equivar_err) - del equivar_err - logging.info(f"Equivariance test passed; equivariance errors:\n{errstr}") - del errstr + if config.equivariance_test > 0: + n_train: int = len(trainer.dataset_train) + assert config.equivariance_test <= n_train + final_model.eval() + indexes = torch.randperm(n_train)[: config.equivariance_test] + errstr = assert_AtomicData_equivariant( + final_model, [trainer.dataset_train[i] for i in indexes] + ) + final_model.train() + logging.info( + "Equivariance test passed; equivariance errors:\n" + " Errors are in real units, where relevant.\n" + " Please note that the large scale of the typical\n" + " shifts to the (atomic) energy can cause\n" + " catastrophic cancellation and give incorrectly\n" + " the equivariance error as zero for those fields.\n" + f"{errstr}" + ) + del errstr, indexes, n_train # Set the trainer trainer.model = final_model diff --git a/nequip/train/_key.py b/nequip/train/_key.py index f3582ebd..17057b8b 100644 --- a/nequip/train/_key.py +++ b/nequip/train/_key.py @@ -12,6 +12,7 @@ ABBREV = { AtomicDataDict.TOTAL_ENERGY_KEY: "e", + AtomicDataDict.PER_ATOM_ENERGY_KEY: "Ei", AtomicDataDict.FORCE_KEY: "f", LOSS_KEY: "loss", VALIDATION: "val", diff --git a/nequip/train/_loss.py b/nequip/train/_loss.py index 4db16274..7514a509 100644 --- a/nequip/train/_loss.py +++ b/nequip/train/_loss.py @@ -126,11 +126,11 @@ def __call__( reduce_dims = tuple(i + 1 for i in range(len(per_atom_loss.shape) - 1)) + spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1) if has_nan: if len(reduce_dims) > 0: per_atom_loss = per_atom_loss.sum(dim=reduce_dims) - spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY] per_species_loss = scatter(per_atom_loss, spe_idx, dim=0) N = scatter(not_nan, spe_idx, dim=0) @@ -146,7 +146,6 @@ def __call__( per_atom_loss = per_atom_loss.mean(dim=reduce_dims) # offset species index by 1 to use 0 for nan - spe_idx = pred[AtomicDataDict.ATOM_TYPE_KEY] _, inverse_species_index = torch.unique(spe_idx, return_inverse=True) per_species_loss = scatter_mean(per_atom_loss, inverse_species_index, dim=0) diff --git a/nequip/train/metrics.py b/nequip/train/metrics.py index 820afd85..2f790bd9 100644 --- a/nequip/train/metrics.py +++ b/nequip/train/metrics.py @@ -169,7 +169,9 @@ def __call__(self, pred: dict, ref: dict): params = {} if per_species: # TO DO, this needs OneHot component. will need to be decoupled - params = {"accumulate_by": pred[AtomicDataDict.ATOM_TYPE_KEY]} + params = { + "accumulate_by": pred[AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1) + } if per_atom: if N is None: N = torch.bincount(ref[AtomicDataDict.BATCH_KEY]).unsqueeze(-1) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index e543d116..7bde37fb 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -37,6 +37,8 @@ save_file, load_file, atomic_write, + finish_all_writes, + atomic_write_group, dtype_from_name, ) from nequip.utils.git import get_commit @@ -131,7 +133,8 @@ class Trainer: Args: model: neural network model - seed (int): random see number + seed (int): random seed number + dataset_seed (int): random seed for dataset operations loss_coeffs (dict): dictionary to store coefficient and loss functions @@ -214,6 +217,7 @@ def __init__( model_builders: Optional[list] = [], device: str = "cuda" if torch.cuda.is_available() else "cpu", seed: Optional[int] = None, + dataset_seed: Optional[int] = None, loss_coeffs: Union[dict, str] = AtomicDataDict.TOTAL_ENERGY_KEY, train_on_keys: Optional[List[str]] = None, metrics_components: Optional[Union[dict, str]] = None, @@ -293,6 +297,10 @@ def __init__( torch.manual_seed(seed) np.random.seed(seed) + self.dataset_rng = torch.Generator() + if dataset_seed is not None: + self.dataset_rng.manual_seed(dataset_seed) + self.logger.info(f"Torch device: {self.device}") self.torch_device = torch.device(self.device) @@ -453,6 +461,7 @@ def as_dict( if item is not None: dictionary["state_dict"][key] = item.state_dict() dictionary["state_dict"]["rng_state"] = torch.get_rng_state() + dictionary["state_dict"]["dataset_rng_state"] = self.dataset_rng.get_state() if torch.cuda.is_available(): dictionary["state_dict"]["cuda_rng_state"] = torch.cuda.get_rng_state( device=self.torch_device @@ -476,20 +485,29 @@ def as_dict( for code in [e3nn, nequip, torch]: dictionary[f"{code.__name__}_version"] = code.__version__ - for code in ["e3nn", "nequip"]: - dictionary[f"{code}_commit"] = get_commit(code) + + codes_for_git = {"e3nn", "nequip"} + for builder in self.model_builders: + if not isinstance(builder, str): + continue + builder = builder.split(".") + if len(builder) > 1: + # it's not a single name which is from nequip + codes_for_git.add(builder[0]) + dictionary["code_versions"] = {code: get_commit(code) for code in codes_for_git} return dictionary - def save_config(self) -> None: + def save_config(self, blocking: bool = True) -> None: save_file( item=self.as_dict(state_dict=False, training_progress=False), supported_formats=dict(yaml=["yaml"]), filename=self.config_path, enforced_format=None, + blocking=blocking, ) - def save(self, filename: Optional[str] = None, format=None): + def save(self, filename: Optional[str] = None, format=None, blocking: bool = True): """save the file as filename Args: @@ -516,11 +534,11 @@ def save(self, filename: Optional[str] = None, format=None): supported_formats=dict(torch=["pth", "pt"], yaml=["yaml"], json=["json"]), filename=filename, enforced_format=format, + blocking=blocking, ) logger.debug(f"Saved trainer to {filename}") - self.save_config() - self.save_model(self.last_model_path) + self.save_model(self.last_model_path, blocking=blocking) logger.debug(f"Saved last model to to {self.last_model_path}") return filename @@ -554,10 +572,10 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): append (bool): if True, append the old model files and append the same logfile """ - d = deepcopy(dictionary) + dictionary = deepcopy(dictionary) for code in [e3nn, nequip, torch]: - version = d.get(f"{code.__name__}_version", None) + version = dictionary.get(f"{code.__name__}_version", None) if version is not None and version != code.__version__: logging.warning( "Loading a pickled model created with different library version(s) may cause issues." @@ -567,14 +585,14 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): # update the restart and append option if append is not None: - d["append"] = append + dictionary["append"] = append model = None iepoch = -1 - if "model" in d: - model = d.pop("model") - elif "progress" in d: - progress = d["progress"] + if "model" in dictionary: + model = dictionary.pop("model") + elif "progress" in dictionary: + progress = dictionary["progress"] # load the model from file iepoch = progress["iepoch"] @@ -585,15 +603,17 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): raise AttributeError("model weights & bias are not saved") model, _ = Trainer.load_model_from_training_session( - traindir=load_path.parent, model_name=load_path.name + traindir=load_path.parent, + model_name=load_path.name, + config_dictionary=dictionary, ) logging.debug(f"Reload the model from {load_path}") - d.pop("progress") + dictionary.pop("progress") - state_dict = d.pop("state_dict", None) + state_dict = dictionary.pop("state_dict", None) - trainer = cls(model=model, **d) + trainer = cls(model=model, **dictionary) if state_dict is not None and trainer.model is not None: logging.debug("Reload optimizer and scheduler states") @@ -604,10 +624,11 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): trainer._initialized = True torch.set_rng_state(state_dict["rng_state"]) + trainer.dataset_rng.set_state(state_dict["dataset_rng_state"]) if torch.cuda.is_available(): torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) - if "progress" in d: + if "progress" in dictionary: trainer.best_metrics = progress["best_metrics"] trainer.best_epoch = progress["best_epoch"] stop_arg = progress.pop("stop_arg", None) @@ -628,12 +649,18 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): @staticmethod def load_model_from_training_session( - traindir, model_name="best_model.pth", device="cpu" + traindir, + model_name="best_model.pth", + device="cpu", + config_dictionary: Optional[dict] = None, ) -> Tuple[torch.nn.Module, Config]: traindir = str(traindir) model_name = str(model_name) - config = Config.from_file(traindir + "/config.yaml") + if config_dictionary is not None: + config = Config.from_dict(config_dictionary) + else: + config = Config.from_file(traindir + "/config.yaml") if config.get("compile_model", False): model = torch.jit.load(traindir + "/" + model_name, map_location=device) @@ -709,8 +736,11 @@ def train(self): self.init_log() self.wall = perf_counter() - if self.iepoch == -1: - self.save() + with atomic_write_group(): + if self.iepoch == -1: + self.save() + if self.iepoch in [-1, 0]: + self.save_config() self.init_metrics() @@ -725,6 +755,7 @@ def train(self): self.final_log() self.save() + finish_all_writes() def batch_step(self, data, validation=False): # no need to have gradients from old steps taking up memory @@ -926,36 +957,36 @@ def end_of_epoch_save(self): """ save model and trainer details """ + with atomic_write_group(): + current_metrics = self.mae_dict[self.metrics_key] + if current_metrics < self.best_metrics: + self.best_metrics = current_metrics + self.best_epoch = self.iepoch - current_metrics = self.mae_dict[self.metrics_key] - if current_metrics < self.best_metrics: - self.best_metrics = current_metrics - self.best_epoch = self.iepoch + self.save_ema_model(self.best_model_path, blocking=False) - self.save_ema_model(self.best_model_path) - - self.logger.info( - f"! Best model {self.best_epoch:8d} {self.best_metrics:8.3f}" - ) + self.logger.info( + f"! Best model {self.best_epoch:8d} {self.best_metrics:8.3f}" + ) - if (self.iepoch + 1) % self.log_epoch_freq == 0: - self.save() + if (self.iepoch + 1) % self.log_epoch_freq == 0: + self.save(blocking=False) - if ( - self.save_checkpoint_freq > 0 - and (self.iepoch + 1) % self.save_checkpoint_freq == 0 - ): - ckpt_path = self.output.generate_file(f"ckpt{self.iepoch+1}.pth") - self.save_model(ckpt_path) + if ( + self.save_checkpoint_freq > 0 + and (self.iepoch + 1) % self.save_checkpoint_freq == 0 + ): + ckpt_path = self.output.generate_file(f"ckpt{self.iepoch+1}.pth") + self.save_model(ckpt_path, blocking=False) - if ( - self.save_ema_checkpoint_freq > 0 - and (self.iepoch + 1) % self.save_ema_checkpoint_freq == 0 - ): - ckpt_path = self.output.generate_file(f"ckpt_ema_{self.iepoch+1}.pth") - self.save_ema_model(ckpt_path) + if ( + self.save_ema_checkpoint_freq > 0 + and (self.iepoch + 1) % self.save_ema_checkpoint_freq == 0 + ): + ckpt_path = self.output.generate_file(f"ckpt_ema_{self.iepoch+1}.pth") + self.save_ema_model(ckpt_path, blocking=False) - def save_ema_model(self, path): + def save_ema_model(self, path, blocking: bool = True): if self.use_ema: # If using EMA, store the EMA validation model @@ -967,11 +998,10 @@ def save_ema_model(self, path): cm = contextlib.nullcontext() with cm: - self.save_model(path) + self.save_model(path, blocking=blocking) - def save_model(self, path): - self.save_config() - with atomic_write(path) as write_to: + def save_model(self, path, blocking: bool = True): + with atomic_write(path, blocking=blocking, binary=True) as write_to: if isinstance(self.model, torch.jit.ScriptModule): torch.jit.save(self.model, write_to) else: @@ -1100,7 +1130,7 @@ def set_dataset( ) if self.train_val_split == "random": - idcs = torch.randperm(total_n) + idcs = torch.randperm(total_n, generator=self.dataset_rng) elif self.train_val_split == "sequential": idcs = torch.arange(total_n) else: @@ -1116,10 +1146,12 @@ def set_dataset( if self.n_val > len(validation_dataset): raise ValueError("Not enough data in dataset for requested n_train") if self.train_val_split == "random": - self.train_idcs = torch.randperm(len(dataset))[: self.n_train] - self.val_idcs = torch.randperm(len(validation_dataset))[ - : self.n_val - ] + self.train_idcs = torch.randperm( + len(dataset), generator=self.dataset_rng + )[: self.n_train] + self.val_idcs = torch.randperm( + len(validation_dataset), generator=self.dataset_rng + )[: self.n_val] elif self.train_val_split == "sequential": self.train_idcs = torch.arange(self.n_train) self.val_idcs = torch.arange(self.n_val) @@ -1139,7 +1171,6 @@ def set_dataset( # https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#enable-async-data-loading-and-augmentation dl_kwargs = dict( batch_size=self.batch_size, - shuffle=self.shuffle, exclude_keys=self.exclude_keys, num_workers=self.dataloader_num_workers, # keep stuff around in memory @@ -1150,6 +1181,14 @@ def set_dataset( pin_memory=(self.torch_device != torch.device("cpu")), # avoid getting stuck timeout=(10 if self.dataloader_num_workers > 0 else 0), + # use the right randomness + generator=self.dataset_rng, + ) + self.dl_train = DataLoader( + dataset=self.dataset_train, + shuffle=self.shuffle, # training should shuffle + **dl_kwargs, ) - self.dl_train = DataLoader(dataset=self.dataset_train, **dl_kwargs) + # validation, on the other hand, shouldn't shuffle + # we still pass the generator just to be safe self.dl_val = DataLoader(dataset=self.dataset_val, **dl_kwargs) diff --git a/nequip/utils/__init__.py b/nequip/utils/__init__.py index 16ad1ee6..79e778a0 100644 --- a/nequip/utils/__init__.py +++ b/nequip/utils/__init__.py @@ -3,7 +3,13 @@ instantiate, get_w_prefix, ) -from .savenload import save_file, load_file, atomic_write +from .savenload import ( + save_file, + load_file, + atomic_write, + finish_all_writes, + atomic_write_group, +) from .config import Config from .output import Output from .modules import find_first_of_type @@ -16,6 +22,8 @@ save_file, load_file, atomic_write, + finish_all_writes, + atomic_write_group, Config, Output, find_first_of_type, diff --git a/nequip/utils/config.py b/nequip/utils/config.py index 72b896a1..d13e0546 100644 --- a/nequip/utils/config.py +++ b/nequip/utils/config.py @@ -262,6 +262,10 @@ def from_file(filename: str, format: Optional[str] = None, defaults: dict = {}): filename=filename, enforced_format=format, ) + return Config.from_dict(dictionary, defaults) + + @staticmethod + def from_dict(dictionary: dict, defaults: dict = {}): c = Config(defaults) c.update(dictionary) return c diff --git a/nequip/utils/git.py b/nequip/utils/git.py index 168f8921..d58923f3 100644 --- a/nequip/utils/git.py +++ b/nequip/utils/git.py @@ -9,7 +9,7 @@ def get_commit(module: str): path = str(Path(module.__file__).parents[0] / "..") retcode = subprocess.run( - "git show --oneline -s".split(), + "git show --oneline --abbrev=40 -s".split(), cwd=path, stdout=subprocess.PIPE, stderr=subprocess.PIPE, diff --git a/nequip/utils/savenload.py b/nequip/utils/savenload.py index 0980fef2..2f26c4c0 100644 --- a/nequip/utils/savenload.py +++ b/nequip/utils/savenload.py @@ -1,49 +1,203 @@ """ utilities that involve file searching and operations (i.e. save/load) """ -from typing import Union +from typing import Union, List, Tuple import sys import logging import contextlib +import contextvars +import tempfile from pathlib import Path -from os import makedirs -from os.path import isfile, isdir, dirname, realpath +import shutil +import os -@contextlib.contextmanager -def atomic_write(filename: Union[Path, str]): - filename = Path(filename) - tmp_path = filename.parent / (f".tmp-{filename.name}~") - # Create the temp file - open(tmp_path, "w").close() +# accumulate writes to group for renaming +_MOVE_SET = contextvars.ContextVar("_move_set", default=None) + + +def _delete_files_if_exist(paths): + # clean up + # better for python 3.8 > + if sys.version_info[1] >= 8: + for f in paths: + f.unlink(missing_ok=True) + else: + # race condition? + for f in paths: + if f.exists(): + f.unlink() + + +def _process_moves(moves: List[Tuple[bool, Path, Path]]): + """blocking to copy (possibly across filesystems) to temp name; then atomic rename to final name""" try: - # do the IO - yield tmp_path - # move the temp file to the final output path, which also removes the temp file - tmp_path.rename(filename) + for _, from_name, to_name in moves: + # blocking copy to temp file in same filesystem + tmp_path = to_name.parent / (f".tmp-{to_name.name}~") + shutil.move(from_name, tmp_path) + # then atomic rename to overwrite + tmp_path.rename(to_name) finally: - # clean up - # better for python 3.8 > - if sys.version_info[1] >= 8: - tmp_path.unlink(missing_ok=True) + _delete_files_if_exist([m[1] for m in moves]) + + +# allow user to enable/disable depending on their filesystem +_ASYNC_ENABLED = os.environ.get("NEQUIP_ASYNC_IO", "false").lower() +assert _ASYNC_ENABLED in ("true", "false") +_ASYNC_ENABLED = _ASYNC_ENABLED == "true" + +if _ASYNC_ENABLED: + import threading + from queue import Queue + + _MOVE_QUEUE = Queue() + _MOVE_THREAD = None + + # Because we use a queue, later writes will always (correctly) + # overwrite earlier writes + def _moving_thread(queue): + while True: + moves = queue.get() + _process_moves(moves) + # logging is thread safe: https://stackoverflow.com/questions/2973900/is-pythons-logging-module-thread-safe + logging.debug(f"Finished writing {', '.join(m[2].name for m in moves)}") + queue.task_done() + + def _submit_move(from_name, to_name, blocking: bool): + global _MOVE_QUEUE + global _MOVE_THREAD + global _MOVE_SET + + # launch thread if its not running + if _MOVE_THREAD is None: + _MOVE_THREAD = threading.Thread( + target=_moving_thread, args=(_MOVE_QUEUE,), daemon=True + ) + _MOVE_THREAD.start() + + # check on health of copier thread + if not _MOVE_THREAD.is_alive(): + _MOVE_THREAD.join() # will raise exception + raise RuntimeError("Writer thread failed.") + + # submit this move + obj = (blocking, from_name, to_name) + if _MOVE_SET.get() is None: + # no current group + _MOVE_QUEUE.put([obj]) + # if it should be blocking, wait for it to be processed + if blocking: + _MOVE_QUEUE.join() + else: + # add and let the group submit and block (or not) + _MOVE_SET.get().append(obj) + + @contextlib.contextmanager + def atomic_write_group(): + global _MOVE_SET + if _MOVE_SET.get() is not None: + # nesting is a no-op + # submit along with outermost context manager + yield + return + token = _MOVE_SET.set(list()) + # run the saves + yield + _MOVE_QUEUE.put(_MOVE_SET.get()) # send it off + # if anyone is blocking, block the whole group: + if any(m[0] for m in _MOVE_SET.get()): + # someone is blocking + _MOVE_QUEUE.join() + # exit context + _MOVE_SET.reset(token) + + def finish_all_writes(): + global _MOVE_QUEUE + _MOVE_QUEUE.join() + # ^ wait for all remaining moves to be processed + + +else: + + def _submit_move(from_name, to_name, blocking: bool): + global _MOVE_SET + obj = (blocking, from_name, to_name) + if _MOVE_SET.get() is None: + # no current group just do it + _process_moves([obj]) else: - # race condition? - if tmp_path.exists(): - tmp_path.unlink() + # add and let the group do it + _MOVE_SET.get().append(obj) + + @contextlib.contextmanager + def atomic_write_group(): + global _MOVE_SET + if _MOVE_SET.get() is not None: + # don't nest them + yield + return + token = _MOVE_SET.set(list()) + yield + _process_moves(_MOVE_SET.get()) # do it + _MOVE_SET.reset(token) + + def finish_all_writes(): + pass # nothing to do since all writes blocked + + +@contextlib.contextmanager +def atomic_write( + filename: Union[Path, str, List[Union[Path, str]]], + blocking: bool = True, + binary: bool = False, +): + aslist: bool = True + if not isinstance(filename, list): + aslist = False + filename = [filename] + filename = [Path(f) for f in filename] + + with contextlib.ExitStack() as stack: + files = [ + stack.enter_context( + tempfile.NamedTemporaryFile( + mode="w" + ("b" if binary else ""), delete=False + ) + ) + for _ in filename + ] + try: + if not aslist: + yield files[0] + else: + yield files + except: # noqa + # ^ noqa cause we want to delete them no matter what if there was a failure + # only remove them if there was an error + _delete_files_if_exist([Path(f.name) for f in files]) + raise + + for tp, fname in zip(files, filename): + _submit_move(Path(tp.name), Path(fname), blocking=blocking) def save_file( - item, supported_formats: dict, filename: str, enforced_format: str = None + item, + supported_formats: dict, + filename: str, + enforced_format: str = None, + blocking: bool = True, ): """ Save file. It can take yaml, json, pickle, json, npz and torch save """ # check whether folder exist - path = dirname(realpath(filename)) - if not isdir(path): + path = os.path.dirname(os.path.realpath(filename)) + if not os.path.isdir(path): logging.debug(f"save_file make dirs {path}") - makedirs(path, exist_ok=True) + os.makedirs(path, exist_ok=True) format, filename = adjust_format_name( supported_formats=supported_formats, @@ -51,17 +205,25 @@ def save_file( enforced_format=enforced_format, ) - with atomic_write(filename) as write_to: + with atomic_write( + filename, + blocking=blocking, + binary={ + "json": False, + "yaml": False, + "pickle": True, + "torch": True, + "npz": True, + }[format], + ) as write_to: if format == "json": import json - with open(write_to, "w+") as fout: - json.dump(item, fout) + json.dump(item, write_to) elif format == "yaml": import yaml - with open(write_to, "w+") as fout: - yaml.dump(item, fout) + yaml.dump(item, write_to) elif format == "torch": import torch @@ -69,8 +231,7 @@ def save_file( elif format == "pickle": import pickle - with open(write_to, "wb") as fout: - pickle.save(item, fout) + pickle.dump(item, write_to) elif format == "npz": import numpy as np @@ -93,7 +254,7 @@ def load_file(supported_formats: dict, filename: str, enforced_format: str = Non else: format = enforced_format - if not isfile(filename): + if not os.path.isfile(filename): abs_path = str(Path(filename).resolve()) raise OSError(f"file {filename} at {abs_path} is not found") diff --git a/nequip/utils/test.py b/nequip/utils/test.py index da2fe4ef..ecfac943 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -1,8 +1,8 @@ -from typing import Union +from typing import Union, Optional, List import torch from e3nn import o3 -from e3nn.util.test import assert_equivariant +from e3nn.util.test import equivariance_error, FLOAT_TOLERANCE from nequip.nn import GraphModuleMixin from nequip.data import ( @@ -24,7 +24,9 @@ def _inverse_permutation(perm): def assert_permutation_equivariant( - func: GraphModuleMixin, data_in: AtomicDataDict.Type + func: GraphModuleMixin, + data_in: AtomicDataDict.Type, + tolerance: Optional[float] = None, ): r"""Test the permutation equivariance of ``func``. @@ -39,7 +41,10 @@ def assert_permutation_equivariant( # Prevent pytest from showing this function in the traceback # __tracebackhide__ = True - atol = PERMUTATION_FLOAT_TOLERANCE[torch.get_default_dtype()] + if tolerance is None: + atol = PERMUTATION_FLOAT_TOLERANCE[torch.get_default_dtype()] + else: + atol = tolerance data_in = data_in.copy() device = data_in[AtomicDataDict.POSITIONS_KEY].device @@ -119,9 +124,13 @@ def assert_permutation_equivariant( def assert_AtomicData_equivariant( func: GraphModuleMixin, - data_in: Union[AtomicData, AtomicDataDict.Type], + data_in: Union[ + AtomicData, AtomicDataDict.Type, List[Union[AtomicData, AtomicDataDict.Type]] + ], + permutation_tolerance: Optional[float] = None, + o3_tolerance: Optional[float] = None, **kwargs, -): +) -> str: r"""Test the rotation, translation, parity, and permutation equivariance of ``func``. For details on permutation testing, see ``assert_permutation_equivariant``. @@ -131,34 +140,43 @@ def assert_AtomicData_equivariant( Args: func: the module or model to test - data_in: the example input data to test with + data_in: the example input data(s) to test with. Only the first is used for permutation testing. **kwargs: passed to ``e3nn.util.test.assert_equivariant`` Returns: - Information on equivariance error from ``e3nn.util.test.assert_equivariant`` + A string description of the errors. """ # Prevent pytest from showing this function in the traceback __tracebackhide__ = True - if not isinstance(data_in, dict): - data_in = AtomicData.to_AtomicDataDict(data_in) + if not isinstance(data_in, list): + data_in = [data_in] + data_in = [AtomicData.to_AtomicDataDict(d) for d in data_in] # == Test permutation of graph nodes == - assert_permutation_equivariant( - func, - data_in, - ) + # TODO: since permutation is distinct, run only on one. + assert_permutation_equivariant(func, data_in[0], tolerance=permutation_tolerance) # == Test rotation, parity, and translation using e3nn == irreps_in = {k: None for k in AtomicDataDict.ALLOWED_KEYS} - irreps_in.update( - { - AtomicDataDict.POSITIONS_KEY: "cartesian_points", - AtomicDataDict.CELL_KEY: "3x1o", - } - ) irreps_in.update(func.irreps_in) - irreps_in = {k: v for k, v in irreps_in.items() if k in data_in} + irreps_in = {k: v for k, v in irreps_in.items() if k in data_in[0]} + irreps_out = func.irreps_out.copy() + # for certain things, we don't care what the given irreps are... + # make sure that we test correctly for equivariance: + for irps in (irreps_in, irreps_out): + if AtomicDataDict.POSITIONS_KEY in irps: + # it should always have been 1o vectors + # since that's actually a valid Irreps + assert o3.Irreps(irps[AtomicDataDict.POSITIONS_KEY]) == o3.Irreps("1o") + irps[AtomicDataDict.POSITIONS_KEY] = "cartesian_points" + if AtomicDataDict.CELL_KEY in irps: + prev_cell_irps = irps[AtomicDataDict.CELL_KEY] + assert prev_cell_irps is None or o3.Irreps(prev_cell_irps) == o3.Irreps( + "3x1o" + ) + # must be this to actually rotate it + irps[AtomicDataDict.CELL_KEY] = "3x1o" def wrapper(*args): arg_dict = {k: v for k, v in zip(irreps_in, args)} @@ -175,25 +193,76 @@ def wrapper(*args): cell = arg_dict[AtomicDataDict.CELL_KEY] assert cell.shape[-2:] == (3, 3) arg_dict[AtomicDataDict.CELL_KEY] = cell.reshape(cell.shape[:-2] + (9,)) - return [output[k] for k in func.irreps_out] - - data_in = AtomicData.to_AtomicDataDict(data_in) - # cell is a special case - if AtomicDataDict.CELL_KEY in data_in: - # flatten - cell = data_in[AtomicDataDict.CELL_KEY] - assert cell.shape[-2:] == (3, 3) - data_in[AtomicDataDict.CELL_KEY] = cell.reshape(cell.shape[:-2] + (9,)) - - args_in = [data_in[k] for k in irreps_in] - - return assert_equivariant( - wrapper, - args_in=args_in, - irreps_in=list(irreps_in.values()), - irreps_out=list(func.irreps_out.values()), - **kwargs, - ) + return [output[k] for k in irreps_out] + + # prepare input data + for d in data_in: + # cell is a special case + if AtomicDataDict.CELL_KEY in d: + # flatten + cell = d[AtomicDataDict.CELL_KEY] + assert cell.shape[-2:] == (3, 3) + d[AtomicDataDict.CELL_KEY] = cell.reshape(cell.shape[:-2] + (9,)) + + errs = [ + equivariance_error( + wrapper, + args_in=[d[k] for k in irreps_in], + irreps_in=list(irreps_in.values()), + irreps_out=list(irreps_out.values()), + **kwargs, + ) + for d in data_in + ] + + # take max across errors + errs = {k: torch.max(torch.vstack([e[k] for e in errs]), dim=0)[0] for k in errs[0]} + + if o3_tolerance is None: + o3_tolerance = FLOAT_TOLERANCE[torch.get_default_dtype()] + anerr = next(iter(errs.values())) + if isinstance(anerr, float) or anerr.ndim == 0: + # old e3nn doesn't report which key + problems = {k: v for k, v in errs.items() if v > o3_tolerance} + + def _describe(errors): + return "\n".join( + "(parity_k={:d}, did_translate={}) -> max error={:.3e}".format( + int(k[0]), + bool(k[1]), + float(v), + ) + for k, v in errors.items() + ) + + if len(problems) > 0: + raise AssertionError( + "Equivariance test failed for cases:" + _describe(problems) + ) + + return _describe(errs) + else: + # it's newer and tells us which is which + all_errs = [] + for case, err in errs.items(): + for key, this_err in zip(irreps_out.keys(), err): + all_errs.append(case + (key, this_err)) + problems = [e for e in all_errs if e[-1] > o3_tolerance] + + def _describe(errors): + return "\n".join( + " (parity_k={:1d}, did_translate={:5}, field={:20}) -> max error={:.3e}".format( + int(k[0]), str(bool(k[1])), str(k[2]), float(k[3]) + ) + for k in errors + ) + + if len(problems) > 0: + raise AssertionError( + "Equivariance test failed for cases:\n" + _describe(problems) + ) + + return _describe(all_errs) _DEBUG_HOOKS = None diff --git a/setup.py b/setup.py index d01cf27e..f206bba9 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ "e3nn>=0.3.5,<0.5.0", "pyyaml", "contextlib2;python_version<'3.7'", # backport of nullcontext + 'contextvars;python_version<"3.7"', # backport of contextvars for savenload "typing_extensions;python_version<'3.8'", # backport of Final "torch-runstats>=0.2.0", "torch-ema>=0.3.0", diff --git a/tests/unit/data/test_AtomicData.py b/tests/unit/data/test_AtomicData.py index f5d6dc27..2bc50743 100644 --- a/tests/unit/data/test_AtomicData.py +++ b/tests/unit/data/test_AtomicData.py @@ -39,10 +39,17 @@ def test_to_ase_batches(atomic_batch): assert np.allclose(atoms.get_positions(), atomic_data.pos[mask]) assert atoms.get_atomic_numbers().shape == (len(atoms),) assert np.array_equal( - atoms.get_atomic_numbers(), atomic_data[AtomicDataDict.ATOM_TYPE_KEY][mask] + atoms.get_atomic_numbers(), + atomic_data[AtomicDataDict.ATOM_TYPE_KEY][mask].view(-1), ) - assert np.array_equal(atoms.get_cell(), atomic_data.cell[batch_idx]) - assert np.array_equal(atoms.get_pbc(), atomic_data.pbc[batch_idx]) + + assert ( + np.max(np.abs(atoms.get_cell()[:] - atomic_data.cell[batch_idx].numpy())) + == 0 + ) + assert not np.logical_xor( + atoms.get_pbc(), atomic_data.pbc[batch_idx].numpy() + ).all() def test_ase_roundtrip(CuFcc): diff --git a/tests/unit/data/test_dataset.py b/tests/unit/data/test_dataset.py index 98aa635c..01b89e2d 100644 --- a/tests/unit/data/test_dataset.py +++ b/tests/unit/data/test_dataset.py @@ -66,6 +66,15 @@ def root(): yield path +def test_type_mapper(): + tm = TypeMapper(chemical_symbol_to_type={"C": 1, "H": 0}) + atomic_numbers = torch.as_tensor([1, 1, 6, 1, 6, 6, 6]) + atom_types = tm.transform(atomic_numbers) + assert atom_types[0] == 0 + untransformed = tm.untransform(atom_types) + assert torch.equal(untransformed, atomic_numbers) + + class TestInit: def test_init(self): with pytest.raises(NotImplementedError) as excinfo: @@ -210,7 +219,9 @@ def test_per_graph_field( # get species count per graph Ns = [] for i in range(npz_dataset.len()): - Ns.append(torch.bincount(npz_dataset[i][AtomicDataDict.ATOM_TYPE_KEY])) + Ns.append( + torch.bincount(npz_dataset[i][AtomicDataDict.ATOM_TYPE_KEY].view(-1)) + ) n_spec = max(len(e) for e in Ns) N = torch.zeros(len(Ns), n_spec) for i in range(len(Ns)): diff --git a/tests/unit/model/test_builder_utils.py b/tests/unit/model/test_builder_utils.py new file mode 100644 index 00000000..dcb45d4d --- /dev/null +++ b/tests/unit/model/test_builder_utils.py @@ -0,0 +1,38 @@ +import pytest + +import torch + +from nequip.data import AtomicDataDict + +from nequip.model.builder_utils import add_avg_num_neighbors + + +def test_avg_num_neighbors(nequip_dataset): + # test basic options + annkey = "avg_num_neighbors" + config = {annkey: 3} + add_avg_num_neighbors(config, initialize=False, dataset=None) + assert config[annkey] == 3.0 # nothing should happen + assert isinstance(config[annkey], float) + + config = {annkey: 3} + # dont need dataset if config isn't auto + add_avg_num_neighbors(config, initialize=False, dataset=None) + with pytest.raises(ValueError): + # need if it is + config = {annkey: "auto"} + add_avg_num_neighbors(config, initialize=True, dataset=None) + + # compute dumb truth + num_neigh = [] + for i in range(len(nequip_dataset)): + frame = nequip_dataset[i] + num_neigh.append( + torch.bincount(frame[AtomicDataDict.EDGE_INDEX_KEY][0]).float() + ) + avg_num_neighbor_truth = torch.mean(torch.cat(num_neigh, dim=0)) + + # compare + config = {annkey: "auto"} + add_avg_num_neighbors(config, initialize=True, dataset=nequip_dataset) + assert config[annkey] == avg_num_neighbor_truth diff --git a/tests/unit/model/test_eng_force.py b/tests/unit/model/test_eng_force.py index 05fca69a..ced59987 100644 --- a/tests/unit/model/test_eng_force.py +++ b/tests/unit/model/test_eng_force.py @@ -21,6 +21,7 @@ COMMON_CONFIG = { "num_types": 3, "types_names": ["H", "C", "O"], + "avg_num_neighbors": None, } r_max = 3 minimal_config1 = dict( @@ -79,14 +80,8 @@ def config(request): @pytest.fixture( params=[ - ( - ["EnergyModel", "ForceOutput"], - AtomicDataDict.FORCE_KEY, - ), - ( - ["EnergyModel"], - AtomicDataDict.TOTAL_ENERGY_KEY, - ), + (["EnergyModel", "ForceOutput"], AtomicDataDict.FORCE_KEY,), + (["EnergyModel"], AtomicDataDict.TOTAL_ENERGY_KEY,), ] ) def model(request, config): @@ -138,9 +133,7 @@ def test_jit(self, model, atomic_batch, device): model_script = script(instance) assert torch.allclose( - instance(data)[out_field], - model_script(data)[out_field], - atol=1e-6, + instance(data)[out_field], model_script(data)[out_field], atol=1e-6, ) # - Try saving, loading in another process, and running - @@ -225,6 +218,43 @@ def test_numeric_gradient(self, config, atomic_batch, device, float_tolerance): numeric, analytical, rtol=5e-2 ) + def test_partial_forces(self, atomic_batch, device): + config = minimal_config1.copy() + config["model_builders"] = [ + "EnergyModel", + "ForceOutput", + ] + partial_config = config.copy() + partial_config["model_builders"] = [ + "EnergyModel", + "PartialForceOutput", + ] + model = model_from_config(config=config, initialize=True) + partial_model = model_from_config(config=partial_config, initialize=True) + model.to(device) + partial_model.to(device) + partial_model.load_state_dict(model.state_dict()) + data = atomic_batch.to(device) + output = model(AtomicData.to_AtomicDataDict(data)) + output_partial = partial_model(AtomicData.to_AtomicDataDict(data)) + # everything should be the same + # including the + for k in output: + assert k != AtomicDataDict.PARTIAL_FORCE_KEY + assert k in output_partial + if output[k].is_floating_point(): + assert torch.allclose( + output[k], + output_partial[k], + atol=1e-6 if k == AtomicDataDict.FORCE_KEY else 1e-8, + ) + else: + assert torch.equal(output[k], output_partial[k]) + n_at = data[AtomicDataDict.POSITIONS_KEY].shape[0] + partial_forces = output_partial[AtomicDataDict.PARTIAL_FORCE_KEY] + assert partial_forces.shape == (n_at, n_at, 3) + # TODO check sparsity? + class TestAutoGradient: def test_cross_frame_grad(self, config, nequip_dataset): @@ -273,7 +303,7 @@ def test_large_separation(self, model, config, molecules): atoms2.positions += 40.0 + np.random.randn(3) atoms_both = atoms1.copy() atoms_both.extend(atoms2) - tm = TypeMapper(chemical_symbol_to_type={"H": 0, "C": 1, "O": 2}) + tm = TypeMapper(chemical_symbols=["H", "C", "O"]) data1 = tm(AtomicData.from_ase(atoms1, r_max=r_max)) data2 = tm(AtomicData.from_ase(atoms2, r_max=r_max)) data_both = tm(AtomicData.from_ase(atoms_both, r_max=r_max))