Skip to content

Commit

Permalink
Merge pull request #272 from melo-gonzo/263-feature-request-allow-wei…
Browse files Browse the repository at this point in the history
…ghts-transfer-andor-restarting-from-earlier-checkpoint-with-experiment-cli

Allow Loading of Checkpoints in Experiments Pipeline
  • Loading branch information
laserkelvin authored Aug 14, 2024
2 parents 70e25f1 + 2e1f584 commit 28557a5
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
21 changes: 21 additions & 0 deletions experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ dataset:
- energy
```
### Checkpoint Loading
Pretrained model checkpoints may be loaded for use in downstream tasks. Models can be loaded and used *as-is*, or only the encoder may be used.
To load a checkpoint, add the `load_weights` field to the experiment config:
```yaml
model: egnn_dgl
dataset:
oqmd:
- task: ScalarRegressionTask
targets:
- energy
load_weights:
method: checkpoint
type: local
path: ./path/to/checkpoint
```
* `method` specifies whether to use the model *as-is* (`checkpoint`), or *encoder-only* (`pretrained`) in the checkpoint.
* `type` specifies where to load the checkpoint from (`local`, or `wandb`).
* `path` points to the location of the checkpoint. WandB checkpoints may be specified by pointing to the model artifact, typically specified by: `entity-name/project-name/model-version:number`


In general, and experiment may the be launched by running:
`python experiments/training_script.py --experiment_config ./experiments/configs/single_task.yaml`

Expand Down
46 changes: 44 additions & 2 deletions experiments/task_config/task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from copy import deepcopy
from typing import Any
from pathlib import Path


import pytorch_lightning as pl

from matsciml.common.registry import registry
from matsciml.models.base import MultiTaskLitModule
from matsciml.models.base import MultiTaskLitModule, BaseTaskModule
from matsciml.models import multitask_from_checkpoint

from experiments.utils.configurator import configurator
from experiments.utils.utils import instantiate_arg_dict, update_arg_dict
Expand All @@ -19,6 +22,7 @@ def setup_task(config: dict[str, Any]) -> pl.LightningModule:
model = update_arg_dict("model", model, config["cli_args"])
configured_tasks = []
data_task_list = []
from_checkpoint = True if "load_weights" in config else False
for dataset_name, tasks in data_task_dict.items():
dset_args = deepcopy(configurator.datasets[dataset_name])
dset_args = update_arg_dict("dataset", dset_args, config["cli_args"])
Expand All @@ -30,7 +34,7 @@ def setup_task(config: dict[str, Any]) -> pl.LightningModule:
additonal_task_args = dset_args.get("task_args", None)
if additonal_task_args is not None:
task_args.update(additonal_task_args)
configured_task = task_class(**task_args)
configured_task = task_class if from_checkpoint else task_class(**task_args)
configured_tasks.append(configured_task)
data_task_list.append(
[configurator.datasets[dataset_name]["dataset"], configured_task]
Expand All @@ -40,4 +44,42 @@ def setup_task(config: dict[str, Any]) -> pl.LightningModule:
task = MultiTaskLitModule(*data_task_list)
else:
task = configured_tasks[0]
if "load_weights" in config:
task = load_from_checkpoint(task, config, task_args)
return task


def load_from_checkpoint(
task: BaseTaskModule, config: dict[str, Any], task_args: dict[str, Any]
) -> BaseTaskModule:
load_config = config["load_weights"]
ckpt = load_config["path"]
method = load_config["method"]
load_type = load_config["type"]
if not isinstance(task, MultiTaskLitModule):
if load_type == "wandb":
# creates lightning wandb logger object and a new run
wandb_logger = get_wandb_logger()
artifact = Path(wandb_logger.download_artifact(ckpt))
ckpt = artifact.joinpath("model.ckpt")
if method == "checkpoint":
task = task.load_from_checkpoint(ckpt)
elif method == "pretrained":
task = task.from_pretrained_encoder(ckpt, **task_args)
else:
raise Exception(
"Unsupported method for loading checkpoint. Must be 'checkpoint' or 'pretrained'"
)
else:
task = multitask_from_checkpoint(ckpt)
return task


def get_wandb_logger():
trainer_args = configurator.trainer
for logger in trainer_args["loggers"]:
if "WandbLogger" in logger["class_path"]:
wandb_logger = instantiate_arg_dict(logger)
return wandb_logger
else:
raise KeyError("WandbLogger Expected in trainer config but not found")

0 comments on commit 28557a5

Please sign in to comment.