Skip to content

Commit

Permalink
ability to return outputs from autoregressive wrapper forward
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 25, 2023
1 parent a00f542 commit 8795108
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.25.8',
version = '1.25.9',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
7 changes: 5 additions & 2 deletions x_transformers/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def generate(

return out

def forward(self, x, **kwargs):
def forward(self, x, return_outputs = False, **kwargs):
seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss

inp, target = x[:, :-1], x[:, 1:]
Expand Down Expand Up @@ -284,4 +284,7 @@ def forward(self, x, **kwargs):
if add_attn_z_loss:
loss = loss + cache.attn_z_loss

return loss
if not return_outputs:
return loss

return loss, (logits, cache)

0 comments on commit 8795108

Please sign in to comment.