mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
f5079a6281
commit
cbadfd41ce
@ -52,7 +52,8 @@ def batch_fun(fun, in_vals, in_dims):
|
||||
@transformation_with_aux
|
||||
def batch_subtrace(master, in_dims, *in_vals):
|
||||
trace = BatchTrace(master, core.cur_sublevel())
|
||||
in_tracers = map(partial(BatchTracer, trace), in_vals, in_dims)
|
||||
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
||||
for val, dim in zip(in_vals, in_dims)]
|
||||
outs = yield in_tracers, {}
|
||||
out_tracers = map(trace.full_raise, outs)
|
||||
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
||||
|
@ -1099,6 +1099,13 @@ class APITest(jtu.JaxTestCase):
|
||||
lambda: api.vmap(lambda x: x, in_axes=(0, 0))(np.ones(3)),
|
||||
ValueError, "axes specification must be a tree prefix")
|
||||
|
||||
def test_vmap_objects_issue_183(self):
|
||||
# https://github.com/google/jax/issues/183
|
||||
fun = lambda f, x: f(x)
|
||||
vfun = api.vmap(fun, (None, 0))
|
||||
ans = vfun(lambda x: x + 1, np.arange(3))
|
||||
self.assertAllClose(ans, onp.arange(1, 4), check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user