Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training documentation update #298

Merged
merged 11 commits into from
Sep 30, 2024
89 changes: 89 additions & 0 deletions docs/source/best-practices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,81 @@ the accelerator.
Training
--------

Target normalization
^^^^^^^^^^^^^^^^^^^^

Tasks can be provided with ``normalize_kwargs``, which are key/value mappings
that specify the mean and standard deviation of a target; an example is given below.

.. code-block: python

Task(
...,
normalize_kwargs={
"energy_mean": 0.0,
"energy_std": 1.0,
}
)

The example above will normalize ``energy`` labelsm and can be substituted with
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 9231641

any of target key of interest (e.g. ``force``, ``bandgap``, etc.)

Target loss scaling
^^^^^^^^^^^^^^^^^^^

A generally common practice is to scale some targets relative to others (e.g. force over
energy, etc). To specify this, you can pass a ``task_loss_scaling`` dictionary to
any task module, which maps target keys to a floating point value that will be used
to multiply the corresponding target loss value before summation and backpropagation.

.. code-block: python
Task(
...,
task_loss_scaling={
"energy": 1.0,
"force": 10.0
}
)


A related, but alternative way to specify target scaling is to apply a *schedule* to
the training loss contributions: essentially, this provides a way to smoothly ramp
up (or down) different targets, i.e. to allow for more complex training curricula.
To achieve this, you will need to use the ``LossScalingScheduler`` callback,

.. autoclass:: matsciml.lightning.callbacks.LossScalingScheduler
:members:


To specify this callback, you must pass subclasses of ``BaseScalingSchedule`` as arguments.
Each schedule type implements the functional form of a schedule, and currently
there are two concrete schedules. Composed together, an example would look like this

.. code-block: python

import pytorch_lightning as pl
from matsciml.lightning.callbacks import LossScalingScheduler
from matsciml.lightning.loss_scaling import LinearScalingSchedule

scheduler = LossScalingScheduler(
LinearScalingSchedule("energy", initial_value=1.0, end_value=5.0, step_frequency="epoch")
)
trainer = pl.Trainer(callbacks=[scheduler])


The stepping schedule is determined during ``setup`` (as training begins), where the callback will
inspect ``Trainer`` arguments to determine how many steps will be taken. The ``step_frequency``
just specifies how often the learning rate is updated.


.. autoclass:: matsciml.lightning.loss_scaling.LinearScalingSchedule
:members:


.. autoclass:: matsciml.lightning.loss_scaling.SigmoidScalingSchedule
:members:


Quick debugging
^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -223,6 +298,20 @@ inspired by observations made in LLM training research, where the breakdown of
assumptions in the convergent properties of ``Adam``-like optimizers causes large
spikes in the training loss. This callback can help identify these occurrences.

The ``devset``/``fast_dev_run`` approach detailed above is also useful for testing
engineering/infrastructure (e.g. accelerator offload and logging), but not necessarily
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fast_dev_run disables logging i believe.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch - fixed in 17d7582

for probing training dynamics. Instead, we recommend using the ``overfit_batches``
argument in ``pl.Trainer``

.. code-block:: python
import pytorch_lightning as pl

trainer = pl.Trainer(overfit_batches=100)


This will disable shuffling in the training and validation splits (per the PyTorch Lightning
documentation), and ensure that the same batches are being reused every epoch.

.. _e3nn documentation: https://docs.e3nn.org/en/latest/

.. _IPEX installation: https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu
18 changes: 18 additions & 0 deletions docs/source/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,24 @@ Inference
"Inference" can be a bit of an overloaded term, and this page is broken down into different possible
downstream use cases for trained models.

Task ``predict`` and ``forward`` methods
----------------------------------------

``matsciml`` tasks implement separate ``forward`` and ``predict`` methods. Both take a
``BatchDict`` as input, and the latter wraps the former. The difference, however, is that
``predict`` is intended for inference use primarily because it will also take care of
reversing the normalization procedure, if they were provided during training, *and* perhaps
more importantly, will ensure that the exponential moving average weights are used instead
of the training ones.

In the special case of force prediction (as a derivative of the energy) tasks, you should
only need to specify normalization ``kwargs`` for energy: the scale value is taking automatically
from the energy value, and applied to forces.

In short, if you are writing functionality that requires unnormalized outputs (e.g. ``ase`` calculators),
please ensure you are using ``predict`` instead of ``forward`` directly.


Parity plots and model evaluations
----------------------------------

Expand Down
158 changes: 148 additions & 10 deletions docs/source/training.rst
Original file line number Diff line number Diff line change
@@ -1,14 +1,152 @@
Training pipeline
=================
Task abstraction
================

