Skip to content

Commit

Permalink
resume training
Browse files Browse the repository at this point in the history
  • Loading branch information
beneisner committed Jun 6, 2024
1 parent 7582b5e commit 064a8b9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
2 changes: 2 additions & 0 deletions configs/train_ndf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ training:
check_val_every_n_epoch: 5

seed: 0
resume_ckpt: null

resources:
num_workers: 8

wandb:
group: Null
run_id_override: Null
23 changes: 22 additions & 1 deletion scripts/train_residual_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from taxpose.training.flow_equivariance_training_module_nocentering import (
EquivarianceTrainingModule,
)
from taxpose.utils.load_model import get_weights_path


def load_emb_weights(checkpoint_reference, wandb_cfg=None, run=None):
Expand Down Expand Up @@ -53,6 +54,24 @@ def main(cfg):
# torch.set_float32_matmul_precision("medium")
TESTING = "PYTEST_CURRENT_TEST" in os.environ

if cfg.resume_ckpt:
print("Resuming from checkpoint")
print(cfg.resume_ckpt)
resume_ckpt = get_weights_path(cfg.resume_ckpt, cfg.wandb)

# Resume the wandb run
if cfg.resume_ckpt.startswith(cfg.wandb.entity):
# Get the run_id from the checkpoint
resume_run_id = cfg.resume_ckpt.split("/")[2].split("-")[1].split(":")[0]
elif cfg.wandb.run_id_override is not None:
resume_run_id = cfg.wandb.run_id_override
else:
resume_run_id = None

else:
resume_ckpt = None
resume_run_id = None

pl.seed_everything(cfg.seed)
logger = WandbLogger(
entity=cfg.wandb.entity,
Expand All @@ -62,6 +81,7 @@ def main(cfg):
job_type=cfg.job_type,
save_code=True,
log_model=True,
id=resume_run_id,
config=omegaconf.OmegaConf.to_container(cfg, resolve=True),
)
# logger.log_hyperparams(cfg)
Expand Down Expand Up @@ -174,7 +194,8 @@ def main(cfg):
cfg.model.pretraining.anchor.ckpt_path
)
)
trainer.fit(model, dm)

trainer.fit(model, dm, ckpt_path=resume_ckpt)

# Print he run id of the current run
print("Run ID: {} ".format(logger.experiment.id))
Expand Down

0 comments on commit 064a8b9

Please sign in to comment.