From 7c5871f464df0db16c4dcc44b20a382885db1075 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 18 Mar 2025 09:38:12 -0700 Subject: [PATCH] [Pallas TPU] Hoist prologue and epilogue outside of pipeline loop PiperOrigin-RevId: 738038138 --- jax/_src/pallas/mosaic/pipeline.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 2044d3d18..184b1497a 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -1196,9 +1196,8 @@ def emit_pipeline( schedule = map_brefs( lambda _, x: get_pipeline_schedule(x), allocations, schedule) - def loop_body(step, indices): - nonlocal allocations - scheduler = Scheduler( + def make_scheduler(step, indices): + return Scheduler( step, indices, grid, @@ -1208,13 +1207,15 @@ def emit_pipeline( init_accumulators=init_accumulators, trace_scopes=trace_scopes, ) + + def loop_body(step, indices): + scheduler = make_scheduler(step, indices) with scheduler.grid_env(): # prepare any local VMEM aliases brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) # loop input handling phase - map_brefs(scheduler.initialize, brefs, refs, schedule) map_brefs(scheduler.copy_in, brefs, refs, schedule) map_brefs(scheduler.wait_in, brefs, refs, schedule) @@ -1243,12 +1244,24 @@ def emit_pipeline( lambda: None) map_brefs(scheduler.swap_slots, brefs, refs, schedule) - map_brefs(scheduler.finalize, brefs, refs, schedule) - return _next_index(indices, grid) - # run pipeline - lax.fori_loop(0, num_steps, loop_body, (0,) * len(grid)) + @pl.when(num_steps > 0) + def _(): + # pipeline prologue + initial_indices = (0,) * len(grid) + scheduler = make_scheduler(0, initial_indices) + brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) + map_brefs(scheduler.initialize, brefs, refs, schedule) + + # pipeline loop + next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices) + + # pipeline epilogue + final_indices = _prev_index(next_indices, grid) + scheduler = make_scheduler(num_steps - 1, final_indices) + brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) + map_brefs(scheduler.finalize, brefs, refs, schedule) return pipeline