Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LAPACK overloads for all variants of pt.linalg.solve #1146

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Dec 31, 2024

Description

Goal of this PR is to give numba mode full coverage of scipy.linalg.solve options. We currently only support assume_a = "gen". If users select a different solver, they get incorrect results (see #422 ). This PR should fix that bug, plus add:

  • support for overwrite_a in numba mode
  • support for overwrite_b in numba mode
  • support for transposed argument (all modes)
  • lu_factor and lu_solve Ops (all modes)
  • support for assume_a = "sym" and assume_a = "pos" in numba mode
  • support for cho_solve in numba mode

We get the lu_factor and lu_solve Ops kind of "for free" because I'm adding overloads for dgetrs and dgetrf. We just have to write the Ops and do the JAX dispatch. JVP for lu_factor is here. Help wanted. I might decide that these Ops are out of scope for this PR and open another one.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1146.org.readthedocs.build/en/1146/

@jessegrabowski jessegrabowski added bug Something isn't working enhancement New feature or request numba SciPy compatibility linalg Linear algebra labels Dec 31, 2024
@jessegrabowski jessegrabowski marked this pull request as ready for review December 31, 2024 15:24
@jessegrabowski jessegrabowski requested review from aseyboldt and ricardoV94 and removed request for aseyboldt December 31, 2024 15:24
Copy link

codecov bot commented Dec 31, 2024

Codecov Report

Attention: Patch coverage is 52.49042% with 248 lines in your changes missing coverage. Please review.

Project coverage is 81.87%. Comparing base (4e85676) to head (fe97e5d).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/slinalg.py 42.62% 205 Missing and 5 partials ⚠️
pytensor/link/numba/dispatch/_LAPACK.py 75.32% 32 Missing and 6 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1146      +/-   ##
==========================================
- Coverage   82.14%   81.87%   -0.28%     
==========================================
  Files         186      187       +1     
  Lines       48210    48617     +407     
  Branches     8678     8705      +27     
==========================================
+ Hits        39603    39804     +201     
- Misses       6440     6640     +200     
- Partials     2167     2173       +6     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/basic.py 80.37% <ø> (-1.86%) ⬇️
pytensor/tensor/slinalg.py 93.49% <100.00%> (+0.01%) ⬆️
pytensor/link/numba/dispatch/_LAPACK.py 75.32% <75.32%> (ø)
pytensor/link/numba/dispatch/slinalg.py 44.42% <42.62%> (-8.38%) ⬇️

@jessegrabowski
Copy link
Member Author

This is pretty close. I just need some help with the destroy_map stuff on Solve. I guess this wasn't being before? The code is a bit hard to follow with all the subclassing.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean we were not doing inplace? It was already implemented for the default backend for Solve and Cholesky: ed6ca16#diff-f5e942a523e3c0402aa63824184c381f5a963867125a21f119e2dab97a110d95

Just not in Numba yet. You shouldn't change anything in how the Ops are created, it's handled by the inplace rewrites. You may need to explicitly trigger them in the numba tests (the default numba mode in compare_py_and_numba does not include them), but I am not sure that utility actually works with inplace stuff

pytensor/tensor/slinalg.py Outdated Show resolved Hide resolved
pytensor/tensor/slinalg.py Outdated Show resolved Hide resolved
tests/link/numba/test_slinalg.py Outdated Show resolved Hide resolved
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")

super().__init__(**kwargs)
self.assume_a = assume_a
self.transposed = transposed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this require changes in the grads?

Copy link
Member Author

@jessegrabowski jessegrabowski Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good catch. There's some logic in the SolveBase L_Op that I thought was handling this, but it's actually handling the lower flag for SolveTriangular and ChoSolve I think. It should be similar/the same, though?

At any rate I'll add some gradient tests to make sure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request linalg Linear algebra numba SciPy compatibility
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: pt.linalg.solve returns incorrect results when mode = "NUMBA"
2 participants