-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rewrite to merge multiple SVD Ops with different settings #769
Conversation
pytensor/tensor/rewriting/linalg.py
Outdated
if svd_count > 1 and compute_uv: | ||
for cl in not_compute_uv_svd_list: | ||
cl.op.core_op.compute_uv = True | ||
return [cl.outputs[0] for cl in not_compute_uv_svd_list] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think changing properties of the op inplace might lead to problems...
This rewrite function should run for each SVD node, so maybe it is easier to just locate an existing compute_uv = True
node, and return that as replacement for each compuet_uv = False
node?
So something like:
- If
compute_uv
is False, return and do nothing - check if there is a
compute_uv = True
node in the graph with the same input. If not, return and do nothing - Return the exising output of that node as replacement for the current
compute_uv = False
node.
I wonder though if there could be bad interactions somewhere if there is a rewrite that replaces compute_uv = Fales
nodes if they are not used? We don't want to run into any infinite cycles...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 Do you know if there are any problems that could happen if a rewrite returns an existing variable instead of a new one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there will be a problem only when a rewrite tries to replace a variable by another that depends on the original variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And yes we shouldn't modify the properties in place. We should replace the smaller Op by the bigger one, just make sure the smaller one is not in the ancestors of the bigger one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise creating a new SVD should be simple, just call the user facing constructor with the specific flags
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I seemed to dump information carelessly. The gist was
- I updated the code logic to be a node rewriter.
- The rewrite is registered properly in optdb. However, I am having trouble coming up with a test case to show the effect of the rewrite. Perhaps @jessegrabowski can provide the original use case that led to you opening the issue Add rewrite to merge multiple
SVD
Op
s with different settings #732?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will arise in gradient graphs. For example, you can just do:
X = pt.dmatrix('X')
s = pt.linalg.svd(X, compute_uv=False)
g = pt.grad(s.sum(), X)
The graph for g
will re-compute the SVD of X
during the backward pass with compute_uv = True
, because we require the matrices U
and V
to compute the gradient of s
with respect to X
. Pytensor then won't be able to see that these two computations are the same, and will end up computing the SVD twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a_pt = matrix("a")
s = svd(a_pt, full_matrices=False, compute_uv=False)
gs = pt.grad(pt.sum(s), a_pt)
f = pytensor.function([a_pt], gs)
e = pytensor.graph.fg.FunctionGraph([a_pt], [gs], clone=False)
Thank you. I indeed received a graph for gs
and e
with 2 different SVD:
But for f
, I receive a graph with just a single SVD (that seems to be rewritten already with compute_uv=True
):
The f
's rewritten graph will be used in calculation if I run f([[1, 2], [3, 4]])
. Does this satisfy your end goal already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is f
summary profile:
Function profiling
==================
Message: /tmp/ipykernel_1282122/871230895.py:10
Time in 1 calls to Function.__call__: 3.448710e-02s
Time in Function.vm.__call__: 0.03426380921155214s (99.353%)
Time in thunks: 0.03424406051635742s (99.295%)
Total compilation time: 4.109558e-02s
Number of Apply nodes: 2
PyTensor rewrite time: 2.893809e-02s
PyTensor validate time: 2.457825e-04s
PyTensor Linker time (includes C, CUDA code generation/compiling): 0.00876139895990491s
C-cache preloading 5.506449e-03s
Import time 8.061258e-04s
Node make_thunk time 1.967770e-03s
Node Dot22(SVD{full_matrices=False, compute_uv=True}.0, SVD{full_matrices=False, compute_uv=True}.2) time 1.942240e-03s
Node SVD{full_matrices=False, compute_uv=True}(a) time 1.436425e-05s
Time in all call to pytensor.grad() 1.036228e-02s
Time since pytensor import 2.774s
Class
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
99.8% 99.8% 0.034s 3.42e-02s Py 1 1 pytensor.tensor.nlinalg.SVD
0.2% 100.0% 0.000s 6.60e-05s C 1 1 pytensor.tensor.blas.Dot22
... (remaining 0 Classes account for 0.00%(0.00s) of the runtime)
Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
99.8% 99.8% 0.034s 3.42e-02s Py 1 1 SVD{full_matrices=False, compute_uv=True}
0.2% 100.0% 0.000s 6.60e-05s C 1 1 Dot22
... (remaining 0 Ops account for 0.00%(0.00s) of the runtime)
Apply
------
<% time> <sum %> <apply time> <time per call> <#call> <id> <Apply name>
99.8% 99.8% 0.034s 3.42e-02s 1 0 SVD{full_matrices=False, compute_uv=True}(a)
0.2% 100.0% 0.000s 6.60e-05s 1 1 Dot22(SVD{full_matrices=False, compute_uv=True}.0, SVD{full_matrices=False, compute_uv=True}.2)
... (remaining 0 Apply instances account for 0.00%(0.00s) of the runtime)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytensor.dprint
may be an easier way to introspect the graphs
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #769 +/- ##
==========================================
+ Coverage 80.85% 80.98% +0.13%
==========================================
Files 162 169 +7
Lines 47016 46985 -31
Branches 11501 11494 -7
==========================================
+ Hits 38014 38052 +38
+ Misses 6751 6719 -32
+ Partials 2251 2214 -37
|
pytensor/tensor/rewriting/linalg.py
Outdated
(x,) = node.inputs | ||
compute_uv = False | ||
|
||
for cl, _ in fgraph.clients[x]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have to be careful because if the output of the SVD is an output of the function one of the clients will be a string "output"
and the call cl.op
will fail.
pytensor/tensor/rewriting/linalg.py
Outdated
|
||
for cl, _ in fgraph.clients[x]: | ||
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): | ||
if (not compute_uv) and cl.op.core_op.compute_uv: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need that first check?
if (not compute_uv) and cl.op.core_op.compute_uv: | |
if cl.op.core_op.compute_uv: |
pytensor/tensor/rewriting/linalg.py
Outdated
|
||
for cl, _ in fgraph.clients[x]: | ||
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): | ||
if (not compute_uv) and cl.op.core_op.compute_uv: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should check if the uv outputs of this node are actually used (i.e., they have clients of their own). If not, they are useless and the rewrite shouldn't happen. In fact, this or another rewrite should change the flag from True to False for those nodes
I would break this rewrite into different logical parts:
|
de46cff
to
0337e9d
Compare
pytensor/tensor/rewriting/linalg.py
Outdated
|
||
|
||
@register_canonicalize | ||
@register_stabilize | ||
@register_specialize | ||
@node_rewriter([SVD]) | ||
def local_svd_uv_simplify(fgraph, node): | ||
"""If we have more than one `SVD` `Op`s and at least one has keyword argument | ||
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere | ||
and allow `pytensor` to re-use the decomposition outputs instead of recomputing. | ||
""" | ||
(x,) = node.inputs | ||
|
||
if node.compute_uv: | ||
# compute_uv=True returns [u, s, v]. | ||
# if at least u or v is used, no need to rewrite this node. | ||
if ( | ||
fgraph.clients[node.outputs[0]] is not None | ||
or fgraph.clients[node.outputs[2]] is not None | ||
): | ||
return | ||
|
||
# Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. | ||
# First, iterate to see if there is an SVD Op that can be reused. | ||
for cl, _ in fgraph.clients[x]: | ||
if cl == "output": | ||
continue | ||
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): | ||
if not cl.op.core_op.compute_uv: | ||
return {fgraph.clients[node.outputs[1]]: cl.outputs[0]} | ||
|
||
# If no SVD reusable, return a new one. | ||
return [svd(x, full_matrices=node.full_matrices, compute_uv=False)] | ||
|
||
else: | ||
# compute_uv=False returns [s]. | ||
# We want rewrite if there is another one with compute_uv=True. | ||
# For this case, just reuse the `s` from the one with compute_uv=True. | ||
for cl, _ in fgraph.clients[x]: | ||
if cl == "output": | ||
continue | ||
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): | ||
if cl.op.core_op.compute_uv: | ||
return [cl.outputs[1]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ricardoV94. My understanding is like this: The SVD
with compute_uv == False
will return [s]
, while the one with compute_uv == True
will return [u, s, v]
. We want to rewrite when there are 2 SVD
Op
s using the same input in the graph with different compute_uv
value. Let's take the specific example of 2 SVD
Op
s, svd_f
which returns [s_f]
and svd_t
which returns [u_t, s_t, v_t]
. Based on whether at least u_t
or v_t
is used (since we still have to calculate both even if we use just one of them for subsequent calculations), 1 of 2 rewrites can happen:
- Case 1: If at least
u_t
orv_t
is used: return[s_t]
in place of[s_f]
. - Case 2: Else: return
[s_f]
in place of[s_t]
. - Case 3: Additionally, if there is just one
SVD
Op
withcompute_uv == True
, but bothu
andv
are not used, then it must be substituted with a newSVD
Op
withcompute_uv == False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup that's it!. When you write down the updated rewrite feel free to add comments with as much explanation as you did here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There could also be some weird cases where there are 3 SVDs, one with uv and full_matrices that actually doesn't use the uv, and one with uv and not full matrices that actually uses them (or vice-versa). In that case we could replace one for the other, but perhaps that's too much to worry and unlikely to happen. I don't see we ignoring this causing any bug. I am just raising attention to it so we don't accidentally rewrite a full-matrices into non full-matrices that are actually used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this one return {fgraph.clients[node.outputs[1]]: cl.outputs[0]}
is this the correct syntax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, that tells to replace the key by the value variable
pytensor/tensor/rewriting/linalg.py
Outdated
if cl == "output": | ||
continue | ||
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): | ||
if cl.op.core_op.compute_uv: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only want to do this if that other node is actually using the UV. If not we would actually want to replace that node by this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be taken care by the first half at that node turn. As this is a local rewrite applied to all SVD node, each node will have its turn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if you don't want to handle that other node there's no reason to rewrite this node into it. In general it's better to do as few rewrites as possible as every time a rewrite succeeds all other candidate rewrites are rerun (until an Equilibrium is achieved and nothing changes anymore).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought I like your eager approach better, it's not readable. Since SVDs are rare we don't need to over optimize
Co-authored-by: Ricardo Vieira <[email protected]>
I will be slower for the next 2 weeks. I am house looking right now, which should be over by then. I don't expect it to resemble a wedding preparation like this, but it is what it is. For the changes you suggested @ricardoV94 I will edit them in a slot of free time tomorrow. |
No worries and best of luck! |
Thanks @ricardoV94 for your patience. Quick updates: I added your suggestions. The tests are not passed right now. I am looking at it. It seems that the rewrite does not happen for the second case =================================== FAILURES ===================================
______________________________ test_svd_uv_merge _______________________________
def test_svd_uv_merge():
a = matrix("a")
s_1 = svd(a, full_matrices=False, compute_uv=False)
_, s_2, _ = svd(a, full_matrices=False, compute_uv=True)
_, s_3, _ = svd(a, full_matrices=True, compute_uv=True)
u_4, s_4, v_4 = svd(a, full_matrices=False, compute_uv=True)
# `grad` will introduces an SVD Op with compute_uv=True
# full_matrices = True is not supported for grad of svd
gs = pt.grad(pt.sum(s_1), a)
# 1. compute_uv=False needs rewriting with compute_uv=True
f_1 = pytensor.function([a], gs)
nodes = f_1.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
assert node.op.compute_uv
svd_counter += 1
assert svd_counter == 1
# 2. compute_uv=True needs rewriting with compute=False, reuse node
f_2 = pytensor.function([a], [s_1, s_2])
nodes = f_2.maker.fgraph.apply_nodes
svd_counter = 0
for node in nodes:
if isinstance(node.op, SVD):
> assert not node.op.compute_uv
E assert not True
E + where True = SVD(full_matrices=False,compute_uv=True).compute_uv
E + where SVD(full_matrices=False,compute_uv=True) = SVD{full_matrices=False, compute_uv=True}(a).op |
It does one node at a time, but keeps applying all rewrites in the database to all compatible nodes, until no further changes take place (that is, until an equilibrium is achieved) |
I tried commenting out my added code to the file, and reinstall |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good, just some small questions regarding the test cases
for node in nodes: | ||
if isinstance(node.op, SVD): | ||
assert not node.op.compute_uv | ||
assert node.op.full_matrices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, there's no point in worrying about whether we keep the same full_matrices or not, since they play no role when we don't compute_uv (right?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I check for full_matrices
parameter to make sure that the rewrite indeed reuse the Op
.
@HangenYuu I think you're registering the rewrites correctly, but you're tracking the wrong Op There's no way to track the pytensor/pytensor/tensor/nlinalg.py Lines 716 to 736 in 75a9fd2
Alternatively, we could create all the permutation versions in advance, return those from the helper function, and then track those directly, since there are only 3 of them*
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 Sorry for calling you if you are busy. The tests pass now. Can you review the changes to see if anything else need modifying, or the PR can be merged? |
@HangenYuu I tested without the explicit "remove" and it seems to work. Also I did a tiny refactor to reduce indentation. I pushed the commit now. I think it's ready to merge! |
Thanks @HangenYuu |
Description
When there are two or more
SVD
Op
s with the same inputs on a graph, differing only bycompute_uv
,compute_uv = False
should be changed toTrue
everywhere. This will allow pytensor to see that these outputs are equivalent and re-use them, rather than computing the decomposition multiple times.Related Issue
SVD
Op
s with different settings #732Checklist
Type of change