Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains authored Jan 22, 2025
2 parents daf6807 + f4a263f commit 872e7ba
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/test_titans.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
def exists(v):
return v is not None

def diff(x, y):
return (x - y).abs().amax()

@pytest.mark.parametrize('seq_len', (32, 1024, 77))
@pytest.mark.parametrize('silu', (False, True))
@pytest.mark.parametrize('learned_mem_model_weights', (False, True))
Expand Down Expand Up @@ -164,7 +167,7 @@ def test_neural_mem_inference(

sequential_retrieved = torch.cat(sequential_retrieved, dim = -2)

assert torch.allclose(parallel_retrieved, sequential_retrieved, atol = 1e-6)
assert torch.allclose(parallel_retrieved, sequential_retrieved, atol = 1e-5)

@pytest.mark.parametrize('seq_len', (1023, 17))
@pytest.mark.parametrize('sliding', (True, False))
Expand Down

0 comments on commit 872e7ba

Please sign in to comment.