Skip to content

Commit

Permalink
add qk rmsnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 22, 2025
1 parent 3598ae2 commit 1bb76d4
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 1 deletion.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,15 @@ $ python train_mac.py
year = {2024}
}
```

```bibtex
@misc{wang2025testtimeregressionunifyingframework,
title = {Test-time regression: a unifying framework for designing sequence models with associative memory},
author = {Ke Alexander Wang and Jiaxin Shi and Emily B. Fox},
year = {2025},
eprint = {2501.12352},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2501.12352},
}
```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "titans-pytorch"
version = "0.1.20"
version = "0.1.21"
description = "Titans"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
3 changes: 3 additions & 0 deletions tests/test_titans.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def exists(v):
@pytest.mark.parametrize('learned_mem_model_weights', (False, True))
@pytest.mark.parametrize('attn_pool_chunks', (False, True))
@pytest.mark.parametrize('momentum', (False, True))
@pytest.mark.parametrize('qk_rmsnorm', (False, True))
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
@pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
def test_titans(
Expand All @@ -21,6 +22,7 @@ def test_titans(
learned_mem_model_weights,
attn_pool_chunks,
momentum,
qk_rmsnorm,
max_grad_norm,
per_parameter_lr_modulation
):
Expand All @@ -31,6 +33,7 @@ def test_titans(
attn_pool_chunks = attn_pool_chunks,
max_grad_norm = max_grad_norm,
momentum = momentum,
qk_rmsnorm = qk_rmsnorm,
per_parameter_lr_modulation = per_parameter_lr_modulation,
learned_mem_model_weights = learned_mem_model_weights
)
Expand Down
12 changes: 12 additions & 0 deletions titans_pytorch/titans.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def __init__(
momentum = True,
pre_rmsnorm = True,
post_rmsnorm = True,
qk_rmsnorm = False,
learned_mem_model_weights = True,
max_grad_norm: float | None = None,
use_accelerated_scan = False,
Expand All @@ -389,6 +390,9 @@ def __init__(

self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()

self.q_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
self.k_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()

# maybe multi-headed

dim_inner = dim_head * heads
Expand Down Expand Up @@ -577,6 +581,10 @@ def store_memories(

batch = keys.shape[0]

# maybe qk rmsnorm

keys = self.k_norm(keys)

# take care of chunking

keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
Expand Down Expand Up @@ -683,6 +691,10 @@ def retrieve_memories(

queries = self.split_heads(queries)

# maybe qk rmsnorm

queries = self.q_norm(queries)

# fetch values from memory model

curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
Expand Down

0 comments on commit 1bb76d4

Please sign in to comment.