-
-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Complex type support #326
Complex type support #326
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there's a couple of unnecessary changes. Other than that, LGTM!
return jsp.linalg.lu_factor(jax.jacfwd(curried)(flat)) | ||
return jsp.linalg.lu_factor( | ||
jax.jacfwd(curried, holomorphic=jnp.iscomplexobj(flat))(flat) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the guard against implicit solvers earlier, is this change necessary?
(I think we'd want a lot of tests to be sure that the nonlinear solvers handle complex numbers correctly, what with all the linear solves and such going on.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this change, implicit solvers are working for me locally. If anything, I'd remove the guard?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So as it happens, I'm currently in the process of upgrading Diffrax to use Lineax and Optimistix to handle the linear and nonlinear solving. This is going to be great from a feature point of view, but it does also mean that any changes we make here will be short-lived.
Other than that, I think we'd need some fairly comprehensive tests to check that things are working. There's subtleties here that can easily go wrong, e.g. transpose vs conjugate-transpose. (JAX does the former, PyTorch does the latter.)
To justify the above, btw -- the position I've been taking with the JAX sciML ecosystem is that it's better not to do something than it is to do it unreliably. Hence why complex support has been such a slow-moving thing to have happen!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you regarding taking slow part. But this code otherwise will just crash, and it passes basic tests. If you insist I'll remove it, or reinstate the guard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, I'd reinstate the guard for now. (As I say, the current implementation will be deleted in ~2 weeks anyway.)
I think getting complex support into Lineax will be the main thing needed to get proper support here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I've put it back
Closing in favour of #330, which fixes the pre-commit failure present here. Thank you to all of @TimonHoess @Randl @andyElking for your efforts getting this in! |
Rebase of #197