[Pallas TPU] Hoist prologue and epilogue outside of pipeline loop

PiperOrigin-RevId: 738038138
This commit is contained in:
jax authors 2025-03-18 09:38:12 -07:00
parent 30941480a1
commit 7c5871f464

View File

@ -1196,9 +1196,8 @@ def emit_pipeline(
schedule = map_brefs(
lambda _, x: get_pipeline_schedule(x), allocations, schedule)
def loop_body(step, indices):
nonlocal allocations
scheduler = Scheduler(
def make_scheduler(step, indices):
return Scheduler(
step,
indices,
grid,
@ -1208,13 +1207,15 @@ def emit_pipeline(
init_accumulators=init_accumulators,
trace_scopes=trace_scopes,
)
def loop_body(step, indices):
scheduler = make_scheduler(step, indices)
with scheduler.grid_env():
# prepare any local VMEM aliases
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
# loop input handling phase
map_brefs(scheduler.initialize, brefs, refs, schedule)
map_brefs(scheduler.copy_in, brefs, refs, schedule)
map_brefs(scheduler.wait_in, brefs, refs, schedule)
@ -1243,12 +1244,24 @@ def emit_pipeline(
lambda: None)
map_brefs(scheduler.swap_slots, brefs, refs, schedule)
map_brefs(scheduler.finalize, brefs, refs, schedule)
return _next_index(indices, grid)
# run pipeline
lax.fori_loop(0, num_steps, loop_body, (0,) * len(grid))
@pl.when(num_steps > 0)
def _():
# pipeline prologue
initial_indices = (0,) * len(grid)
scheduler = make_scheduler(0, initial_indices)
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
map_brefs(scheduler.initialize, brefs, refs, schedule)
# pipeline loop
next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices)
# pipeline epilogue
final_indices = _prev_index(next_indices, grid)
scheduler = make_scheduler(num_steps - 1, final_indices)
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
map_brefs(scheduler.finalize, brefs, refs, schedule)
return pipeline