diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 2cd81a6c..9952f2de 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -27,6 +27,29 @@ class ModelArgs: depth_init: bool = ( True # initialization uses each unique layer_id or total model layer count ) + norm_type: str = "fastlayernorm" # or "rmsnorm" + + +class FastLayerNorm(torch.nn.Module): + """ + class for non-parametric (no scaling) layer norm + + """ + + def __init__( + self, + size: Optional[int] = None, + eps: float = 1e-06, + ): + super().__init__() + self.eps = eps + self.normalized_shape = (size,) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm(x, self.normalized_shape, eps=self.eps) + + def reset_parameters(self): + pass class RMSNorm(torch.nn.Module): @@ -393,8 +416,18 @@ def __init__(self, layer_id: int, model_args: ModelArgs): ) self.layer_id = layer_id self.num_layers = model_args.n_layers - self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) + if model_args.norm_type == "rmsnorm": + self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) + elif model_args.norm_type == "fastlayernorm": + self.attention_norm = FastLayerNorm( + size=model_args.dim, eps=model_args.norm_eps + ) + self.ffn_norm = FastLayerNorm(size=model_args.dim, eps=model_args.norm_eps) + else: + raise NotImplementedError( + f"{model_args.norm_type} is not supported. Please use rmsnorm or fastlayernorm." + ) if model_args.depth_init: self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5