Skip to content

Commit

Permalink
[not for land] TE experiments
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.use_te
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Dec 3, 2024
1 parent 87e2c09 commit 08fc333
Show file tree
Hide file tree
Showing 7 changed files with 411 additions and 14 deletions.
113 changes: 113 additions & 0 deletions test/test_te.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import copy

import torch
import torch.nn as nn

# path hack, TODO remove
import sys
sys.path.insert(0, '/home/vasiliy/local/torchtitan/')
import torchtitan.te_utils as te_utils
from torchtitan.models.norms import build_norm
from torchtitan.models.llama.model import FeedForward

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

# torch.use_deterministic_algorithms(True)
torch.manual_seed(0)

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)

def test_linear_module_swap():
x = torch.randn(32, 32, device='cuda')

m = nn.Sequential(nn.Linear(32, 32)).cuda()
te_utils.swap_linear_to_te_linear(m)
print(m)
m = torch.compile(m)

with maybe_te_float8_ctx:
y = m(x)
y.sum().backward()

print('done')

# Subsection of TransformerBlock with only the ffn norm and the ffn
class NormFFNBlock(nn.Module):
def __init__(self, dim, hidden_dim, multiple_of):
super().__init__()
self.ffn_norm = build_norm("rmsnorm", dim, eps=1e-12)
self.feed_forward = FeedForward(dim, hidden_dim, multiple_of, None)

def forward(self, h):
out = h + self.feed_forward(self.ffn_norm(h))
return out

def SQNR(x, y):
return 20 * torch.log10(
torch.linalg.norm(x) / torch.linalg.norm(x - y)
)

def test_norm_ffn_rewrite():
dim = 256
hidden_dim = 512
multiple_of = 1

x = torch.randn(1, 128, 256).cuda().bfloat16()
x_copy = copy.deepcopy(x)

m = NormFFNBlock(dim, hidden_dim, multiple_of).cuda().bfloat16()
m_copy = copy.deepcopy(m)
print(m)

y = m(x)
y.sum().backward()

te_utils.swap_norm_ffn_to_te_friendly_norm_ffn(m_copy)
print(m_copy)

y_copy = m_copy(x_copy)
y_copy.sum().backward()

# TODO: debug why not an exact match
print(torch.allclose(y, y_copy))
print(SQNR(y, y_copy))

# TODO test w13
# assert torch.allclose(m.ffn.w2.grad, m_copy.ffn.w2.grad, atol=0, rtol=0)

te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_linear(m_copy)
print(m_copy)

y_copy2 = m_copy(x_copy)
print(torch.allclose(y_copy, y_copy2))
print(SQNR(y_copy, y_copy2))

# works, so a bug in the swap above?
def test_split_linear():
M, K, N = 32, 64, 128
# M, K, N = 4, 6, 8

x = torch.randn(M, K)

fc1 = nn.Linear(K, N, bias=False)
fc2 = nn.Linear(K, N, bias=False)

fc3 = nn.Linear(K, N * 2, bias=False)
fc3.weight = torch.nn.Parameter(
torch.cat([copy.deepcopy(fc1.weight), copy.deepcopy(fc2.weight)], dim=0)
)

y1 = fc1(x)
y2 = fc2(x)
y3 = fc3(x)
y3_1, y3_2 = torch.split(y3, fc3.out_features // 2, dim=-1)

assert torch.allclose(y1, y3_1)
assert torch.allclose(y2, y3_2)


if __name__ == '__main__':
test()
47 changes: 47 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,53 @@ def __init__(self):
action="store_true",
help="Whether to compile the model",
)
self.parser.add_argument(
"--training.compile_ln_linear",
action="store_true",
help="Whether to compile only the LNLinear blocks",
)
self.parser.add_argument(
"--training.horizontally_fuse_fcs",
action="store_true",
help="""
If true, fuses ffn.fc1 and ffn.fc3 into ffn.fc13. Note that this is required
to use te.LayerNormLinear for FFNs.
TODO also implement this for attention.
""",
)
self.parser.add_argument(
"--training.te_swap_linear",
action="store_true",
help="""
If true, swaps torch.nn.Linear with te.Linear
(not for land)
Note:
* requires training.te_float8_autocast to use float8
""",
)
self.parser.add_argument(
"--training.te_swap_ln_linear",
action="store_true",
help="""
If true, swaps NormFeedForward.norm_w13 from
nn.Sequential(RMSNorm, nn.Linear) to te.LayerNormLinear
(not for land)
Note:
* requires training.horizontally_fuse_fcs to enable this swap
* this swap happens strictly before `training.te_swap_linear` if both are enabled
* requires training.te_float8_autocast to use float8
""",
)
self.parser.add_argument(
"--training.te_float8_autocast",
action="store_true",
help="""
If true, enables TE's float8 autocast context manager
(not for land)
""",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
Expand Down
2 changes: 2 additions & 0 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ def convert_to_float8_training(self, model: nn.Module):
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
# module_filter_fn=lambda mod, fqn: fqn != "output" and "norm_w13" in fqn,
)
logger.info(
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
f"{self.config.enable_fsdp_float8_all_gather}"
)
print(model)

