-
Notifications
You must be signed in to change notification settings - Fork 304
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added tests for scalar multiplication for FactoredMatrix * Added __mul__ and __rmul__ to FactoredMatrix * Tests for errors when multiplying by non-scalar * Added scalar.shape to error message * Fixed imports to make isort happy * Black Formatting * Changed to random.random and randint * Implementation dependent test for factored matrix A.
- Loading branch information
1 parent
5b26456
commit 090081f
Showing
2 changed files
with
75 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import random | ||
|
||
import pytest | ||
import torch | ||
from torch.testing import assert_close | ||
|
||
from transformer_lens import FactoredMatrix | ||
|
||
|
||
# This test function is parametrized with different types of scalars, including non-scalar tensors and arrays, to check that the correct errors are raised. | ||
# Considers cases with and without leading dimensions as well as left and right multiplication. | ||
@pytest.mark.parametrize( | ||
"scalar, error_expected", | ||
[ | ||
# Test cases with different types of scalar values. | ||
(torch.rand(1), None), # 1-element Tensor. No error expected. | ||
(random.random(), None), # float. No error expected. | ||
(random.randint(-100, 100), None), # int. No error expected. | ||
# Test cases with non-scalar values that are expected to raise errors. | ||
( | ||
torch.rand(2, 2), | ||
AssertionError, | ||
), # Non-scalar Tensor. AssertionError expected. | ||
(torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected. | ||
], | ||
) | ||
@pytest.mark.parametrize("leading_dim", [False, True]) | ||
@pytest.mark.parametrize("multiply_from_left", [False, True]) | ||
def test_multiply(scalar, leading_dim, multiply_from_left, error_expected): | ||
# Prepare a FactoredMatrix, with or without leading dimensions | ||
if leading_dim: | ||
a = torch.rand(6, 2, 3) | ||
b = torch.rand(6, 3, 4) | ||
else: | ||
a = torch.rand(2, 3) | ||
b = torch.rand(3, 4) | ||
|
||
fm = FactoredMatrix(a, b) | ||
|
||
if error_expected: | ||
# If an error is expected, check that the correct exception is raised. | ||
with pytest.raises(error_expected): | ||
if multiply_from_left: | ||
_ = fm * scalar | ||
else: | ||
_ = scalar * fm | ||
else: | ||
# If no error is expected, check that the multiplication results in the correct value. | ||
# Use FactoredMatrix.AB to calculate the product of the two factor matrices before comparing with the expected value. | ||
if multiply_from_left: | ||
assert_close((fm * scalar).AB, (a @ b) * scalar) | ||
else: | ||
assert_close((scalar * fm).AB, scalar * (a @ b)) | ||
# This next test is implementation dependant and can be broken and removed at any time! | ||
# It checks that the multiplication is performed on the A factor matrix. | ||
if multiply_from_left: | ||
assert_close((fm * scalar).A, a * scalar) | ||
else: | ||
assert_close((scalar * fm).A, scalar * a) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters