Skip to content

Commit

Permalink
move value residual so that kv caching works properly
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 31, 2024
1 parent d88537d commit 881be6b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.42.3',
version = '1.42.4',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
26 changes: 15 additions & 11 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,20 @@ def forward(

k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r))

# if previous values passed in for residual, either invoke resformer or neutreno

orig_values = v

if exists(value_residual):
if self.neutreno_value_residual:
diff_values = (value_residual - v) * self.neutreno_alpha
diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
else:
# https://arxiv.org/abs/2410.17897v1
v = 0.5 * (v + value_residual)

# take care of caching

if exists(cache):
ck, cv = cache.cached_kv

Expand Down Expand Up @@ -1363,16 +1377,6 @@ def forward(
if exists(self.data_dependent_alibi):
attn_bias = self.data_dependent_alibi(x)

# if previous values passed in for residual, either invoke resformer or neutreno

if exists(value_residual):
if self.neutreno_value_residual:
diff_values = (value_residual - v) * self.neutreno_alpha
diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
else:
# https://arxiv.org/abs/2410.17897v1
v = 0.5 * (v + value_residual)

# attention is all we need

out, intermediates = self.attend(
Expand All @@ -1384,7 +1388,7 @@ def forward(

# store the values for resformer or Neutreno

intermediates.values = v
intermediates.values = orig_values

if exists(value_residual) and self.neutreno_value_residual:
out = out + diff_values
Expand Down

0 comments on commit 881be6b

Please sign in to comment.