Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.jit changes the key order of returned dictionaries #24398

Open
carlosgmartin opened this issue Oct 18, 2024 · 5 comments
Open

jax.jit changes the key order of returned dictionaries #24398

carlosgmartin opened this issue Oct 18, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@carlosgmartin
Copy link
Contributor

Description

jax.jit changes the key order of returned dictionaries:

$ python3 -c "import jax; print(jax.jit(lambda: {'b': None, 'a': None})())"
{'a': None, 'b': None}

Dictionaries are guaranteed to be ordered since Python 3.7.

Potentially related:

Parenthetically, this is also true for jax.tree.map, as pointed out in the first issue above.

python3 -c "import jax; print(jax.tree.map(lambda x: x, {'b': None, 'a': None}))"
{'a': None, 'b': None}

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  1.26.4
python: 3.12.7 (main, Oct  1 2024, 02:05:46) [Clang 15.0.0 (clang-1500.3.9.4)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='MacBook-Pro-2.local', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:46 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6031', machine='arm64')
@carlosgmartin carlosgmartin added the bug Something isn't working label Oct 18, 2024
@yashk2810
Copy link
Collaborator

yashk2810 commented Oct 18, 2024

I think this is working as expected.

Mainly because of 2 things:

  • we need to sort the dictionary order because if not, we will get cache misses. Example jit(f)({'a': 1, 'b': 2}) vs jit(f)({'b': 2, 'a': 1}). We should get a cache hit for both but we won't if we don't sort.

  • Second and probably most important, in multi-controller JAX, if by mistake, if the orders differ, you can get hangs which is very bad and much harder to debug. So if dictionaries are sorted, this problem just doesn't occur by construction.

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Oct 18, 2024

Also forgot to add that the current key-sorting approach causes errors for incomparable keys:

$ python3 -c "import jax; print(jax.tree.map(lambda x: x, {'b': None, 1: None}))"
TypeError: '<' not supported between instances of 'int' and 'str'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 342, in tree_map
    leaves, treedef = tree_flatten(tree, is_leaf)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 79, in tree_flatten
    return default_registry.flatten(tree, is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Comparator raised exception while sorting pytree dictionary keys.

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 19, 2024

xref #15358 for the unorderable keys issue.

@carlosgmartin
Copy link
Contributor Author

@yashk2810 Is there any way those issues could be resolved internally within JAX's machinery while respecting the key order at the user level?

@XuehaiPan
Copy link
Contributor

Is there any way those issues could be resolved internally within JAX's machinery while respecting the key order at the user level?

FYI:

One solution is to store the input dict keys in insertion order in Node during flatten, and update the PyTreeDef.unflatten method to respect the key order while reconstructing the output pytree.

leaves, treedef = jax.tree_util.tree_flatten({'b': 2, 'a': 1})
leaves   # [1, 2]
treedef  # PyTreeDef({'a': *, 'b': *})
treedef.unflatten([11, 22])  # {'b': 22, 'a': 11} # respect original key order

Ref:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants