[Pallas TPU] Close over consts in while_loop lowering to avoid passing refs in/out of loop

PiperOrigin-RevId: 660238073
This commit is contained in:
Sharad Vikram 2024-08-06 22:32:46 -07:00 committed by jax authors
parent dd958adc39
commit 803453ed74

View File

@ -2160,7 +2160,7 @@ def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
def _scan_lowering_rule(
ctx: LoweringRuleContext,
*args,
jaxpr: jax_core.Jaxpr,
jaxpr: jax_core.ClosedJaxpr,
linear: tuple[bool, ...],
length: int,
reverse: bool,
@ -2241,7 +2241,7 @@ def _while_lowering_rule(
body_jaxpr,
):
# First try to lower via a simpler fori loop, which may optimize better.
fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
fori_jaxpr, _ = pallas_utils.pattern_match_while_to_fori_loop(
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
)
if fori_jaxpr is not None:
@ -2262,19 +2262,12 @@ def _while_lowering_rule(
cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = (
split_list(ctx.block_shapes, [cond_nconsts, body_nconsts])
)
cond_const_types = [a.type for a in cond_consts]
body_const_types = [a.type for a in body_consts]
carry_types = [a.type for a in carry]
all_types = [*cond_const_types, *body_const_types, *carry_types]
while_op = scf.WhileOp(all_types, args)
while_op = scf.WhileOp(carry_types, carry)
before_block = while_op.before.blocks.append(*all_types)
cond_consts_, _, carry_ = split_list(
before_block.arguments,
[cond_nconsts, body_nconsts],
)
cond_args = [*cond_consts_, *carry_]
before_block = while_op.before.blocks.append(*carry_types)
with ir.InsertionPoint.at_block_begin(before_block):
cond_args = [*cond_consts, *before_block.arguments]
[cond] = jaxpr_subcomp(
ctx.lowering_context.replace(
block_shapes=[*cond_const_block_shapes, *carry_block_shapes]
@ -2284,30 +2277,19 @@ def _while_lowering_rule(
)
scf.condition(cond, before_block.arguments)
after_block = while_op.after.blocks.append(*all_types)
cond_consts_, body_consts_, carry_ = split_list(
after_block.arguments,
[cond_nconsts, body_nconsts],
)
all_args = [*cond_consts_, *body_consts_, *carry_]
cond_const_args, body_const_args, carry_args = split_list(
all_args, [cond_nconsts, body_nconsts]
)
after_block = while_op.after.blocks.append(*carry_types)
with ir.InsertionPoint.at_block_begin(after_block):
body_args = [*body_consts, *after_block.arguments]
loop_out = jaxpr_subcomp(
ctx.lowering_context.replace(
block_shapes=[*body_const_block_shapes, *carry_block_shapes],
),
body_jaxpr.jaxpr,
*body_const_args,
*carry_args,
*body_args,
)
all_handles = [*cond_const_args, *body_const_args, *loop_out]
if all_handles:
scf.yield_(all_handles)
all_out = list(while_op.results_)
return all_out[cond_nconsts + body_nconsts :]
if loop_out:
scf.yield_(loop_out)
return list(while_op.results)
lowering_rules[lax.while_p] = _while_lowering_rule