mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
optimize while_loop by moving readonly carry components to be consts
also fix a bug in ordered effects in cond_fun lowering fixes google/flax#4700
This commit is contained in:
parent
9011d66a29
commit
6e52b1e95b
@ -1461,9 +1461,34 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `while`: {disallowed_effects}')
|
||||
|
||||
# If the body forwards an input carry to an output carry, *and* it's not used
|
||||
# by the cond fun, it can be moved to be a body const. Doing so can lead to
|
||||
# efficiency wins: if e.g. we vmap the loop with a batched predicate, we batch
|
||||
# the carry too, but not the body consts.
|
||||
body_fwd = pe._jaxpr_forwarding(body_jaxpr.jaxpr)
|
||||
_, carry_fwd = split_list(body_fwd, [len(body_consts)])
|
||||
cond_jaxpr_, keep_cond = pe.dce_jaxpr(
|
||||
cond_jaxpr.jaxpr, [True],
|
||||
[True] * len(cond_consts) + [i != f for i, f in enumerate(body_fwd)])
|
||||
_, keep_cond_carry = split_list(keep_cond, [len(cond_consts)])
|
||||
move_to_const = [i == f and not k for i, (f, k)
|
||||
in enumerate(zip(body_fwd, keep_cond_carry))]
|
||||
if any(move_to_const):
|
||||
cond_jaxpr = pe.close_jaxpr(cond_jaxpr_)
|
||||
body_jaxpr = pe.prune_closed_jaxpr_outputs(
|
||||
body_jaxpr, [not m for m in move_to_const])
|
||||
body_jaxpr = pe.move_binders_to_front(
|
||||
body_jaxpr, [False] * len(body_consts) + move_to_const)
|
||||
init_vals, new_body_consts = partition_list(move_to_const, init_vals)
|
||||
body_consts = [*new_body_consts, *body_consts]
|
||||
|
||||
outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
|
||||
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
|
||||
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
|
||||
|
||||
if any(move_to_const):
|
||||
outs = pe.merge_lists(move_to_const, outs, new_body_consts)
|
||||
return tree_unflatten(body_tree, outs)
|
||||
|
||||
|
||||
@ -1839,18 +1864,19 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
pred = lax.reduce_or(pred, tuple(range(len(pred_aval.shape))))
|
||||
return pred
|
||||
def body(args):
|
||||
return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args))
|
||||
return core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args)
|
||||
def new_cond(pred_args):
|
||||
pred, _ = pred_args
|
||||
pred, *_ = pred_args
|
||||
return pred
|
||||
def new_body(pred_args):
|
||||
_, args = pred_args
|
||||
args = body(args)
|
||||
pred = cond(args)
|
||||
return pred, args
|
||||
_, cond_consts, body_consts, carry = pred_args
|
||||
carry = body((*body_consts, *carry))
|
||||
pred = cond((*cond_consts, *carry))
|
||||
return pred, cond_consts, body_consts, carry
|
||||
def fun(*args):
|
||||
pred = cond(args)
|
||||
_, out = while_loop(new_cond, new_body, (pred, args))
|
||||
cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
|
||||
pred = cond((*cond_consts, *carry))
|
||||
*_, out = while_loop(new_cond, new_body, (pred, cond_consts, body_consts, carry))
|
||||
return out
|
||||
return mlir.lower_fun(fun)(ctx, *args)
|
||||
|
||||
|
@ -492,8 +492,8 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def test_while_loop_body_and_cond_error(self):
|
||||
def while_cond(val):
|
||||
i, cond_val, _ = val
|
||||
_ = jnp.sin(cond_val)
|
||||
return i < 2
|
||||
j = jnp.sin(cond_val)
|
||||
return i + (0. * j) < 2 # don't let the sin value be dead code
|
||||
|
||||
def while_body(val):
|
||||
i, cond_val, body_val = val
|
||||
|
@ -2362,7 +2362,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
elif loop == "fori_inside_cond":
|
||||
func = lambda x: lax.cond(
|
||||
True,
|
||||
x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x),
|
||||
x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x),
|
||||
1., lambda x: x)
|
||||
elif loop == "fori_inside_scan":
|
||||
func = lambda x: lax.scan(
|
||||
@ -3122,6 +3122,27 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return x + y
|
||||
jax.linearize(f, 1., 2.) # don't crash
|
||||
|
||||
def test_readonly_carry_optimization(self):
|
||||
# https://github.com/google/flax/issues/4700
|
||||
def foo(w, x, c_max):
|
||||
def while_cond(val):
|
||||
c, x, w = val
|
||||
return c < c_max
|
||||
|
||||
def while_body(val):
|
||||
c, x, w = val
|
||||
return c + 1, x @ w, w
|
||||
|
||||
_, x, w = jax.lax.while_loop(while_cond, while_body, (0, x, w))
|
||||
return w, x
|
||||
|
||||
w = jnp.ones((2, 2))
|
||||
xs = jnp.ones((4, 2))
|
||||
c_maxs = jnp.arange(4)
|
||||
w_, _ = jax.vmap(foo, in_axes=(None, 0, 0), out_axes=(None, 0)
|
||||
)(w, xs, c_maxs) # doesn't crash
|
||||
self.assertAllClose(w, w_, check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user