Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite to merge multiple SVD Ops with different settings #769

Merged
merged 18 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,5 @@ pytensor-venv/
testing-report.html
coverage.xml
.coverage.*
pics
*.ipynb
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved
29 changes: 28 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@

from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
node_rewriter,
)
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
Expand Down Expand Up @@ -377,3 +381,26 @@
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([SVD])
def local_svd_uv_simplify(fgraph, node):
HangenYuu marked this conversation as resolved.
Show resolved Hide resolved
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
"""
(x,) = node.inputs
compute_uv = False

Check warning on line 396 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L395-L396

Added lines #L395 - L396 were not covered by tests

for cl, _ in fgraph.clients[x]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to be careful because if the output of the SVD is an output of the function one of the clients will be a string "output" and the call cl.op will fail.

if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if (not compute_uv) and cl.op.core_op.compute_uv:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need that first check?

Suggested change
if (not compute_uv) and cl.op.core_op.compute_uv:
if cl.op.core_op.compute_uv:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check if the uv outputs of this node are actually used (i.e., they have clients of their own). If not, they are useless and the rewrite shouldn't happen. In fact, this or another rewrite should change the flag from True to False for those nodes

compute_uv = True
break

Check warning on line 402 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L401-L402

Added lines #L401 - L402 were not covered by tests

if compute_uv and not node.op.compute_uv:
full_matrices = node.op.full_matrices
return [SVD(full_matrices=full_matrices, compute_uv=compute_uv)]

Check warning on line 406 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L405-L406

Added lines #L405 - L406 were not covered by tests
Loading