Skip to content

Commit

Permalink
throw out yet another memory model, gated residual mlp variant
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 20, 2025
1 parent 6d6721a commit 2a26f1f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "titans-pytorch"
version = "0.1.5"
version = "0.1.6"
description = "Titans"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
3 changes: 2 additions & 1 deletion titans_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
NeuralMemory,
MemoryMLP,
MemoryAttention,
FactorizedMemoryMLP
FactorizedMemoryMLP,
GatedResidualMemoryMLP
)

from titans_pytorch.mac_transformer import (
Expand Down
43 changes: 42 additions & 1 deletion titans_pytorch/titans.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations
from typing import Callable

import math
from functools import partial

import torch
from torch import nn, Tensor
from torch import nn, cat, Tensor
import torch.nn.functional as F
from torch.nn import Linear, Module, Parameter, ParameterList
from torch.func import functional_call, vmap, grad
Expand Down Expand Up @@ -154,6 +155,46 @@ def forward(

return x

# memory mlp, but with gated residual + final projection

class GatedResidualMemoryMLP(Module):
def __init__(
self,
dim,
depth
):
super().__init__()
self.depth = depth

self.weights = ParameterList([
ParameterList([
Parameter(torch.randn(dim, dim)),
Parameter(torch.randn(dim * 2, dim)),
]) for _ in range(depth)
])

self.final_proj = Parameter(torch.randn(dim, dim))

for param in self.parameters():
nn.init.xavier_uniform_(param)

def forward(
self,
x
):
for weight, to_gates in self.weights:
res = x

x = x @ weight
x = F.silu(x)

# gated residual

gates = cat((x, res), dim = -1) @ to_gates
x = res.lerp(x, gates.sigmoid())

return x @ self.final_proj

# memory mlp with factorized weights
# so can tradeoff capacity for smaller chunk sizes

Expand Down

0 comments on commit 2a26f1f

Please sign in to comment.