mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
batching rule for cond
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
e025fac1ad
commit
7830cedea6
@ -164,7 +164,7 @@ def while_loop(cond_fun, body_fun, init_val):
|
||||
raise TypeError(msg.format(cond_tree))
|
||||
if cond_jaxpr.out_avals != [ShapedArray((), onp.bool_)]:
|
||||
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
|
||||
raise TypeError(msg.format(coud_jaxpr.out_avals))
|
||||
raise TypeError(msg.format(cond_jaxpr.out_avals))
|
||||
if not treedef_children(in_tree) == [body_tree]:
|
||||
msg = "body_fun output pytree structure must match init_val, got {} and {}."
|
||||
raise TypeError(msg.format(body_tree, treedef_children(in_tree)[0]))
|
||||
@ -301,18 +301,6 @@ def cond(pred, true_operand, true_fun, false_operand, false_fun):
|
||||
true_nconsts=len(true_consts), false_nconsts=len(false_consts))
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
def _cond_impl(pred, *args, **kwargs):
|
||||
true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict(
|
||||
kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"])
|
||||
true_nops = len(true_jaxpr.in_avals) - true_nconsts
|
||||
true_consts, true_ops, false_consts, false_ops = split_list(
|
||||
args, [true_nconsts, true_nops, false_nconsts])
|
||||
|
||||
if pred:
|
||||
return core.jaxpr_as_fun(true_jaxpr)(*(true_consts + true_ops))
|
||||
else:
|
||||
return core.jaxpr_as_fun(false_jaxpr)(*(false_consts + false_ops))
|
||||
|
||||
def _cond_abstract_eval(*args, **kwargs):
|
||||
return kwargs["true_jaxpr"].out_avals
|
||||
|
||||
@ -339,10 +327,47 @@ def _cond_translation_rule(c, axis_env, pred, *args, **kwargs):
|
||||
|
||||
return c.Conditional(pred, true_op, true_c, false_op, false_c)
|
||||
|
||||
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts,
|
||||
false_nconsts):
|
||||
# TODO: maybe avoid moving arg axes to front if we're promoting to select?
|
||||
args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
||||
else x for x, d in zip(args, dims)]
|
||||
true_nops = len(true_jaxpr.in_avals) - true_nconsts
|
||||
(pred,), true_consts, true_ops, false_consts, false_ops = split_list(
|
||||
args, [1, true_nconsts, true_nops, false_nconsts])
|
||||
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
||||
orig_bat = [d is not batching.not_mapped for d in dims]
|
||||
(pred_bat,), t_bat, tconst_bat, f_bat, fconst_bat = split_list(
|
||||
orig_bat, [1, true_nconsts, len(true_ops), false_nconsts])
|
||||
|
||||
_, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False)
|
||||
_, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False)
|
||||
out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]
|
||||
|
||||
true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat)
|
||||
false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat)
|
||||
|
||||
if pred_bat:
|
||||
true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops))
|
||||
false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops))
|
||||
true_out = [batching.broadcast(x, size, 0) if not b else x
|
||||
for x, b in zip(true_out, out_bat)]
|
||||
false_out = [batching.broadcast(x, size, 0) if not b else x
|
||||
for x, b in zip(false_out, out_bat)]
|
||||
return [lax.select(pred, t, f)
|
||||
for t, f in zip(true_out, false_out)], [0] * len(true_out)
|
||||
else:
|
||||
out_dims = [0 if b else batching.not_mapped for b in out_bat]
|
||||
return cond_p.bind(
|
||||
*itertools.chain([pred], true_consts, true_ops, false_consts, false_ops),
|
||||
true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched,
|
||||
true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
|
||||
|
||||
cond_p = lax.Primitive('cond')
|
||||
cond_p.multiple_results = True
|
||||
cond_p.def_impl(_cond_impl)
|
||||
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
|
||||
cond_p.def_abstract_eval(_cond_abstract_eval)
|
||||
batching.primitive_batchers[cond_p] = _cond_batching_rule
|
||||
xla.initial_style_translations[cond_p] = _cond_translation_rule
|
||||
|
||||
|
||||
|
@ -475,6 +475,58 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertEqual(fun(4), cfun(4))
|
||||
self.assertEqual(cfun(4), (4, 2., 4.))
|
||||
|
||||
def testCondBatched(self):
|
||||
def fun(x, y, z):
|
||||
pred = lax.lt(x, 3)
|
||||
true_fun = lambda y: y
|
||||
false_fun = lambda z: lax.neg(z)
|
||||
return lax.cond(pred, y, true_fun, z, false_fun)
|
||||
|
||||
# these cases stay as cond
|
||||
x = onp.array(2)
|
||||
y = onp.array([1, 2])
|
||||
z = onp.array([3, 4])
|
||||
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
|
||||
expected = onp.array([1, 2])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" not in str(jaxpr)
|
||||
|
||||
x = onp.array(4)
|
||||
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
|
||||
expected = onp.array([-3, -4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" not in str(jaxpr)
|
||||
|
||||
fun = api.jit(fun)
|
||||
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
|
||||
expected = onp.array([-3, -4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
z = onp.array(5)
|
||||
ans = api.vmap(fun, (None, 0, None))(x, y, z)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, None)))(x, y, z)
|
||||
expected = onp.array([-5, -5])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" not in str(jaxpr)
|
||||
|
||||
|
||||
# these cases become select
|
||||
x = onp.array([2, 4])
|
||||
ans = api.vmap(fun, (0, 0, None))(x, y, z)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (0, 0, None)))(x, y, z)
|
||||
expected = onp.array([1, -5])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" in str(jaxpr)
|
||||
|
||||
z = onp.array([3, 4])
|
||||
ans = api.vmap(fun)(x, y, z)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun))(x, y, z)
|
||||
expected = onp.array([1, -4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" in str(jaxpr)
|
||||
|
||||
def testIssue514(self):
|
||||
# just check this doesn't crash
|
||||
lax.cond(True,
|
||||
|
Loading…
x
Reference in New Issue
Block a user