Skip to content

Commit

Permalink
move init_device and buffer_device outside pp condition
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 committed Oct 28, 2024
1 parent 2ca9882 commit 7af331c
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ def loss_fn(pred, labels):
if job_config.training.compile:
loss_fn = torch.compile(loss_fn)

# move sharded model to CPU/GPU and initialize weights via DTensor
if job_config.checkpoint.create_seed_checkpoint:
init_device = "cpu"
buffer_device = None
elif job_config.training.enable_cpu_offload:
init_device = "cpu"
buffer_device = "cuda"
else:
init_device = "cuda"
buffer_device = None

# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
Expand All @@ -151,24 +162,13 @@ def loss_fn(pred, labels):
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
init_device = "cpu" if job_config.training.enable_cpu_offload else "cuda"
m.to_empty(device=init_device)
buffer_device = "cuda" if job_config.training.enable_cpu_offload else None
m.init_weights(buffer_device=buffer_device)
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
init_device = (
"cpu"
if job_config.checkpoint.create_seed_checkpoint
or job_config.training.enable_cpu_offload
else "cuda"
)
model.to_empty(device=init_device)
buffer_device = "cuda" if job_config.training.enable_cpu_offload else None
model.init_weights(buffer_device=buffer_device)
model.train()

Expand Down

0 comments on commit 7af331c

Please sign in to comment.