From 4dd71035844d5f2501c11e3d739b2b06eb824c5a Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 4 Jun 2024 09:51:35 -0700 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- torchtitan/checkpoint.py | 18 ++++++++++++++++++ torchtitan/config_manager.py | 10 +++++++++- torchtitan/parallelisms/parallelize_llama.py | 7 +++---- train.py | 3 ++- 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index fb7c41c8..b7f41bbc 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -7,6 +7,7 @@ import enum import os import re +import shutil import time from multiprocessing import get_context from typing import Any, Dict @@ -110,6 +111,7 @@ def __init__( ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint + self.keep_latest_k = ckpt_config.keep_latest_k if not self.enable_checkpoint: return @@ -313,6 +315,7 @@ def save(self, curr_step: int, force: bool = False) -> None: else: dcp.save(self.states, checkpoint_id=checkpoint_id) self.reset() + self._purge_stale_checkpoints() logger.info( "Finished saving the checkpoint (or staging if async is enabled)" @@ -364,3 +367,18 @@ def load(self, step: int = -1) -> bool: f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." ) return True + + def _purge_stale_checkpoints(self): + if self.keep_latest_k > 0: + discovered_checkpoints = [] + for filename in os.listdir(self.folder): + match = re.search(r"step-(\d+)", filename) + path = os.path.join(self.folder, filename) + discovered_checkpoints.append((int(match.group(1)), path)) + + discovered_checkpoints.sort() + to_delete = discovered_checkpoints[: -1 * self.keep_latest_k] + + for _, path in to_delete: + logger.info(f"Deleting old checkpoint {path}") + shutil.rmtree(path, ignore_errors=True) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 6a730dcb..e901a184 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -408,7 +408,15 @@ def __init__(self): "disabled" is the default mode. """, ) - + self.parser.add_argument( + "--checkpoint.keep_latest_k", + type=int, + default=0, + help=""" + Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints. + 0 is the default value. + """, + ) # activation checkpointing configs self.parser.add_argument( "--activation_checkpoint.mode", diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 3617eb23..9260fa62 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -190,16 +190,15 @@ def pipeline_llama_manual( splits = job_config.experimental.pipeline_parallel_split_points start_layer = splits[stage_idx - 1] if stage_idx > 0 else None stop_layer = splits[stage_idx] if stage_idx < pp_size - 1 else None - if pp_rank > 0: model.tok_embeddings = None - drop_layers = True + drop_layers = start_layer is not None for name in list(model.layers.keys()): # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) - if start_layer is None or f"layers.{name}" == start_layer: + if f"layers.{name}" == start_layer: drop_layers = False - if stop_layer is not None and f"layers.{name}" == stop_layer: + if f"layers.{name}" == stop_layer: drop_layers = True if drop_layers: del model.layers[name] diff --git a/train.py b/train.py index 6a8512a4..a39bf5c1 100644 --- a/train.py +++ b/train.py @@ -221,7 +221,8 @@ def loss_fn(pred, labels): model, world_mesh, parallel_dims, job_config ) - model.to_empty(device="cuda") + init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" + model.to_empty(device=init_device) if parallel_dims.pp_enabled: pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn)