[Pallas TPU] Remove next_slot SMEM tensor from pipeline emitter

PiperOrigin-RevId: 735564365
This commit is contained in:
jax authors 2025-03-10 17:18:49 -07:00
parent 988a1208a9
commit 02505fa757

View File

@ -207,6 +207,8 @@ class BufferedRef:
is_accumulator: whether this BufferedRef is an accumulator.
is_input_output: whether this BufferedRef is an input/output without
automatic accumulation.
swap: Tracks whether the BufferedRef slots need to be swapped before next
copy.
"""
spec: pl.BlockSpec # static metadata
dtype: Any # static metadata
@ -214,9 +216,14 @@ class BufferedRef:
window_ref: REF | None
accum_ref: REF | None
current_slot: ArrayRef | None
# TODO(ramiroleal): Unused by class. Remove argument from
# BufferedRef instantiations.
next_slot: ArrayRef | None
sem_recvs: SemaphoreTuple | None
sem_sends: SemaphoreTuple | None
# TODO(ramiroleal): Improve prefetch/postyeet interface to avoid
# using this ref.
swap: ArrayRef | None
def tree_flatten(self):
return (
@ -227,6 +234,7 @@ class BufferedRef:
self.next_slot,
self.sem_recvs,
self.sem_sends,
self.swap,
),
(self.spec, self.dtype, self.buffer_type),
)
@ -240,7 +248,7 @@ class BufferedRef:
return BufferType
@classmethod
def create(cls, spec, dtype, buffer_type) -> BufferedRef:
def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef:
"""Create a BufferedRef.
Args:
@ -248,6 +256,7 @@ class BufferedRef:
dtype: dtype for buffers.
buffer_type: enum indicating whether this is an input, output, or in/out
accumulator buffered reference.
needs_swap_ref: whether a swap slots tracker needs to be allocated.
Returns:
Initialized BufferedRef
@ -271,6 +280,7 @@ class BufferedRef:
next_slot=None,
sem_recvs=None,
sem_sends=None,
swap=None,
)
else:
memory_space = SMEM if spec.memory_space == SMEM else VMEM
@ -281,7 +291,7 @@ class BufferedRef:
window_ref=memory_space((2,) + block_shape, dtype),
accum_ref=accum_ref,
current_slot=SMEM((1,), jnp.int32),
next_slot=SMEM((1,), jnp.int32),
next_slot=None,
sem_recvs=(
None
if buffer_type is BufferType.OUTPUT
@ -292,23 +302,24 @@ class BufferedRef:
if buffer_type is BufferType.INPUT
else SemaphoreType.DMA((2,))
),
swap=SMEM((1,), jnp.bool) if needs_swap_ref else None,
)
@classmethod
def input(cls, spec, dtype):
return cls.create(spec, dtype, BufferType.INPUT)
def input(cls, spec, dtype, needs_swap_ref=True):
return cls.create(spec, dtype, BufferType.INPUT, needs_swap_ref)
@classmethod
def output(cls, spec, dtype):
return cls.create(spec, dtype, BufferType.OUTPUT)
def output(cls, spec, dtype, needs_swap_ref=True):
return cls.create(spec, dtype, BufferType.OUTPUT, needs_swap_ref)
@classmethod
def accumulator(cls, spec, dtype):
return cls.create(spec, dtype, BufferType.ACCUMULATOR)
def accumulator(cls, spec, dtype, needs_swap_ref=True):
return cls.create(spec, dtype, BufferType.ACCUMULATOR, needs_swap_ref)
@classmethod
def input_output(cls, spec, dtype):
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT)
def input_output(cls, spec, dtype, needs_swap_ref=True):
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, needs_swap_ref)
@property
def block_shape(self):
@ -329,7 +340,7 @@ class BufferedRef:
if self.memory_space == VMEM:
return self.window_ref.at[buffer_slice]
else:
return self.window_ref.at[(self.current_slot[0], *buffer_slice)]
return self.window_ref.at[(self.current_slot_index, *buffer_slice)]
@property
def is_input(self):
@ -355,6 +366,14 @@ class BufferedRef:
def is_input_output(self):
return self.buffer_type == BufferType.INPUT_OUTPUT
@property
def current_slot_index(self):
return self.current_slot[0]
@property
def next_slot_index(self):
return lax.rem(self.current_slot_index + 1, 2)
def bind_existing_ref(self, window_ref, indices):
"""For handling VMEM references, the pipeline aliases the existing ref."""
if self.memory_space == VMEM:
@ -373,12 +392,15 @@ class BufferedRef:
"""Initialize slot indices."""
if self.memory_space == VMEM: return
self.current_slot[0] = 0
self.next_slot[0] = 0
if self.swap is not None:
self.swap[0] = False
def swap_slots(self):
"""Switch to the next slot."""
if self.memory_space == VMEM: return
self.current_slot[0] = self.next_slot[0]
self.current_slot[0] = self.next_slot_index
if self.swap is not None:
self.swap[0] = False
def get_dma_slice(self, src_shape, src_dtype, grid_indices):
# We need to handle blocks that might go OOB in the src array. An in bounds
@ -441,8 +463,9 @@ class BufferedRef:
"""Starts copy of HBM dma slice into the current slot."""
assert self.is_input
if self.memory_space == VMEM: return
next_slot = lax.rem(self.current_slot[0] + 1, 2)
self.next_slot[0] = next_slot
if self.swap is not None:
self.swap[0] = True
next_slot = self.next_slot_index
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
tpu_primitives.make_async_copy(
@ -455,8 +478,9 @@ class BufferedRef:
"""Starts copy of HBM dma slice from the current slot."""
assert self.is_output
if self.memory_space == VMEM: return
slot = self.current_slot[0]
self.next_slot[0] = lax.rem(slot + 1, 2)
if self.swap is not None:
self.swap[0] = True
slot = self.current_slot_index
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
@ -471,7 +495,7 @@ class BufferedRef:
if self.memory_space == VMEM: return
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
current_slot = self.current_slot[0]
current_slot = self.current_slot_index
tpu_primitives.make_async_copy(
src_ref.at[src_slice], # nb: doesn't matter
self.window_ref.at[current_slot].at[
@ -484,7 +508,8 @@ class BufferedRef:
"""Waits for output copy to finish."""
assert self.is_output
if self.memory_space == VMEM: return
prev_slot = lax.rem(self.current_slot[0] + 1, 2)
# In a double buffer, previous slot is the same as next slot.
prev_slot = self.next_slot_index
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
@ -671,10 +696,7 @@ class Scheduler:
def _start():
if buffered_ref.is_input:
buffered_ref.copy_in(src_ref, self.indices)
# In the prologue this makes it so we wait on the prologue copy to finish.
# In other iterations this is the regular swap.
buffered_ref.swap_slots()
buffered_ref.swap_slots()
def wait_in(self, buffered_ref, src_ref, schedule=None):
if schedule is None:
@ -780,9 +802,32 @@ class Scheduler:
@self._named_scope("ep_finalize")
def _end():
if buffered_ref.is_output:
buffered_ref.swap_slots() # formally correct, not actually necessary.
buffered_ref.wait_out(dst_ref, self.indices)
def swap_slots(self, buffered_ref, hbm_ref, schedule=None):
if buffered_ref.swap is not None:
swap = buffered_ref.swap[0]
else:
# If we are not using an SMEM `swap` tensor to keep track of
# swaps needed, then all the copies into and out of BufferedRefs
# are done by direct calls to the `copy_in` and `copy_out`
# methods in the pipeline loop. To determine if the BufferedRef
# needs a swap of slots, we recalculate the copy-in/copy-out
# conditions.
if schedule is None:
schedule = _default_schedule
pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref)
pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref)
copied_in = pred_in & buffered_ref.is_input & ~self.last_step
copied_out = pred_out & buffered_ref.is_output
swap = copied_in | copied_out
@pl.when(swap)
@self._named_scope("ep_swap")
def _swap():
buffered_ref.swap_slots()
# END SCHEDULE --------------------------------------------------------------
@ -875,6 +920,7 @@ def make_pipeline_allocations(
in_specs=None,
out_specs=None,
should_accumulate_out=False,
needs_swap_ref=True,
):
"""Create BufferedRefs for the pipeline.
@ -887,6 +933,7 @@ def make_pipeline_allocations(
out_specs: output pallas block specs
should_accumulate_out: booleans to indicate which outputs should be treated
as accumulators.
needs_swap_ref: whether a swap slots tracker needs to be allocated.
Returns:
A list of BufferedRefs, one corresponding to each ref specified in the
@ -905,12 +952,12 @@ def make_pipeline_allocations(
in_refs = refs[:num_in_specs]
out_refs = refs[num_in_specs:]
def make_input_bref(in_spec, in_ref):
return BufferedRef.input(in_spec, in_ref.dtype)
return BufferedRef.input(in_spec, in_ref.dtype, needs_swap_ref)
in_brefs = jax.tree.map(make_input_bref, in_specs, in_refs)
def make_output_bref(out_spec, out_ref, accumulate):
if accumulate:
return BufferedRef.accumulator(out_spec, out_ref.dtype)
return BufferedRef.output(out_spec, out_ref.dtype)
return BufferedRef.accumulator(out_spec, out_ref.dtype, needs_swap_ref)
return BufferedRef.output(out_spec, out_ref.dtype, needs_swap_ref)
out_brefs = jax.tree.map(
make_output_bref, out_specs, out_refs, should_accumulate_out)
return (*in_brefs, *out_brefs)
@ -1109,6 +1156,14 @@ def emit_pipeline(
scratches = ()
if allocations is None:
# run with inline scoped allocations
# Prefetch and postyeet are arbitrary functions that can copy
# into or out of any of the BufferedRefs. Thus, we need a ref
# for the scheduler to mark when the prefetch or postyeet
# functions perform a copy and the slots need to be
# swapped. Without prefetch and postyeet, the swapping logic can
# be performed without the need for state.
needs_swap_ref = prefetch is not None or postyeet is not None
return primitives.run_scoped(
lambda allocations: pipeline(
*refs,
@ -1125,7 +1180,9 @@ def emit_pipeline(
*refs,
in_specs=in_specs,
out_specs=out_specs,
should_accumulate_out=should_accumulate_out),
should_accumulate_out=should_accumulate_out,
needs_swap_ref=needs_swap_ref,
),
)
if isinstance(allocations, list):
allocations = tuple(allocations)
@ -1184,6 +1241,8 @@ def emit_pipeline(
lax.cond(step == 0,
lambda: postyeet(*brefs, scheduler),
lambda: None)
map_brefs(scheduler.swap_slots, brefs, refs, schedule)
map_brefs(scheduler.finalize, brefs, refs, schedule)
return _next_index(indices, grid)