Skip to content

Commit

Permalink
add non parametric layernorm, FastLayerNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 committed Feb 26, 2024
1 parent 2bbb124 commit 508c73c
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,29 @@ class ModelArgs:
depth_init: bool = (
True # initialization uses each unique layer_id or total model layer count
)
norm_type: str = "fastlayernorm" # or "rmsnorm"


class FastLayerNorm(torch.nn.Module):
"""
class for non-parametric (no scaling) layer norm
"""

def __init__(
self,
size: Optional[int] = None,
eps: float = 1e-06,
):
super().__init__()
self.eps = eps
self.normalized_shape = (size,)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(x, self.normalized_shape, eps=self.eps)

def reset_parameters(self):
pass


class RMSNorm(torch.nn.Module):
Expand Down Expand Up @@ -393,8 +416,18 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
)
self.layer_id = layer_id
self.num_layers = model_args.n_layers
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
if model_args.norm_type == "rmsnorm":
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
elif model_args.norm_type == "fastlayernorm":
self.attention_norm = FastLayerNorm(
size=model_args.dim, eps=model_args.norm_eps
)
self.ffn_norm = FastLayerNorm(size=model_args.dim, eps=model_args.norm_eps)
else:
raise NotImplementedError(
f"{model_args.norm_type} is not supported. Please use rmsnorm or fastlayernorm."
)

if model_args.depth_init:
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
Expand Down

0 comments on commit 508c73c

Please sign in to comment.