batching rule for cond

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
James Bradbury 2019-08-23 11:02:19 -07:00
parent e025fac1ad
commit 7830cedea6
2 changed files with 91 additions and 14 deletions

View File

@ -164,7 +164,7 @@ def while_loop(cond_fun, body_fun, init_val):
raise TypeError(msg.format(cond_tree))
if cond_jaxpr.out_avals != [ShapedArray((), onp.bool_)]:
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
raise TypeError(msg.format(coud_jaxpr.out_avals))
raise TypeError(msg.format(cond_jaxpr.out_avals))
if not treedef_children(in_tree) == [body_tree]:
msg = "body_fun output pytree structure must match init_val, got {} and {}."
raise TypeError(msg.format(body_tree, treedef_children(in_tree)[0]))
@ -301,18 +301,6 @@ def cond(pred, true_operand, true_fun, false_operand, false_fun):
true_nconsts=len(true_consts), false_nconsts=len(false_consts))
return tree_unflatten(out_tree, out)
def _cond_impl(pred, *args, **kwargs):
true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict(
kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"])
true_nops = len(true_jaxpr.in_avals) - true_nconsts
true_consts, true_ops, false_consts, false_ops = split_list(
args, [true_nconsts, true_nops, false_nconsts])
if pred:
return core.jaxpr_as_fun(true_jaxpr)(*(true_consts + true_ops))
else:
return core.jaxpr_as_fun(false_jaxpr)(*(false_consts + false_ops))
def _cond_abstract_eval(*args, **kwargs):
return kwargs["true_jaxpr"].out_avals
@ -339,10 +327,47 @@ def _cond_translation_rule(c, axis_env, pred, *args, **kwargs):
return c.Conditional(pred, true_op, true_c, false_op, false_c)
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts,
false_nconsts):
# TODO: maybe avoid moving arg axes to front if we're promoting to select?
args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(args, dims)]
true_nops = len(true_jaxpr.in_avals) - true_nconsts
(pred,), true_consts, true_ops, false_consts, false_ops = split_list(
args, [1, true_nconsts, true_nops, false_nconsts])
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
orig_bat = [d is not batching.not_mapped for d in dims]
(pred_bat,), t_bat, tconst_bat, f_bat, fconst_bat = split_list(
orig_bat, [1, true_nconsts, len(true_ops), false_nconsts])
_, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False)
_, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False)
out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]
true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat)
false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat)
if pred_bat:
true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops))
false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops))
true_out = [batching.broadcast(x, size, 0) if not b else x
for x, b in zip(true_out, out_bat)]
false_out = [batching.broadcast(x, size, 0) if not b else x
for x, b in zip(false_out, out_bat)]
return [lax.select(pred, t, f)
for t, f in zip(true_out, false_out)], [0] * len(true_out)
else:
out_dims = [0 if b else batching.not_mapped for b in out_bat]
return cond_p.bind(
*itertools.chain([pred], true_consts, true_ops, false_consts, false_ops),
true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched,
true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
cond_p = lax.Primitive('cond')
cond_p.multiple_results = True
cond_p.def_impl(_cond_impl)
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
cond_p.def_abstract_eval(_cond_abstract_eval)
batching.primitive_batchers[cond_p] = _cond_batching_rule
xla.initial_style_translations[cond_p] = _cond_translation_rule

View File

@ -475,6 +475,58 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(fun(4), cfun(4))
self.assertEqual(cfun(4), (4, 2., 4.))
def testCondBatched(self):
def fun(x, y, z):
pred = lax.lt(x, 3)
true_fun = lambda y: y
false_fun = lambda z: lax.neg(z)
return lax.cond(pred, y, true_fun, z, false_fun)
# these cases stay as cond
x = onp.array(2)
y = onp.array([1, 2])
z = onp.array([3, 4])
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
expected = onp.array([1, 2])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
x = onp.array(4)
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
expected = onp.array([-3, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
fun = api.jit(fun)
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
expected = onp.array([-3, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
z = onp.array(5)
ans = api.vmap(fun, (None, 0, None))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, None)))(x, y, z)
expected = onp.array([-5, -5])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
# these cases become select
x = onp.array([2, 4])
ans = api.vmap(fun, (0, 0, None))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (0, 0, None)))(x, y, z)
expected = onp.array([1, -5])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" in str(jaxpr)
z = onp.array([3, 4])
ans = api.vmap(fun)(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun))(x, y, z)
expected = onp.array([1, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" in str(jaxpr)
def testIssue514(self):
# just check this doesn't crash
lax.cond(True,