From ae590792b472564833fdb053937fe0ae2d9a985a Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 25 Feb 2024 20:18:58 -0800 Subject: [PATCH] linting - add raises NotImplError --- torchtrain/models/llama/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 9952f2de..e294e389 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -401,6 +401,9 @@ class TransformerBlock(nn.Module): attention_norm (RMSNorm): Layer normalization for attention output. ffn_norm (RMSNorm): Layer normalization for feedforward output. + Raises: + NotImplementedError: If norm_type is not rmsnorm or fastlayernorm. + """ def __init__(self, layer_id: int, model_args: ModelArgs):