diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 9952f2de..e294e389 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -401,6 +401,9 @@ class TransformerBlock(nn.Module): attention_norm (RMSNorm): Layer normalization for attention output. ffn_norm (RMSNorm): Layer normalization for feedforward output. + Raises: + NotImplementedError: If norm_type is not rmsnorm or fastlayernorm. + """ def __init__(self, layer_id: int, model_args: ModelArgs):