Skip to content

Commit

Permalink
repro compiled rmsnorm error
Browse files Browse the repository at this point in the history
  • Loading branch information
tianyu-l committed May 15, 2024
1 parent 7f92f45 commit 2762a0a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
2 changes: 2 additions & 0 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

TORCH_LOGS="+dynamo" \
TORCHDYNAMO_VERBOSE=1 \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
10 changes: 6 additions & 4 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,18 @@ class RMSNorm(nn.Module):
"""

@staticmethod
def _norm(x: torch.Tensor, eps: float):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)

def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x: torch.Tensor):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
self.compiled_norm = torch.compile(RMSNorm._norm)

def forward(self, x: torch.Tensor):
output = self._norm(x.float()).type_as(x)
output = self.compiled_norm(x.float(), self.eps).type_as(x)
return output * self.weight

def reset_parameters(self):
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "debugmodel"
norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
# test tokenizer.model, for debug purpose only
tokenizer_path = "./test/assets/test_tiktoken.model"

Expand Down

0 comments on commit 2762a0a

Please sign in to comment.