def precompute_float8_dynamic_scale_for_fsdp(
self, model: Union[nn.Module, List[nn.Module]]
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
}

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
# "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
"debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
dim=4096,
n_layers=32,
# n_layers=1,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
Expand Down
60 changes: 53 additions & 7 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
apply_activation_checkpointing,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
Expand Down Expand Up @@ -75,7 +76,7 @@ def parallelize_llama(
"fused_rmsnorm is not compatible with torch.compile yet. "
"Please use rmsnorm or layernorm."
)
apply_compile(model)
apply_compile(model, job_config)

if (
parallel_dims.dp_shard_enabled
Expand Down Expand Up @@ -243,15 +244,44 @@ def apply_tp(
}


import transformer_engine.pytorch as te
rng_seed = 1234
torch.manual_seed(rng_seed)
torch.cuda.manual_seed(rng_seed)
CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker()
CUDA_RNG_STATES_TRACKER.add("model-parallel-rng", rng_seed)


def get_cuda_rng_tracker():
return CUDA_RNG_STATES_TRACKER


def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
valid_ac_modes = ("full", "selective")
valid_ac_modes = ("full", "selective", "full_te")
if ac_config.mode not in valid_ac_modes:
raise ValueError(
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
)

if ac_config.mode == "full":
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
elif ac_config.mode == "full_te":
# copy-paste from https://github.com/NVIDIA/TransformerEngine/blob/64126aa8c469b2a97ace01f925f3d5786d5fd1bb/examples/pytorch/fsdp/fsdp.py, apply_fsdp_checkpointing
# note:
# LLaMa 3 8B on 8 H100s with this option:
# 42.27 GiB, 4880 tps, strictly worse than PT-D's full AC. Have not done debugging
# on the cause yet.

wrapper = lambda m: ptd_checkpoint_wrapper(
m,
checkpoint_fn=te.distributed.checkpoint,
use_reentrant=False,
get_rng_state_tracker=get_cuda_rng_tracker,
)
def check_fn(submodule):
return True
apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
return module

assert ac_config.mode == "selective", f"{ac_config.mode}"
use_op_sac = ac_config.selective_ac_option == "op"
Expand Down Expand Up @@ -314,16 +344,32 @@ def apply_ac(model: nn.Module, ac_config):
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")


def apply_compile(model: nn.Module):
def apply_compile(model: nn.Module, job_config):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)
if job_config.training.compile_ln_linear:
def _apply_compile(mod):
for name, child in mod.named_children():
# hacky check, but good enough for this use case
if isinstance(child, torch.nn.Sequential) and len(child) == 2:
new_child = torch.compile(child)
setattr(mod, name, new_child)
else:
_apply_compile(child)

logger.info("Compiling each LNLinear with torch.compile")
_apply_compile(model)
print(model)
# TODO also option for just linear
else:
for layer_id, transformer_block in model.layers.named_children():
# transformer_block = torch.compile(transformer_block, fullgraph=True)
transformer_block = torch.compile(transformer_block, fullgraph=False)
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")
logger.info("Compiling each TransformerBlock with torch.compile")


def apply_fsdp(
Expand Down
Loading

0 comments on commit 08fc333

Please sign in to comment.