mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #16645 from treyra:main
PiperOrigin-RevId: 546987247
This commit is contained in:
commit
f4eed78e90
@ -1275,6 +1275,11 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
|
||||
msg = f"{name} must have at least one non-None value in in_axes"
|
||||
raise ValueError(msg)
|
||||
|
||||
def _get_argument_type(x):
|
||||
try:
|
||||
return shaped_abstractify(x).str_short()
|
||||
except TypeError: #Catch all for user specified objects that can't be interpreted as a data type
|
||||
return "unknown"
|
||||
msg = [f"{name} got inconsistent sizes for array axes to be mapped:\n"]
|
||||
args, kwargs = tree_unflatten(tree, vals)
|
||||
try:
|
||||
@ -1283,15 +1288,15 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
|
||||
ba = None
|
||||
if ba is None:
|
||||
args_paths = [f'args{keystr(p)} '
|
||||
f'of type {shaped_abstractify(x).str_short()}'
|
||||
f'of type {_get_argument_type(x)}'
|
||||
for p, x in generate_key_paths(args)]
|
||||
kwargs_paths = [f'kwargs{keystr(p)} '
|
||||
f'of type {shaped_abstractify(x).str_short()}'
|
||||
f'of type {_get_argument_type(x)}'
|
||||
for p, x in generate_key_paths(kwargs)]
|
||||
key_paths = [*args_paths, *kwargs_paths]
|
||||
else:
|
||||
key_paths = [f'argument {name}{keystr(p)} '
|
||||
f'of type {shaped_abstractify(x).str_short()}'
|
||||
f'of type {_get_argument_type(x)}'
|
||||
for name, arg in ba.arguments.items()
|
||||
for p, x in generate_key_paths(arg)]
|
||||
all_sizes = [_get_axis_size(name, np.shape(x), d) if d is not None else None
|
||||
|
@ -1699,6 +1699,15 @@ class APITest(jtu.JaxTestCase):
|
||||
):
|
||||
jax.device_put((x, y, z), device=(s1, s2))
|
||||
|
||||
def test_vmap_inconsistent_sizes_constructs_proper_error_message(self):
|
||||
def f(x1, x2, g):
|
||||
return g(x1, x2)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped:"
|
||||
):
|
||||
jax.vmap(f, (0, 0, None))(jnp.ones(2), jnp.ones(3), jnp.add)
|
||||
|
||||
def test_device_get_scalar(self):
|
||||
x = np.arange(12.).reshape((3, 4)).astype("float32")
|
||||
|
Loading…
x
Reference in New Issue
Block a user