From 2d26af6a169bd820fb0377ede125a32d1c2b58a2 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 30 Sep 2024 10:25:35 -0700 Subject: [PATCH] offer l2 distance attention for starters and cite --- README.md | 10 +++++++++ setup.py | 2 +- tests/test_x_transformers.py | 19 +++++++++++++++++ x_transformers/attend.py | 36 +++++++++++++++++++++++++++++--- x_transformers/x_transformers.py | 2 ++ 5 files changed, 65 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 3cc6baa..aee0f57 100644 --- a/README.md +++ b/README.md @@ -2239,6 +2239,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17) booktitle = {Neural Information Processing Systems}, year = {2018}, url = {https://api.semanticscholar.org/CorpusID:44064935} +``` + +```bibtex +@article{Kim2020TheLC, + title = {The Lipschitz Constant of Self-Attention}, + author = {Hyunjik Kim and George Papamakarios and Andriy Mnih}, + journal = {ArXiv}, + year = {2020}, + volume = {abs/2006.04710}, + url = {https://api.semanticscholar.org/CorpusID:219530837} } ``` diff --git a/setup.py b/setup.py index 1ad4f3c..4abe9a9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.37.4', + version = '1.37.6', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/tests/test_x_transformers.py b/tests/test_x_transformers.py index 0138579..bc54168 100644 --- a/tests/test_x_transformers.py +++ b/tests/test_x_transformers.py @@ -301,3 +301,22 @@ def test_sigsoftmax(): model.eval() eval_logits = model(x) + +@pytest.mark.parametrize('attn_one_kv_head', (True, False)) +def test_l2_distance(attn_one_kv_head): + + model = TransformerWrapper( + num_tokens = 20000, + max_seq_len = 1024, + attn_layers = Decoder( + dim = 512, + depth = 12, + heads = 8, + attn_l2_distance = True, + attn_one_kv_head = attn_one_kv_head, + ) + ) + + x = torch.randint(0, 256, (1, 1024)) + + model(x) diff --git a/x_transformers/attend.py b/x_transformers/attend.py index 5e931ca..1557fa4 100644 --- a/x_transformers/attend.py +++ b/x_transformers/attend.py @@ -13,7 +13,7 @@ from packaging import version from dataclasses import dataclass -from einops import rearrange, repeat +from einops import rearrange, repeat, pack, unpack # constants @@ -39,9 +39,16 @@ def default(val, d): def compact(arr): return [*filter(exists, arr)] -def softclamp(t, value): +@torch.jit.script +def softclamp(t: Tensor, value: float): return (t / value).tanh() * value +def pack_one(t, pattern): + return pack([t], pattern) + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + def once(fn): called = False @wraps(fn) @@ -55,6 +62,18 @@ def inner(x): print_once = once(print) +# alternative distance functions + +def qk_l2_dist_squared(q, k): + if k.ndim == 3: + k = repeat(k, 'b j d -> b h j d', h = q.shape[1]) + + q, packed_shape = pack_one(q, '* i d') + k, _ = pack_one(k, '* j d') + + l2_dist_squared = torch.cdist(q, k) ** 2 + return unpack_one(l2_dist_squared, packed_shape, '* i j') + # functions for creating causal mask # need a special one for onnx cpu (no support for .triu) @@ -80,6 +99,7 @@ def __init__( sparse_topk = None, scale = None, qk_norm = False, + l2_distance = False, flash = False, softclamp_logits = False, logit_softclamp_value = 50., @@ -123,6 +143,11 @@ def __init__( assert not (flash and sigsoftmax), 'sigsoftmax not available for flash attention' self.sigsoftmax = sigsoftmax + # l2 distance attention + + assert not (flash and l2_distance), 'l2 distance attention does not work with flash attention just yet' + self.l2_distance = l2_distance + # add a key / value token composed of zeros # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html @@ -325,7 +350,12 @@ def forward( kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' - sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale + if not self.l2_distance: + sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) + else: + sim = -qk_l2_dist_squared(q, k) + + sim = sim * scale if exists(prev_attn): sim = sim + prev_attn diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index f181b8e..374ec76 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -923,6 +923,7 @@ def __init__( qk_norm_groups = 1, qk_norm_scale = 10, qk_norm_dim_scale = False, + l2_distance = False, one_kv_head = False, kv_heads = None, shared_kv = False, @@ -1037,6 +1038,7 @@ def __init__( sparse_topk = sparse_topk, qk_norm = qk_norm, scale = qk_norm_scale if qk_norm else self.scale, + l2_distance = l2_distance, add_zero_kv = add_zero_kv, flash = flash, softclamp_logits = softclamp_logits,