mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas TPU] Close over consts in while_loop lowering to avoid passing refs in/out of loop
PiperOrigin-RevId: 660238073
This commit is contained in:
parent
dd958adc39
commit
803453ed74
@ -2160,7 +2160,7 @@ def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
|
||||
def _scan_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
jaxpr: jax_core.ClosedJaxpr,
|
||||
linear: tuple[bool, ...],
|
||||
length: int,
|
||||
reverse: bool,
|
||||
@ -2241,7 +2241,7 @@ def _while_lowering_rule(
|
||||
body_jaxpr,
|
||||
):
|
||||
# First try to lower via a simpler fori loop, which may optimize better.
|
||||
fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
|
||||
fori_jaxpr, _ = pallas_utils.pattern_match_while_to_fori_loop(
|
||||
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
|
||||
)
|
||||
if fori_jaxpr is not None:
|
||||
@ -2262,19 +2262,12 @@ def _while_lowering_rule(
|
||||
cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = (
|
||||
split_list(ctx.block_shapes, [cond_nconsts, body_nconsts])
|
||||
)
|
||||
cond_const_types = [a.type for a in cond_consts]
|
||||
body_const_types = [a.type for a in body_consts]
|
||||
carry_types = [a.type for a in carry]
|
||||
all_types = [*cond_const_types, *body_const_types, *carry_types]
|
||||
while_op = scf.WhileOp(all_types, args)
|
||||
while_op = scf.WhileOp(carry_types, carry)
|
||||
|
||||
before_block = while_op.before.blocks.append(*all_types)
|
||||
cond_consts_, _, carry_ = split_list(
|
||||
before_block.arguments,
|
||||
[cond_nconsts, body_nconsts],
|
||||
)
|
||||
cond_args = [*cond_consts_, *carry_]
|
||||
before_block = while_op.before.blocks.append(*carry_types)
|
||||
with ir.InsertionPoint.at_block_begin(before_block):
|
||||
cond_args = [*cond_consts, *before_block.arguments]
|
||||
[cond] = jaxpr_subcomp(
|
||||
ctx.lowering_context.replace(
|
||||
block_shapes=[*cond_const_block_shapes, *carry_block_shapes]
|
||||
@ -2284,30 +2277,19 @@ def _while_lowering_rule(
|
||||
)
|
||||
scf.condition(cond, before_block.arguments)
|
||||
|
||||
after_block = while_op.after.blocks.append(*all_types)
|
||||
cond_consts_, body_consts_, carry_ = split_list(
|
||||
after_block.arguments,
|
||||
[cond_nconsts, body_nconsts],
|
||||
)
|
||||
all_args = [*cond_consts_, *body_consts_, *carry_]
|
||||
cond_const_args, body_const_args, carry_args = split_list(
|
||||
all_args, [cond_nconsts, body_nconsts]
|
||||
)
|
||||
after_block = while_op.after.blocks.append(*carry_types)
|
||||
with ir.InsertionPoint.at_block_begin(after_block):
|
||||
body_args = [*body_consts, *after_block.arguments]
|
||||
loop_out = jaxpr_subcomp(
|
||||
ctx.lowering_context.replace(
|
||||
block_shapes=[*body_const_block_shapes, *carry_block_shapes],
|
||||
),
|
||||
body_jaxpr.jaxpr,
|
||||
*body_const_args,
|
||||
*carry_args,
|
||||
*body_args,
|
||||
)
|
||||
all_handles = [*cond_const_args, *body_const_args, *loop_out]
|
||||
if all_handles:
|
||||
scf.yield_(all_handles)
|
||||
|
||||
all_out = list(while_op.results_)
|
||||
return all_out[cond_nconsts + body_nconsts :]
|
||||
if loop_out:
|
||||
scf.yield_(loop_out)
|
||||
return list(while_op.results)
|
||||
|
||||
|
||||
lowering_rules[lax.while_p] = _while_lowering_rule
|
||||
|
Loading…
x
Reference in New Issue
Block a user