Merge pull request #8403 from shoyer:empty-map-args-error

PiperOrigin-RevId: 408501394
This commit is contained in:
jax authors 2021-11-08 19:39:04 -08:00
commit e5e5bb3ac1
2 changed files with 21 additions and 0 deletions

View File

@ -1466,6 +1466,13 @@ def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None, axis_size=None) -> F:
return batched_fun
def _mapped_axis_size(tree, vals, dims, name, *, kws=False):
if not vals:
args, kwargs = tree_unflatten(tree, vals)
raise ValueError(
f"{name} wrapped function must be passed at least one argument "
f"containing an array, got empty *args={args} and **kwargs={kwargs}"
)
def _get_axis_size(name: str, shape: Tuple[int, ...], axis: int):
try:
return shape[axis]

View File

@ -2244,6 +2244,20 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, "must be an int"):
api.pmap(lambda x: x, in_axes=False)(jnp.zeros(1))
def test_vmap_empty_arguments(self):
with self.assertRaisesRegex(
ValueError,
"vmap wrapped function must be passed at least one argument "
r"containing an array, got empty \*args=\(\{\},\) and \*\*kwargs=\{\}"):
api.vmap(lambda x: x)({})
def test_pmap_empty_arguments(self):
with self.assertRaisesRegex(
ValueError,
"pmap wrapped function must be passed at least one argument "
r"containing an array, got empty \*args=\(\{\},\) and \*\*kwargs=\{\}"):
api.pmap(lambda x: x)({})
def test_pmap_global_cache(self):
def f(x, y):
return x, y