allow unmapped vmap args to be arbitrary objects

fixes #183
This commit is contained in:
Matthew Johnson 2019-10-28 15:20:49 -07:00
parent f5079a6281
commit cbadfd41ce
2 changed files with 9 additions and 1 deletions

View File

@ -52,7 +52,8 @@ def batch_fun(fun, in_vals, in_dims):
@transformation_with_aux
def batch_subtrace(master, in_dims, *in_vals):
trace = BatchTrace(master, core.cur_sublevel())
in_tracers = map(partial(BatchTracer, trace), in_vals, in_dims)
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_dims)]
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)

View File

@ -1099,6 +1099,13 @@ class APITest(jtu.JaxTestCase):
lambda: api.vmap(lambda x: x, in_axes=(0, 0))(np.ones(3)),
ValueError, "axes specification must be a tree prefix")
def test_vmap_objects_issue_183(self):
# https://github.com/google/jax/issues/183
fun = lambda f, x: f(x)
vfun = api.vmap(fun, (None, 0))
ans = vfun(lambda x: x + 1, np.arange(3))
self.assertAllClose(ans, onp.arange(1, 4), check_dtypes=False)
if __name__ == '__main__':
absltest.main()