Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar multiplication #355

Merged
62 changes: 62 additions & 0 deletions tests/unit/factored_matrix/test_multiply_by_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
import torch
from torch.testing import assert_close

from transformer_lens import FactoredMatrix


# This test function is parametrized with different types of scalars, including non-scalar tensors and arrays, to check that the correct errors are raised.
# Considers cases with and without leading dimensions as well as left and right multiplication.
@pytest.mark.parametrize(
"scalar, error_expected",
[
# Test cases with different types of scalar values.
(torch.rand(1), None), # 1-element Tensor. No error expected.
(float(torch.rand(1).item()), None), # float. No error expected.
matthiasdellago marked this conversation as resolved.
Show resolved Hide resolved
(int(torch.randint(1, 10, (1,)).item()), None), # int. No error expected.
matthiasdellago marked this conversation as resolved.
Show resolved Hide resolved

# Test cases with non-scalar values that are expected to raise errors.
(torch.rand(2, 2), AssertionError), # Non-scalar Tensor. AssertionError expected.
(torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected.

]
)
@pytest.mark.parametrize(
"leading_dim",
[
False,
True
]
)
@pytest.mark.parametrize(
"multiply_from_left",
[
False,
True
]
)
def test_multiply(scalar, leading_dim, multiply_from_left, error_expected):
# Prepare a FactoredMatrix, with or without leading dimensions
if leading_dim:
a = torch.rand(6, 2, 3)
b = torch.rand(6, 3, 4)
else:
a = torch.rand(2, 3)
b = torch.rand(3, 4)

fm = FactoredMatrix(a, b)

if error_expected:
# If an error is expected, check that the correct exception is raised.
with pytest.raises(error_expected):
if multiply_from_left:
_ = fm * scalar
else:
_ = scalar * fm
else:
# If no error is expected, check that the multiplication results in the correct value.
# Use FactoredMatrix.AB to calculate the product of the two factor matrices before comparing with the expected value.
if multiply_from_left:
assert_close((fm * scalar).AB, (a @ b) * scalar)
matthiasdellago marked this conversation as resolved.
Show resolved Hide resolved
else:
assert_close((scalar * fm).AB, (a @ b) * scalar)
20 changes: 20 additions & 0 deletions transformer_lens/FactoredMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,26 @@ def __rmatmul__(
return FactoredMatrix(other, self.AB)
elif isinstance(other, FactoredMatrix):
return other.A @ (other.B @ self)

def __mul__(
self,
scalar: Union[int, float, torch.Tensor]
) -> FactoredMatrix:
"""
Left scalar multiplication. Scalar multiplication distributes over matrix multiplication, so we can just multiply one of the factor matrices by the scalar.
"""
if isinstance(scalar, torch.Tensor):
assert scalar.numel() == 1, f"Tensor must be a scalar for use with * but was of shape {scalar.shape}. For matrix multiplication, use @ instead."
return FactoredMatrix(self.A * scalar, self.B)

def __rmul__(
self,
scalar: Union[int, float, torch.Tensor]
) -> FactoredMatrix:
"""
Right scalar multiplication. For scalar multiplication from the right, we can reuse the __mul__ method.
"""
return self * scalar

@property
@typeguard_ignore
Expand Down