diff --git a/pyproject.toml b/pyproject.toml index 25eeef9..7e5ec13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "titans-pytorch" -version = "0.1.36" +version = "0.1.37" description = "Titans" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/titans_pytorch/neural_memory.py b/titans_pytorch/neural_memory.py index c081ae7..4faddbf 100644 --- a/titans_pytorch/neural_memory.py +++ b/titans_pytorch/neural_memory.py @@ -723,6 +723,8 @@ def forward_inference( prev_layer_updates = TensorDict(prev_layer_updates) prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:]) + values = None + if store_seq_cache_len == self.chunk_size: next_updates, next_states, values = self.store_memories( @@ -770,7 +772,7 @@ def forward( if seq_len < self.retrieve_chunk_size: out = self.init_empty_memory_embed(batch, seq_len) - next_store_state = (seq_len, seq, None, None) + next_store_state = NeuralMemCache(seq_len, seq, None, None) out = (out, next_store_state)