[pallas mgpu] Lowering for while loops as long as they are secretly for loops.

PiperOrigin-RevId: 698427307
This commit is contained in:
Christos Perivolaropoulos 2024-11-20 09:59:39 -08:00 committed by jax authors
parent 439d34da15
commit 8d84f28373
2 changed files with 54 additions and 0 deletions

View File

@ -1473,6 +1473,44 @@ def _scan_lowering_rule(
return for_out
@register_lowering_rule(lax.while_p)
def _while_lowering_rule(
ctx: LoweringRuleContext,
*args,
cond_jaxpr,
body_jaxpr,
cond_nconsts,
body_nconsts,
):
# First try to lower via a simpler fori loop, which may optimize better.
fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
)
del cond_jaxpr, body_jaxpr
if fori_jaxpr is None:
raise NotImplementedError(err)
if fori_jaxpr.constvars:
raise NotImplementedError
lb_aval, ub_aval, *_ = ctx.avals_in[body_nconsts:]
# Reflect the changes of the pattern matcher to the context.
avals_in = (
*ctx.avals_in[cond_nconsts:body_nconsts],
ctx.avals_in[body_nconsts], # the index
*ctx.avals_in[body_nconsts + 2:],
)
avals_out = tuple(ctx.avals_out[2:])
ctx = ctx.replace(avals_in=avals_in, avals_out=avals_out)
_, consts, (lb, ub, *args) = util.split_list(args, [cond_nconsts, body_nconsts])
lb, ub = _ensure_ir_value(lb, lb_aval.dtype), _ensure_ir_value(ub, ub_aval.dtype)
length = arith_dialect.subi(ub, lb)
for_out = _lower_jaxpr_to_for_loop(ctx, fori_jaxpr, lb, length, consts, *args, has_loop_index=True)
return (ub, ub, *for_out)
@register_lowering_rule(lax.cond_p)
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
index_aval, *_arg_avals = ctx.avals_in

View File

@ -676,6 +676,22 @@ class PallasCallTest(PallasTest):
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
def test_fori_loop_dynamic_bounds(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
grid=(1,)
)
def kernel(o_ref):
zero = pl.program_id(0)
# Equivalent to 2 + 3.
o_ref[...] = jax.lax.broadcast(
jax.lax.fori_loop(2 + zero, 4 + zero, lambda i, x: x + i, 0), o_ref.shape
)
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
def test_fori_loop_tuple(self):
@functools.partial(
pl.pallas_call,