Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency in using == operator on a jax array and a list #24454

Open
gokul-uf opened this issue Oct 22, 2024 · 1 comment
Open

Inconsistency in using == operator on a jax array and a list #24454

gokul-uf opened this issue Oct 22, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@gokul-uf
Copy link

gokul-uf commented Oct 22, 2024

Description

The way jax implements == for dealing with lists and jax arrays is inconsistent with how NumPy handles it.
Small example:

jnp_arr = jnp.asarray(range(1000))
np_arr = np.asarray(range(1000))

print(jnp_arr == range(1000)) # False

print(np_arr == range(1000)) # array([True, True, ....])

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.10.15 (main, Sep 27 2024, 06:07:40) [GCC 12.2.0]
jax.devices (4 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0) TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='<REDACTED>', release='6.1.100+', version='#1 SMP PREEMPT_DYNAMIC Sat Aug 24 16:19:44 UTC 2024', machine='x86_64')

@gokul-uf gokul-uf added the bug Something isn't working label Oct 22, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Oct 22, 2024

Hi - thanks for the report!

This is working as intended, but I agree it's a bit of a strange corner case. The issue is that, across all its APIs, JAX does not implicitly convert list inputs to arrays, because when that was possible it led to silent and difficult-to-debug performance issues in practice. This is discussed at JAX sharp bits: non-array inputs.

So for example this works in NumPy:

>>> np.sum([1, 2, 3])
int64(6)

but this fails in JAX:

>>> jnp.sum([1, 2, 3])
...
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

Similarly, this works in NumPy:

>>> x = np.array([1, 1, 1])
>>> np.equal(x, [1, 2, 3])
array([ True, False, False])

but fails in JAX:

>>> jnp.equal(x, [1, 2, 3])
TypeError: equal requires ndarray or scalar arguments, got <class 'list'> at position 1.

So given this, we now have to make a decision about what == does when it encounters a JAX array and a list. One option would be to raise an error, as with jnp.equal, but this causes problems because there are situations where the Python interpreter expects __eq__ to not fail. So instead we opt to return NotImplemented from JAX, such that the equality check dispatches (in this case) via list.__eq__, and this returns False. Though perhaps raising an error would be more useful to users – I'm not sure.

Anyway, I hope that makes it clear why it's intended that JAX's behavior diverges from NumPy's behavior in this case. If you want to do array operations in JAX, you need to use arrays, not sequences like lists or tuples.

@jakevdp jakevdp self-assigned this Oct 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants