Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex type support #326

Closed
wants to merge 7 commits into from
Closed

Complex type support #326

wants to merge 7 commits into from

Conversation

Randl
Copy link
Contributor

@Randl Randl commented Oct 29, 2023

Rebase of #197

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a couple of unnecessary changes. Other than that, LGTM!

diffrax/integrate.py Outdated Show resolved Hide resolved
return jsp.linalg.lu_factor(jax.jacfwd(curried)(flat))
return jsp.linalg.lu_factor(
jax.jacfwd(curried, holomorphic=jnp.iscomplexobj(flat))(flat)
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the guard against implicit solvers earlier, is this change necessary?

(I think we'd want a lot of tests to be sure that the nonlinear solvers handle complex numbers correctly, what with all the linear solves and such going on.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, implicit solvers are working for me locally. If anything, I'd remove the guard?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So as it happens, I'm currently in the process of upgrading Diffrax to use Lineax and Optimistix to handle the linear and nonlinear solving. This is going to be great from a feature point of view, but it does also mean that any changes we make here will be short-lived.

Other than that, I think we'd need some fairly comprehensive tests to check that things are working. There's subtleties here that can easily go wrong, e.g. transpose vs conjugate-transpose. (JAX does the former, PyTorch does the latter.)

To justify the above, btw -- the position I've been taking with the JAX sciML ecosystem is that it's better not to do something than it is to do it unreliably. Hence why complex support has been such a slow-moving thing to have happen!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you regarding taking slow part. But this code otherwise will just crash, and it passes basic tests. If you insist I'll remove it, or reinstate the guard

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I'd reinstate the guard for now. (As I say, the current implementation will be deleted in ~2 weeks anyway.)
I think getting complex support into Lineax will be the main thing needed to get proper support here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I've put it back

@patrick-kidger
Copy link
Owner

Closing in favour of #330, which fixes the pre-commit failure present here.

Thank you to all of @TimonHoess @Randl @andyElking for your efforts getting this in!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants