Skip to content

Commit

Permalink
always fix head dimension at 64, so users do not trip up with imprope…
Browse files Browse the repository at this point in the history
…r ratio of dimension to heads
  • Loading branch information
lucidrains committed Nov 3, 2020
1 parent 0dfba29 commit a398a5b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions examples/enwik8_simple/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __len__(self):
model.train()

for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader), return_loss = True)
loss = model(next(train_loader))
loss.backward()

print(f'training loss: {loss.item()}')
Expand All @@ -93,7 +93,7 @@ def __len__(self):
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader), return_loss = True)
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')

if i % GENERATE_EVERY == 0:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
12 changes: 6 additions & 6 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads = 8, causal = False, mask = None):
def __init__(self, dim, dim_head = 64, heads = 8, causal = False, mask = None):
super().__init__()
assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
self.scale = (dim // heads) ** -0.5
self.scale = dim_head ** -0.5
self.heads = heads
self.causal = causal
self.mask = mask

self.to_q = nn.Linear(dim, dim, bias = False)
self.to_kv = nn.Linear(dim, dim * 2, bias = False)
self.to_out = nn.Linear(dim, dim)
inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)

def forward(self, x, context = None, mask = None, context_mask = None):
b, n, _, h, device = *x.shape, self.heads, x.device
Expand Down

0 comments on commit a398a5b

Please sign in to comment.