raise error if vmap/pmap in_axes are booleans

fixes #6372
This commit is contained in:
Matthew Johnson 2021-04-09 14:43:13 -07:00
parent f1a6397948
commit 60828e9b19
2 changed files with 19 additions and 4 deletions

View File

@ -1259,14 +1259,12 @@ def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None) -> F:
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axes = tuple(in_axes)
in_axes_, out_axes_ = tree_leaves(in_axes), tree_leaves(out_axes)
if not all(isinstance(l, (type(None), int)) for l in in_axes_):
if not all(type(l) is int for l in tree_leaves(in_axes)):
raise TypeError("vmap in_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {in_axes}.")
if not all(isinstance(l, (type(None), int)) for l in out_axes_):
if not all(type(l) is int for l in tree_leaves(out_axes)):
raise TypeError("vmap out_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {out_axes}.")
del in_axes_, out_axes_
@wraps(fun, docstr=docstr)
@api_boundary
@ -1560,6 +1558,13 @@ def pmap(
donate_tuple = rebase_donate_argnums(_ensure_index_tuple(donate_argnums),
static_broadcasted_tuple)
if not all(type(l) is int for l in tree_leaves(in_axes)):
raise TypeError("pmap in_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {in_axes}.")
if not all(type(l) is int for l in tree_leaves(out_axes)):
raise TypeError("pmap out_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {out_axes}.")
@wraps(fun)
@api_boundary
def f_pmapped(*args, **kwargs):

View File

@ -1939,6 +1939,16 @@ class APITest(jtu.JaxTestCase):
foo, in_axes=((0, collections.OrderedDict([('a', 1), ('b', 2)])),))
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
def test_vmap_in_axes_bool_error(self):
# https://github.com/google/jax/issues/6372
with self.assertRaisesRegex(TypeError, "must be an int"):
api.vmap(lambda x: x, in_axes=False)(jnp.zeros(3))
def test_pmap_in_axes_bool_error(self):
# https://github.com/google/jax/issues/6372
with self.assertRaisesRegex(TypeError, "must be an int"):
api.pmap(lambda x: x, in_axes=False)(jnp.zeros(1))
def test_pmap_global_cache(self):
def f(x, y):
return x, y