mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[Pallas TPU] Hoist prologue and epilogue outside of pipeline loop
PiperOrigin-RevId: 738038138
This commit is contained in:
parent
30941480a1
commit
7c5871f464
@ -1196,9 +1196,8 @@ def emit_pipeline(
|
||||
schedule = map_brefs(
|
||||
lambda _, x: get_pipeline_schedule(x), allocations, schedule)
|
||||
|
||||
def loop_body(step, indices):
|
||||
nonlocal allocations
|
||||
scheduler = Scheduler(
|
||||
def make_scheduler(step, indices):
|
||||
return Scheduler(
|
||||
step,
|
||||
indices,
|
||||
grid,
|
||||
@ -1208,13 +1207,15 @@ def emit_pipeline(
|
||||
init_accumulators=init_accumulators,
|
||||
trace_scopes=trace_scopes,
|
||||
)
|
||||
|
||||
def loop_body(step, indices):
|
||||
scheduler = make_scheduler(step, indices)
|
||||
with scheduler.grid_env():
|
||||
|
||||
# prepare any local VMEM aliases
|
||||
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
|
||||
|
||||
# loop input handling phase
|
||||
map_brefs(scheduler.initialize, brefs, refs, schedule)
|
||||
map_brefs(scheduler.copy_in, brefs, refs, schedule)
|
||||
map_brefs(scheduler.wait_in, brefs, refs, schedule)
|
||||
|
||||
@ -1243,12 +1244,24 @@ def emit_pipeline(
|
||||
lambda: None)
|
||||
|
||||
map_brefs(scheduler.swap_slots, brefs, refs, schedule)
|
||||
map_brefs(scheduler.finalize, brefs, refs, schedule)
|
||||
|
||||
return _next_index(indices, grid)
|
||||
|
||||
# run pipeline
|
||||
lax.fori_loop(0, num_steps, loop_body, (0,) * len(grid))
|
||||
@pl.when(num_steps > 0)
|
||||
def _():
|
||||
# pipeline prologue
|
||||
initial_indices = (0,) * len(grid)
|
||||
scheduler = make_scheduler(0, initial_indices)
|
||||
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
|
||||
map_brefs(scheduler.initialize, brefs, refs, schedule)
|
||||
|
||||
# pipeline loop
|
||||
next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices)
|
||||
|
||||
# pipeline epilogue
|
||||
final_indices = _prev_index(next_indices, grid)
|
||||
scheduler = make_scheduler(num_steps - 1, final_indices)
|
||||
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
|
||||
map_brefs(scheduler.finalize, brefs, refs, schedule)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user