vescale.checkpoint
is an automatic distributed checkpointing system for LLM training and inference.
[05/30/2024] We improved vescale.checkpoint
with the following new features for fast checkpointing (where front three features are built-in techniques without necessitating manual activation):
-
Saving Plan Caching: During training, the program may save model and optimizer checkpoints every n steps. Once a saving plan is created, it remains unchanged as long as the model does. We implemented plan caching to avoid regenerating the plan when checkpointing a model or optimizer multiple times, reducing unnecessary compute and communication costs. As of 05/30/2024, PyTorch DCP does not support plan caching.
-
Saving Plan Load-Balancing: In data parallel training, models are replicated across GPUs with different data parallel ranks but the same pipeline and tensor parallel ranks. Existing PyTorch DCP (as of 05/30/2024) deduplicates replicated tensors using a simple algorithm, causing GPUs with data parallel rank 0 to save the entire model, leading to load imbalance. We implemented a load-balancing algorithm to address this issue when deduplicating model tensors.
-
D2H Tensor Copying via Pinned Memory: When copying tensors from GPU to host memory,
vescale.checkpoint
uses pinned host memory, reducing memory allocation costs each time a checkpoint is saved. As of 05/30/2024, PyTorch DCP does not support pinned memory. -
Checkpoint Broadcasting: In data parallel training, models are replicated across GPUs with different data parallel ranks but the same pipeline and tensor parallel ranks. If
broadcast_checkpoint
is enabled,vescale.checkpoint.load
lets GPUs with data parallel rank 0 to load the model and broadcast it to other GPUs with higher data parallel ranks. If GPUs are connected with NCCL, broadcasting model tensors speeds up checkpoint loading compared to all GPUs loading models from persistent storage. E.g.:# prepare checkpoint state for the model and optimizer checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } # load the checkpoint vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state, broadcast_checkpoint=True)
-
Asynchronous Checkpointing: When
vescale.checkpoint.save
is called, it first generates a saving plan and then synchronously copies tensors from GPU to host memory. Ifasync_checkpoint
is enabled, the training program can continue after the D2H copying, whilevescale.checkpoint.save
continues to serialize tensors and dump the checkpoint to persistent storage asynchronously without blocking training. As of 05/30/2024, PyTorch DCP does not support asynchronous checkpointing. E.g.:# prepare checkpoint state for the model and optimizer checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } # save the checkpoint asynchronuously vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state, async_checkpoint=True)
-
Manually managing distributed checkpointing, such as writing model saving/loading/resharding scripts under complex distributed environments, is painful and error-prone.
-
torch.save
andtorch.load
lacks the capability of managing checkpointing in distributed settings, let alone resharding checkpoints for different distributed settings. Although existing systems extendtorch.save
for saving checkpoints on multiple GPUs or machines, the saved checkpoints are heavily coupled with a single distributed setting like the degrees of data, tensor and pipeline parallelism. Consequently, existing systems withtorch.load
fail to load checkpoints with varying degrees of parallelism, which is common in elastic training or switching between training and fine-tuning. -
PyTorch Distirbuted Checkpoint
indeed supports checkpoint resharding to some extent. Nonetheless, it currently only supports resharding for the simplest data parallelism, but not for the complex tensor nor pipeline parallelism, which are commonly used in 3D parallelism of LLM training. Furthermore, it does not support load-time resharding for Distributed Optimizer, nor provide decent performance optimizations.
vescale.checkpoint
offers simple and straightforward APIs,
enabling users to load and save distributed model (e.g., DModule
) and optimizer (e.g., DistributedOptimizer
) seamlessly, abstracting away the complexities of underlying details such as process rank and device mesh.
vescale.checkpoint
supports load-time checkpoint resharding when varying the degrees of data, tensor, or pipeline parallelism for both veScale model (e.g., DModule
) and optimizer (e.g., DistributedOptimizer
).
vescale.checkpoint
incorporates fast checkpointing and various I/O optimization techinques, enhancing I/O efficiency during LLM training.
vescale.checkpoint
is built on top of PyTorch Distributed Checkpoint
with significant differences as discussed above.
-
Saving checkpoint:
# prepare checkpoint state for the model and optimizer checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } # save the checkpoint vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state)
-
Loading checkpoint (under different world size or 3D parallel sizes):
# prepare checkpoint state for the model and optimizer checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } # load the checkpoint vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state)
-
APIs can be found in:
<repo>/vescale/checkpoint/__init__.py
-
End-to-end example can be found in:
<repo>/examples/nanogpt_4D_finetune/finetune_4D.py
-
More examples can be found under
<repo>/test/checkpoint/*.py
and<repo>/examples/
-
Original examples can be found in PyTorch Distributed Checkpoint