diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index ae02d15b..698079a6 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -129,7 +129,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): if parallel_dims.sp_enabled: # First we apply Sequence Parallelism if it's enabled tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh - sp_degree = job_config.training.sequence_parallelism_degree + sp_degree = job_config.training.sequence_parallel_degree # First: # 1. parallelize the first embedding and the last linear proj layer # 2. shard the first layer of transformer block diff --git a/train.py b/train.py index 5145462b..56b7e160 100644 --- a/train.py +++ b/train.py @@ -102,9 +102,10 @@ def main(job_config: JobConfig): # build dataloader # need dp world size and rank - # TODO: dp might not always be 0 so we need to handle that more carefully - dp_degree = world_mesh.size(0) - dp_rank = world_mesh.get_local_rank(0) + dp_mesh = world_mesh["dp"] + dp_degree = dp_mesh.size() + dp_rank = dp_mesh.get_local_rank() + print("testtest: ", torch.distributed.get_rank(), dp_degree, dp_rank) build_dataloader_fn = dataloader_fn[job_config.training.dataset] data_loader = build_dataloader_fn( tokenizer, @@ -253,8 +254,8 @@ def main(job_config: JobConfig): np.max(losses_since_last_log), ) global_avg_loss, global_max_loss = ( - dist_mean(avg_loss, world_mesh), - dist_max(max_loss, world_mesh), + dist_mean(avg_loss, dp_mesh), + dist_max(max_loss, dp_mesh), ) time_delta = timer() - time_last_log