Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Loading of Checkpoints in Experiments Pipeline #272

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ dataset:
- energy
```

### Checkpoint Loading
Pretrained model checkpoints may be loaded for use in downstream tasks. Models can be loaded and used *as-is*, or only the encoder may be used.

To load a checkpoint, add the `load_weights` field to the experiment config:
```yaml
model: egnn_dgl
dataset:
oqmd:
- task: ScalarRegressionTask
targets:
- energy
load_weights:
method: checkpoint
type: local
path: ./path/to/checkpoint
```
* `method` specifies whether to use the model *as-is* (`checkpoint`), or *encoder-only* (`pretrained`) in the checkpoint.
* `type` specifies where to load the checkpoint from (`local`, or `wandb`).
* `path` points to the location of the checkpoint. WandB checkpoints may be specified by pointing to the model artifact, typically specified by: `entity-name/project-name/model-version:number`


In general, and experiment may the be launched by running:
`python experiments/training_script.py --experiment_config ./experiments/configs/single_task.yaml`

Expand Down
46 changes: 44 additions & 2 deletions experiments/task_config/task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from copy import deepcopy
from typing import Any
from pathlib import Path


import pytorch_lightning as pl

from matsciml.common.registry import registry
from matsciml.models.base import MultiTaskLitModule
from matsciml.models.base import MultiTaskLitModule, BaseTaskModule
from matsciml.models import multitask_from_checkpoint

from experiments.utils.configurator import configurator
from experiments.utils.utils import instantiate_arg_dict, update_arg_dict
Expand All @@ -19,6 +22,7 @@ def setup_task(config: dict[str, Any]) -> pl.LightningModule:
model = update_arg_dict("model", model, config["cli_args"])
configured_tasks = []
data_task_list = []
from_checkpoint = True if "load_weights" in config else False
for dataset_name, tasks in data_task_dict.items():
dset_args = deepcopy(configurator.datasets[dataset_name])
dset_args = update_arg_dict("dataset", dset_args, config["cli_args"])
Expand All @@ -30,7 +34,7 @@ def setup_task(config: dict[str, Any]) -> pl.LightningModule:
additonal_task_args = dset_args.get("task_args", None)
if additonal_task_args is not None:
task_args.update(additonal_task_args)
configured_task = task_class(**task_args)
configured_task = task_class if from_checkpoint else task_class(**task_args)
configured_tasks.append(configured_task)
data_task_list.append(
[configurator.datasets[dataset_name]["dataset"], configured_task]
Expand All @@ -40,4 +44,42 @@ def setup_task(config: dict[str, Any]) -> pl.LightningModule:
task = MultiTaskLitModule(*data_task_list)
else:
task = configured_tasks[0]
if "load_weights" in config:
task = load_from_checkpoint(task, config, task_args)
return task


def load_from_checkpoint(
task: BaseTaskModule, config: dict[str, Any], task_args: dict[str, Any]
) -> BaseTaskModule:
load_config = config["load_weights"]
ckpt = load_config["path"]
method = load_config["method"]
load_type = load_config["type"]
if not isinstance(task, MultiTaskLitModule):
if load_type == "wandb":
# creates lightning wandb logger object and a new run
wandb_logger = get_wandb_logger()
artifact = Path(wandb_logger.download_artifact(ckpt))
ckpt = artifact.joinpath("model.ckpt")
if method == "checkpoint":
task = task.load_from_checkpoint(ckpt)
elif method == "pretrained":
task = task.from_pretrained_encoder(ckpt, **task_args)
else:
raise Exception(
"Unsupported method for loading checkpoint. Must be 'checkpoint' or 'pretrained'"
)
else:
task = multitask_from_checkpoint(ckpt)
return task


def get_wandb_logger():
trainer_args = configurator.trainer
for logger in trainer_args["loggers"]:
if "WandbLogger" in logger["class_path"]:
wandb_logger = instantiate_arg_dict(logger)
return wandb_logger
else:
raise KeyError("WandbLogger Expected in trainer config but not found")
Loading