From 2762a0a22461b6301ef000b5559151a3d586f5cd Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 15 May 2024 14:27:06 -0700 Subject: [PATCH] repro compiled rmsnorm error --- run_llama_train.sh | 2 ++ torchtitan/models/norms.py | 10 ++++++---- train_configs/debug_model.toml | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/run_llama_train.sh b/run_llama_train.sh index 33aaf79b..d973b602 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -29,6 +29,8 @@ if [ $# -ne 0 ]; then overrides="$*" fi +TORCH_LOGS="+dynamo" \ +TORCHDYNAMO_VERBOSE=1 \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ train.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index e29338d9..0451b545 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -82,16 +82,18 @@ class RMSNorm(nn.Module): """ + @staticmethod + def _norm(x: torch.Tensor, eps: float): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x: torch.Tensor): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + self.compiled_norm = torch.compile(RMSNorm._norm) def forward(self, x: torch.Tensor): - output = self._norm(x.float()).type_as(x) + output = self.compiled_norm(x.float(), self.eps).type_as(x) return output * self.weight def reset_parameters(self): diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 4541fec7..b8ec566f 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -20,7 +20,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "debugmodel" -norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm # test tokenizer.model, for debug purpose only tokenizer_path = "./test/assets/test_tiktoken.model"