Skip to content

Commit

Permalink
Add rewrite to merge multiple SVD Ops with different settings (#769)
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo Vieira <[email protected]>
  • Loading branch information
HangenYuu and ricardoV94 authored Jun 28, 2024
1 parent a8d7638 commit 920b409
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 1 deletion.
63 changes: 62 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,25 @@

from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
node_rewriter,
)
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
det,
inv,
kron,
pinv,
svd,
)
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
Expand Down Expand Up @@ -377,3 +382,59 @@ def local_lift_through_linalg(
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def svd_uv_merge(fgraph, node):
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
"""
if not isinstance(node.op.core_op, SVD):
return

(x,) = node.inputs

if node.op.core_op.compute_uv:
# compute_uv=True returns [u, s, v].
# if at least u or v is used, no need to rewrite this node.
if (
len(fgraph.clients[node.outputs[0]]) > 0
or len(fgraph.clients[node.outputs[2]]) > 0
):
return

# Else, has to replace the s of this node with s of an SVD Op that compute_uv=False.
# First, iterate to see if there is an SVD Op that can be reused.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if not cl.op.core_op.compute_uv:
return {
node.outputs[1]: cl.outputs[0],
}

# If no SVD reusable, return a new one.
return {
node.outputs[1]: svd(
x, full_matrices=node.op.core_op.full_matrices, compute_uv=False
),
}

else:
# compute_uv=False returns [s].
# We want rewrite if there is another one with compute_uv=True.
# For this case, just reuse the `s` from the one with compute_uv=True.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if cl.op.core_op.compute_uv and (
len(fgraph.clients[cl.outputs[0]]) > 0
or len(fgraph.clients[cl.outputs[2]]) > 0
):
return [cl.outputs[1]]
66 changes: 66 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import (
SVD,
Det,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
matrix_inverse,
svd,
)
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import (
Expand Down Expand Up @@ -390,3 +392,67 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]

np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)


def test_svd_uv_merge():
a = matrix("a")
s_1 = svd(a, full_matrices=False, compute_uv=False)
_, s_2, _ = svd(a, full_matrices=False, compute_uv=True)
_, s_3, _ = svd(a, full_matrices=True, compute_uv=True)
u_4, s_4, v_4 = svd(a, full_matrices=True, compute_uv=True)
# `grad` will introduces an SVD Op with compute_uv=True
# full_matrices = True is not supported for grad of svd
gs = pt.grad(pt.sum(s_1), a)

# 1. compute_uv=False needs rewriting with compute_uv=True
f_1 = pytensor.function([a], gs)
nodes = f_1.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert node.op.compute_uv
svd_counter += 1
assert svd_counter == 1

# 2. compute_uv=True needs rewriting with compute=False, reuse node
f_2 = pytensor.function([a], [s_1, s_2])
nodes = f_2.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
svd_counter += 1
assert svd_counter == 1

# 3. compute_uv=True needs rewriting with compute=False, create new node
# full_matrices needs to retain the value
f_3 = pytensor.function([a], [s_2])
nodes = f_3.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
svd_counter += 1
assert svd_counter == 1

# Case 2 of 3. for a different full_matrices
f_4 = pytensor.function([a], [s_3])
nodes = f_4.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
assert node.op.full_matrices
svd_counter += 1
assert svd_counter == 1

# 4. No rewrite should happen
f_5 = pytensor.function([a], [u_4])
nodes = f_5.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert node.op.full_matrices
assert node.op.compute_uv
svd_counter += 1
assert svd_counter == 1

0 comments on commit 920b409

Please sign in to comment.