From 8699f1c5e1acef72f7d8c3a3e6252d2526da5112 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sat, 24 Feb 2024 19:39:37 -0800 Subject: [PATCH 01/10] add meta_init --- torchtrain/meta_init.py | 48 ++++++++++++++++++++ torchtrain/parallelisms/parallelize_llama.py | 11 +++-- train.py | 7 ++- train_configs/debug_model.toml | 2 +- 4 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 torchtrain/meta_init.py diff --git a/torchtrain/meta_init.py b/torchtrain/meta_init.py new file mode 100644 index 00000000..3eafb012 --- /dev/null +++ b/torchtrain/meta_init.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +import torch +from torch import nn +from torch.distributed.fsdp._common_utils import _is_fsdp_flattened + +from contextlib import contextmanager + + +@contextmanager +def meta_model_init(): + """init model on meta device""" + saved_register_parameter = nn.Module.register_parameter + saved_register_buffer = nn.Module.register_buffer + + def register_meta_param(module, name, param): + saved_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + module._parameters[name] = param_cls( + module._parameters[name].to(torch.device("meta")), **kwargs + ) + + def register_meta_buffer(module, name, buffer): + saved_register_buffer(module, name, buffer) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(torch.device("meta")) + + try: + nn.Module.register_parameter = register_meta_param + nn.Module.register_buffer = register_meta_buffer + yield + finally: + nn.Module.register_parameter = saved_register_parameter + nn.Module.register_buffer = saved_register_buffer + + +@torch.no_grad() +def meta_to_real_init_fn(module: nn.Module): + for submodule in module.modules(): + for param_name, param in submodule.named_parameters(recurse=False): + if not _is_fsdp_flattened(param) and param.is_meta: + materialized_param = nn.Parameter( + torch.randn_like(param, device=torch.device("cuda")) + ) + setattr(submodule, param_name, materialized_param) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 805bfa87..35ec18c8 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -33,6 +33,7 @@ RowwiseParallel, ) from torchtrain.config_manager import JobConfig +from torchtrain.meta_init import meta_to_real_init_fn from torchtrain.logging_utils import rank0_log @@ -153,6 +154,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + + + fsdp_config = { "mixed_precision": MixedPrecision( param_dtype=torch.bfloat16, @@ -164,23 +168,24 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): # When torch.compile is active, it requires us to set use_orig_params=True "use_orig_params": True, "device_mesh": dp_mesh, + "param_init_fn": meta_to_real_init_fn, } with enable_wrap(wrapper_cls=FSDP, **fsdp_config): for layer_id, transformer_block in enumerate(model.layers): # apply AC to each layer # before wrapping with FSDP, we need to make sure the layer is on GPU - transformer_block = transformer_block.cuda() + # transformer_block = transformer_block.cuda() transformer_block = checkpoint_wrapper(transformer_block, job_config) # Wraps each layer with FSDP model.layers[layer_id] = wrap(transformer_block) # wrap the rest layers with FSDP - model = wrap(model.cuda()) + model = wrap(model) # .cuda()) rank0_log("Applied FSDP to the model...") # redundant if FSDP is enabled, but ensure the model is on device regardless of which parallelisms were used - model.cuda() + # model.cuda() return model diff --git a/train.py b/train.py index f3f3389e..1a1b92f4 100644 --- a/train.py +++ b/train.py @@ -28,6 +28,7 @@ from torchtrain.profiling import maybe_run_profiler from torchtrain.utils import dist_max, dist_mean +from torchtrain.meta_init import meta_model_init @dataclass @@ -113,8 +114,10 @@ def main(job_config: JobConfig): model_config = models_config[model_name][job_config.model.flavor] model_config.vocab_size = tokenizer.n_words - model = model_cls.from_model_args(model_config) - + # build model + with meta_model_init(): + model = model_cls.from_model_args(model_config) + model.reset_parameters() # log model size model_param_count = get_num_params(model) rank0_log( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 1cca38b0..6602c048 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -3,7 +3,7 @@ dump_folder = "./outputs" [profiling] -run_profiler = true +run_profiler = false save_traces_folder = "profiling/traces" # profiling frequency - example: 10 means every 10th iter will be profiled profile_every_x_iter = 10 From f211961cc34079e9db2ec798f2596cd71d531233 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sat, 24 Feb 2024 20:07:58 -0800 Subject: [PATCH 02/10] add meta_init, handle rope embedding buffers --- torchtrain/metrics.py | 4 +-- torchtrain/models/llama/__init__.py | 2 +- train.py | 10 +++++--- train_configs/llama7B.toml | 39 +++++++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 6 deletions(-) create mode 100644 train_configs/llama7B.toml diff --git a/torchtrain/metrics.py b/torchtrain/metrics.py index b2ad3cc9..0680c41e 100644 --- a/torchtrain/metrics.py +++ b/torchtrain/metrics.py @@ -192,8 +192,8 @@ def get_num_params(model: nn.Module, only_trainable: bool = False) -> int: param_list = list(model.parameters()) if only_trainable: param_list = [p for p in param_list if p.requires_grad] - unique_params = {p.data_ptr(): p for p in param_list}.values() - return sum(p.numel() for p in unique_params) + # unique_params = {p.data_ptr(): p for p in param_list}.values() + return sum(p.numel() for p in param_list) class MetricLogger: diff --git a/torchtrain/models/llama/__init__.py b/torchtrain/models/llama/__init__.py index c1f87f89..b40801d0 100644 --- a/torchtrain/models/llama/__init__.py +++ b/torchtrain/models/llama/__init__.py @@ -10,7 +10,7 @@ "1B": ModelArgs(dim=1024, n_layers=16, n_heads=8), "7B": ModelArgs(dim=4096, n_layers=32, n_heads=32), "13B": ModelArgs(dim=5120, n_layers=40, n_heads=40), - "40B": ModelArgs(dim=5120, n_layers=80, n_heads=40), + "26B": ModelArgs(dim=5120, n_layers=80, n_heads=40), "70B": ModelArgs( dim=8192, n_layers=80, diff --git a/train.py b/train.py index 1a1b92f4..25741b4c 100644 --- a/train.py +++ b/train.py @@ -109,20 +109,20 @@ def main(job_config: JobConfig): ) # build model - # TODO: add meta initialization model_cls = model_name_to_cls[model_name] model_config = models_config[model_name][job_config.model.flavor] model_config.vocab_size = tokenizer.n_words - # build model + # build model using meta init with meta_model_init(): model = model_cls.from_model_args(model_config) - model.reset_parameters() + # log model size model_param_count = get_num_params(model) rank0_log( f"Model {model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters" ) + gpu_metrics = GPUMemoryMonitor("cuda") rank0_log(f"GPU memory usage: {gpu_metrics}") @@ -131,6 +131,10 @@ def main(job_config: JobConfig): model, world_mesh, parallel_dims, job_config ) + # we have now moved from meta to device, + # reset parameters for proper initialization + model.reset_parameters() + # to use FSDP-customized gradient scaler and gradient clipping solutions assert isinstance(model, FSDP) diff --git a/train_configs/llama7B.toml b/train_configs/llama7B.toml new file mode 100644 index 00000000..2358ba1e --- /dev/null +++ b/train_configs/llama7B.toml @@ -0,0 +1,39 @@ +# TorchTrain Config.toml +[job] +dump_folder = "./outputs" + +[profiling] +run_profiler = false +save_traces_folder = "profiling/traces" +# profiling frequency - example: 10 means every 10th iter will be profiled +profile_every_x_iter = 10 + +[metrics] +enable_tensorboard = false +save_tb_folder = "tb" +log_freq = 10 + +[model] +name = "llama" +flavor = "7B" +tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model" + +[optimizer] +name = "AdamW" +lr = 8e-4 + + +[training] +batch_size = 8 +seq_len = 2048 +warmup_pct = 0.20 # lr scheduler warm up +max_norm = 1.0 # grad norm clipping +steps = 10 +data_parallel_degree = -1 +sequence_parallel_degree = 1 +pipeline_parallel_degree = 1 +compile = false +checkpoint_interval = 3600 +checkpoint_interval_type = "steps" +checkpoint_folder = "" +dataset = "alpaca" From 2b1871ed83c8aefa88627c4c3eee6afb0326002b Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sat, 24 Feb 2024 21:08:29 -0800 Subject: [PATCH 03/10] linting and updated license header --- torchtrain/meta_init.py | 6 +++--- torchtrain/parallelisms/parallelize_llama.py | 6 ++---- train.py | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/torchtrain/meta_init.py b/torchtrain/meta_init.py index 3eafb012..d67e6ef7 100644 --- a/torchtrain/meta_init.py +++ b/torchtrain/meta_init.py @@ -1,12 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from contextlib import contextmanager import torch from torch import nn from torch.distributed.fsdp._common_utils import _is_fsdp_flattened -from contextlib import contextmanager - @contextmanager def meta_model_init(): diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 35ec18c8..058afcd1 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -33,9 +33,9 @@ RowwiseParallel, ) from torchtrain.config_manager import JobConfig -from torchtrain.meta_init import meta_to_real_init_fn from torchtrain.logging_utils import rank0_log +from torchtrain.meta_init import meta_to_real_init_fn logger = logging.getLogger(__name__) @@ -155,8 +155,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - - fsdp_config = { "mixed_precision": MixedPrecision( param_dtype=torch.bfloat16, @@ -182,7 +180,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): model.layers[layer_id] = wrap(transformer_block) # wrap the rest layers with FSDP - model = wrap(model) # .cuda()) + model = wrap(model) # .cuda()) rank0_log("Applied FSDP to the model...") diff --git a/train.py b/train.py index 25741b4c..0e078fef 100644 --- a/train.py +++ b/train.py @@ -21,6 +21,7 @@ from torchtrain.datasets import create_tokenizer, dataloader_fn from torchtrain.logging_utils import init_logger, rank0_log from torchtrain.lr_scheduling import get_lr_scheduler +from torchtrain.meta_init import meta_model_init from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config @@ -28,7 +29,6 @@ from torchtrain.profiling import maybe_run_profiler from torchtrain.utils import dist_max, dist_mean -from torchtrain.meta_init import meta_model_init @dataclass From 6a15267db6d456150d296381f485b8e55b859d04 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 25 Feb 2024 07:38:06 -0800 Subject: [PATCH 04/10] remove commented out .cuda() refs --- torchtrain/parallelisms/parallelize_llama.py | 7 ++----- train_configs/debug_model.toml | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 058afcd1..dd47d2b7 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -171,19 +171,16 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): with enable_wrap(wrapper_cls=FSDP, **fsdp_config): for layer_id, transformer_block in enumerate(model.layers): + # apply AC to each layer - # before wrapping with FSDP, we need to make sure the layer is on GPU - # transformer_block = transformer_block.cuda() transformer_block = checkpoint_wrapper(transformer_block, job_config) # Wraps each layer with FSDP model.layers[layer_id] = wrap(transformer_block) # wrap the rest layers with FSDP - model = wrap(model) # .cuda()) + model = wrap(model) rank0_log("Applied FSDP to the model...") - # redundant if FSDP is enabled, but ensure the model is on device regardless of which parallelisms were used - # model.cuda() return model diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 6602c048..1cca38b0 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -3,7 +3,7 @@ dump_folder = "./outputs" [profiling] -run_profiler = false +run_profiler = true save_traces_folder = "profiling/traces" # profiling frequency - example: 10 means every 10th iter will be profiled profile_every_x_iter = 10 From 8fea6749f6770d58ad7c82bdafe3003dd77bd543 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 28 Feb 2024 04:33:06 -0800 Subject: [PATCH 05/10] add correct 1B size --- torchtrain/models/llama/__init__.py | 3 ++- train_configs/debug_model.toml | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torchtrain/models/llama/__init__.py b/torchtrain/models/llama/__init__.py index b40801d0..e6175ca9 100644 --- a/torchtrain/models/llama/__init__.py +++ b/torchtrain/models/llama/__init__.py @@ -7,7 +7,8 @@ llama_configs = { "debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16), - "1B": ModelArgs(dim=1024, n_layers=16, n_heads=8), + "271M": ModelArgs(dim=1024, n_layers=16, n_heads=8), + "1B": ModelArgs(dim=2048, n_layers=18, n_heads=16), "7B": ModelArgs(dim=4096, n_layers=32, n_heads=32), "13B": ModelArgs(dim=5120, n_layers=40, n_heads=40), "26B": ModelArgs(dim=5120, n_layers=80, n_heads=40), diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 1cca38b0..f57e14f7 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -3,7 +3,7 @@ dump_folder = "./outputs" [profiling] -run_profiler = true +run_profiler = false save_traces_folder = "profiling/traces" # profiling frequency - example: 10 means every 10th iter will be profiled profile_every_x_iter = 10 @@ -15,7 +15,7 @@ log_freq = 10 [model] name = "llama" -flavor = "debugmodel" +flavor = "1B" tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model" [optimizer] From 8b20a62b6920d9d82683b6e7dc9bc19136505134 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 28 Feb 2024 06:08:50 -0800 Subject: [PATCH 06/10] move model.reset_params() to inside parallelize llama --- train.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/train.py b/train.py index 85658b3b..95bf4da0 100644 --- a/train.py +++ b/train.py @@ -145,10 +145,6 @@ def main(job_config: JobConfig): model, world_mesh, parallel_dims, job_config ) - # we have now moved from meta to device, - # reset parameters for proper initialization - model.reset_parameters() - # to use FSDP-customized gradient scaler and gradient clipping solutions assert isinstance(model, FSDP) From bef7cf921cad88db38e1ec7f77324bf074439f08 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 28 Feb 2024 06:19:08 -0800 Subject: [PATCH 07/10] move reset params to be inside parallelize_llama --- torchtrain/parallelisms/parallelize_llama.py | 5 +++++ train_configs/debug_model.toml | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index dd47d2b7..3bd3c3a9 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -182,5 +182,10 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): model = wrap(model) rank0_log("Applied FSDP to the model...") + else: + model.cuda() + # we have now moved from meta to device, + # reset parameters for proper initialization + model.reset_parameters() return model diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index f57e14f7..6602c048 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -15,7 +15,7 @@ log_freq = 10 [model] name = "llama" -flavor = "1B" +flavor = "debugmodel" tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model" [optimizer] From 7b659350f3aa23638ef7dca263d741a7a52f8c6c Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Mon, 4 Mar 2024 17:35:58 -0800 Subject: [PATCH 08/10] remove llama7B as now llama_7b already added --- train_configs/debug_model.toml | 2 +- train_configs/llama7B.toml | 39 ---------------------------------- 2 files changed, 1 insertion(+), 40 deletions(-) delete mode 100644 train_configs/llama7B.toml diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index c4004c5f..d0f24431 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -4,7 +4,7 @@ dump_folder = "./outputs" description = "debug training" [profiling] -run_profiler = false +run_profiler = true save_traces_folder = "profiling/traces" # profiling frequency - example: 10 means every 10th iter will be profiled profile_every_x_iter = 10 diff --git a/train_configs/llama7B.toml b/train_configs/llama7B.toml deleted file mode 100644 index 2358ba1e..00000000 --- a/train_configs/llama7B.toml +++ /dev/null @@ -1,39 +0,0 @@ -# TorchTrain Config.toml -[job] -dump_folder = "./outputs" - -[profiling] -run_profiler = false -save_traces_folder = "profiling/traces" -# profiling frequency - example: 10 means every 10th iter will be profiled -profile_every_x_iter = 10 - -[metrics] -enable_tensorboard = false -save_tb_folder = "tb" -log_freq = 10 - -[model] -name = "llama" -flavor = "7B" -tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model" - -[optimizer] -name = "AdamW" -lr = 8e-4 - - -[training] -batch_size = 8 -seq_len = 2048 -warmup_pct = 0.20 # lr scheduler warm up -max_norm = 1.0 # grad norm clipping -steps = 10 -data_parallel_degree = -1 -sequence_parallel_degree = 1 -pipeline_parallel_degree = 1 -compile = false -checkpoint_interval = 3600 -checkpoint_interval_type = "steps" -checkpoint_folder = "" -dataset = "alpaca" From eee05e8c2cbf5335933b9a4ed0eb0486b94b6985 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Mon, 4 Mar 2024 18:56:33 -0800 Subject: [PATCH 09/10] remove call to reset_params in model init, with comments for how it interplays with meta_init and when it should be re-activated. --- torchtrain/models/llama/model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 2cd81a6c..af066fc9 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -466,7 +466,14 @@ def __init__(self, model_args: ModelArgs): self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) # init model weights - self.reset_parameters() + + # we are doing meta_init, which will call reset_parameters() after + # the model is moved to actual device. + # If you modify and are not using meta_init, you will need to call + # reset_parameters() manually as below: + + # self.reset_parameters() + rank0_log(f"Model built with: {self.model_args}") def reset_parameters( From 4261c54fb44aea9cb9d053514f515078461e25d7 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Mon, 4 Mar 2024 21:38:07 -0800 Subject: [PATCH 10/10] remove init call to reset_params in rmsnorm init to be consistent --- torchtrain/models/llama/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index af066fc9..1ba505cf 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -47,7 +47,9 @@ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.empty(dim)) - self.reset_parameters() + + # re-enable if not using meta-init + # self.reset_parameters() def _norm(self, x: torch.Tensor): """