mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8403 from shoyer:empty-map-args-error
PiperOrigin-RevId: 408501394
This commit is contained in:
commit
e5e5bb3ac1
@ -1466,6 +1466,13 @@ def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None, axis_size=None) -> F:
|
||||
return batched_fun
|
||||
|
||||
def _mapped_axis_size(tree, vals, dims, name, *, kws=False):
|
||||
if not vals:
|
||||
args, kwargs = tree_unflatten(tree, vals)
|
||||
raise ValueError(
|
||||
f"{name} wrapped function must be passed at least one argument "
|
||||
f"containing an array, got empty *args={args} and **kwargs={kwargs}"
|
||||
)
|
||||
|
||||
def _get_axis_size(name: str, shape: Tuple[int, ...], axis: int):
|
||||
try:
|
||||
return shape[axis]
|
||||
|
@ -2244,6 +2244,20 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, "must be an int"):
|
||||
api.pmap(lambda x: x, in_axes=False)(jnp.zeros(1))
|
||||
|
||||
def test_vmap_empty_arguments(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap wrapped function must be passed at least one argument "
|
||||
r"containing an array, got empty \*args=\(\{\},\) and \*\*kwargs=\{\}"):
|
||||
api.vmap(lambda x: x)({})
|
||||
|
||||
def test_pmap_empty_arguments(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"pmap wrapped function must be passed at least one argument "
|
||||
r"containing an array, got empty \*args=\(\{\},\) and \*\*kwargs=\{\}"):
|
||||
api.pmap(lambda x: x)({})
|
||||
|
||||
def test_pmap_global_cache(self):
|
||||
def f(x, y):
|
||||
return x, y
|
||||
|
Loading…
x
Reference in New Issue
Block a user