Skip to content

Commit

Permalink
[WIP] Used per-parameter FSDP
Browse files Browse the repository at this point in the history
  • Loading branch information
awgu committed Feb 23, 2024
1 parent 55a6b0b commit a39226a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 53 deletions.
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
# LOG_RANK=0,1 NGPU=4 SP=2 ./run_llama_train.sh

MODEL=${MODEL:-"llama"}
MODEL_CONF=${MODEL_CONF:-"debugmodel"}
MODEL_CONF=${MODEL_CONF:-"7B"}
NGPU=${NGPU:-"8"}
PP=${PP:-"1"}
SP=${SP:-"1"}
Expand Down
43 changes: 10 additions & 33 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,7 @@
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp import (
BackwardPrefetch,
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
Expand Down Expand Up @@ -157,32 +151,15 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
# TODO: see whether we should expose a option to user
reduce_dtype=torch.float32,
),
"sharding_strategy": ShardingStrategy.FULL_SHARD,
"backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
# When torch.compile is active, it requires us to set use_orig_params=True
"use_orig_params": True,
"device_mesh": dp_mesh,
}

with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):
# apply AC to each layer
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()
transformer_block = checkpoint_wrapper(transformer_block, args)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)

# wrap the rest layers with FSDP
model = wrap(model.cuda())

mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
transformer_block = checkpoint_wrapper(transformer_block, args)
fully_shard(transformer_block, **fsdp_config)
model.layers[layer_id] = transformer_block
model = fully_shard(model, **fsdp_config)
rank0_log("Applied FSDP to the model...")

# redundant if FSDP is enabled, but ensure the model is on device regardless of which parallelisms were used
Expand Down
35 changes: 16 additions & 19 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# torch imports
import torch
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

from torchtrain.checkpoint import CheckpointManager, IntervalType
Expand Down Expand Up @@ -50,26 +49,25 @@ def load_state_dict(self, state_dict) -> None:


def build_optimizer(model, args):
# build optimizer
if args.optimizer == "Adam":
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, foreach=True)
elif args.optimizer == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, foreach=True)
else:
raise NotImplementedError(f"optimizer {args.optimizer} not added")

return optimizer


def build_grad_scaler(model):
# apply gradient scaling if mixed precision training is enabled with fp16 param dtype
if model.mixed_precision.param_dtype == torch.float16:
enable_grad_scaling = True
rank0_log("Enabling gradient scaling for mixed precision training.")
else:
enable_grad_scaling = False
rank0_log("Gradient scaling not enabled.")

# TODO: We do not expose the mixed precision attribute. This is low
# priority since we do not use fp16.
# if model.mixed_precision.param_dtype == torch.float16:
# enable_grad_scaling = True
# rank0_log("Enabling gradient scaling for mixed precision training.")
# else:
enable_grad_scaling = False
rank0_log("Gradient scaling not enabled.")
return ShardedGradScaler(enabled=enable_grad_scaling)


Expand Down Expand Up @@ -121,9 +119,6 @@ def main(args):
# apply PTD parallelisms + AC
model = models_parallelize_fns[model_name](model, world_mesh, parallel_dims, args)

# to use FSDP-customized gradient scaler and gradient clipping solutions
assert isinstance(model, FSDP)

# build optimizer after apply parallelisms to the model
optimizer = build_optimizer(model, args)
scheduler = get_lr_scheduler(optimizer, args)
Expand All @@ -135,9 +130,9 @@ def main(args):
# torch.compile model for improved performance
if args.compile:
rank0_log(f"Compiling model {model_name} with torch.compile...")
model = torch.compile(
model,
)
model = torch.compile(model)

rank0_log("Disabling clip_grad_norm_() since it is not supported yet")

train_state = TrainState()

Expand Down Expand Up @@ -187,7 +182,9 @@ def main(args):

# clip gradients (after unscaling gradients of the optimizer's params)
scaler.unscale_(optimizer)
model.clip_grad_norm_(args.max_norm)
# TODO: Disable `clip_grad_norm_()` until it is supported:
# https://github.com/pytorch/pytorch/pull/120238
# model.clip_grad_norm_(args.max_norm)

# optimizer step
# If gradients don't contain infs/NaNs, optimizer.step() is then called;
Expand Down

0 comments on commit a39226a

Please sign in to comment.