1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

optimize while_loop by moving readonly carry components to be consts

also fix a bug in ordered effects in cond_fun lowering

fixes 
This commit is contained in:
Matthew Johnson 2025-04-11 00:25:42 +00:00
parent 9011d66a29
commit 6e52b1e95b
3 changed files with 58 additions and 11 deletions

@ -1461,9 +1461,34 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {disallowed_effects}')
# If the body forwards an input carry to an output carry, *and* it's not used
# by the cond fun, it can be moved to be a body const. Doing so can lead to
# efficiency wins: if e.g. we vmap the loop with a batched predicate, we batch
# the carry too, but not the body consts.
body_fwd = pe._jaxpr_forwarding(body_jaxpr.jaxpr)
_, carry_fwd = split_list(body_fwd, [len(body_consts)])
cond_jaxpr_, keep_cond = pe.dce_jaxpr(
cond_jaxpr.jaxpr, [True],
[True] * len(cond_consts) + [i != f for i, f in enumerate(body_fwd)])
_, keep_cond_carry = split_list(keep_cond, [len(cond_consts)])
move_to_const = [i == f and not k for i, (f, k)
in enumerate(zip(body_fwd, keep_cond_carry))]
if any(move_to_const):
cond_jaxpr = pe.close_jaxpr(cond_jaxpr_)
body_jaxpr = pe.prune_closed_jaxpr_outputs(
body_jaxpr, [not m for m in move_to_const])
body_jaxpr = pe.move_binders_to_front(
body_jaxpr, [False] * len(body_consts) + move_to_const)
init_vals, new_body_consts = partition_list(move_to_const, init_vals)
body_consts = [*new_body_consts, *body_consts]
outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
if any(move_to_const):
outs = pe.merge_lists(move_to_const, outs, new_body_consts)
return tree_unflatten(body_tree, outs)
@ -1839,18 +1864,19 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
pred = lax.reduce_or(pred, tuple(range(len(pred_aval.shape))))
return pred
def body(args):
return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args))
return core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args)
def new_cond(pred_args):
pred, _ = pred_args
pred, *_ = pred_args
return pred
def new_body(pred_args):
_, args = pred_args
args = body(args)
pred = cond(args)
return pred, args
_, cond_consts, body_consts, carry = pred_args
carry = body((*body_consts, *carry))
pred = cond((*cond_consts, *carry))
return pred, cond_consts, body_consts, carry
def fun(*args):
pred = cond(args)
_, out = while_loop(new_cond, new_body, (pred, args))
cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
pred = cond((*cond_consts, *carry))
*_, out = while_loop(new_cond, new_body, (pred, cond_consts, body_consts, carry))
return out
return mlir.lower_fun(fun)(ctx, *args)

@ -492,8 +492,8 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def test_while_loop_body_and_cond_error(self):
def while_cond(val):
i, cond_val, _ = val
_ = jnp.sin(cond_val)
return i < 2
j = jnp.sin(cond_val)
return i + (0. * j) < 2 # don't let the sin value be dead code
def while_body(val):
i, cond_val, body_val = val

@ -2362,7 +2362,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
elif loop == "fori_inside_cond":
func = lambda x: lax.cond(
True,
x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x),
x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x),
1., lambda x: x)
elif loop == "fori_inside_scan":
func = lambda x: lax.scan(
@ -3122,6 +3122,27 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return x + y
jax.linearize(f, 1., 2.) # don't crash
def test_readonly_carry_optimization(self):
# https://github.com/google/flax/issues/4700
def foo(w, x, c_max):
def while_cond(val):
c, x, w = val
return c < c_max
def while_body(val):
c, x, w = val
return c + 1, x @ w, w
_, x, w = jax.lax.while_loop(while_cond, while_body, (0, x, w))
return w, x
w = jnp.ones((2, 2))
xs = jnp.ones((4, 2))
c_maxs = jnp.arange(4)
w_, _ = jax.vmap(foo, in_axes=(None, 0, 0), out_axes=(None, 0)
)(w, xs, c_maxs) # doesn't crash
self.assertAllClose(w, w_, check_dtypes=False)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())