Merge pull request #16645 from treyra:main

PiperOrigin-RevId: 546987247
This commit is contained in:
jax authors 2023-07-10 14:41:31 -07:00
commit f4eed78e90
2 changed files with 17 additions and 3 deletions

View File

@ -1275,6 +1275,11 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
msg = f"{name} must have at least one non-None value in in_axes"
raise ValueError(msg)
def _get_argument_type(x):
try:
return shaped_abstractify(x).str_short()
except TypeError: #Catch all for user specified objects that can't be interpreted as a data type
return "unknown"
msg = [f"{name} got inconsistent sizes for array axes to be mapped:\n"]
args, kwargs = tree_unflatten(tree, vals)
try:
@ -1283,15 +1288,15 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
ba = None
if ba is None:
args_paths = [f'args{keystr(p)} '
f'of type {shaped_abstractify(x).str_short()}'
f'of type {_get_argument_type(x)}'
for p, x in generate_key_paths(args)]
kwargs_paths = [f'kwargs{keystr(p)} '
f'of type {shaped_abstractify(x).str_short()}'
f'of type {_get_argument_type(x)}'
for p, x in generate_key_paths(kwargs)]
key_paths = [*args_paths, *kwargs_paths]
else:
key_paths = [f'argument {name}{keystr(p)} '
f'of type {shaped_abstractify(x).str_short()}'
f'of type {_get_argument_type(x)}'
for name, arg in ba.arguments.items()
for p, x in generate_key_paths(arg)]
all_sizes = [_get_axis_size(name, np.shape(x), d) if d is not None else None

View File

@ -1699,6 +1699,15 @@ class APITest(jtu.JaxTestCase):
):
jax.device_put((x, y, z), device=(s1, s2))
def test_vmap_inconsistent_sizes_constructs_proper_error_message(self):
def f(x1, x2, g):
return g(x1, x2)
with self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:"
):
jax.vmap(f, (0, 0, None))(jnp.ones(2), jnp.ones(3), jnp.add)
def test_device_get_scalar(self):
x = np.arange(12.).reshape((3, 4)).astype("float32")