Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add meta_init, enable it as default init process #84

Merged
merged 12 commits into from
Mar 5, 2024
48 changes: 48 additions & 0 deletions torchtrain/meta_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from contextlib import contextmanager

import torch
from torch import nn
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened


@contextmanager
def meta_model_init():
"""init model on meta device"""
saved_register_parameter = nn.Module.register_parameter
saved_register_buffer = nn.Module.register_buffer

def register_meta_param(module, name, param):
saved_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(
module._parameters[name].to(torch.device("meta")), **kwargs
)

def register_meta_buffer(module, name, buffer):
saved_register_buffer(module, name, buffer)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(torch.device("meta"))

try:
nn.Module.register_parameter = register_meta_param
nn.Module.register_buffer = register_meta_buffer
yield
finally:
nn.Module.register_parameter = saved_register_parameter
nn.Module.register_buffer = saved_register_buffer


@torch.no_grad()
def meta_to_real_init_fn(module: nn.Module):
for submodule in module.modules():
for param_name, param in submodule.named_parameters(recurse=False):
if not _is_fsdp_flattened(param) and param.is_meta:
materialized_param = nn.Parameter(
torch.randn_like(param, device=torch.device("cuda"))
)
setattr(submodule, param_name, materialized_param)
4 changes: 2 additions & 2 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def get_num_params(model: nn.Module, only_trainable: bool = False) -> int:
param_list = list(model.parameters())
if only_trainable:
param_list = [p for p in param_list if p.requires_grad]
unique_params = {p.data_ptr(): p for p in param_list}.values()
return sum(p.numel() for p in unique_params)
# unique_params = {p.data_ptr(): p for p in param_list}.values()
return sum(p.numel() for p in param_list)
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved


class MetricLogger:
Expand Down
5 changes: 3 additions & 2 deletions torchtrain/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

llama_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16),
"1B": ModelArgs(dim=1024, n_layers=16, n_heads=8),
"271M": ModelArgs(dim=1024, n_layers=16, n_heads=8),
"1B": ModelArgs(dim=2048, n_layers=18, n_heads=16),
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
"13B": ModelArgs(dim=5120, n_layers=40, n_heads=40),
"40B": ModelArgs(dim=5120, n_layers=80, n_heads=40),
"26B": ModelArgs(dim=5120, n_layers=80, n_heads=40),
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
"70B": ModelArgs(
dim=8192,
n_layers=80,
Expand Down
13 changes: 11 additions & 2 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim))
self.reset_parameters()

# re-enable if not using meta-init
# self.reset_parameters()

def _norm(self, x: torch.Tensor):
"""
Expand Down Expand Up @@ -466,7 +468,14 @@ def __init__(self, model_args: ModelArgs):
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)

# init model weights
self.reset_parameters()

# we are doing meta_init, which will call reset_parameters() after
# the model is moved to actual device.
# If you modify and are not using meta_init, you will need to call
# reset_parameters() manually as below:

# self.reset_parameters()
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved

rank0_log(f"Model built with: {self.model_args}")

def reset_parameters(
Expand Down
14 changes: 9 additions & 5 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import rank0_log
from torchtrain.meta_init import meta_to_real_init_fn

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -193,6 +194,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names

fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
Expand All @@ -204,12 +206,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# When torch.compile is active, it requires us to set use_orig_params=True
"use_orig_params": True,
"device_mesh": dp_mesh,
"param_init_fn": meta_to_real_init_fn,
}

with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()

# apply selective AC
transformer_block = checkpoint_wrapper(
Expand All @@ -220,10 +221,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
model.layers[layer_id] = wrap(transformer_block)

# wrap the rest layers with FSDP
model = wrap(model.cuda())
model = wrap(model)

rank0_log("Applied FSDP to the model...")
else:
model.cuda()

# redundant if FSDP is enabled, but ensure the model is on device regardless of which parallelisms were used
model.cuda()
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
# we have now moved from meta to device,
# reset parameters for proper initialization
model.reset_parameters()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologize that FSDP meta-device init is confusing, but I think this might not be fully correct.

  1. The contract from PyTorch core is supposed to be that module.reset_parameters() only resets/initializes the parameters immediately owned by module, not those of its children/submodules. Here, model: Transformer is initializing all of its submodules' parameters. This contract is the only way for reset_parameters() to always work compositionally, otherwise (like on our case) we must assume Transformer to always be the root module.
  2. When we call model.reset_parameters(), the parameters have already been flattened and sharded by FSDP. This means that any initialization method that depends on the tensor shape would be incorrect. That is why we would normally want users to do the correct initialization in the param_init_fn.

For this Llama case, it looks like perhaps the reset_parameters() have been written to not depend on the tensor shape directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct that the init does not depend on tensor shape directly.
I see your point though about only resetting it's own params and not children, but let's meet to discuss as I also have questions on how meta_init should work for FSDP2 and we can get an updated impl.

return model
7 changes: 5 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchtrain.datasets import create_tokenizer, dataloader_fn
from torchtrain.logging_utils import init_logger, rank0_log
from torchtrain.lr_scheduling import get_lr_scheduler
from torchtrain.meta_init import meta_model_init
from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor

from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
Expand Down Expand Up @@ -115,15 +116,17 @@ def main(job_config: JobConfig):
)

# build model
# TODO: add meta initialization
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
model_config.vocab_size = tokenizer.n_words

model = model_cls.from_model_args(model_config)
# build model using meta init
with meta_model_init():
model = model_cls.from_model_args(model_config)

# log model size
model_param_count = get_num_params(model)

if _is_local_logging:
rank0_log(
f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}"
Expand Down
Loading