Skip to content

Commit

Permalink
remove a bunch of unnecessary biases
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 5, 2024
1 parent a549c55 commit 7916859
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.26.6',
version = '1.27.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
7 changes: 4 additions & 3 deletions x_transformers/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AttentionLayers,
ScaledSinusoidalEmbedding,
AbsolutePositionalEmbedding,
LayerNorm,
always,
pad_at_dim
)
Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(
else:
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)

self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity()
self.emb_dropout = nn.Dropout(emb_dropout)

# memory tokens
Expand All @@ -71,8 +72,8 @@ def __init__(

# project in and out

self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
self.project_out = nn.Linear(dim, dim_out, bias = False) if exists(dim_out) else nn.Identity()

def forward(
self,
Expand Down
29 changes: 21 additions & 8 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):

self.mlp.append(Sequential(
nn.Linear(1, dim),
nn.LayerNorm(dim) if norm else None,
LayerNorm(dim) if norm else None,
nn.SiLU()
))

Expand Down Expand Up @@ -498,6 +498,19 @@ def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True)
return x / norm.clamp(min = self.eps) * self.g

class LayerNorm(nn.Module):
def __init__(self, dim):
"""
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
latest pytorch actually has a way to turn this off in nn.LayerNorm
"""
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))

def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
Expand Down Expand Up @@ -634,7 +647,7 @@ def __init__(

self.ff = Sequential(
project_in,
nn.LayerNorm(inner_dim) if post_act_ln else None,
LayerNorm(inner_dim) if post_act_ln else None,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out, bias = not no_bias)
)
Expand Down Expand Up @@ -1083,7 +1096,7 @@ def __init__(
elif use_simple_rmsnorm:
norm_class = SimpleRMSNorm
else:
norm_class = nn.LayerNorm
norm_class = LayerNorm

norm_fn = partial(norm_class, dim)

Expand Down Expand Up @@ -1415,12 +1428,12 @@ def __init__(
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))

self.patch_to_embedding = nn.Sequential(
nn.LayerNorm(patch_dim),
LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
LayerNorm(dim)
)

self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
LayerNorm(dim) if post_emb_norm else nn.Identity()
self.dropout = nn.Dropout(emb_dropout)

self.attn_layers = attn_layers
Expand Down Expand Up @@ -1515,7 +1528,7 @@ def __init__(

self.emb_frac_gradient = emb_frac_gradient

self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
self.emb_dropout = nn.Dropout(emb_dropout)

self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
Expand All @@ -1524,7 +1537,7 @@ def __init__(
self.init_()

logits_dim = default(logits_dim, num_tokens)
self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
self.to_logits = nn.Linear(dim, logits_dim, bias = False) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()

# memory tokens (like [cls]) from Memory Transformers paper

Expand Down

0 comments on commit 7916859

Please sign in to comment.