diff --git a/pyproject.toml b/pyproject.toml index 89622b1..04cabea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "titans-pytorch" -version = "0.1.17" +version = "0.1.18" description = "Titans" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -29,7 +29,7 @@ dependencies = [ "axial_positional_embedding>=0.3.9", "einops>=0.8.0", "einx>=0.3.0", - "hyper-connections>=0.1.8", + "hyper-connections>=0.1.9", "Ninja", "rotary-embedding-torch", "tensordict", diff --git a/titans_pytorch/mac_transformer.py b/titans_pytorch/mac_transformer.py index 0a4b1ab..ef385ba 100644 --- a/titans_pytorch/mac_transformer.py +++ b/titans_pytorch/mac_transformer.py @@ -217,7 +217,8 @@ def forward_flex( self, seq, value_residual = None, - flex_attn_fn: Callable | None = None + flex_attn_fn: Callable | None = None, + output_gating = None ): assert not (exists(value_residual) ^ exists(self.to_learned_v_mix)) @@ -267,6 +268,9 @@ def forward_flex( out = self.to_out(out) + if exists(output_gating): + out = out * output_gating + return out, orig_v def forward( @@ -274,10 +278,11 @@ def forward( seq, value_residual = None, flex_attn_fn: Callable | None = None, - disable_flex_attn = False + disable_flex_attn = False, + output_gating = None ): if seq.is_cuda and self.use_flex_attn and not disable_flex_attn: - return self.forward_flex(seq, value_residual, flex_attn_fn) + return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating) assert not (exists(value_residual) ^ exists(self.to_learned_v_mix)) @@ -361,50 +366,10 @@ def forward( out = inverse_segment(out) - return out, orig_v - -# Attention + Neural Memory gating configuration, as depicted in Figure 2 - -class NeuralMemoryGatingWrapper(Module): - def __init__( - self, - dim, - attn: SegmentedAttention, - neural_mem: NeuralMemory | None = None, - gate_attn_output = True - ): - super().__init__() - self.attn = attn - self.neural_mem = neural_mem - self.gate_attn_output = gate_attn_output - - def forward( - self, - seq, - *args, - **kwargs - ): - batch, seq_len = seq.shape[:2] - mem = self.neural_mem - - if not exists(mem): - return self.attn(seq, *args, **kwargs), 0. + if exists(output_gating): + out = out * output_gating - # initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory - - retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True) - - if not self.gate_attn_output: - seq = seq + retrieved - - # attention - - attn_out, values = self.attn(seq, *args, **kwargs) - - if self.gate_attn_output: - attn_out = attn_out * retrieved.sigmoid() - - return (attn_out, values), kv_aux_loss + return out, orig_v # MAC transformer @@ -494,16 +459,10 @@ def __init__( **neural_memory_kwargs ) - attn = NeuralMemoryGatingWrapper( - dim, - attn = attn, - neural_mem = mem, - gate_attn_output = neural_mem_gate_attn_output - ) - ff = FeedForward(dim = dim, mult = ff_mult) self.layers.append(ModuleList([ + init_hyper_conn(dim = dim, branch = mem, add_branch_out_to_residual = not neural_mem_gate_attn_output) if exists(mem) else None, init_hyper_conn(dim = dim, branch = attn), init_hyper_conn(dim = dim, branch = ff) ])) @@ -512,6 +471,10 @@ def __init__( self.to_logits = LinearNoBias(dim, num_tokens) + # whether to gate the attention output with the retrieved memories + + self.gate_attn_output = neural_mem_gate_attn_output + # auxiliary loss on kv recon self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0. @@ -652,19 +615,34 @@ def forward( x = self.expand_streams(x) - for attn, ff in self.layers: + for mem, attn, ff in self.layers: + + retrieved = None + attn_out_gates = None + + if exists(mem): + retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True) + kv_recon_losses = kv_recon_losses + mem_kv_aux_loss - (x, values), maybe_mem_kv_aux_loss = attn( + if self.gate_attn_output: + attn_out_gates = retrieved.sigmoid() + else: + seq = retrieved + + # attention + + x, values = attn( x, value_residual = value_residual, disable_flex_attn = disable_flex_attn, - flex_attn_fn = flex_attn_fn + flex_attn_fn = flex_attn_fn, + output_gating = attn_out_gates ) - kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss - value_residual = default(value_residual, values) + # feedforward + x = ff(x) x = self.reduce_streams(x)