mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
f1a6397948
commit
60828e9b19
13
jax/api.py
13
jax/api.py
@ -1259,14 +1259,12 @@ def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None) -> F:
|
||||
# rather than raising an error. https://github.com/google/jax/issues/2367
|
||||
in_axes = tuple(in_axes)
|
||||
|
||||
in_axes_, out_axes_ = tree_leaves(in_axes), tree_leaves(out_axes)
|
||||
if not all(isinstance(l, (type(None), int)) for l in in_axes_):
|
||||
if not all(type(l) is int for l in tree_leaves(in_axes)):
|
||||
raise TypeError("vmap in_axes must be an int, None, or (nested) container "
|
||||
f"with those types as leaves, but got {in_axes}.")
|
||||
if not all(isinstance(l, (type(None), int)) for l in out_axes_):
|
||||
if not all(type(l) is int for l in tree_leaves(out_axes)):
|
||||
raise TypeError("vmap out_axes must be an int, None, or (nested) container "
|
||||
f"with those types as leaves, but got {out_axes}.")
|
||||
del in_axes_, out_axes_
|
||||
|
||||
@wraps(fun, docstr=docstr)
|
||||
@api_boundary
|
||||
@ -1560,6 +1558,13 @@ def pmap(
|
||||
donate_tuple = rebase_donate_argnums(_ensure_index_tuple(donate_argnums),
|
||||
static_broadcasted_tuple)
|
||||
|
||||
if not all(type(l) is int for l in tree_leaves(in_axes)):
|
||||
raise TypeError("pmap in_axes must be an int, None, or (nested) container "
|
||||
f"with those types as leaves, but got {in_axes}.")
|
||||
if not all(type(l) is int for l in tree_leaves(out_axes)):
|
||||
raise TypeError("pmap out_axes must be an int, None, or (nested) container "
|
||||
f"with those types as leaves, but got {out_axes}.")
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def f_pmapped(*args, **kwargs):
|
||||
|
@ -1939,6 +1939,16 @@ class APITest(jtu.JaxTestCase):
|
||||
foo, in_axes=((0, collections.OrderedDict([('a', 1), ('b', 2)])),))
|
||||
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
||||
|
||||
def test_vmap_in_axes_bool_error(self):
|
||||
# https://github.com/google/jax/issues/6372
|
||||
with self.assertRaisesRegex(TypeError, "must be an int"):
|
||||
api.vmap(lambda x: x, in_axes=False)(jnp.zeros(3))
|
||||
|
||||
def test_pmap_in_axes_bool_error(self):
|
||||
# https://github.com/google/jax/issues/6372
|
||||
with self.assertRaisesRegex(TypeError, "must be an int"):
|
||||
api.pmap(lambda x: x, in_axes=False)(jnp.zeros(1))
|
||||
|
||||
def test_pmap_global_cache(self):
|
||||
def f(x, y):
|
||||
return x, y
|
||||
|
Loading…
x
Reference in New Issue
Block a user