Skip to content

Commit

Permalink
Fix vmap error message when args passed by keyword
Browse files Browse the repository at this point in the history
See the new test for a case that used to produce the wrong message.

Fixes: #24406
  • Loading branch information
garymm committed Oct 22, 2024
1 parent ad53add commit a8fd651
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
32 changes: 18 additions & 14 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,21 +1044,25 @@ def _get_argument_type(x):
args, kwargs = tree_unflatten(tree, vals)
try:
ba = inspect.signature(fn).bind(*args, **kwargs)
signature_parameters: list[str] = list(ba.signature.parameters.keys())
except (TypeError, ValueError):
ba = None
if ba is None:
args_paths = [f'args{keystr(p)} '
f'of type {_get_argument_type(x)}'
for p, x in generate_key_paths(args)]
kwargs_paths = [f'kwargs{keystr(p)} '
f'of type {_get_argument_type(x)}'
for p, x in generate_key_paths(kwargs)]
key_paths = [*args_paths, *kwargs_paths]
else:
key_paths = [f'argument {name}{keystr(p)} '
f'of type {_get_argument_type(x)}'
for name, arg in ba.arguments.items()
for p, x in generate_key_paths(arg)]
signature_parameters = None

def arg_name(i, key_path):
if signature_parameters is None:
return f"args{keystr(key_path)}"
else:
return f"argument {signature_parameters[i]}"

args_paths = [
f"{arg_name(i, p)} of type {_get_argument_type(x)}"
for i, (p, x) in enumerate(generate_key_paths(args))
]
kwargs_paths = [
f"kwargs{keystr(p)} of type {_get_argument_type(x)}"
for p, x in generate_key_paths(kwargs)
]
key_paths = [*args_paths, *kwargs_paths]
all_sizes = [_get_axis_size(name, np.shape(x), d) if d is not None else None
for x, d in zip(vals, dims)]
size_counts = collections.Counter(s for s in all_sizes if s is not None)
Expand Down
14 changes: 14 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,6 +1919,20 @@ def f(x1, x2, g):
):
jax.vmap(f, (0, 0, None))(jnp.ones(2), jnp.ones(3), jnp.add)

def test_vmap_inconsistent_sizes_constructs_proper_error_message_kwargs(self):
def f(x1, x2, a3):
return jnp.add(x1, x2, a3)

with self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
r" \* most axes \(2 of them\) had size 2, e.g. axis 0 of argument x1 of type float32\[2\];\n"
r" \* one axis had size 1: axis 0 of kwargs\['a3'\] of type float32\[1\]",
):
# previously this test would fail when kwargs were passed in a different order
# that what's in the function signature: issue #24406
jax.vmap(f)(jnp.ones(2), a3=jnp.ones(1), x2=jnp.ones(2))

def test_device_get_scalar(self):
x = np.arange(12.).reshape((3, 4)).astype("float32")
x = api.device_put(x)
Expand Down

0 comments on commit a8fd651

Please sign in to comment.