diff --git a/Dockerfile.xpu b/Dockerfile.xpu new file mode 100644 index 00000000..d01e2caa --- /dev/null +++ b/Dockerfile.xpu @@ -0,0 +1,29 @@ +FROM amr-registry.caas.intel.com/aipg/kinlongk-pytorch:nightly-xpu +# for setup run with root +USER 0 +# needed to make sure mamba environment is activated +ARG MAMBA_DOCKERFILE_ACTIVATE=1 +# install packages, particularly xpu torch from nightly wheels +RUN pip install torch-geometric +# DGl is currently unsupported, and uses libtorch so we need to build it +WORKDIR /opt/matsciml +COPY . . +RUN pip install -e . +RUN git clone --recurse-submodules https://github.com/dmlc/dgl.git /opt/dgl +ENV DGL_HOME=/opt/dgl +WORKDIR /opt/dgl/build +RUN cmake -DUSE_CUDA=OFF -DPython3_EXECUTABLE=/opt/conda/bin/python .. && make +WORKDIR /opt/dgl/python +RUN pip install . +RUN micromamba clean --all --yes && rm -rf /opt/xpu-backend /var/lib/apt/lists/* +# make conda read-writable for user +RUN chown -R $MAMBA_USER:$MAMBA_USER /opt/matsciml && chown -R $MAMBA_USER:$MAMBA_USER /opt/conda +# change back to non-root user +USER $MAMBA_USER +LABEL org.opencontainers.image.authors="Kin Long Kelvin Lee" +LABEL org.opencontainers.image.vendor="Intel Labs" +LABEL org.opencontainers.image.base.name="amr-registry.caas.intel.com/aipg/kinlongk-pytorch:nightly" +LABEL org.opencontainers.image.title="kinlongk-pytorch" +LABEL org.opencontainers.image.description="XPU enabled PyTorch+Triton from Github artifact wheel builds." +HEALTHCHECK NONE +ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu new file mode 100644 index 00000000..d9f44c7c --- /dev/null +++ b/docker/Dockerfile.xpu @@ -0,0 +1,61 @@ +FROM mambaorg/micromamba:noble AS oneapi_xpu +SHELL ["/bin/bash", "-c"] +ENV DEBIAN_FRONTEND=noninteractive +ENV TZ=Etc/UTC +# for setup run with root +USER 0 +# needed to make sure mamba environment is activated +ARG MAMBA_DOCKERFILE_ACTIVATE=1 +# VGID as an environment variable specifying the group ID for render on host +# this needs to match, otherwise non-root users will not pick up cards +ARG VGID=993 +# install firmware, oneAPI, etc. +RUN apt-get update -y && apt-get install -y software-properties-common wget git make g++ gcc gpg-agent +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ + | gpg --dearmor > /usr/share/keyrings/intel-for-pytorch-gpu-dev-keyring.gpg +RUN echo "deb [signed-by=/usr/share/keyrings/intel-for-pytorch-gpu-dev-keyring.gpg] https://apt.repos.intel.com/intel-for-pytorch-gpu-dev all main" > /etc/apt/sources.list.d/intel-for-pytorch-gpu-dev.list +RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key \ + | gpg --yes --dearmor --output /usr/share/keyrings/intel-graphics.gpg +RUN echo "deb [arch=amd64 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu noble unified" > /etc/apt/sources.list.d/intel-gpu-noble.list +RUN apt-get update -y && \ + apt-get upgrade -y && \ + apt-get install -y git make g++ gcc gpg-agent wget \ + intel-for-pytorch-gpu-dev-0.5 \ + intel-pti-dev \ + cmake \ + tzdata \ + zlib1g zlib1g-dev \ + xpu-smi \ + intel-opencl-icd intel-level-zero-gpu libze1 intel-oneapi-mpi \ + intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \ + libegl-mesa0 libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ + libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \ + mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo +# make sure oneAPI components are in environment variables +RUN source /opt/intel/oneapi/setvars.sh +# make it so you don't have to source oneAPI every time +COPY entrypoint.sh /usr/local/bin/entrypoint.sh +ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] +FROM oneapi_xpu +# set aliases so python and pip are met +RUN ln -s /opt/conda/bin/python /usr/local/bin/python && ln -s /opt/conda/bin/pip /usr/local/bin/pip +# clone matsciml into container and install +RUN git clone https://github.com/IntelLabs/matsciml /opt/matsciml +WORKDIR /opt/matsciml +# install packages, particularly xpu torch from nightly wheels +RUN micromamba install -y -n base -f conda.yml && \ + pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/xpu && \ + pip install -e './[all]' +RUN micromamba clean --all --yes && rm -rf /opt/xpu-backend /var/lib/apt/lists/* +# let non-root mamba user have access to GPUS +RUN groupadd -g $VGID render && usermod -a -G video,render $MAMBA_USER +# make conda read-writable for user +RUN chown -R $MAMBA_USER:$MAMBA_USER /opt/matsciml && chown -R $MAMBA_USER:$MAMBA_USER /opt/conda +# change back to non-root user +USER $MAMBA_USER +LABEL org.opencontainers.image.authors="Kin Long Kelvin Lee" +LABEL org.opencontainers.image.vendor="Intel Labs" +LABEL org.opencontainers.image.base.name="amr-registry.caas.intel.com/aipg/kinlongk-pytorch:nightly" +LABEL org.opencontainers.image.title="kinlongk-pytorch" +LABEL org.opencontainers.image.description="XPU enabled PyTorch+Triton from Github artifact wheel builds." +HEALTHCHECK NONE diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh new file mode 100755 index 00000000..954bd826 --- /dev/null +++ b/docker/entrypoint.sh @@ -0,0 +1,5 @@ +#!/bin/bash +# this sources oneAPI components and silences the output so we don't +# have to see the wall of text every time we enter the container +source /opt/intel/oneapi/setvars.sh > /dev/null 2>&1 +exec "$@" diff --git a/docs/source/index.rst b/docs/source/index.rst index aa66676a..8e71e586 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -12,6 +12,7 @@ The Open MatSciML Toolkit Getting started datasets + schema transforms models training diff --git a/docs/source/schema.rst b/docs/source/schema.rst new file mode 100644 index 00000000..4b3fcbf0 --- /dev/null +++ b/docs/source/schema.rst @@ -0,0 +1,181 @@ +Schema +========== + +The Open MatSciML Toolkit tries to place emphasis on reproducibility, and the general +rule of "explicit is better than implicit" by defining schema for data and other development +concepts. + +The intention is to move away from hardcoded ``Dataset`` classes that are rigid in +that they require writing code, as well as not always reliably reproducible as the +underlying data and frameworks change and evolve over time. Instead, the schema +provided in ``matsciml`` tries to shift technical debt from maintaining code to +**documenting** data, which assuming a thorough and complete description, should +in principle be usable regardless of breaking API changes in frameworks that we rely +on like ``pymatgen``, ``torch_geometric``, and so on. As a dataset is being packaged +for distribution/defined, the schema should also make intentions of the developer clear +to the end-user, e.g. what target label is available, how it was calculated, and so on, +to help subsequent reproduction efforts. As an effect, this also makes development of +``matsciml`` a lot more streamlined, as it then homogenizes field names (i.e. we can +reliably expect ``cart_coords`` to be available and are cartesian coordinates). + +.. TIP:: + You do not have to construct objects contained in schema if they are ``pydantic`` + models themselves: for example, the ``PeriodicBoundarySchema`` is required in + ``DataSampleSchema``, but you can alternatively just pass a dictionary with the + expected key/value mappings (i.e. ``{'x': True, 'y': True, 'z': False}``) for + the relevant schema. + + +Dataset schema reference +######################## + +This schema lays out what can be described as metadata for a dataset. We define all of +the expected fields in ``targets``, and record checksums for each dataset split such +that we can record what model was trained on what specific split. Currently, it is the +responsibility of the dataset distributor to record this metadata for their dataset, +and package it as a ``metadata.json`` file in the same folder as the HDF5 files. + +.. autoclass:: matsciml.datasets.schema.DatasetSchema + :members: + +Data sample schema reference +############################ + +This schema comprises a **single** data sample, providing standardized field names for +a host of commonly used properties. Most properties are optional for the class construction, +but we highly recommend perusing the fields shown below to find the attribute closest to +the property being recorded: ``pydantic`` does not allow arbitrary attributes to be stored +in schema, but non-standard properties can be stashed away in ``extras`` as a dictionary of +property name/values. + +.. autoclass:: matsciml.datasets.schema.DataSampleSchema + :members: + + +Creating datasets with schema +############################# + +.. NOTE:: + This section is primarily for people interested in developing new datasets. + +The premise behind defining these schema rigorously is to encourage reproducible workflows +with (hopefully) less technical debt: we can safely rely on data to validate itself, +catch mistakes when serializing/dumping data for others to use, and set reasonable expectations +on what data will be available at what parts of training and evaluation. For those interested +in creating their own datasets, this section lays out some preliminary notes on how to wrangle +your data to adhere to schema, and make the data ready to be used by the pipeline. + +Matching your data to the schema +================================ + +First step in dataset creation is taking whatever primary data format you have, and mapping +them to the ``DataSampleSchema`` laid out above. The required keys include ``index``, ``cart_coords``, +and so on, and by definition need to be provided. The code below shows an example loop +over a list of data, which we convert into a dictionary with the same keys as expected in ``DataSampleSchema``: + +.. :: + :caption: Example abstract code for mapping your data to the schema + all_data = ... # should be a list of samples + samples = [] + for index, data in enumerate(all_data): + temp_dict = {} + temp_dict["cart_coords"] = data.positions + temp_dict['index'] = index + temp_dict['datatype'] = "OptimizationCycle" # must be one of the enums + temp_dict['num_atoms'] = len(data.positions) + schema = DataSampleSchema(**temp_dict) + samples.append(schema) + +You end up with a list of ``DataSampleSchema`` which undergo all of the validation +and consistency checks. + +Data splits +====================== + +At this point you could call it a day, but if we want to create uniform random +training and validation splits, this is a good point to do so. The code below +shows one way of generating the splits: keep in mind that this mechanism for +splitting might not be appropriate for your data - to mitigate data leakage, +you may need to consider using more sophisticated algorithms that consider chemical +elements, de-correlate dynamics, etc. Treat the code below as boilerplate, and +modify it as needed. + +.. :: + :caption: Example code showing how to generate training and validation splits + import numpy as np + import h5py + from matsciml.datasets.generic import write_data_to_hdf5_group + + SEED = 73926 # this will be reused when generating the metadata + rng = np.random.default_rng(SEED) + + all_indices = np.arange(len(samples)) + val_split = int(len(samples) * 0.2) + rng.shuffle(all_indices) + train_indices = all_indices[val_split:] + val_indices = all_indices[:val_split] + + # instantiate HDF5 files + train_h5 = h5py.File("./train.h5", mode="w") + + for index in train_indices: + sample = samples[index] + # store each data sample as a group comprising array data + group = train_h5.create_group(str(index)) + # takes advantage of pydantic serialization + for key, value in sample.model_dump(round_trip=True).items(): + if value is not None: + write_data_to_hdf5_group(key, value, group) + + +Repeat the loop above for your validation set. + +Dataset metadata +================== + +Once we have created these splits, there's a bunch of metadata associated +with **how** we created the splits that we should record so that at runtime, +there's no ambiguity which data and splits are being used and where they +came from. + +.. :: + from datetime import datetime + from matsciml.datasets.generic import MatSciMLDataset + from matsciml.datasets.schema import DatasetSchema + + # use the datasets we created above; `strict_checksum` needs to be + # set the False here because we're going to be generating the checksum + train_dset = MatSciMLDataset("./train.h5", strict_checksum=False) + train_checksum = train_dset.blake2s_checksum + + # fill in the dataset metadata schema + dset_schema = DatasetSchema( + name="My new dataset", + creation=datetime.now(), + split_blake2s={ + "train": train_checksum, + "validation": ..., + "test": ..., # these should be made the same way as the training set + }, + targets=[...], # see below + dataset_type="OptimizationCycle", # choose one of the `DataTypeEnum` + seed=SEED, # from the first code snippet + ) + # writes the schema where it's expected + dset.to_json_file("metadata.json") + + +Hopefully you can appreciate that the metadata is meant to lessen the burden +of future users of the dataset (including yourself!). The last thing to cover +here is that ``targets`` was omitted in the snippet above: this field is meant +for you to record every property that may or may not be part of the standard +``DataSampleSchema`` which is intended to be used throughout training. This is +the ``TargetSchema``: you must detail the name, expected shape, and a short +description of every property (including the standard ones). The main motivation +for this is that ``total_energy`` for one dataset may mean something very different +between one dataset to the next (electronic energy? thermodynamic corrections?), +and specifying this for the end user will remove any ambiguities. + + +.. autoclass:: matsciml.datasets.schema.TargetSchema + :members: diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 00000000..954bd826 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,5 @@ +#!/bin/bash +# this sources oneAPI components and silences the output so we don't +# have to see the wall of text every time we enter the container +source /opt/intel/oneapi/setvars.sh > /dev/null 2>&1 +exec "$@" diff --git a/matsciml/datasets/__init__.py b/matsciml/datasets/__init__.py index 3b4e87a3..96ac3ff3 100644 --- a/matsciml/datasets/__init__.py +++ b/matsciml/datasets/__init__.py @@ -20,6 +20,8 @@ from matsciml.datasets.ocp_datasets import IS2REDataset, S2EFDataset from matsciml.datasets.oqmd import OQMDDataset from matsciml.datasets.symmetry import SyntheticPointGroupDataset +from matsciml.datasets.schema import DatasetSchema, DataSampleSchema +from matsciml.datasets.generic import MatSciMLDataset, MatSciMLDataModule __all__ = [ "AlexandriaDataset", @@ -34,4 +36,8 @@ "SyntheticPointGroupDataset", "MultiDataset", "ColabFitDataset", + "DatasetSchema", + "DataSampleSchema", + "MatSciMLDataModule", + "MatSciMLDataset", ] diff --git a/matsciml/datasets/generic.py b/matsciml/datasets/generic.py new file mode 100644 index 00000000..61413274 --- /dev/null +++ b/matsciml/datasets/generic.py @@ -0,0 +1,477 @@ +from __future__ import annotations + +from hashlib import blake2s +from functools import cache +from os import PathLike +from pathlib import Path +from typing import Any, Callable, Literal +from logging import getLogger + +import h5py +from torch.utils.data import DataLoader, Dataset +from lightning import pytorch as pl + +from matsciml.datasets.schema import ( + DatasetSchema, + DataSampleSchema, + collate_samples_into_batch_schema, +) + + +logger = getLogger("matsciml.datasets.MatSciMLDataset") + +__all__ = ["MatSciMLDataset", "MatSciMLDataModule"] + + +def write_data_to_hdf5_group(key: str, data: Any, h5_group: h5py.Group) -> None: + """ + Writes data recursively to an HDF5 group. + + For dictionary data, we will create a new group and call this + same function to write the subkey/values within the new group. + + For strings, the data is written to the ``attrs``. + + Parameters + ---------- + key : str + Key to write the data to; this will create a new + ``h5py.Dataset`` object under this name within the group. + data : Any + Any data to write to the HDF5 group. + h5_group : h5py.Group + Instance of an ``h5py.Group`` to write datasets to. + """ + if isinstance(data, dict): + subgroup = h5_group.create_group(key) + for subkey, subvalue in data.items(): + write_data_to_hdf5_group(subkey, subvalue, subgroup) + elif isinstance(data, str): + h5_group.attrs[key] = data + else: + h5_group[key] = data + + +def read_hdf5_data(h5_group: h5py.Group) -> dict[str, Any]: + """ + Recursively read in an HDF5 group's worth of data. + + This function loops over every key/value pair contained + in the group. For ``h5py.Dataset`` objects, we read in + all the data, whereas for groups, we recursively apply + this function. + + For primarily string-based data, we also peek into + the group's ``attrs`` storage and retrieve that data as well. + + Parameters + ---------- + h5_group : h5py.Group + Instance of an ``h5py.Group`` - intended usage is + to pass the top level group within an ``h5py.File``, + and retrieve all of the data pertaining to a sample. + + Returns + ------- + dict[str, Any] + Dictionary representation of the data, with matrix data + as ``np.ndarray``s. + """ + output_dict = {} + for key, value in h5_group.items(): + if isinstance(value, h5py.Dataset): + # [()] is the catch-all for scalar and matrix data + value = value[()] + elif isinstance(value, h5py.Group): + # call function recursively if it's a group + value = read_hdf5_data(value) + output_dict[key] = value + # things like strings are primarily stored as attrs + for key, value in h5_group.attrs.items(): + if isinstance(value, dict): + output_dict[key] = {subkey: subvalue for subkey, subvalue in value.items()} + else: + output_dict[key] = value + return output_dict + + +class MatSciMLDataset(Dataset): + def __init__( + self, + filepath: PathLike, + transforms: list[Callable] | None = None, + strict_checksum: bool = True, + ): + """ + Dataset class for generic ``MatSciMLDataset``s that use + the data schema specifications. + + The main output of this class is mainly data loading from + HDF5 files, parsing ``DatasetSchema`` metadata that are + adjacent to HDF5 files, and returning data samples in the + form of ``DataSampleSchema`` objects, which in principle + should replace conventional ``DataDict`` (i.e. just plain + dictionaries with arbitrary key/value pairs) that were used + in earlier ``matsciml`` versions. + + Parameters + ---------- + filepath : PathLike + Filepath to a specific HDF5 file, containing a data split + of either ``train``, ``test``, ``validation``, or ``predict``. + transforms : list[Callable], optional + If provided, should be a list of Python callable objects + that will operate on the data. + strict_checksum : bool, default True + If ``True``, the dataset will refuse to run if it does not + match any of the checksums contained in the metadata. This + implementation does not **need** to know which split the data + is, but has to match at least one of the specified splits. + This can be disabled manually by setting to ``False``, but + means the dataset can be modified. + + Raises + ------ + RuntimeError: + If no checksums in the metadata match the current data + while ``strict_checksum`` is set to ``True``. + """ + super().__init__() + if not isinstance(filepath, Path): + filepath = Path(filepath) + self.filepath = filepath + self.transforms = transforms + if strict_checksum: + metadata = self.metadata + success = False + for key, value in metadata.split_blake2s.items(): + if value == self.blake2s_checksum: + success = True + logger.debug( + f"Matched dataset checksum with {key} split from metadata." + ) + if not success: + raise RuntimeError( + "Dataset checksum failed to validate against any splits in metadata." + ) + + @property + def metadata(self) -> DatasetSchema: + """ + Underlying ``DatasetSchema`` that should accompany all ``matsciml`` datasets. + + This schema should contain information about splits, target properties, + and if relevant, graph wiring and so on. + + Returns + ------- + DatasetSchema + Validated ``DatasetSchema`` object + + Raises + ------ + RuntimeError: + If there is no metadata + """ + if not hasattr(self, "_metadata"): + meta_target = self.filepath.parent.joinpath("metadata.json") + if not meta_target.exists(): + raise FileNotFoundError( + "No `metadata.json` specifying DatasetSchema found in dataset directory.." + ) + with open(meta_target) as read_file: + metadata = DatasetSchema.model_validate_json( + read_file.read(), strict=True + ) + self._metadata = metadata + return self._metadata + + @property + @cache + def blake2s_checksum(self) -> str: + """ + Computes the BLAKE2s hash for the current dataset. + + This functions by opening the binary file for reading and + iterating over lines in the file. + + Returns + ------- + str + BLAKE2s hash for the HDF5 file of this dataset. + """ + with open(self.filepath, "rb") as read_file: + hasher = blake2s() + for line in read_file.readlines(): + hasher.update(line) + return hasher.hexdigest() + + def read_data(self) -> h5py.File: + return h5py.File(str(self.filepath.absolute()), mode="r") + + def write_data( + self, index: int, sample: DataSampleSchema, overwrite: bool = False + ) -> None: + """ + Writes a data sample at index to the current HDF5 file. + + Most likely not the most performant way to write data + since it's in serial, but is easily accessible. + + Parameters + ---------- + index : int + Index to write the data to. Must not already be + present in the dataset if ``overwrite`` is False. + sample : DataSampleSchema + A data sample defined by an instance of ``DataSampleSchema``. + overwrite : bool, default False + If False, if ``index`` already exists in the file + a ``RuntimeError`` will be raised. + """ + with h5py.File(str(self.filepath).absolute(), "w") as h5_file: + if overwrite and str(index) in h5_file: + del h5_file[str(index)] + group = h5_file.create_group(str(index)) + sample_data = sample.model_dump(round_trip=True) + for key, value in sample_data.items(): + write_data_to_hdf5_group(key, value, group) + + @cache + def __len__(self) -> int: + with self.read_data() as h5_data: + return len(h5_data.keys()) + + @property + @cache + def keys(self) -> list[str]: + with self.read_data() as h5_data: + return list(h5_data.keys()) + + def __getitem__(self, index: int) -> DataSampleSchema: + """ + Retrieves a sample from the present dataset. + + Data samples are organized into top-level ``h5py.Group``s, + and the passed ``index`` value maps onto the underlying + ``data_index`` which corresponds to the index the data sample + originally had before splits. + + We recursively read in data contained within a group, and + use the key/values to reconstruct and validate with a ``DataSampleSchema``. + Finally, we apply transforms if they are provided to this object, + and return it. + + Parameters + ---------- + index : int + Integer corresponding to a value within the range of + the dataset length. This may or may not coincide with + the actual ``h5py.Group`` keys, but refers to the ``keys`` + property of ``MatSciMLDataset`` to retrieve the 'real' + key to read with. + + Returns + ------- + DataSampleSchema + Data sample from disk after validation, and transforms + applied if relevant. + + Raises + ------ + KeyError: + If the data sample index is missing from the dataset. + KeyError: + If a target key is defined in the metadata, but missing + from the dictionary before passing into ``DataSampleSchema``. + RuntimeError: + If a transform was unable to be applied to the ``DataSampleSchema``. + """ + data_index = self.keys[index] + with self.read_data() as h5_data: + try: + sample_group = h5_data[data_index] + except KeyError as e: + raise KeyError(f"Data sample {data_index} missing from dataset.") from e + sample_data = read_hdf5_data(sample_group) + # validate expected data + for target in self.metadata.targets: + is_missing = True + if target.name in sample_data: + is_missing = False + if "extras" in sample_data: + if target.name in sample_data["extras"]: + is_missing = False + if is_missing: + raise KeyError( + f"Expected {target.name} in data sample but not found." + ) + sample = DataSampleSchema(**sample_data) + # now try to apply transforms + if self.transforms: + for transform in self.transforms: + try: + sample = transform(sample) + except Exception as e: + raise RuntimeError( + f"Unable to apply {transform} on sample at index {index}." + ) from e + return sample + + +class MatSciMLDataModule(pl.LightningDataModule): + def __init__( + self, + filepath: PathLike, + transforms: list[Callable] | None = None, + strict_checksum: bool = False, + **loader_kwargs, + ): + """ + Initialize a ``MatSciMLDataModule`` that uses the HDF5 + binary data format. Provides a ``Lightning`` wrapper around + the dataset class, which individually handles splits whereas + this class handles a collection of HDF5 files. + + Parameters + ---------- + filepath : PathLike + Filepath to a root folder containing HDF5 files + for each split, and a metadata JSON file. + transforms : list[Callable], optional + List of transforms to process data samples after loading. + loader_kwargs + Additional keyword arguments that are passed to + dataloaders. + + Raises + ------ + RuntimeError: + If the provided filepath is not a directory, this method + will raise a ``RuntimeError``. + """ + loader_kwargs.setdefault("num_workers", 0) + loader_kwargs.setdefault("persistent_workers", False) + loader_kwargs.setdefault("batch_size", 8) + super().__init__() + if not isinstance(filepath, Path): + filepath = Path(filepath) + if not filepath.is_dir(): + raise RuntimeError(f"Expected filepath to be a directory; got {filepath}") + self.metadata = filepath.joinpath("metadata.json") + # add to things to save + hparams_to_save = { + "filepath": filepath, + "transforms": transforms, + "strict_checksum": strict_checksum, + "metadata": self.metadata.model_dump(), + "loader_kwargs": loader_kwargs, + } + self.save_hyperparameters(hparams_to_save) + + @property + def metadata(self) -> DatasetSchema: + return self._metadata + + @metadata.setter + def metadata(self, filepath: Path) -> None: + if not filepath.exists(): + raise RuntimeError( + "No metadata found in target directory. Expected a metadata.json to exist." + ) + with open(filepath, "r") as read_file: + metadata = DatasetSchema.model_validate_json(read_file.read(), strict=True) + self._metadata = metadata + + @property + def h5_files(self) -> dict[Literal["train", "test", "validation", "predict"], Path]: + """ + Returns a mapping of split to HDF5 filepaths within + the root folder. Entries will only be present if the file + can be found. + + Returns + ------- + dict[Literal["train", "test", "validation", "predict"], Path] + Available split to HDF5 filepath mapping. + """ + return self._h5_files + + @h5_files.setter + def h5_files(self, root_dir: Path) -> None: + """ + Given a root folder directory, discover subsplits of data + by matching the ``.h5`` file extension. + + Parameters + ---------- + root_dir : Path + Folder containing data splits and metadata. + + Raises + ------ + RuntimeError + If not ``.h5`` files were discovered within this folder, + we raise a ``RuntimeError``. + """ + h5_files = {} + for prefix in ["train", "test", "validation", "predict"]: + h5_file = root_dir.joinpath(prefix).with_suffix(".h5") + if h5_file.exists(): + h5_files[prefix] = h5_file.absolute() + logger.debug(f"Found {h5_file} data.") + if len(h5_files) == 0: + raise RuntimeError("No .h5 files found in target directory.") + self._h5_files = h5_files + + def setup(self, stage: str | None = None) -> None: + # check and set the available HDF5 data files + self.h5_files = self.hparams.filepath + self.datasets = { + key: MatSciMLDataset( + path, self.hparams.transforms, self.hparams.strict_checksum + ) + for key, path in self.h5_files.items() + } + if stage == "fit": + assert "train" in self.datasets, "No training split available!" + if stage == "predict": + assert "predict" in self.datasets, "No predict split available!" + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.datasets["train"], + shuffle=True, + **self.hparams.loader_kwargs, + collate_fn=collate_samples_into_batch_schema, + ) + + def val_dataloader(self) -> DataLoader | None: + if "validation" not in self.datasets: + return None + return DataLoader( + self.datasets["validation"], + shuffle=False, + **self.hparams.loader_kwargs, + collate_fn=collate_samples_into_batch_schema, + ) + + def test_dataloader(self) -> DataLoader | None: + if "test" not in self.datasets: + return None + return DataLoader( + self.datasets["test"], + shuffle=False, + **self.hparams.loader_kwargs, + collate_fn=collate_samples_into_batch_schema, + ) + + def predict_dataloader(self) -> DataLoader | None: + if "predict" not in self.datasets: + return None + return DataLoader( + self.datasets["predict"], + shuffle=False, + **self.hparams.loader_kwargs, + collate_fn=collate_samples_into_batch_schema, + ) diff --git a/matsciml/datasets/schema.py b/matsciml/datasets/schema.py new file mode 100644 index 00000000..3de4800c --- /dev/null +++ b/matsciml/datasets/schema.py @@ -0,0 +1,1143 @@ +from __future__ import annotations + +from importlib import import_module +from enum import Enum +from datetime import datetime +from typing import Literal, Any, Self +from os import PathLike +from pathlib import Path +import re + +from ase import Atoms +from ase.geometry import cell_to_cellpar, cellpar_to_cell, complete_cell +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + create_model, + field_validator, + model_validator, + ValidationInfo, +) +from numpydantic import NDArray, Shape +from loguru import logger +import numpy as np +import torch + +from matsciml.common.packages import package_registry +from matsciml.common.inspection import get_all_args +from matsciml.common.types import Embeddings, ModelOutput +from matsciml.modules.normalizer import Normalizer +from matsciml.datasets.transforms import PeriodicPropertiesTransform +from matsciml.datasets.utils import cart_frac_conversion + +"""This module defines schemas pertaining to data, using ``pydantic`` models +to help with validation and (de)serialization. + +The driving principle behind this is to try and define a standardized data +format that is also relatively low maintenance. The ``DatasetSchema`` is +used to fully qualify a dataset, including file hashing, a list of expected +targets, and so on. The ``DataSampleSchema`` provides an "on-the-rails" +experience for both developers and users by defining a consistent set of +attribute names that should be +""" + +__all__ = [ + "DataSampleSchema", + "DataSampleEnum", + "DatasetSchema", + "SplitHashSchema", + "PeriodicBoundarySchema", + "NormalizationSchema", + "GraphWiringSchema", + "TargetSchema", + "collate_samples_into_batch_schema", +] + +# ruff: noqa: F722 + + +class MatsciMLSchema(BaseModel): + """ + Implements a base class with (de)serialization methods + for saving and loading JSON files. + """ + + @classmethod + def from_json_file(cls, json_path: PathLike) -> Self: + """ + Deserialize a JSON file, validating against the expected dataset + schema. + + Parameters + ---------- + json_path : PathLike + Path to a JSON metadata file. + + Returns + ------- + DataSampleSchema + Instance of a validated ``DatasetSchema``. + + Raises + ------ + FileNotFoundError + If the specified JSON file does not exist. + """ + if not isinstance(json_path, Path): + json_path = Path(json_path) + if not json_path.exists(): + raise FileNotFoundError( + f"{json_path} JSON metadata file could not be found." + ) + with open(json_path, "r") as read_file: + return cls.model_validate_json(read_file.read(), strict=True) + + def to_json_file(self, json_path: PathLike) -> None: + """ + Write out the schema to a JSON file. + + Parameters + ---------- + json_path : PathLike + Filepath to save the data to. If unspecified, the + suffix ``.json`` will be used. + """ + if not isinstance(json_path, Path): + json_path = Path(json_path) + with open(json_path.with_suffix(".json"), "w+") as write_file: + write_file.write(self.model_dump_json(round_trip=True, indent=2)) + + +class DataSampleEnum(str, Enum): + """ + An Enum for categorizing data samples, which implicitly + informs us what the sample is intended to be used for + by virtue of what data is available, as well as how data + samples may relate to one another. An example would be + ``OptimizationCycle``, which implies the dataset should + contain multiple samples per structure of atomic forces. + + These tend to map more directly to computational chemistry + workflows, although naturally some types of calculations + will have overlap between them (e.g. an excited state geometry + optimization). In those cases, the recommendation would be + to select the intended use case - i.e. ``OptimizationCycle`` + is the preferred enum for this example as it infers the + presence of atomic forces. + + Attributes + ---------- + scf : str + Describes data pertaining to a single SCF cycle, which + comprises energy values, population analyses, orbital + coefficients, spin states, convergence properties etc. + opt_trajectory : str + Describes data comprising a single optimization or relaxation + step, which includes atomic forces, (partial) Hessians, + and geometry convergence metrics. + e_property : str + Describes a specific electronic property calculation. This can range + from multipole moments, to polarization, etc. Choose this + category if your intention is to provide properties, even if + certain electronic properties come for 'free' with SCF calculations. + n_property : str + Describes a specific nuclear property calculation, such as nuclear + multipole moments (e.g. nitrogen quadrupole), magnetic moments, etc. + states : str + Describes an excited state calculation that does not involve geometry + optimizations. This may refer to oscillator strengths/transition + moments. + """ + + scf = "SCFCycle" + opt_trajectory = "OptimizationCycle" + e_property = "ElectronicPropertyCalculation" + n_property = "NuclearPropertyCalculation" + states = "ExcitedStateCalculation" + + +class SplitHashSchema(MatsciMLSchema): + """ + Schema for defining a set of data splits, with associated + hashes for each split. + + This model will do a rudimentary check to make sure each + value resembles a 64-character long hash. This is intended + to work in tandem with the ``MatSciMLDataset.blake2s_checksum`` + property. For producers of datasets, you will need to be able + to load in the dataset and record the checksum, and add it + to this data structure. + + Attributes + ---------- + train : str + blake2s hash for the training split. + test + blake2s hash for the test split. + validation + blake2s hash for the validation split. + predict + blake2s hash for the predict split. + """ + + train: str | None = None + test: str | None = None + validation: str | None = None + predict: str | None = None + + @staticmethod + def string_is_hashlike(input_str: str) -> bool: + """ + Simple method for checking if a string looks like a hash, + which is just a string of lowercase alphanumerals. + + Parameters + ---------- + input_str : str + String to check if it looks like a hash. + + Returns + ------- + bool + True if the string appears to be a hash, False + otherwise. + """ + lookup = re.compile(r"[0-9a-f]{64}") + # returns None if there are no matches + if lookup.match(input_str): + return True + return False + + @field_validator("*") + @classmethod + def check_hash_like(cls, value: str, info: ValidationInfo) -> str: + if value is not None: + is_string_like_hash = SplitHashSchema.string_is_hashlike(value) + if not is_string_like_hash: + raise ValueError( + f"Entry for {info.field_name} does not appear to be a hash." + ) + return value + + @model_validator(mode="after") + def check_not_all_none(self) -> Self: + if not any( + [getattr(self, key) for key in ["train", "test", "validation", "predict"]] + ): + raise RuntimeError("No splits were defined.") + return self + + +class PeriodicBoundarySchema(BaseModel): + """ + Specifies periodic boundary conditions for each axis. + """ + + x: bool + y: bool + z: bool + + +class NormalizationSchema(MatsciMLSchema): + target_key: str + mean: float + std: float + + @field_validator("std") + @classmethod + def std_must_be_positive(cls, value: float) -> float: + if value < 0.0: + raise ValidationError("Standard deviation cannot be negative.") + return value + + def to_normalizer(self) -> Normalizer: + """ + Create ``Normalizer`` object for compatiability with training + pipelines. + + TODO refactor the ``Normalizer`` class to be needed, and just + use this class directly. + + Returns + ------- + Normalizer + Normalizer object used for computation. + """ + return Normalizer(mean=self.mean, std=self.std) + + +class GraphWiringSchema(MatsciMLSchema): + """ + Provides a specification for tracking how graphs within + a dataset are wired. Primarily, a ``cutoff_radius`` is + specified to artificially truncate a neighborhood, and + a package or algorithm is used to compute this neighborhood + function. + + The validation of this schema includes checking to ensure + that a specific package used for performing neighborhood + calculations matches the recorded version, to ensure that + there are no unexpected changes to the algorithm used. + + Attributes + ---------- + cutoff_radius : float + Cutoff radius used to specify the neighborhood region, + typically assumed to be in angstroms for most algorithms. + algorithm : Literal['pymatgen', 'ase', 'custom'] + Algorithm used for computing the atom neighborhoods and + subsequently edges. If either ``pymatgen`` or ``ase`` are + specified, schema validation will include checking versions. + allow_mismatch : bool + If set to True, we will not perform the algorithm version + checking. If set to False, a mismatch in algorithm version + will throw a ``ValidationError``. + algo_version : str, optional + Version number of ``pymatgen`` or ``ase`` depending on which + is being used. If ``algorithm`` is 'custom', this is ignored + as this check has not yet been implemented. + algo_hash_path : str, optional + Nominally a path to a Python import that, when imported, returns + a hash used to match against ``algo_hash``. Currently not implemented. + algo_hash : str, optional + Version hash used for a custom algorithm. Currently not implemented. + kwargs : dict[str, Any], optional + Additional keyword arguments that might be passed + to custom algorithms. + """ + + cutoff_radius: float + algorithm: Literal["pymatgen", "ase", "custom"] + allow_mismatch: bool + adaptive_cutoff: bool + algo_version: str | None = None + algo_hash_path: str | None = None + algo_hash: str | None = None + max_neighbors: int = -1 + kwargs: dict[str, Any] = Field(default_factory=dict) + + @classmethod + def from_transform( + cls, pbc_transform: PeriodicPropertiesTransform, allow_mismatch: bool + ) -> Self: + package = pbc_transform.backend + version = cls._check_package_version(package) + return cls( + cutoff_radius=pbc_transform.cutoff_radius, + algorithm=pbc_transform.backend, + allow_mismatch=allow_mismatch, + algo_version=version, + adaptive_cutoff=pbc_transform.adaptive_cutoff, + max_neighbors=pbc_transform.max_neighbors, + kwargs={ + "is_cartesian": pbc_transform.is_cartesian, + "allow_self_loops": pbc_transform.allow_self_loops, + "convert_to_unit_cell": pbc_transform.convert_to_unit_cell, + }, + ) + + @staticmethod + def _check_package_version(backend: str) -> str | None: + """Simple function for checking the version of pymatgen/ase""" + if backend == "pymatgen": + pmg = import_module("pymatgen.core") + actual_version = pmg.__version__ + elif backend == "ase": + ase = import_module("ase") + actual_version = ase.__version__ + else: + logger.warning("Periodic backend unsupported and cannot check version.") + actual_version = None + return actual_version + + @model_validator(mode="after") + def check_algo_version(self): + if not self.algo_version and not self.algo_hash: + raise RuntimeError("At least one form of algorithm versioning is required.") + if self.algo_version: + actual_version = self._check_package_version(self.algorithm) + if actual_version is None: + return self + # throw validation error only if we don't allow mismatches + if self.algo_version != actual_version and not self.allow_mismatch: + raise RuntimeError( + f"GraphWiringSchema algorithm version mismatch for package {self.algorithm}" + f" installed {actual_version}, expected {self.algo_version}." + ) + elif self.algo_version != actual_version: + logger.warning( + f"GraphWiringSchema algorithm version mismatch for package {self.algorithm}" + f" installed {actual_version}, expected {self.algo_version}." + " `allow_mismatch` was set to True, turning this into a warning message instead of an exception." + ) + return self + if self.algo_hash: + algo_path = getattr(self, "algo_hash_path", None) + if not algo_path: + raise RuntimeError( + "Graph wiring algorithm hash specified but no path to resolve." + ) + logger.warning( + "Hash checking for custom algorithms is not currently implemented." + ) + return self + + def to_transform(self) -> PeriodicPropertiesTransform: + """ + Generates the transform responsible for graph edge computation + based on the schema. + + Returns + ------- + PeriodicPropertiesTransform + Instance of the periodic properties transform with + schema settings mapped. + """ + if self.algorithm in ["pymatgen", "ase"]: + possible_kwargs = get_all_args(PeriodicPropertiesTransform) + valid_kwargs = { + key: value + for key, value in self.kwargs.items() + if key in possible_kwargs + } + return PeriodicPropertiesTransform( + cutoff_radius=self.cutoff_radius, backend=self.algorithm, **valid_kwargs + ) + else: + raise NotImplementedError( + "Custom backend for neighborhood algorithm not supported yet." + ) + + +class TargetSchema(MatsciMLSchema): + """ + Schema that specifies a target label or property. + + The intention is to provide sufficient fields to make targets + in a dataset fully documented to leave zero ambiguities. + + Attributes + ---------- + name : str + String name of the target. This will be used to look up + targets throughout the pipeline. + shape : str + Designated shape of the target. Use '*' to specify variable + dimensions, and integers for fixed dimensions separated + by commas. As an example, '*' could designate a number of + node features (since the number of nodes is variable), and + '*, 3' could represent a vector property also over nodes. + description : str + Long text description of what this target is and how it + was calculated. + units : str, optional + Expected units of this property. This is more for documentation + for now, but in the future it may be helpful to do unit + conversions with this field. + """ + + name: str + shape: str + description: str + units: str | None = None + + @model_validator(mode="after") + def check_shape_str(self) -> Self: + """This checks to make sure that the shape specification is valid for ``Shape``.""" + invalid_regex = re.compile(r"[^\d\,\*\s]+") + if invalid_regex.search(self.shape): + raise ValueError( + f"Target shape should be specified with digits, commas, and wildcard only. Got {self.shape}" + ) + return self + + +class DatasetSchema(MatsciMLSchema): + """ + A schema for defining a collection of data samples. + + This schema is to accompany all serialized datasets, which + simultaneously documents the data **and** improves its + reproducibility by defining + + Attributes + ---------- + name : str + Name of the dataset. + creation : datetime + An immutable ``datetime`` for when the dataset was + created. + targets : list[TargetSchema] + A list of ``TargetSchema`` objects or dictionaries that satisfy + the schema. This is used simultaneously for documentation as + well as for data loading to look specifically for these keys. + split_blake2s : SplitHashSchema + Schema representing blake2s checksums for each dataset split. + modified : datetime, optional + Datetime object for recording when the dataset was last + modified. + description : str, optional + An optional, but highly recommended string for describing the + nature and origins of this dataset. There is no limit to how + long this description is, but ideally should be readable by + humans and whatever is not obvious (such as what target key + represents what property) should be included here. + graph_schema : GraphWiringSchema, optional + A schema that defines how the dataset is intended to build + edges. This defines dictates how edges are created at runtime. + normalization : dict[str, NormalizationSchema], optional + Defines a collection of normalization mean/std for targets. + If not None, this schema will validate against ``target_keys`` + and raise an error if there are keys in ``normalization`` that + do not match ``target_keys``. + node_stats : NormalizationSchema, optional + Mean/std values for the nodes per data sample. + edge_stats : NormalizationSchema, optional + Mean/std values for the number of edges per data sample. + seed : int, optional + Random seed used to generate the splits. This is kept optional + in the case where splits are not randomly generated. + """ + + name: str + creation: datetime + targets: list[TargetSchema] + split_blake2s: SplitHashSchema + dataset_type: DataSampleEnum | list[DataSampleEnum] + modified: datetime | None = None + description: str | None = None + graph_schema: GraphWiringSchema | None = None + normalization: dict[str, NormalizationSchema] | None = None + node_stats: NormalizationSchema | None = None + edge_stats: NormalizationSchema | None = None + seed: int | None = None + + @classmethod + def from_json_file(cls, json_path: PathLike) -> Self: + """ + Deserialize a JSON file, validating against the expected dataset + schema. + + Parameters + ---------- + json_path : PathLike + Path to a JSON metadata file. + + Returns + ------- + DataSampleSchema + Instance of a validated ``DatasetSchema``. + + Raises + ------ + FileNotFoundError + If the specified JSON file does not exist. + """ + if not isinstance(json_path, Path): + json_path = Path(json_path) + if not json_path.exists(): + raise FileNotFoundError( + f"{json_path} JSON metadata file could not be found." + ) + with open(json_path, "r") as read_file: + return cls.model_validate_json(read_file.read(), strict=True) + + @field_validator("dataset_type") + @classmethod + def cast_dataset_type( + cls, value: str | DataSampleEnum | list[DataSampleEnum | str] + ) -> list[DataSampleEnum]: + """ + Validate and cast string values into enums. + + Returns + ------- + list[DataSampleEnum] + Regardless of the number of types specified, return + a list of enum(s). + """ + if isinstance(value, str): + value = DataSampleEnum(value) + assert value + if isinstance(value, list): + temp = [] + for subvalue in value: + if isinstance(subvalue, str): + subvalue = DataSampleEnum(subvalue) + assert subvalue + temp.append(subvalue) + value = temp + else: + value = [value] + return value + + @model_validator(mode="after") + def check_target_normalization(self) -> Self: + """Cross-check target normalization specification with defined targets.""" + if self.normalization is not None: + # first check every key is available as targets + target_keys = set([target.name for target in self.targets]) + norm_keys = set(self.normalization.keys()) + # check to see if we have unexpected norm keys + diff = norm_keys - target_keys + if len(diff) > 0: + raise ValidationError(f"Unexpected keys in normalization: {diff}") + return self + + +class DataSampleSchema(MatsciMLSchema): + """ + Defines a schema for a single data sample. + + Includes fields for the most commonly used properties, particularly + for interatomic potentials, and includes additional ones that help + fully specify the state of a atomic structure/material, such as + isotopic masses, charges, and electronic states. + + This schema uses ``numpydantic`` for type and shape hinting; + it does not enforce what provides the array (e.g. ``torch`` + or ``numpy``), and when dumping to JSON will ensure that it + is serializable (i.e. converts it to a list first). We also implement + a consistency check after model creation to validate that per-atom + fields have the right number of atoms. + + Parameters + ---------- + index : int + Integer counter for the sample within the full dataset. + This value is used to uniquely identify the sample within the + dataset, and is helpful during debugging. + num_atoms : int + Specifies the number of atoms to be expected of this data + sample. Recording this explicitly makes it accessible, + instead of relying on determination after batching, etc. + with tricks like ``graph.ptr``. + cart_coords : NDArray[Shape['*, 3'], float] + A variable length array of 3D vectors of floating point numbers + corresponding to the cartesian coordinates of the structure. + atomic_numbers : NDArray[Shape['*'], int] + A variable length array of integers corresponding to the + atomic numbers of each species. + pbc : PeriodicBoundarySchema + A schema that specifies which axes are periodic. Can also + pass a dictionary of coordinate/``bool`` values instead of + constructing the ``PeriodicBoundarySchema`` ahead of time. + datatype : DataSampleEnum + Categorizes the data sample according to types defined in + the ``DataSampleEnum``. This is mainly for documentation, + but allows users/developers to expect certain data fields + to be populated. + alpha_electron_spins : NDArray[Shape['*'], float], optional + Specifies the alpha spin value per-atom as a variable length + array of floating point values. Assumes unrestricted/open + shell species; alternatively, specify the same number in + ``beta_electron_spins``. + beta_electron_spins : NDArray[Shape['*'], float], optional + Specifies the beta spin value per-atom as a variable length + array of floating point values. Assumes unrestricted/open + shell species; alternatively, specify the same number in + ``alpha_electron_spins``. + nuclear_spins : NDArray[Shape['*'], float], optional + Specifies the nuclear spin value per-atom as a variable + length array of floating point values. + isotopic_masses : NDArray[Shape['*'], float], optional + Specifies isotopic masses for each atom as a variable + length array of floating point values. + atomic_charges : NDArray[Shape['*'], float], optional + Specifies some characterization of charge for each atom + as a variable length array of floating point values. + atomic_energies : NDArray[Shape['*'], float], optional + Ascribes energy values - i.e. contributions from each atom - + to each atom in the sample as a variable length array of + floating point values. + atomic_labels : NDArray[Shape['*'], int], optional + Indices to 'tag' atoms - useful for classification tasks + and masking. Specified as a variable length array of integers. + total_energy : float, optional + Total energy of the system by whatever definition. If there + are multiple types of total energy values, we recommend writing + the most primitive type (e.g. total electronic energy) available, + and add others (e.g. corrections, etc.) to ``extra``. + forces : NDArray[Shape['*, 3'], float], optional + Specifies atomic forces on each atom as a variable length + array of 3D vectors with floating point values. + stresses : NDArray[Shape['*, 3, 3'], float], optional + Specifies a stress tensor per atom as a variable length + array of 3x3 matrices of floating point values. + lattice_parameters : NDArray[Shape['6'], float], optional + Specifies a vector of lattice parameters in order of + ``a,b,c,alpha,beta,gamma``. Assumes angles ``alpha,beta,gamma`` + are in degrees. + lattice_matrix : NDArray[Shape['3, 3'], float], optional + Specifies the fully specified lattice matrix as a 3x3 matrix + of floating point values. If the choice is between this field + or ``lattice_parameters``, populate this field but ideally both + as the matrix generated from parameters may not be unique. + edge_index : NDArray[Shape['2, *'], int], optional + Indices to indicate edges between atoms as a variable length + array of 2D vectors. The variable length in this case corresponds + to the number of **edges**, not atoms/nodes. + charge : float, optional + Some characterization of charge for the whole system as a floating + point value. + multiplicity : float, optional + Electronic multiplicity of the system as a floating point value. + While not explicitly checked, this couples with ``electronic_state_index`` + to fully specify an electronic state. This value is defined as 2S+1, + with S being the number of unpaired electrons (or the total electron spin + angular momentum). + electronic_state_index : int, default 0 + Specifies the electronic state, with zero (default) being the + electronic ground state for a given multiplicity. The index + should be ordered by energy, i.e. the first singlet excited state + would be given by a value of 1, and a multiplicity of 0. + images : NDArray[Shape['*, 3'], int], optional + Variable length array of 3D vectors of integers that index + periodic images (i.e. neighboring unit cells). The length + is expected to match that of ``edge_index``. + offsets : NDArray[Shape['*, 3'], float], optional + Variable length array of 3D vectors of floating points that + can be used to shift the point of reference from the origin + unit cell to the corresponding periodic image. The length + should be the same as ``edge_index``/``images``. + unit_offsets : NDArray[Shape['*, 3'], float], optional + Builds on top of ``offsets``, including the difference in + positions between two atoms in fractional coordinates. + Expects the length to be the same as ``edge_index``/``images`` + and ``offsets``. + graph : Any, optional + This field is not intended to be serialized, but is included + to allow the field to be populated during runtime as a way + to store an arbitrary graph object. We do not want to serialize + the graph object directly, as reloading with version mismatches + can be made impossible with breaking API changes regardless of + the framework. Instead, opt to save ``edge_index``. + extras : dict[str, Any], optional + Provides a vehicle for out-of-spec data to be transported in + this schema. This is useful if the property you wish to save + does not fit under any of the currently defined fields, but + is not recommended as it bypasses any of the type and shape + validations that ``pydantic``/``numpydantic`` provides. + transform_store : dict[str, Any], optional + Dictionary storage for transform results. This is a way to organize + products of transforms, e.g. instead of overwriting properties. + """ + + index: int + num_atoms: int + cart_coords: NDArray[Shape["*, 3"], float] + atomic_numbers: NDArray[Shape["*"], int] + pbc: PeriodicBoundarySchema + datatype: DataSampleEnum + alpha_electron_spins: NDArray[Shape["*"], float] | None = None + beta_electron_spins: NDArray[Shape["*"], float] | None = None + nuclear_spins: NDArray[Shape["*"], float] | None = ( + None # optional nuclear spin at atom + ) + isotopic_masses: NDArray[Shape["*"], float] | None = None + atomic_charges: NDArray[Shape["*"], float] | None = None + atomic_energies: NDArray[Shape["*"], float] | None = None + atomic_labels: NDArray[Shape["*"], int] | None = ( + None # allows atoms to be tagged with class labels + ) + total_energy: float | None = None + forces: NDArray[Shape["*, 3"], float] | None = None + stresses: NDArray[Shape["*, 3, 3"], float] | None = None + lattice_parameters: NDArray[Shape["6"], float] | None = None + lattice_matrix: NDArray[Shape["3, 3"], float] | None = None + edge_index: NDArray[Shape["2, *"], int] | None = ( + None # allows for precomputed edges + ) + frac_coords: NDArray[Shape["*, 3"], float] | None = None + charge: float | None = None # overall system charge + multiplicity: float | None = None # overall system multiplicity + electronic_state_index: int = 0 + images: NDArray[Shape["*, 3"], int] | None = None + offsets: NDArray[Shape["*, 3"], float] | None = None + unit_offsets: NDArray[Shape["*, 3"], float] | None = None + graph: Any = None + extras: dict[str, Any] | None = None + transform_store: dict[str, Any] = Field(default_factory=dict) + + model_config = ConfigDict(arbitrary_types_allowed=True, use_enum_values=True) + + def __getattr__(self, name: str) -> Any | None: + """Overrides the behavior of `getattr` to also look in `extras` if available""" + if name in self.__dir__(): + return self.__dict__[name] + if self.extras is not None and name in self.extras: + return self.extras[name] + return None + + def _exception_wrapper(self, exception: Exception): + """ + Re-raises an exception that uses this class, and chains the sample index. + + This is to make debugging more informative, as it allows + arbitrary exceptions to be raised while also informing us + which sample specifically is causing issues. + + Parameters + ---------- + exception : Exception + Any possible ``Exception``. The type of exception is + used to re-raise the exception including the sample index. + + Raises + ------ + exception_cls + Raises the same exception as the input one, with + an additional message. + """ + exception_cls = exception.__class__ + raise exception_cls( + f"Data schema validation failed at sample {self.index}." + ) from exception + + @field_validator("lattice_matrix") + @classmethod + def orthogonal_lattice_matrix(cls, values: NDArray[Shape["3, 3"], float] | None): + """ + Ensures that the lattice matrix comprises a complete + basis of orthogonal vectors. + """ + + if values is not None: + values = complete_cell(values) + return values + + @model_validator(mode="before") + @classmethod + def convert_lattice_and_parameters(cls, values: Any) -> Any: + lattice_params = values.get("lattice_parameters", None) + lattice_matrix = values.get("lattice_matrix", None) + if lattice_params is None and lattice_matrix is not None: + lattice_params = cell_to_cellpar(lattice_matrix) + values["lattice_parameters"] = lattice_params + if lattice_params is not None and lattice_matrix is None: + lattice_matrix = cellpar_to_cell(lattice_params) + values["lattice_matrix"] = lattice_matrix + return values + + @field_validator("lattice_parameters") + @classmethod + def check_lattice_param_angles( + cls, values: NDArray[Shape["6"], np.floating] | None + ): + """ + Check to make sure that if lattice parameters are set, then the + angles are in degrees, not radians. + + The check is done by expecting a vector of six elements is passed, + with the latter three are the angles. We assume that angles are + in radians if none of the values exceed 2pi (i.e. 360 degrees). + + This check is only performed if the lattice parameters are provided. + + Parameters + ---------- + values : NDArray[Shape['6'], float] | None + Vector of lattice parameters. + + Raises + ------ + ValueError: + If all lattice angles are smaller than or equal to + 2pi, they are likely to be in radians. + """ + if values is not None: + all_are_radians = np.all(values[3:] <= 2 * np.pi) + if all_are_radians: + raise ValueError( + "Expected lattice angles to be in degrees. All input values are smaller than 2 * np.pi." + ) + return values + + @model_validator(mode="after") + def coordinate_consistency(self) -> Self: + """Sets fractional coordinates if parameters are available, and checks them""" + if self.frac_coords is None and self.lattice_parameters is not None: + self.frac_coords = cart_frac_conversion( + self.cart_coords, *self.lattice_parameters, to_fractional=True + ) + if isinstance(self.frac_coords, NDArray): + if self.frac_coords.shape != self.cart_coords.shape: + raise ValueError( + "Fractional coordinate dimensions do not match cartesians." + ) + # round coordinate values so that -1e-6 is just zero and doesn't fail the test + round_coords = np.round(self.frac_coords, decimals=2) + if np.any(np.logical_or(round_coords > 1.01, round_coords < 0.0)): + logger.warning( + f"Fractional coordinates are outside of [0, 1]: {round_coords}" + ) + return self + + @model_validator(mode="after") + def atom_count_consistency(self) -> Self: + for key in [ + "atomic_numbers", + "electron_spins", + "nuclear_spins", + "isotopic_masses", + "atomic_charges", + "atomic_energies", + "atomic_labels", + ]: + value = getattr(self, key, None) + if value is not None: + if len(value) != self.num_atoms: + self._exception_wrapper( + ValueError( + f"Inconsistent number of elements for {key}; expected {self.num_atoms}, got {len(value)}." + ) + ) + for key in ["forces", "stresses"]: + value = getattr(self, key, None) + if value is not None: + if value.shape[0] != self.num_atoms: + self._exception_wrapper( + ValueError( + f"Inconsistent number of elements for node property {key}; expected {self.num_atoms}, got {value.shape[0]}." + ) + ) + if self.edge_index is not None: + for key in ["images", "offsets", "unit_offsets"]: + value = getattr(self, key, None) + if value is not None: + if value.shape[0] != self.edge_index: + self._exception_wrapper( + ValueError( + f"Inconsistent number of elements for edge property {key}." + ) + ) + return self + + def __eq__(self, other: DataSampleSchema) -> bool: + """Overrides the equivalence test, including array allclose comparisons""" + assert isinstance( + other, DataSampleSchema + ), "Equal comparison can only be done against `DataSampleSchema`." + self_dict = self.model_dump() + other_dict = other.model_dump() + for key in self_dict.keys(): + self_value = self_dict[key] + other_value = other_dict[key] + # skip None comparisons + if self_value is None and other_value is None: + continue + try: + if type(self_value) != type(other_value): + return False + if isinstance(self_value, torch.Tensor): + check = torch.allclose(self_value, other_value) + if not check: + return False + elif isinstance(self_value, np.ndarray): + check = np.allclose(self_value, other_value) + if not check: + return False + else: + # for everything else, str, int, float, etc. builtin types + if not self_value == other_value: + return False + except Exception: + # if at any point any exception is raised, they're + # not equal + logger.debug(f"Comparison failed on key {key}") + return False + return True + + @model_validator(mode="after") + def check_edge_data(self) -> Self: + """Ensure that if edge properties are consistent with number of edges.""" + if self.edge_index is not None: + num_edges = self.edge_index.shape[1] + for key in ["images", "offsets", "unit_offsets"]: + value = getattr(self, key) + if value is not None: + if value.shape[0] != num_edges: + self._exception_wrapper( + ValueError( + f"Mismatch in edge property {key}. " + "Expected the first dimension to match the number of edges." + ) + ) + return self + + def to_ase_atoms(self) -> Atoms: + """ + Provides a simple conversion to an ``ase.Atoms`` object. + + This method does not strictly check that outputs are mapped + correctly, but at least maps the fields in the schema to + intended attributes in the ``Atoms`` class. + + Returns + ------- + Atoms + Instance of an ``Atoms`` object constructed with + the current data sample. + """ + pbc = [value for value in self.pbc.model_dump().values()] + return Atoms( + positions=self.cart_coords, + cell=self.lattice_matrix, + numbers=self.atomic_numbers, + tags=self.atomic_labels, + charges=self.atomic_charges, + masses=self.isotopic_masses, + pbc=pbc, + ) + + @property + def graph_backend(self) -> Literal["dgl", "pyg"] | None: + if not self.graph: + return None + else: + if "pyg" in package_registry: + from torch_geometric.data import Data as PyGGraph + + if isinstance(self.graph, PyGGraph): + return "pyg" + elif "dgl" in package_registry: + from dgl import DGLGraph + + if isinstance(self.graph, DGLGraph): + return "dgl" + else: + self._exception_wrapper( + TypeError(f"Unexpected graph type: {type(self.graph)}") + ) + + +def _concatenate_data_list(all_data: list[Any]) -> list[Any] | torch.Tensor: + """ + Concatenates an arbitrary list of data where possible. + + For array-types (NumPy, PyTorch), we first convert all + of the sample data into tensors, followed by concatenation + along the first dimension (nodes or edges). + + For scalar-types, we return a 1D tensor. + + For all other types, we return the inputs unmodified. + + Parameters + ---------- + all_data : list[Any] + List of data to try and concatenate. + + Returns + ------- + list[Any] | torch.Tensor + If the concatenation was successful, returns a tensor. + Otherwise, returns the unmodified input. + """ + sample = all_data[0] + if isinstance(sample, (np.ndarray, torch.Tensor)): + # homogenize all samples into tensors + all_data = [torch.Tensor(s) for s in all_data] + return torch.concat(all_data) + if isinstance(sample, (float, int)): + output = torch.Tensor(all_data) + if isinstance(sample, int): + output = output.long() + return output + else: + # leave the data as a list + return all_data + + +def collate_samples_into_batch_schema(samples: list[DataSampleSchema]) -> object: + """ + Function to collate a list of ``DataSampleSchema`` into a dynamically + generated ``BatchSchema``. + + The additional logic in this function is to handle graphs, copying + references to their respective data, and calling the respective framework + batching functions. + + The purpose of on-the-fly schema generation is to do some degree of + validation, but primarily provide regular structure for use in the + task pipeline side of things. Given that the schema should be serializable, + it may also make debugging more streamlined. + + Parameters + ---------- + samples : list[DataSampleSchema] + List of data samples that have been pre-validated. + + Returns + ------- + object + Instance of a ``BatchSchema`` object. This is not explicitly annotated + since the model/class is defined dynamically based off incoming data. + """ + ref_schema = samples[0].schema() + # initial keys are going to hold the main structure of the schema + schema_to_generate = { + "num_atoms": (NDArray[Shape["*"], int] | torch.LongTensor, ...), + "batch_size": (int, ...), + "graph": (Any | None, None), + "num_edges": (NDArray[Shape["*"], int] | torch.LongTensor | None, None), + "embeddings": (Embeddings | None, None), + "outputs": (ModelOutput | None, None), + } + collected_data = {} # holds all the data to unpack into the generated schema + # check to see if graphs are present + if samples[0].graph is not None: + graph_sample = samples[0].graph + if "pyg" in package_registry: + from torch_geometric.data import Batch, Data + + if isinstance(graph_sample, Data): + batched_graph = Batch.from_data_list( + [sample.graph for sample in samples] + ) + graph_type = Batch + for key in batched_graph.keys(): + data = getattr(batched_graph, key) + schema_to_generate[key] = (type(data), ...) + collected_data[key] = data + collected_data["num_edges"] = _concatenate_data_list( + [sample.graph.batch_num_edges() for sample in samples] + ).long() + else: + from dgl import DGLGraph, batch + + if isinstance(graph_sample, DGLGraph): + batched_graph = batch([sample.graph for sample in samples]) + graph_type = DGLGraph + for key, data in batched_graph.ndata.items(): + schema_to_generate[key] = (type(data), ...) + collected_data[key] = data + for key, data in batched_graph.edata.items(): + schema_to_generate[key] = (type(data), ...) + collected_data[key] = data + collected_data["num_edges"] = _concatenate_data_list( + [sample.graph.batch_num_edges() for sample in samples] + ).long() + collected_data["num_atoms"] = _concatenate_data_list( + [sample.num_atoms for sample in samples] + ).long() + collected_data["graph"] = batched_graph + schema_to_generate["graph"] = (graph_type, ...) + # for everything else that wasn't packed into the graph + for key in ref_schema["required"]: + if key not in schema_to_generate: + schema_to_generate[key] = (Any, ...) + if key not in collected_data: + collected_data[key] = _concatenate_data_list( + [getattr(sample, key) for sample in samples] + ) + collected_data["batch_size"] = len(samples) + # generate the schema, then create the model + BatchSchema = create_model( + "BatchSchema", + **schema_to_generate, + __config__=ConfigDict(arbitrary_types_allowed=True, use_enum_values=True), + ) + return BatchSchema(**collected_data) diff --git a/matsciml/datasets/tests/test_schema.py b/matsciml/datasets/tests/test_schema.py new file mode 100644 index 00000000..27911327 --- /dev/null +++ b/matsciml/datasets/tests/test_schema.py @@ -0,0 +1,220 @@ +from hashlib import blake2s +from datetime import datetime + +import pytest +import numpy as np +import torch +from ase.geometry import cell_to_cellpar + +from matsciml.datasets import schema +from matsciml.datasets.transforms import PeriodicPropertiesTransform + +fake_hashes = { + key: blake2s(bytes(key.encode("utf-8"))).hexdigest() + for key in ["train", "test", "validation", "predict"] +} + + +def test_split_schema_pass(): + s = schema.SplitHashSchema(**fake_hashes) + assert s + + +def test_partial_schema_pass(): + s = schema.SplitHashSchema( + train=fake_hashes["train"], validation=fake_hashes["validation"] + ) + assert s + + +def test_bad_hash_fail(): + with pytest.raises(ValueError): + # chop off the end of the hash + s = schema.SplitHashSchema(train=fake_hashes["train"][:-2]) # noqa: F841 + + +def test_no_hashes(): + with pytest.raises(RuntimeError): + s = schema.SplitHashSchema() # noqa: F841 + + +def test_dataset_minimal_schema_pass(): + splits = schema.SplitHashSchema( + train=fake_hashes["train"], validation=fake_hashes["validation"] + ) + dset = schema.DatasetSchema( + name="GenericDataset", + creation=datetime.now(), + dataset_type="SCFCycle", + targets=[ + { + "name": "total_energy", + "shape": "0", + "description": "Total energy of the system.", + }, + { + "name": "forces", + "shape": "*,3", + "description": "Atomic forces per node.", + }, + ], + split_blake2s=splits, + ) + assert dset + + +def test_dataset_minimal_schema_roundtrip(): + """Make sure the dataset minimal schema can dump and reload""" + splits = schema.SplitHashSchema( + train=fake_hashes["train"], validation=fake_hashes["validation"] + ) + dset = schema.DatasetSchema( + name="GenericDataset", + creation=datetime.now(), + dataset_type="SCFCycle", + targets=[ + { + "name": "total_energy", + "shape": "0", + "description": "Total energy of the system.", + }, + { + "name": "forces", + "shape": "*,3", + "description": "Atomic forces per node.", + }, + ], + split_blake2s=splits, + ) + json_rep = dset.model_dump_json() + reloaded_dset = dset.model_validate_json(json_rep) + assert reloaded_dset == dset + + +@pytest.mark.parametrize("backend", ["pymatgen", "ase"]) +@pytest.mark.parametrize("cutoff_radius", [3.0, 6.0, 10.0]) +def test_graph_wiring_from_transform(backend, cutoff_radius): + transform = PeriodicPropertiesTransform(cutoff_radius, backend=backend) + s = schema.GraphWiringSchema.from_transform(transform, allow_mismatch=False) + recreate = s.to_transform() + assert transform.__dict__ == recreate.__dict__ + + +@pytest.mark.parametrize("backend", ["pymatgen", "ase"]) +def test_graph_wiring_version_mismatch(backend): + """Ensures that an exception is thrown when backend version does not match""" + with pytest.raises(RuntimeError): + s = schema.GraphWiringSchema( # noqa: F841 + cutoff_radius=10.0, + algorithm=backend, + allow_mismatch=False, + algo_version="fake_version", + adaptive_cutoff=False, + ) + + +@pytest.mark.parametrize("num_atoms", [5, 12, 16]) +@pytest.mark.parametrize("array_lib", ["numpy", "torch"]) +def test_basic_data_sample_schema(num_atoms, array_lib): + if array_lib == "numpy": + coords = np.random.rand(num_atoms, 3) + numbers = np.random.randint(1, 100, (num_atoms)) + else: + coords = torch.rand(num_atoms, 3) + numbers = torch.randint(1, 100, (num_atoms,)) + pbc = {"x": True, "y": True, "z": True} + data = schema.DataSampleSchema( + index=0, + num_atoms=num_atoms, + cart_coords=coords, + atomic_numbers=numbers, + pbc=pbc, + datatype="SCFCycle", + ) + assert data + + +@pytest.mark.parametrize("num_atoms", [5, 12, 16]) +@pytest.mark.parametrize("array_lib", ["numpy", "torch"]) +def test_basic_data_sample_roundtrip(num_atoms, array_lib): + if array_lib == "numpy": + coords = np.random.rand(num_atoms, 3) + numbers = np.random.randint(1, 100, (num_atoms)) + else: + coords = torch.rand(num_atoms, 3) + numbers = torch.randint(1, 100, (num_atoms,)) + pbc = {"x": True, "y": True, "z": True} + data = schema.DataSampleSchema( + index=0, + num_atoms=num_atoms, + cart_coords=coords, + atomic_numbers=numbers, + pbc=pbc, + datatype="SCFCycle", + ) + json = data.model_dump_json() + recreate = schema.DataSampleSchema.model_validate_json(json) + assert recreate == data + + +@pytest.mark.parametrize("num_atoms", [3, 10, 25]) +@pytest.mark.parametrize("array_lib", ["numpy", "torch"]) +def test_data_sample_fail_coord_shape(num_atoms, array_lib): + if array_lib == "numpy": + coords = np.random.rand(num_atoms, 5) + numbers = np.random.randint(1, 100, (num_atoms)) + else: + coords = torch.rand(num_atoms, 5) + numbers = torch.randint(1, 100, (num_atoms,)) + pbc = {"x": True, "y": True, "z": True} + with pytest.raises(ValueError): + data = schema.DataSampleSchema( # noqa: F841 + index=0, + num_atoms=num_atoms, + cart_coords=coords, + atomic_numbers=numbers, + pbc=pbc, + datatype="SCFCycle", + ) + + +def test_lattice_param_to_matrix_consistency(): + """Make sure that lattice parameters map to matrix correctly during validation""" + coords = np.random.rand(5, 3) + numbers = np.random.randint(1, 100, (5)) + data = schema.DataSampleSchema( + index=0, + num_atoms=5, + cart_coords=coords, + atomic_numbers=numbers, + pbc={"x": True, "y": True, "z": True}, + datatype="OptimizationCycle", + lattice_parameters=[5.0, 5.0, 5.0, 90.0, 90.0, 90.0], + ) + assert data.frac_coords is not None + assert data.lattice_matrix is not None + reconverted = cell_to_cellpar(data.lattice_matrix) + assert np.allclose(reconverted, data.lattice_parameters) + exact = np.array([[5.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 5.0]]) + assert np.allclose(exact, data.lattice_matrix) + + +def test_lattice_matrix_to_param_consistency(): + """Make sure that lattice parameters map to matrix correctly during validation""" + coords = np.random.rand(5, 3) + numbers = np.random.randint(1, 100, (5)) + data = schema.DataSampleSchema( + index=0, + num_atoms=5, + cart_coords=coords, + atomic_numbers=numbers, + pbc={"x": True, "y": True, "z": True}, + datatype="OptimizationCycle", + lattice_matrix=[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + ) + assert data.frac_coords is not None + assert data.lattice_matrix is not None + converted = cell_to_cellpar(data.lattice_matrix) + assert np.allclose(converted, data.lattice_parameters) + exact = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + assert np.allclose(exact, data.lattice_matrix) diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index a580811d..f45da6e3 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -18,6 +18,7 @@ from pymatgen.core import Lattice, Structure from tqdm import tqdm import ase +from ase.geometry import Cell, cellpar_to_cell, complete_cell from ase.neighborlist import neighbor_list from matsciml.common import package_registry @@ -1030,52 +1031,12 @@ def cart_frac_conversion( Fractional coordinate representation """ - def cot(x: float) -> float: - """cotangent of x""" - return -np.tan(x + np.pi / 2) - - def csc(x: float) -> float: - """cosecant of x""" - return 1 / np.sin(x) - - # convert to radians if angles are passed as degrees - if angles_are_degrees: - alpha = alpha * np.pi / 180.0 - beta = beta * np.pi / 180.0 - gamma = gamma * np.pi / 180.0 - - # This matrix is normally for fractional to cart. Implements the matrix found in - # https://en.wikipedia.org/wiki/Fractional_coordinates#General_transformations_between_fractional_and_Cartesian_coordinates - rotation = torch.tensor( - [ - [ - a - * np.sin(beta) - * np.sqrt( - 1 - - ( - (cot(alpha) * cot(beta)) - - (csc(alpha) * csc(beta) * np.cos(gamma)) - ) - ** 2.0 - ), - 0.0, - 0.0, - ], - [ - a * csc(alpha) * np.cos(gamma) - a * cot(alpha) * np.cos(beta), - b * np.sin(alpha), - 0.0, - ], - [a * np.cos(beta), b * np.cos(alpha), c], - ], - dtype=coords.dtype, - ) + lattice_matrix = complete_cell(cellpar_to_cell([a, b, c, alpha, beta, gamma])) + cell = Cell(lattice_matrix) if to_fractional: - # invert elements for the opposite conversion - rotation = torch.linalg.inv(rotation) - output = coords @ rotation - return output + return cell.scaled_positions(coords) + else: + return cell.cartesian_positions(coords) def build_nearest_images(max_image_number: int) -> torch.Tensor: diff --git a/pyproject.toml b/pyproject.toml index 1f3b64a5..cda01f2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,9 @@ dependencies = [ "e3nn", "mace-torch==0.3.6", "monty==2024.2.2", - "loguru" + "loguru", + "numpydantic>=1.6.4", + "pydantic>=2.9.2", ] description = "PyTorch Lightning and Deep Graph Library enabled materials science deep learning pipeline" dynamic = ["version", "readme"]