Training with the Open MatSci ML Toolkit utilizes—for the most part—the
PyTorch Lightning abstractions.
The Open MatSciML Toolkit uses PyTorch Lightning abstractions for managing the flow
of training: how data from a datamodule gets mapped, to what loss terms are calculated,
to what gets logged is defined in a base task class. From start to finish, this module
will take in the definition of an encoding architecture (through ``encoder_class`` and
``encoder_kwargs`` keyword arguments), construct it, and in concrete task implementations,
initialize the respective output heads a set of provided or task-specific target keys.
The ``encoder_kwargs`` specification makes things a bit more verbose, but this ensures
that the hyperparameters are saved appropriately per the ``save_hyperparameters`` method
in PyTorch Lightning.

Task API reference
##################

.. autosummary::
:toctree: generated
:recursive:
``BaseTaskModule`` API reference
--------------------------------

matsciml.models.base
.. autoclass:: matsciml.models.base.BaseTaskModule
:members:


Multi task reference
--------------------------------

One core functionality for ``matsciml`` is the ability to compose multiple tasks
together, in an (almost) seamless fashion from the single task case.

.. important::
The ``MultiTaskLitModule`` is not written in a particularly friendly way at
the moment, and may be subject to a significant refactor later!


.. autoclass:: matsciml.models.base.MultiTaskLitModule
:members:


``OutputHead`` API reference
----------------------------

While there is a singular ``OutputHead`` definition, the blocks that constitute
an ``OutputHead`` can be specified depending on the type of model architecture
being used. The default stack is based on simple ``nn.Linear`` layers, however,
for architectures like MACE which may depend on preserving irreducible representations,
the ``IrrepOutputBlock`` allows users to specify transformations per-representation.

.. autoclass:: matsciml.models.common.OutputHead
:members:


.. autoclass:: matsciml.models.common.OutputBlock
:members:


.. autoclass:: matsciml.models.common.IrrepOutputBlock
:members:


Scalar regression
-----------------

This task is primarily designed for tasks adjacent to property prediction: you can
predict an arbitrary number of properties (per output head), based on a shared
embedding (i.e. one structure maps to a single embedding, which is used by each head).

A special case for using this class would be in tandem (as a multitask setup) with
the :ref:`_gradfree_force`, which treats energy/force prediction as two
separate output heads, albeit with the same shared embedding.

Please use continuous valued (e.g. ``nn.MSELoss``) loss metrics for this task.


.. autoclass:: matsciml.models.base.ScalarRegressionTask
:members:


Binary classification
-----------------------

This task, as the name suggests, uses the embedding to perform one or more binary
classifications with a shared embedding. This can be something like a ``stability``
label like in the Materials Project. Keep in mind, however, that a special class
exists for crystal symmetry classification.

.. autoclass:: matsciml.models.base.BinaryClassificationTask
:members:

.. _crystal_symmetry:

Crystal symmetry classification
-------------------------------

This task is a specialized class for what is essentially multiclass classification,
where given an embedding, we predict which crystal space group the structure belongs
to using ``nn.CrossEntropyLoss``. This can be a good potential pretraining task.


.. note::
This task expects that your data includes ``spacegroup`` target key.

.. autoclass:: matsciml.models.base.CrystalSymmetryClassificationTask
:members:


Force regression task
---------------------

This task implements energy/force regression, where an ``OutputHead`` is used to first
predict the energy, followed by taking its derivative with respect to the input coordinates.
From a developer perspective, this task is quite mechanically different due to the need
for manual ``autograd``, which is not normally supported by PyTorch Lightning workflows.


.. note::
This task expects that your data includes ``force`` target key.

.. autoclass:: matsciml.models.base.ForceRegressionTask
:members:


.. _gradfree_force:

Gradient-free force regression task
-----------------------------------

This task implements a force prediction task, albeit as a direct output head property
prediction as opposed to the derivative of an energy value using ``autograd``.

.. note::
This task expects that your data includes ``force`` target key.

.. autoclass:: matsciml.models.base.GradFreeForceRegressionTask
:members:


Node denoising task
-------------------

This task implements a powerful, and recently becoming more popular, pre-training strategy
for graph neural networks. The premise is quite simple: an encoder learns as a denoising
autoencoder by taking in a perturbed structure, and attempting to predict the amount of
noise in the 3D coordinates.

As a requirement, this task requires the following data transform; you are able to specify
the scale of the noise added to the positions and intuitively the large the scale, the higher
potential difficulty in the task.

.. autoclass:: matsciml.datasets.transforms.pretraining.NoisyPositions
:members:


.. autoclass:: matsciml.models.base.NodeDenoisingTask
:members:
Loading