mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas mgpu] Lowering for while loops as long as they are secretly for loops.
PiperOrigin-RevId: 698427307
This commit is contained in:
parent
439d34da15
commit
8d84f28373
@ -1473,6 +1473,44 @@ def _scan_lowering_rule(
|
||||
return for_out
|
||||
|
||||
|
||||
@register_lowering_rule(lax.while_p)
|
||||
def _while_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
cond_jaxpr,
|
||||
body_jaxpr,
|
||||
cond_nconsts,
|
||||
body_nconsts,
|
||||
):
|
||||
# First try to lower via a simpler fori loop, which may optimize better.
|
||||
fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
|
||||
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
|
||||
)
|
||||
del cond_jaxpr, body_jaxpr
|
||||
if fori_jaxpr is None:
|
||||
raise NotImplementedError(err)
|
||||
|
||||
if fori_jaxpr.constvars:
|
||||
raise NotImplementedError
|
||||
|
||||
lb_aval, ub_aval, *_ = ctx.avals_in[body_nconsts:]
|
||||
# Reflect the changes of the pattern matcher to the context.
|
||||
avals_in = (
|
||||
*ctx.avals_in[cond_nconsts:body_nconsts],
|
||||
ctx.avals_in[body_nconsts], # the index
|
||||
*ctx.avals_in[body_nconsts + 2:],
|
||||
)
|
||||
|
||||
avals_out = tuple(ctx.avals_out[2:])
|
||||
ctx = ctx.replace(avals_in=avals_in, avals_out=avals_out)
|
||||
_, consts, (lb, ub, *args) = util.split_list(args, [cond_nconsts, body_nconsts])
|
||||
|
||||
lb, ub = _ensure_ir_value(lb, lb_aval.dtype), _ensure_ir_value(ub, ub_aval.dtype)
|
||||
length = arith_dialect.subi(ub, lb)
|
||||
|
||||
for_out = _lower_jaxpr_to_for_loop(ctx, fori_jaxpr, lb, length, consts, *args, has_loop_index=True)
|
||||
return (ub, ub, *for_out)
|
||||
|
||||
@register_lowering_rule(lax.cond_p)
|
||||
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
|
||||
index_aval, *_arg_avals = ctx.avals_in
|
||||
|
@ -676,6 +676,22 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
|
||||
|
||||
def test_fori_loop_dynamic_bounds(self):
|
||||
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
||||
grid=(1,)
|
||||
)
|
||||
def kernel(o_ref):
|
||||
zero = pl.program_id(0)
|
||||
# Equivalent to 2 + 3.
|
||||
o_ref[...] = jax.lax.broadcast(
|
||||
jax.lax.fori_loop(2 + zero, 4 + zero, lambda i, x: x + i, 0), o_ref.shape
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
|
||||
|
||||
def test_fori_loop_tuple(self):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
|
Loading…
x
Reference in New Issue
Block a user