-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.jit changes the key order of returned dictionaries #24398
Comments
I think this is working as expected. Mainly because of 2 things:
|
Also forgot to add that the current key-sorting approach causes errors for incomparable keys: $ python3 -c "import jax; print(jax.tree.map(lambda x: x, {'b': None, 1: None}))"
TypeError: '<' not supported between instances of 'int' and 'str'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 342, in tree_map
leaves, treedef = tree_flatten(tree, is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/carlos/venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 79, in tree_flatten
return default_registry.flatten(tree, is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Comparator raised exception while sorting pytree dictionary keys. |
xref #15358 for the unorderable keys issue. |
@yashk2810 Is there any way those issues could be resolved internally within JAX's machinery while respecting the key order at the user level? |
FYI:
|
Description
jax.jit changes the key order of returned dictionaries:
Dictionaries are guaranteed to be ordered since Python 3.7.
Potentially related:
tree_util
#11871Parenthetically, this is also true for jax.tree.map, as pointed out in the first issue above.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: