diff --git a/torchtrain/meta_init.py b/torchtrain/meta_init.py new file mode 100644 index 00000000..d67e6ef7 --- /dev/null +++ b/torchtrain/meta_init.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from contextlib import contextmanager + +import torch +from torch import nn +from torch.distributed.fsdp._common_utils import _is_fsdp_flattened + + +@contextmanager +def meta_model_init(): + """init model on meta device""" + saved_register_parameter = nn.Module.register_parameter + saved_register_buffer = nn.Module.register_buffer + + def register_meta_param(module, name, param): + saved_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + module._parameters[name] = param_cls( + module._parameters[name].to(torch.device("meta")), **kwargs + ) + + def register_meta_buffer(module, name, buffer): + saved_register_buffer(module, name, buffer) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(torch.device("meta")) + + try: + nn.Module.register_parameter = register_meta_param + nn.Module.register_buffer = register_meta_buffer + yield + finally: + nn.Module.register_parameter = saved_register_parameter + nn.Module.register_buffer = saved_register_buffer + + +@torch.no_grad() +def meta_to_real_init_fn(module: nn.Module): + for submodule in module.modules(): + for param_name, param in submodule.named_parameters(recurse=False): + if not _is_fsdp_flattened(param) and param.is_meta: + materialized_param = nn.Parameter( + torch.randn_like(param, device=torch.device("cuda")) + ) + setattr(submodule, param_name, materialized_param) diff --git a/torchtrain/metrics.py b/torchtrain/metrics.py index d56d80a3..91a1e184 100644 --- a/torchtrain/metrics.py +++ b/torchtrain/metrics.py @@ -193,8 +193,8 @@ def get_num_params(model: nn.Module, only_trainable: bool = False) -> int: param_list = list(model.parameters()) if only_trainable: param_list = [p for p in param_list if p.requires_grad] - unique_params = {p.data_ptr(): p for p in param_list}.values() - return sum(p.numel() for p in unique_params) + # unique_params = {p.data_ptr(): p for p in param_list}.values() + return sum(p.numel() for p in param_list) class MetricLogger: diff --git a/torchtrain/models/llama/__init__.py b/torchtrain/models/llama/__init__.py index c1f87f89..e6175ca9 100644 --- a/torchtrain/models/llama/__init__.py +++ b/torchtrain/models/llama/__init__.py @@ -7,10 +7,11 @@ llama_configs = { "debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16), - "1B": ModelArgs(dim=1024, n_layers=16, n_heads=8), + "271M": ModelArgs(dim=1024, n_layers=16, n_heads=8), + "1B": ModelArgs(dim=2048, n_layers=18, n_heads=16), "7B": ModelArgs(dim=4096, n_layers=32, n_heads=32), "13B": ModelArgs(dim=5120, n_layers=40, n_heads=40), - "40B": ModelArgs(dim=5120, n_layers=80, n_heads=40), + "26B": ModelArgs(dim=5120, n_layers=80, n_heads=40), "70B": ModelArgs( dim=8192, n_layers=80, diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 2cd81a6c..1ba505cf 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -47,7 +47,9 @@ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.empty(dim)) - self.reset_parameters() + + # re-enable if not using meta-init + # self.reset_parameters() def _norm(self, x: torch.Tensor): """ @@ -466,7 +468,14 @@ def __init__(self, model_args: ModelArgs): self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) # init model weights - self.reset_parameters() + + # we are doing meta_init, which will call reset_parameters() after + # the model is moved to actual device. + # If you modify and are not using meta_init, you will need to call + # reset_parameters() manually as below: + + # self.reset_parameters() + rank0_log(f"Model built with: {self.model_args}") def reset_parameters( diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 698079a6..d11fac9f 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -35,6 +35,7 @@ ) from torchtrain.config_manager import JobConfig from torchtrain.logging_utils import rank0_log +from torchtrain.meta_init import meta_to_real_init_fn logger = logging.getLogger(__name__) @@ -193,6 +194,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + fsdp_config = { "mixed_precision": MixedPrecision( param_dtype=torch.bfloat16, @@ -204,12 +206,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): # When torch.compile is active, it requires us to set use_orig_params=True "use_orig_params": True, "device_mesh": dp_mesh, + "param_init_fn": meta_to_real_init_fn, } with enable_wrap(wrapper_cls=FSDP, **fsdp_config): for layer_id, transformer_block in enumerate(model.layers): - # before wrapping with FSDP, we need to make sure the layer is on GPU - transformer_block = transformer_block.cuda() # apply selective AC transformer_block = checkpoint_wrapper( @@ -220,10 +221,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): model.layers[layer_id] = wrap(transformer_block) # wrap the rest layers with FSDP - model = wrap(model.cuda()) + model = wrap(model) rank0_log("Applied FSDP to the model...") + else: + model.cuda() - # redundant if FSDP is enabled, but ensure the model is on device regardless of which parallelisms were used - model.cuda() + # we have now moved from meta to device, + # reset parameters for proper initialization + model.reset_parameters() return model diff --git a/train.py b/train.py index 3d4c3ae2..9c8e2f7b 100644 --- a/train.py +++ b/train.py @@ -22,6 +22,7 @@ from torchtrain.datasets import create_tokenizer, dataloader_fn from torchtrain.logging_utils import init_logger, rank0_log from torchtrain.lr_scheduling import get_lr_scheduler +from torchtrain.meta_init import meta_model_init from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config @@ -115,15 +116,17 @@ def main(job_config: JobConfig): ) # build model - # TODO: add meta initialization model_cls = model_name_to_cls[model_name] model_config = models_config[model_name][job_config.model.flavor] model_config.vocab_size = tokenizer.n_words - model = model_cls.from_model_args(model_config) + # build model using meta init + with meta_model_init(): + model = model_cls.from_model_args(model_config) # log model size model_param_count = get_num_params(model) + if _is_local_logging: rank0_log( f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}"