make pmap read axis size from kwargs

fixes #1170
This commit is contained in:
Matthew Johnson 2019-08-12 18:03:25 -07:00
parent 39b01564d4
commit 4a4304e08c
2 changed files with 9 additions and 1 deletions

View File

@ -711,7 +711,7 @@ def pmap(fun, axis_name=None):
@wraps(fun)
def f_pmapped(*args, **kwargs):
axis_size = _pmap_axis_size(args)
axis_size = _pmap_axis_size((args, kwargs))
f = lu.wrap_init(fun)
args_flat, in_tree = tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun_leafout(f, in_tree)

View File

@ -955,6 +955,14 @@ class BatchingTest(jtu.JaxTestCase):
result = vmap(f, (None, 0, None))(onp.zeros((10,)), onp.arange(10,), 1.)
self.assertAllClose(result, onp.eye(10), check_dtypes=False)
def testIssue1170(self):
def f(index1, index2):
return np.arange(36).reshape(6, 6)[index1, index2]
g = jax.jit(jax.pmap(f))
ans = g(index1=onp.asarray([1]), index2=onp.asarray([2]))
expected = g(onp.asarray([1]), onp.asarray([2]))
self.assertAllClose(ans, expected, check_dtypes=True)
if __name__ == '__main__':
absltest.main()