From f745feddcfeaa2b942a62fded6b6a06c42a62acb Mon Sep 17 00:00:00 2001 From: "Gonzales, Carmelo" Date: Tue, 13 Aug 2024 15:53:51 -0700 Subject: [PATCH 1/3] feat: add checkpoint loading as an option in the experiment config. update readme --- experiments/README.md | 21 +++++++++++++++++++++ experiments/task_config/task_config.py | 26 ++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/experiments/README.md b/experiments/README.md index 986913bf..64eb5dca 100644 --- a/experiments/README.md +++ b/experiments/README.md @@ -21,6 +21,27 @@ dataset: - energy ``` +### Checkpoint Loading +Pretrained model checkpoints may be loaded for use in downstream tasks. Models can be loaded and used *as-is*, or only the encoder may be used. + +To load a checkpoint, add the `load_weights` field to the experiment config: +```yaml +model: egnn_dgl +dataset: + oqmd: + - task: ScalarRegressionTask + targets: + - energy +load_weights: + method: checkpoint + type: local + path: ./path/to/checkpoint +``` +* `method` specifies whether to use the model *as-is* (`checkpoint`), or *encoder-only* (`pretrained`) in the checkpoint. +* `type` specifies where to load the checkpoint from. Currently only locally stored checkpoints may be used. +* `path` points to the location of the checkpoint. + + In general, and experiment may the be launched by running: `python experiments/training_script.py --experiment_config ./experiments/configs/single_task.yaml` diff --git a/experiments/task_config/task_config.py b/experiments/task_config/task_config.py index 0ecbb5a8..32703ea9 100644 --- a/experiments/task_config/task_config.py +++ b/experiments/task_config/task_config.py @@ -6,7 +6,8 @@ import pytorch_lightning as pl from matsciml.common.registry import registry -from matsciml.models.base import MultiTaskLitModule +from matsciml.models.base import MultiTaskLitModule, BaseTaskModule +from matsciml.models import multitask_from_checkpoint from experiments.utils.configurator import configurator from experiments.utils.utils import instantiate_arg_dict, update_arg_dict @@ -19,6 +20,7 @@ def setup_task(config: dict[str, Any]) -> pl.LightningModule: model = update_arg_dict("model", model, config["cli_args"]) configured_tasks = [] data_task_list = [] + from_checkpoint = True if "load_weights" in config else False for dataset_name, tasks in data_task_dict.items(): dset_args = deepcopy(configurator.datasets[dataset_name]) dset_args = update_arg_dict("dataset", dset_args, config["cli_args"]) @@ -30,7 +32,7 @@ def setup_task(config: dict[str, Any]) -> pl.LightningModule: additonal_task_args = dset_args.get("task_args", None) if additonal_task_args is not None: task_args.update(additonal_task_args) - configured_task = task_class(**task_args) + configured_task = task_class if from_checkpoint else task_class(**task_args) configured_tasks.append(configured_task) data_task_list.append( [configurator.datasets[dataset_name]["dataset"], configured_task] @@ -40,4 +42,24 @@ def setup_task(config: dict[str, Any]) -> pl.LightningModule: task = MultiTaskLitModule(*data_task_list) else: task = configured_tasks[0] + if "load_weights" in config: + task = load_from_checkpoint(task, config, task_args) + return task + + +def load_from_checkpoint( + task: BaseTaskModule, config: dict[str, Any], task_args: dict[str, Any] +) -> BaseTaskModule: + load_config = config["load_weights"] + ckpt = load_config["path"] + method = load_config["method"] + load_type = load_config["type"] + if not isinstance(task, MultiTaskLitModule): + if load_type == "local": + if method == "checkpoint": + task = task.load_from_checkpoint(ckpt) + if method == "pretrained": + task = task.from_pretrained_encoder(ckpt, **task_args) + else: + task = multitask_from_checkpoint(ckpt) return task From a069e44b5e54515a803a6c444047aa11ebd8d382 Mon Sep 17 00:00:00 2001 From: "Gonzales, Carmelo" Date: Wed, 14 Aug 2024 15:05:46 -0700 Subject: [PATCH 2/3] feat: adding ability to download checkpoint from wandb --- experiments/task_config/task_config.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/experiments/task_config/task_config.py b/experiments/task_config/task_config.py index 32703ea9..24cb9cdd 100644 --- a/experiments/task_config/task_config.py +++ b/experiments/task_config/task_config.py @@ -2,6 +2,8 @@ from copy import deepcopy from typing import Any +from pathlib import Path + import pytorch_lightning as pl @@ -60,6 +62,22 @@ def load_from_checkpoint( task = task.load_from_checkpoint(ckpt) if method == "pretrained": task = task.from_pretrained_encoder(ckpt, **task_args) + if load_type == "wandb": + # creates lightning wandb logger object and a new run + wandb_logger = get_wandb_logger() + artifact = Path(wandb_logger.download_artifact(ckpt)) + task = task.load_from_checkpoint(artifact.joinpath("model.ckpt")) + else: task = multitask_from_checkpoint(ckpt) return task + + +def get_wandb_logger(): + trainer_args = configurator.trainer + for logger in trainer_args["loggers"]: + if "WandbLogger" in logger["class_path"]: + wandb_logger = instantiate_arg_dict(logger) + return wandb_logger + else: + raise KeyError("WandbLogger Expected in trainer config but not found") From 2e1f584a096c7d5c740c725780b214084a68943e Mon Sep 17 00:00:00 2001 From: "Gonzales, Carmelo" Date: Wed, 14 Aug 2024 15:12:49 -0700 Subject: [PATCH 3/3] refactor: checkpoint loading consolidation, add to readme --- experiments/README.md | 4 ++-- experiments/task_config/task_config.py | 16 +++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/experiments/README.md b/experiments/README.md index 64eb5dca..56fb266b 100644 --- a/experiments/README.md +++ b/experiments/README.md @@ -38,8 +38,8 @@ load_weights: path: ./path/to/checkpoint ``` * `method` specifies whether to use the model *as-is* (`checkpoint`), or *encoder-only* (`pretrained`) in the checkpoint. -* `type` specifies where to load the checkpoint from. Currently only locally stored checkpoints may be used. -* `path` points to the location of the checkpoint. +* `type` specifies where to load the checkpoint from (`local`, or `wandb`). +* `path` points to the location of the checkpoint. WandB checkpoints may be specified by pointing to the model artifact, typically specified by: `entity-name/project-name/model-version:number` In general, and experiment may the be launched by running: diff --git a/experiments/task_config/task_config.py b/experiments/task_config/task_config.py index 24cb9cdd..4d924c4a 100644 --- a/experiments/task_config/task_config.py +++ b/experiments/task_config/task_config.py @@ -57,17 +57,19 @@ def load_from_checkpoint( method = load_config["method"] load_type = load_config["type"] if not isinstance(task, MultiTaskLitModule): - if load_type == "local": - if method == "checkpoint": - task = task.load_from_checkpoint(ckpt) - if method == "pretrained": - task = task.from_pretrained_encoder(ckpt, **task_args) if load_type == "wandb": # creates lightning wandb logger object and a new run wandb_logger = get_wandb_logger() artifact = Path(wandb_logger.download_artifact(ckpt)) - task = task.load_from_checkpoint(artifact.joinpath("model.ckpt")) - + ckpt = artifact.joinpath("model.ckpt") + if method == "checkpoint": + task = task.load_from_checkpoint(ckpt) + elif method == "pretrained": + task = task.from_pretrained_encoder(ckpt, **task_args) + else: + raise Exception( + "Unsupported method for loading checkpoint. Must be 'checkpoint' or 'pretrained'" + ) else: task = multitask_from_checkpoint(ckpt) return task