Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable more complex tests, fix related errors #392

Merged
merged 4 commits into from
May 4, 2024

Conversation

Randl
Copy link
Contributor

@Randl Randl commented Mar 23, 2024

Similarly to patrick-kidger/lineax#89 many of fixes are just type casts due to strict typing. Should consider which of these need to be casts and which aren't

Also I included changes from #391 for simplicity

diffrax/_brownian/tree.py Outdated Show resolved Hide resolved
diffrax/_brownian/tree.py Outdated Show resolved Hide resolved
diffrax/_global_interpolation.py Outdated Show resolved Hide resolved

if jnp.iscomplexobj(A) and isinstance(solver, diffrax.AbstractImplicitSolver):
return
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this exclusion? I think ideally we do expect to be able to do this (once all the relevant upstream things have landed).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently,
https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_integrate.py#L638-L642
I agree that we will want to remove it in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It used to work (somehow) when we merged the basic support, but now I think strict typing broke it in multiple places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some related fixes (also see patrick-kidger/optimistix#53), but I think enabling it should come in a separate PR

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, sounds reasonable!

test/test_integrate.py Outdated Show resolved Hide resolved
diffrax/_global_interpolation.py Outdated Show resolved Hide resolved
diffrax/_global_interpolation.py Outdated Show resolved Hide resolved
diffrax/_root_finder/_verychord.py Outdated Show resolved Hide resolved
@Randl Randl changed the base branch from main to dev April 22, 2024 23:20
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, this LGTM! Let me know whether the points I've flagged up are intentional or not. If they are then I'm happy to merge this as-is.



@pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders())
def test_sde_strong_order(solver_ctr, noise, theoretical_order):
@pytest.mark.parametrize("dtype", [jnp.float64])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentionally leaving off complex support?

@pytest.mark.parametrize(
"dtype",
(jnp.float64,),
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise here.

@@ -274,7 +274,8 @@ def order(self, terms):


# Essentially used as a check that our general IMEX implementation is correct.
def test_sil3():
@pytest.mark.parametrize("dtype", (jnp.float64,))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here!

@Randl
Copy link
Contributor Author

Randl commented May 4, 2024

Yep all three are intentional since there are still failing tests and i prefer to fix them separately, keeping this pr only to type errors and test modifications. I'll open a separate PR for these

@patrick-kidger patrick-kidger merged commit c4deca4 into patrick-kidger:dev May 4, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

Okay then! Great stuff, I've just merged this.

patrick-kidger pushed a commit that referenced this pull request May 19, 2024
* Fix complex tests

* Fix more complex tests

* New sde related fixes

* New sde related fixes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants