mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 01:06:05 +00:00
parent
39b01564d4
commit
4a4304e08c
@ -711,7 +711,7 @@ def pmap(fun, axis_name=None):
|
||||
|
||||
@wraps(fun)
|
||||
def f_pmapped(*args, **kwargs):
|
||||
axis_size = _pmap_axis_size(args)
|
||||
axis_size = _pmap_axis_size((args, kwargs))
|
||||
f = lu.wrap_init(fun)
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
|
||||
|
@ -955,6 +955,14 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
result = vmap(f, (None, 0, None))(onp.zeros((10,)), onp.arange(10,), 1.)
|
||||
self.assertAllClose(result, onp.eye(10), check_dtypes=False)
|
||||
|
||||
def testIssue1170(self):
|
||||
def f(index1, index2):
|
||||
return np.arange(36).reshape(6, 6)[index1, index2]
|
||||
g = jax.jit(jax.pmap(f))
|
||||
ans = g(index1=onp.asarray([1]), index2=onp.asarray([2]))
|
||||
expected = g(onp.asarray([1]), onp.asarray([2]))
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user