From 36d229376eccc5063ed680add48a146fc6ab0264 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 22 Mar 2024 17:12:52 -0700 Subject: [PATCH 01/16] Update [ghstack-poisoned] --- run_llama_train.sh | 2 +- torchtrain/meta_init.py | 6 +++ torchtrain/models/llama/model.py | 19 ++++++---- torchtrain/parallelisms/parallelize_llama.py | 40 +++++++++++++++++--- train.py | 4 +- train_configs/debug_model.toml | 2 +- 6 files changed, 57 insertions(+), 16 deletions(-) diff --git a/run_llama_train.sh b/run_llama_train.sh index 13b66aea..a906b5cd 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -11,7 +11,7 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain} # e.g. # LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh -NGPU=${NGPU:-"8"} +NGPU=${NGPU:-"2"} # by default log just rank 0 output, LOG_RANK=${LOG_RANK:-0} diff --git a/torchtrain/meta_init.py b/torchtrain/meta_init.py index d67e6ef7..a4d0a728 100644 --- a/torchtrain/meta_init.py +++ b/torchtrain/meta_init.py @@ -46,3 +46,9 @@ def meta_to_real_init_fn(module: nn.Module): torch.randn_like(param, device=torch.device("cuda")) ) setattr(submodule, param_name, materialized_param) + for param_name, param in submodule.named_buffers(recurse=False): + if param.is_meta: + materialized_param = nn.Parameter( + torch.randn_like(param, device=torch.device("cuda")) + ) + setattr(submodule, param_name, materialized_param) diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index da6a4e14..e0a368c9 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -334,13 +334,16 @@ def __init__(self, model_args: ModelArgs): self.model_args = model_args self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) - self.freqs_cis = precompute_freqs_cis( - # Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation - # of models is 4096. - # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training - # or fine-tuning. - self.model_args.dim // self.model_args.n_heads, - self.model_args.max_seq_len * 2, + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + # Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation + # of models is 4096. + # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training + # or fine-tuning. + self.model_args.dim // self.model_args.n_heads, + self.model_args.max_seq_len * 2, + ), ) def forward(self, tokens: torch.Tensor): @@ -355,7 +358,7 @@ def forward(self, tokens: torch.Tensor): """ _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) - self.freqs_cis = self.freqs_cis.to(h.device) + # self.freqs_cis = self.freqs_cis.to(h.device) freqs_cis = self.freqs_cis[0:seqlen] return h, freqs_cis diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 38014e53..13194a6d 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -7,6 +7,7 @@ from collections import defaultdict import torch +from pippy import annotate_split_points, Pipe, PipeSplitWrapper from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -125,7 +126,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ # apply PTD parallelisms if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet.") + pp_mesh = world_mesh["pp"] + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = len(model.layers) // parallel_dims.pp + for i in range(1, parallel_dims.pp): + annotate_split_points( + model, + { + f"layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING + }, + ) + + # Get example input + label_shape = input_shape = (8, 2048) # TODO + input_ids = torch.randint( + model.vocab_size, input_shape, dtype=torch.int64, device="meta" + ) + labels = torch.randint( + model.vocab_size, label_shape, dtype=torch.int64, device="meta" + ) + print("input_ids: ", input_ids.shape, input_ids.dtype) + print("labels: ", labels.shape, labels.dtype) + + # Create a pipeline representation from the model + pipe = Pipe.from_tracing(model, parallel_dims.pp, example_args=(input_ids,)) + model = pipe.get_stage_module(stage_idx) # First we apply Sequence Parallelism if it's enabled if parallel_dims.sp_enabled: @@ -230,9 +255,14 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): meta_to_real_init_fn(model) model.cuda() - # we have now moved from meta to device, - # reset parameters for proper initialization - model.reset_parameters() - logger.info("Model fully initialized via reset_parameters") + if parallel_dims.pp_enabled: + setattr(pipe.split_gm, f"submod_{stage_idx}", model) + return pipe + else: + # TODO figure out PP compatible deferred initialization + # we have now moved from meta to device, + # reset parameters for proper initialization + model.reset_parameters() + logger.info("Model fully initialized via reset_parameters") return model diff --git a/train.py b/train.py index 184841f0..92744013 100644 --- a/train.py +++ b/train.py @@ -241,10 +241,12 @@ def main(job_config: JobConfig): input_ids = input_ids.cuda() labels = labels.cuda() - + print("i", input_ids.shape) + print("l", labels.shape) optimizer.zero_grad() # forward + # TODO - integrate pp batch splitter pred = model(input_ids) with loss_parallel() if parallel_dims.loss_parallel_enabled else contextlib.nullcontext(): diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 55cb93f9..1a1336f0 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 sequence_parallel_degree = 1 -pipeline_parallel_degree = 1 +pipeline_parallel_degree = 2 fp8_linear = "" compile = false checkpoint_interval = 3600 From 5fb7d1222f105ea32bdce8f31c2b73aeb8273411 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 2 Apr 2024 16:11:28 -0700 Subject: [PATCH 02/16] Update [ghstack-poisoned] --- train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index cbca73b7..ffa52709 100644 --- a/train.py +++ b/train.py @@ -310,8 +310,6 @@ def main(job_config: JobConfig): input_ids = input_ids.cuda() labels = labels.cuda() - print("i", input_ids.shape) - print("l", labels.shape) optimizer.zero_grad() if parallel_dims.pp_enabled: @@ -322,6 +320,12 @@ def main(job_config: JobConfig): pp_schedule.step(target=labels, losses=losses) else: schedule.step() + + # todo optimizer and scaler stuff + + # todo loss properly + current_loss = 10.0 + losses_since_last_log.append(current_loss) else: # forward pred = model(input_ids) From d4374611715de3b7a9c9a8b7e1991b1fcdebc97a Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 3 Apr 2024 15:40:41 -0700 Subject: [PATCH 03/16] Update [ghstack-poisoned] --- train.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index ffa52709..13c38f17 100644 --- a/train.py +++ b/train.py @@ -220,7 +220,10 @@ def main(job_config: JobConfig): group=pp_mesh.get_group(), ) pp_schedule = PipelineScheduleGPipe( - stage, n_microbatches=parallel_dims.pp, loss_fn=None + stage, + n_microbatches=parallel_dims.pp, + loss_fn=lambda output, target: output.sum() + + torch.tensor([123.0], device=output.device), ) model.to_empty(device="cuda") else: @@ -313,9 +316,11 @@ def main(job_config: JobConfig): optimizer.zero_grad() if parallel_dims.pp_enabled: + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 + if pp_mesh.get_local_rank() == 0: pp_schedule.step(input_ids) - elif pp_mesh.get_local_rank() == pp_mesh.size() - 1: + elif is_last_stage: losses = [] pp_schedule.step(target=labels, losses=losses) else: @@ -324,7 +329,9 @@ def main(job_config: JobConfig): # todo optimizer and scaler stuff # todo loss properly - current_loss = 10.0 + current_loss = ( + torch.mean(torch.stack(losses)).item() if is_last_stage else -1.0 + ) losses_since_last_log.append(current_loss) else: # forward From 522f93b024dc8ecffb188f25228b3e779158430e Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 3 Apr 2024 16:19:39 -0700 Subject: [PATCH 04/16] Update [ghstack-poisoned] --- train.py | 95 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 50 insertions(+), 45 deletions(-) diff --git a/train.py b/train.py index 13c38f17..e373cc8d 100644 --- a/train.py +++ b/train.py @@ -131,7 +131,9 @@ def main(job_config: JobConfig): world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + # torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + torch.cuda.set_device(device) init_distributed(job_config) world_mesh = parallel_dims.build_mesh(device_type="cuda") @@ -150,6 +152,14 @@ def main(job_config: JobConfig): dp_rank = dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 + + if parallel_dims.pp_enabled: + pp_mesh = world_mesh["pp"] + pp_degree = pp_mesh.size() + pp_rank = pp_mesh.get_local_rank() + else: + pp_degree, pp_rank = 1, 0 + data_loader = build_dataloader_fn( job_config.training.dataset, job_config.training.dataset_path, @@ -199,20 +209,29 @@ def main(job_config: JobConfig): model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) + if parallel_dims.pp_enabled: + pipe_meta = model + model = pipe_meta.get_stage_module(pp_rank) + + # build grad scaler which is effective only when mixed precision training + # is enabled with fp16 param dtype under FSDP + scaler = build_grad_scaler(model) + + def loss_fn(pred, labels): + with ( + loss_parallel() + if parallel_dims.loss_parallel_enabled + else contextlib.nullcontext() + ): + loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) + + # backward on scaled loss to create scaled gradients + scaler.scale(loss) + return loss # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if # there are virtual stages if parallel_dims.pp_enabled: - pipe_meta = model - pp_mesh = world_mesh["pp"] - pp_degree = pp_mesh.size() - pp_rank = pp_mesh.get_local_rank() - logger.info( - f"{Color.blue}Extracting pipeline module for stage {pp_rank}{Color.reset}" - ) - device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") - - model = pipe_meta.get_stage_module(pp_rank) stage = PipelineStage( pipe=pipe_meta, stage_index=pp_rank, @@ -222,8 +241,7 @@ def main(job_config: JobConfig): pp_schedule = PipelineScheduleGPipe( stage, n_microbatches=parallel_dims.pp, - loss_fn=lambda output, target: output.sum() - + torch.tensor([123.0], device=output.device), + loss_fn=loss_fn, ) model.to_empty(device="cuda") else: @@ -239,11 +257,6 @@ def main(job_config: JobConfig): # build optimizer after applying parallelisms to the model optimizer = build_optimizer(model, job_config) scheduler = get_lr_scheduler(optimizer, job_config) - - # build grad scaler which is effective only when mixed precision training - # is enabled with fp16 param dtype under FSDP - scaler = build_grad_scaler(model) - metric_logger = build_metric_logger(job_config) # torch.compile model for improved performance @@ -316,6 +329,7 @@ def main(job_config: JobConfig): optimizer.zero_grad() if parallel_dims.pp_enabled: + # pipeline F/Loss/B is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 if pp_mesh.get_local_rank() == 0: @@ -326,44 +340,35 @@ def main(job_config: JobConfig): else: schedule.step() - # todo optimizer and scaler stuff - - # todo loss properly + # accumulate losses across pipeline microbatches current_loss = ( torch.mean(torch.stack(losses)).item() if is_last_stage else -1.0 ) - losses_since_last_log.append(current_loss) else: - # forward + # non-pipeline F/Loss/B pred = model(input_ids) - with ( - loss_parallel() - if parallel_dims.loss_parallel_enabled - else contextlib.nullcontext() - ): - loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) + loss = loss_fn(pred, labels) + loss.backward() - # backward on scaled loss to create scaled gradients - scaler.scale(loss).backward() + current_loss = loss.item() - # clip gradients (after unscaling gradients of the optimizer's params) - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_( - model.parameters(), job_config.training.max_norm, foreach=True - ) + # clip gradients (after unscaling gradients of the optimizer's params) + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + model.parameters(), job_config.training.max_norm, foreach=True + ) - # optimizer step - # If gradients don't contain infs/NaNs, optimizer.step() is then called; - # otherwise, optimizer.step() is skipped. - scaler.step(optimizer) - scheduler.step() + # optimizer step + # If gradients don't contain infs/NaNs, optimizer.step() is then called; + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + scheduler.step() - # updates the scale for next iteration - scaler.update() + # updates the scale for next iteration + scaler.update() - current_loss = loss.item() - losses_since_last_log.append(current_loss) + losses_since_last_log.append(current_loss) # log metrics if ( From 1707df962b8102f3e1e31a455c6ce0691ae93d40 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 5 Apr 2024 11:45:59 -0700 Subject: [PATCH 05/16] Update [ghstack-poisoned] --- torchtrain/parallelisms/parallelize_llama.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 45acaee6..3c1b2c60 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -236,20 +236,24 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ) ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} - for layer_id, transformer_block in enumerate(model.layers): + for layer_name, transformer_block in model.layers.named_children(): if job_config.activation_checkpoint.mode in ("full", "selective"): transformer_block = checkpoint_wrapper( transformer_block, job_config.activation_checkpoint ) # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately - reshard_after_forward = layer_id < len(model.layers) - 1 + # reshard_after_forward = layer_id < len(model.layers) - 1 + # TODO(whc) need to fix correctly handle layer-ids on pp-split module + reshard_after_forward = True fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) - model.layers[layer_id] = transformer_block + # model.layers[layer_id] = transformer_block + # TODO(whc) + setattr(model.layers, layer_name, transformer_block) model = fully_shard(model, **fsdp_config) if ac_mode in ("full", "selective"): logger.info(f"Applied {ac_mode} activation checkpointing to the model") From 72469d4f91063898e6236239d897b7a333925316 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 10 Apr 2024 14:57:17 -0700 Subject: [PATCH 06/16] Update [ghstack-poisoned] --- run_llama_train.sh | 2 +- torchtrain/parallelisms/parallelize_llama.py | 7 +++++++ train.py | 15 ++++++++------- train_configs/debug_model.toml | 2 +- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/run_llama_train.sh b/run_llama_train.sh index 12a13d4d..33f04305 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -26,4 +26,4 @@ fi torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -train.py --job.config_file ${CONFIG_FILE} $overrides --training.checkpoint_folder /data/users/whc/torchtrain +train.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index e5785f96..956ca035 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -134,6 +134,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): the model must fit on GPU or CPU memory. """ if parallel_dims.pp_enabled: + + if job_config.model.norm_type == "fused_rmsnorm": + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode + # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + raise NotImplementedError( + "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + ) pp_mesh = world_mesh["pp"] stage_idx = pp_mesh.get_local_rank() layers_per_rank = len(model.layers) // parallel_dims.pp diff --git a/train.py b/train.py index 5a014605..2cda5862 100644 --- a/train.py +++ b/train.py @@ -15,7 +15,7 @@ import torch import torch.nn.functional as F -from pippy.PipelineSchedule import PipelineScheduleGPipe +from pippy.PipelineSchedule import ScheduleGPipe from pippy.PipelineStage import PipelineStage from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel @@ -225,7 +225,7 @@ def loss_fn(pred, labels): device=device, group=pp_mesh.get_group(), ) - pp_schedule = PipelineScheduleGPipe( + pp_schedule = ScheduleGPipe( stage, n_microbatches=parallel_dims.pp, loss_fn=loss_fn, @@ -317,7 +317,7 @@ def loss_fn(pred, labels): # pipeline parallel forward / backward inside step() call is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 - with loss_parallel_ctx: + with loss_parallel_ctx(): if pp_mesh.get_local_rank() == 0: pp_schedule.step(input_ids) elif is_last_stage: @@ -327,8 +327,10 @@ def loss_fn(pred, labels): schedule.step() # accumulate losses across pipeline microbatches - current_loss = ( - torch.mean(torch.stack(losses)).item() if is_last_stage else -1.0 + loss = ( + torch.mean(torch.stack(losses)) + if is_last_stage + else torch.Tensor([-1.0]) ) else: # forward / backward @@ -337,7 +339,6 @@ def loss_fn(pred, labels): loss = loss_fn(pred, labels) loss.backward() # TODO(whc) rebase conflict, rewrite how loss is handled? - current_loss = loss.item() # clip gradients torch.nn.utils.clip_grad_norm_( @@ -348,7 +349,7 @@ def loss_fn(pred, labels): optimizer.step() scheduler.step() - losses_since_last_log.append(current_loss) + losses_since_last_log.append(loss) # log metrics if ( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 6447e13b..4db45eda 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama" flavor = "debugmodel" -norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model" [optimizer] From 957193652427305590916dac20535ae98560a540 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 1 May 2024 18:36:42 -0700 Subject: [PATCH 07/16] Update [ghstack-poisoned] --- test_runner.py | 23 +++++++++++++++++++++++ train.py | 13 ++++++++++--- train_configs/debug_model.toml | 2 +- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/test_runner.py b/test_runner.py index 80d75ad8..cac64f8c 100755 --- a/test_runner.py +++ b/test_runner.py @@ -26,6 +26,7 @@ class OverrideDefinitions: override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) test_descr: str = "default" + requires_seed_ckpt: bool = False CONFIG_DIR = "./train_configs" @@ -85,6 +86,28 @@ class OverrideDefinitions: ], "Checkpoint Integration Test - Save Model Weights Only bf16", ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--training.pipeline_parallel_degree 4", + "--training.data_parallel_degree 1", + ], + ], + "PP 1D test", + requires_seed_ckpt=True, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--training.pipeline_parallel_degree 2", + "--training.data_parallel_degree 2", + ], + ], + "PP+DP 2D test", + requires_seed_ckpt=True, + ), ] diff --git a/train.py b/train.py index 69654cfb..cca4a735 100644 --- a/train.py +++ b/train.py @@ -19,10 +19,10 @@ import torch import torch.nn.functional as F -from torch.distributed import destroy_process_group -from torch.distributed.checkpoint.stateful import Stateful from pippy.PipelineSchedule import ScheduleGPipe from pippy.PipelineStage import PipelineStage +from torch.distributed import destroy_process_group +from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel @@ -153,6 +153,7 @@ def main(job_config: JobConfig): pp_mesh = world_mesh["pp"] pp_degree = pp_mesh.size() pp_rank = pp_mesh.get_local_rank() + else: pp_degree, pp_rank = 1, 0 @@ -289,7 +290,13 @@ def loss_fn(pred, labels): logger.info("Created seed checkpoint") return - checkpoint.load() + checkpoint_loaded = checkpoint.load() + + if parallel_dims.pp_enabled and not checkpoint_loaded: + raise RuntimeError( + "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. " + "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" + ) # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index b2a29deb..542dd8a7 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -36,7 +36,7 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -pipeline_parallel_degree = 2 +pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) From a12f5248ef3c08b3f6ee7a6966cf2e3c2d3f2f11 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 1 May 2024 20:26:29 -0700 Subject: [PATCH 08/16] Update [ghstack-poisoned] --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index b82120a6..ac350893 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ tensorboard sentencepiece tiktoken blobfile +# TODO remove pippy requirement after completing migration to pytorch +git+https://github.com/pytorch/pippy From 413fdc82b087ec6fd55409adcdbab337cb650685 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 1 May 2024 20:33:32 -0700 Subject: [PATCH 09/16] Update [ghstack-poisoned] --- test_runner.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/test_runner.py b/test_runner.py index cac64f8c..0d780c90 100755 --- a/test_runner.py +++ b/test_runner.py @@ -26,7 +26,7 @@ class OverrideDefinitions: override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) test_descr: str = "default" - requires_seed_ckpt: bool = False + requires_seed_checkpoint: bool = False CONFIG_DIR = "./train_configs" @@ -95,7 +95,7 @@ class OverrideDefinitions: ], ], "PP 1D test", - requires_seed_ckpt=True, + requires_seed_checkpoint=True, ), OverrideDefinitions( [ @@ -106,27 +106,35 @@ class OverrideDefinitions: ], ], "PP+DP 2D test", - requires_seed_ckpt=True, + requires_seed_checkpoint=True, ), ] +def _run_cmd(cmd): + return subprocess.run( + [cmd], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + shell=True, + ) + + def run_test(test_flavor: OverrideDefinitions, full_path: str): # run_test supports sequence of tests. for override_arg in test_flavor.override_args: + if test_flavor.requires_seed_checkpoint: + run_cmd( + f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh --checkpoint.folder {test_checkpoint_dir}" + ) cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh" if override_arg: cmd += " " + " ".join(override_arg) print( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) - result = subprocess.run( - [cmd], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - shell=True, - ) + result = _run_cmd(cmd) print(result.stdout) if result.returncode != 0: raise Exception( From 918265e74166b5fe0ffd3e081fe7038dd71c7715 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 1 May 2024 21:59:55 -0700 Subject: [PATCH 10/16] Update [ghstack-poisoned] --- train.py | 5 ++--- train_configs/debug_model.toml | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index cca4a735..a2be23ed 100644 --- a/train.py +++ b/train.py @@ -358,12 +358,11 @@ def loss_fn(pred, labels): else torch.Tensor([-1.0]) ) else: - # forward / backward - with loss_parallel_ctx: + # Non-PP forward / backward + with loss_parallel_ctx(): pred = model(input_ids) loss = loss_fn(pred, labels) loss.backward() - # TODO(whc) rebase conflict, rewrite how loss is handled? # clip gradients torch.nn.utils.clip_grad_norm_( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 542dd8a7..0133298b 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -20,7 +20,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "debugmodel" -norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm +norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm # test tokenizer.model, for debug purpose only tokenizer_path = "./test/assets/test_tiktoken.model" @@ -42,7 +42,7 @@ compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) [checkpoint] -enable_checkpoint = true +enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 5 From 07ae194fabf3afe726e9e3df48c0e49984319917 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 1 May 2024 22:21:03 -0700 Subject: [PATCH 11/16] Update [ghstack-poisoned] --- test_runner.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test_runner.py b/test_runner.py index 367693f1..b4362994 100755 --- a/test_runner.py +++ b/test_runner.py @@ -126,16 +126,21 @@ def _run_cmd(cmd): def run_test(test_flavor: OverrideDefinitions, full_path: str): # run_test supports sequence of tests. for override_arg in test_flavor.override_args: - if test_flavor.requires_seed_checkpoint: - run_cmd( - f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh --checkpoint.folder {test_checkpoint_dir}" - ) + cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh" if override_arg: cmd += " " + " ".join(override_arg) print( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) + + if test_flavor.requires_seed_checkpoint: + print("Creating seed checkpoint") + result = run_cmd( + f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh --checkpoint.folder {test_checkpoint_dir}" + ) + print(result.stdout) + result = _run_cmd(cmd) print(result.stdout) if result.returncode != 0: From c7e1a7df8fa4bac9f33ff206fde6c98dd40220c3 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 2 May 2024 08:31:16 -0700 Subject: [PATCH 12/16] Update [ghstack-poisoned] --- create_seed_checkpoint.sh | 3 +++ train.py | 1 + 2 files changed, 4 insertions(+) diff --git a/create_seed_checkpoint.sh b/create_seed_checkpoint.sh index 0c00145d..255eb25d 100755 --- a/create_seed_checkpoint.sh +++ b/create_seed_checkpoint.sh @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +echo "create_seed_checkpoint.sh top" + set -ex # libUV is a scalable backend for TCPStore which is used in processGroup @@ -27,6 +29,7 @@ overrides="" if [ $# -ne 0 ]; then overrides="$*" fi +echo "create_seed_checkpoint.sh call torchrun" torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ diff --git a/train.py b/train.py index 5b7ae1d7..a7940cb5 100644 --- a/train.py +++ b/train.py @@ -470,6 +470,7 @@ def loss_fn(pred, labels): if __name__ == "__main__": + print("train.py __main__") config = JobConfig() config.parse_args() main(config) From 6adc6dc33757b7ef4111f9708b99b50826f47198 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 2 May 2024 09:52:52 -0700 Subject: [PATCH 13/16] Update [ghstack-poisoned] --- create_seed_checkpoint.sh | 3 --- test_runner.py | 2 +- train.py | 1 - 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/create_seed_checkpoint.sh b/create_seed_checkpoint.sh index 255eb25d..0c00145d 100755 --- a/create_seed_checkpoint.sh +++ b/create_seed_checkpoint.sh @@ -5,8 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -echo "create_seed_checkpoint.sh top" - set -ex # libUV is a scalable backend for TCPStore which is used in processGroup @@ -29,7 +27,6 @@ overrides="" if [ $# -ne 0 ]; then overrides="$*" fi -echo "create_seed_checkpoint.sh call torchrun" torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ diff --git a/test_runner.py b/test_runner.py index b4362994..75d4405f 100755 --- a/test_runner.py +++ b/test_runner.py @@ -136,7 +136,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): if test_flavor.requires_seed_checkpoint: print("Creating seed checkpoint") - result = run_cmd( + result = _run_cmd( f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh --checkpoint.folder {test_checkpoint_dir}" ) print(result.stdout) diff --git a/train.py b/train.py index a7940cb5..5b7ae1d7 100644 --- a/train.py +++ b/train.py @@ -470,7 +470,6 @@ def loss_fn(pred, labels): if __name__ == "__main__": - print("train.py __main__") config = JobConfig() config.parse_args() main(config) From 2bc6a9a71cc418757a0c790cc93809d0fe0afd89 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 3 May 2024 09:23:50 -0700 Subject: [PATCH 14/16] Update [ghstack-poisoned] --- test_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test_runner.py b/test_runner.py index 4cac5aa2..e13287bf 100755 --- a/test_runner.py +++ b/test_runner.py @@ -91,6 +91,7 @@ class OverrideDefinitions: [ [ "--checkpoint.enable_checkpoint", + f"--checkpoint.folder {test_checkpoint_dir}", "--training.pipeline_parallel_degree 2", "--training.data_parallel_degree 1", "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue @@ -104,6 +105,7 @@ class OverrideDefinitions: [ [ "--checkpoint.enable_checkpoint", + f"--checkpoint.folder {test_checkpoint_dir}", "--training.pipeline_parallel_degree 2", "--training.data_parallel_degree 2", "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue From a43fa7fd39707208a635ec688a7ba02053f27c30 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 3 May 2024 11:37:35 -0700 Subject: [PATCH 15/16] Update [ghstack-poisoned] --- test_runner.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test_runner.py b/test_runner.py index e13287bf..3ece3436 100755 --- a/test_runner.py +++ b/test_runner.py @@ -91,7 +91,7 @@ class OverrideDefinitions: [ [ "--checkpoint.enable_checkpoint", - f"--checkpoint.folder {test_checkpoint_dir}", + f"--checkpoint.folder {test_checkpoint_dir}_pp", "--training.pipeline_parallel_degree 2", "--training.data_parallel_degree 1", "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue @@ -105,7 +105,7 @@ class OverrideDefinitions: [ [ "--checkpoint.enable_checkpoint", - f"--checkpoint.folder {test_checkpoint_dir}", + f"--checkpoint.folder {test_checkpoint_dir}_pp_dp", "--training.pipeline_parallel_degree 2", "--training.data_parallel_degree 2", "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue @@ -139,9 +139,16 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): ) if test_flavor.requires_seed_checkpoint: + checkpoint_folder_arg = None + for arg in override_arg: + if "--checkpoint.folder" in arg: + checkpoint_folder_arg = arg + assert ( + checkpoint_folder_arg is not None + ), "Can't use seed checkpoint if folder is not specified" print("Creating seed checkpoint") result = _run_cmd( - f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh --checkpoint.folder {test_checkpoint_dir}" + f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {checkpoint_folder_arg}" ) print(result.stdout) From 6558428077a7031c9ac96bbd27c24fb4cda9503d Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 3 May 2024 12:07:58 -0700 Subject: [PATCH 16/16] Update [ghstack-poisoned] --- torchtitan/parallelisms/parallelize_llama.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index d4b01a1d..27f1f28e 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -157,13 +157,10 @@ def apply_pipeline_parallelism(model, world_mesh, parallel_dims, job_config: Job for i in range(1, parallel_dims.pp) } # Get example input - label_shape = input_shape = (8, 2048) # TODO + input_shape = (job_config.training.batch_size, job_config.training.seq_len) input_ids = torch.randint( model.vocab_size, input_shape, dtype=torch.int64, device="meta" ) - labels = torch.randint( - model.vocab_size, label_shape, dtype=torch.int64, device="meta" - ) # Create a pipeline representation from the model pipe = pipeline(