diff --git a/test/test_job_config.py b/test/test_job_config.py index 23571f7d..5dcf7490 100644 --- a/test/test_job_config.py +++ b/test/test_job_config.py @@ -1,6 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import tempfile + import pytest from torchtrain.config_manager import JobConfig @@ -20,3 +22,9 @@ def test_job_file_does_not_exist(self): with pytest.raises(FileNotFoundError): config = JobConfig() config.parse_args(["--job.config_file", "ohno.toml"]) + + def test_empty_config_file(self): + with tempfile.NamedTemporaryFile() as fp: + config = JobConfig() + config.parse_args(["--job.config_file", fp.name]) + assert config.job.description diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index 77d8593f..41d7007b 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -233,7 +233,6 @@ def init_args_from_command_line( ) parser.add_argument( "--training.enable_selective_ac", - default=False, action="store_true", help="whether to enable selective activation checkpointing", ) diff --git a/torchtrain/parallelisms/__init__.py b/torchtrain/parallelisms/__init__.py index 1c9a5641..fdcd938d 100644 --- a/torchtrain/parallelisms/__init__.py +++ b/torchtrain/parallelisms/__init__.py @@ -23,6 +23,7 @@ class ParallelDims: sp: int pp: int world_size: int + enable_loss_parallel: bool def __post_init__(self): self._validate() @@ -63,6 +64,10 @@ def sp_enabled(self): def pp_enabled(self): return self.pp > 1 + @property + def loss_parallel_enabled(self): + return self.sp > 1 and self.enable_loss_parallel + @cached_property def model_parallel_size(self): return self.sp * self.pp diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index cc82118e..34252b6b 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -122,15 +122,16 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): NOTE: the model passed in preferrablably shoule be a meta device model, otherwise the model needs to be small enough on GPU or can fit into CPU. - # TODO: apply SP """ # apply PTD parallelisms if parallel_dims.pp_enabled: raise NotImplementedError("PP not implemented yet.") + + # First we apply Sequence Parallelism if it's enabled if parallel_dims.sp_enabled: - # First we apply Sequence Parallelism if it's enabled - tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh + tp_mesh = world_mesh["sp"] sp_degree = job_config.training.sequence_parallel_degree + # First: # 1. parallelize the first embedding and the last linear proj layer # 2. shard the first layer of transformer block @@ -144,9 +145,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): "output": ColwiseParallel( input_layouts=Shard(0), output_layouts=Shard(-1) - if job_config.training.enable_loss_parallel + if parallel_dims.loss_parallel_enabled else Replicate(), - use_local_output=not job_config.training.enable_loss_parallel, + use_local_output=not parallel_dims.loss_parallel_enabled, ), "layers.0": PrepareModuleInput( input_layouts=(Replicate(), None), @@ -156,6 +157,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): }, ) + # shard the RMSNorm layer before last linear proj layer + distribute_rmsnorm(model.norm, tp_mesh) + # apply sequence parallelism to every transformer block for layer_id, transformer_block in enumerate(model.layers): layer_plan = { @@ -194,8 +198,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): rank0_log("Applied Sequence Parallelism to the model...") if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + dp_mesh = world_mesh["dp"] fsdp_config = { "mixed_precision": MixedPrecision( @@ -227,6 +230,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): rank0_log("Applied FSDP to the model...") else: + meta_to_real_init_fn(model) model.cuda() # we have now moved from meta to device, diff --git a/train.py b/train.py index 269ac522..80056ce1 100644 --- a/train.py +++ b/train.py @@ -94,6 +94,7 @@ def main(job_config: JobConfig): sp=job_config.training.sequence_parallel_degree, pp=job_config.training.pipeline_parallel_degree, world_size=world_size, + enable_loss_parallel=job_config.training.enable_loss_parallel, ) world_mesh = parallel_dims.build_mesh(device_type="cuda") rank0_log(f"Starting job: {job_config.job.description}") @@ -104,11 +105,13 @@ def main(job_config: JobConfig): tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader - # need dp world size and rank - dp_mesh = world_mesh["dp"] - dp_degree = dp_mesh.size() - dp_rank = dp_mesh.get_local_rank() build_dataloader_fn = dataloader_fn[job_config.training.dataset] + if parallel_dims.dp_enabled: + dp_mesh = world_mesh["dp"] + dp_degree = dp_mesh.size() + dp_rank = dp_mesh.get_local_rank() + else: + dp_degree, dp_rank = 1, 0 data_loader = build_dataloader_fn( tokenizer, job_config.training.batch_size, @@ -217,10 +220,7 @@ def main(job_config: JobConfig): pred = model(input_ids) - loss_parallel_enabled = ( - parallel_dims.sp_enabled and job_config.training.enable_loss_parallel - ) - with loss_parallel() if loss_parallel_enabled else contextlib.nullcontext(): + with loss_parallel() if parallel_dims.loss_parallel_enabled else contextlib.nullcontext(): loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) # backward on scaled loss to create scaled gradients @@ -259,10 +259,13 @@ def main(job_config: JobConfig): np.mean(losses_since_last_log), np.max(losses_since_last_log), ) - global_avg_loss, global_max_loss = ( - dist_mean(avg_loss, dp_mesh), - dist_max(max_loss, dp_mesh), - ) + if parallel_dims.dp_enabled: + global_avg_loss, global_max_loss = ( + dist_mean(avg_loss, dp_mesh), + dist_max(max_loss, dp_mesh), + ) + else: + global_avg_loss, global_max_loss = avg_loss, max_loss time_delta = timer() - time_last_log wps = nwords_since_last_log / (