mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 08:26:07 +00:00
parent
39b01564d4
commit
4a4304e08c
@ -711,7 +711,7 @@ def pmap(fun, axis_name=None):
|
|||||||
|
|
||||||
@wraps(fun)
|
@wraps(fun)
|
||||||
def f_pmapped(*args, **kwargs):
|
def f_pmapped(*args, **kwargs):
|
||||||
axis_size = _pmap_axis_size(args)
|
axis_size = _pmap_axis_size((args, kwargs))
|
||||||
f = lu.wrap_init(fun)
|
f = lu.wrap_init(fun)
|
||||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||||
flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
|
flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
|
||||||
|
@ -955,6 +955,14 @@ class BatchingTest(jtu.JaxTestCase):
|
|||||||
result = vmap(f, (None, 0, None))(onp.zeros((10,)), onp.arange(10,), 1.)
|
result = vmap(f, (None, 0, None))(onp.zeros((10,)), onp.arange(10,), 1.)
|
||||||
self.assertAllClose(result, onp.eye(10), check_dtypes=False)
|
self.assertAllClose(result, onp.eye(10), check_dtypes=False)
|
||||||
|
|
||||||
|
def testIssue1170(self):
|
||||||
|
def f(index1, index2):
|
||||||
|
return np.arange(36).reshape(6, 6)[index1, index2]
|
||||||
|
g = jax.jit(jax.pmap(f))
|
||||||
|
ans = g(index1=onp.asarray([1]), index2=onp.asarray([2]))
|
||||||
|
expected = g(onp.asarray([1]), onp.asarray([2]))
|
||||||
|
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user