-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LAPACK overloads for all variants of pt.linalg.solve
#1146
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1146 +/- ##
==========================================
- Coverage 82.14% 81.87% -0.28%
==========================================
Files 186 187 +1
Lines 48210 48617 +407
Branches 8678 8705 +27
==========================================
+ Hits 39603 39804 +201
- Misses 6440 6640 +200
- Partials 2167 2173 +6
|
fe97e5d
to
86dd9cb
Compare
This is pretty close. I just need some help with the destroy_map stuff on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean we were not doing inplace? It was already implemented for the default backend for Solve and Cholesky: ed6ca16#diff-f5e942a523e3c0402aa63824184c381f5a963867125a21f119e2dab97a110d95
Just not in Numba yet. You shouldn't change anything in how the Ops are created, it's handled by the inplace rewrites. You may need to explicitly trigger them in the numba tests (the default numba mode in compare_py_and_numba
does not include them), but I am not sure that utility actually works with inplace stuff
if assume_a not in ("gen", "sym", "her", "pos"): | ||
raise ValueError(f"{assume_a} is not a recognized matrix structure") | ||
|
||
super().__init__(**kwargs) | ||
self.assume_a = assume_a | ||
self.transposed = transposed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this require changes in the grads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good catch. There's some logic in the SolveBase L_Op
that I thought was handling this, but it's actually handling the lower
flag for SolveTriangular
and ChoSolve
I think. It should be similar/the same, though?
At any rate I'll add some gradient tests to make sure
Description
Goal of this PR is to give numba mode full coverage of
scipy.linalg.solve
options. We currently only supportassume_a = "gen"
. If users select a different solver, they get incorrect results (see #422 ). This PR should fix that bug, plus add:overwrite_a
in numba modeoverwrite_b
in numba modetransposed
argument (all modes)lu_factor
andlu_solve
Ops
(all modes)assume_a = "sym"
andassume_a = "pos"
in numba modecho_solve
in numba modeWe get the
lu_factor
andlu_solve
Ops kind of "for free" because I'm adding overloads fordgetrs
anddgetrf
. We just have to write the Ops and do the JAX dispatch. JVP forlu_factor
is here. Help wanted. I might decide that these Ops are out of scope for this PR and open another one.Related Issue
pt.linalg.solve
returns incorrect results whenmode = "NUMBA"
#422Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1146.org.readthedocs.build/en/1146/