Unofficial implementation of Titans in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
$ pip install titans-pytorch
import torch
from titans_pytorch import NeuralMemory
mem = NeuralMemory(
dim = 384,
chunk_size = 64,
pre_rmsnorm = True
).cuda()
seq = torch.randn(2, 1024, 384).cuda()
retrieved = mem(seq)
assert seq.shape == retrieved.shape
$ pip install .[examples]
For the SOTA linear attention, you will also need to run
$ pip install -r requirements.txt
Then modify train.py
and run it to query nature
$ python train.py
@inproceedings{Behrouz2024TitansLT,
title = {Titans: Learning to Memorize at Test Time},
author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:275212078}
}
@software{Kyrylov_Accelerated_Scan_2024,
author = {Kyrylov, Volodymyr},
doi = {10.5281/zenodo.10600962},
title = {Accelerated Scan},
version = {0.1.2},
year = {2024}
}