diff --git a/setup.py b/setup.py index a869ed1..09655cd 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.35.2', + version = '1.35.3', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 2ae5b5c..8687351 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -920,6 +920,7 @@ def __init__( kv_heads = None, shared_kv = False, value_dim_head = None, + dim_out = None, tensor_product = False, # https://arxiv.org/abs/2208.06061 add_zero_kv = False, # same as add_zero_attn in pytorch rotary_embed_values = False, @@ -1057,7 +1058,11 @@ def __init__( # attention on attention self.attn_on_attn = on_attn - self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False) + + # output dimension by default same as input, but can be overridden + + dim_out = default(dim_out, dim) + self.to_out = nn.Sequential(nn.Linear(out_dim, dim_out * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim_out, bias = False) # whether to rotate positions into values, for absolute positions in addition to relative