mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas TPU] Remove next_slot
SMEM tensor from pipeline emitter
PiperOrigin-RevId: 735564365
This commit is contained in:
parent
988a1208a9
commit
02505fa757
@ -207,6 +207,8 @@ class BufferedRef:
|
||||
is_accumulator: whether this BufferedRef is an accumulator.
|
||||
is_input_output: whether this BufferedRef is an input/output without
|
||||
automatic accumulation.
|
||||
swap: Tracks whether the BufferedRef slots need to be swapped before next
|
||||
copy.
|
||||
"""
|
||||
spec: pl.BlockSpec # static metadata
|
||||
dtype: Any # static metadata
|
||||
@ -214,9 +216,14 @@ class BufferedRef:
|
||||
window_ref: REF | None
|
||||
accum_ref: REF | None
|
||||
current_slot: ArrayRef | None
|
||||
# TODO(ramiroleal): Unused by class. Remove argument from
|
||||
# BufferedRef instantiations.
|
||||
next_slot: ArrayRef | None
|
||||
sem_recvs: SemaphoreTuple | None
|
||||
sem_sends: SemaphoreTuple | None
|
||||
# TODO(ramiroleal): Improve prefetch/postyeet interface to avoid
|
||||
# using this ref.
|
||||
swap: ArrayRef | None
|
||||
|
||||
def tree_flatten(self):
|
||||
return (
|
||||
@ -227,6 +234,7 @@ class BufferedRef:
|
||||
self.next_slot,
|
||||
self.sem_recvs,
|
||||
self.sem_sends,
|
||||
self.swap,
|
||||
),
|
||||
(self.spec, self.dtype, self.buffer_type),
|
||||
)
|
||||
@ -240,7 +248,7 @@ class BufferedRef:
|
||||
return BufferType
|
||||
|
||||
@classmethod
|
||||
def create(cls, spec, dtype, buffer_type) -> BufferedRef:
|
||||
def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef:
|
||||
"""Create a BufferedRef.
|
||||
|
||||
Args:
|
||||
@ -248,6 +256,7 @@ class BufferedRef:
|
||||
dtype: dtype for buffers.
|
||||
buffer_type: enum indicating whether this is an input, output, or in/out
|
||||
accumulator buffered reference.
|
||||
needs_swap_ref: whether a swap slots tracker needs to be allocated.
|
||||
|
||||
Returns:
|
||||
Initialized BufferedRef
|
||||
@ -271,6 +280,7 @@ class BufferedRef:
|
||||
next_slot=None,
|
||||
sem_recvs=None,
|
||||
sem_sends=None,
|
||||
swap=None,
|
||||
)
|
||||
else:
|
||||
memory_space = SMEM if spec.memory_space == SMEM else VMEM
|
||||
@ -281,7 +291,7 @@ class BufferedRef:
|
||||
window_ref=memory_space((2,) + block_shape, dtype),
|
||||
accum_ref=accum_ref,
|
||||
current_slot=SMEM((1,), jnp.int32),
|
||||
next_slot=SMEM((1,), jnp.int32),
|
||||
next_slot=None,
|
||||
sem_recvs=(
|
||||
None
|
||||
if buffer_type is BufferType.OUTPUT
|
||||
@ -292,23 +302,24 @@ class BufferedRef:
|
||||
if buffer_type is BufferType.INPUT
|
||||
else SemaphoreType.DMA((2,))
|
||||
),
|
||||
swap=SMEM((1,), jnp.bool) if needs_swap_ref else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def input(cls, spec, dtype):
|
||||
return cls.create(spec, dtype, BufferType.INPUT)
|
||||
def input(cls, spec, dtype, needs_swap_ref=True):
|
||||
return cls.create(spec, dtype, BufferType.INPUT, needs_swap_ref)
|
||||
|
||||
@classmethod
|
||||
def output(cls, spec, dtype):
|
||||
return cls.create(spec, dtype, BufferType.OUTPUT)
|
||||
def output(cls, spec, dtype, needs_swap_ref=True):
|
||||
return cls.create(spec, dtype, BufferType.OUTPUT, needs_swap_ref)
|
||||
|
||||
@classmethod
|
||||
def accumulator(cls, spec, dtype):
|
||||
return cls.create(spec, dtype, BufferType.ACCUMULATOR)
|
||||
def accumulator(cls, spec, dtype, needs_swap_ref=True):
|
||||
return cls.create(spec, dtype, BufferType.ACCUMULATOR, needs_swap_ref)
|
||||
|
||||
@classmethod
|
||||
def input_output(cls, spec, dtype):
|
||||
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT)
|
||||
def input_output(cls, spec, dtype, needs_swap_ref=True):
|
||||
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, needs_swap_ref)
|
||||
|
||||
@property
|
||||
def block_shape(self):
|
||||
@ -329,7 +340,7 @@ class BufferedRef:
|
||||
if self.memory_space == VMEM:
|
||||
return self.window_ref.at[buffer_slice]
|
||||
else:
|
||||
return self.window_ref.at[(self.current_slot[0], *buffer_slice)]
|
||||
return self.window_ref.at[(self.current_slot_index, *buffer_slice)]
|
||||
|
||||
@property
|
||||
def is_input(self):
|
||||
@ -355,6 +366,14 @@ class BufferedRef:
|
||||
def is_input_output(self):
|
||||
return self.buffer_type == BufferType.INPUT_OUTPUT
|
||||
|
||||
@property
|
||||
def current_slot_index(self):
|
||||
return self.current_slot[0]
|
||||
|
||||
@property
|
||||
def next_slot_index(self):
|
||||
return lax.rem(self.current_slot_index + 1, 2)
|
||||
|
||||
def bind_existing_ref(self, window_ref, indices):
|
||||
"""For handling VMEM references, the pipeline aliases the existing ref."""
|
||||
if self.memory_space == VMEM:
|
||||
@ -373,12 +392,15 @@ class BufferedRef:
|
||||
"""Initialize slot indices."""
|
||||
if self.memory_space == VMEM: return
|
||||
self.current_slot[0] = 0
|
||||
self.next_slot[0] = 0
|
||||
if self.swap is not None:
|
||||
self.swap[0] = False
|
||||
|
||||
def swap_slots(self):
|
||||
"""Switch to the next slot."""
|
||||
if self.memory_space == VMEM: return
|
||||
self.current_slot[0] = self.next_slot[0]
|
||||
self.current_slot[0] = self.next_slot_index
|
||||
if self.swap is not None:
|
||||
self.swap[0] = False
|
||||
|
||||
def get_dma_slice(self, src_shape, src_dtype, grid_indices):
|
||||
# We need to handle blocks that might go OOB in the src array. An in bounds
|
||||
@ -441,8 +463,9 @@ class BufferedRef:
|
||||
"""Starts copy of HBM dma slice into the current slot."""
|
||||
assert self.is_input
|
||||
if self.memory_space == VMEM: return
|
||||
next_slot = lax.rem(self.current_slot[0] + 1, 2)
|
||||
self.next_slot[0] = next_slot
|
||||
if self.swap is not None:
|
||||
self.swap[0] = True
|
||||
next_slot = self.next_slot_index
|
||||
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
|
||||
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
|
||||
tpu_primitives.make_async_copy(
|
||||
@ -455,8 +478,9 @@ class BufferedRef:
|
||||
"""Starts copy of HBM dma slice from the current slot."""
|
||||
assert self.is_output
|
||||
if self.memory_space == VMEM: return
|
||||
slot = self.current_slot[0]
|
||||
self.next_slot[0] = lax.rem(slot + 1, 2)
|
||||
if self.swap is not None:
|
||||
self.swap[0] = True
|
||||
slot = self.current_slot_index
|
||||
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
|
||||
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
|
||||
tpu_primitives.make_async_copy(
|
||||
@ -471,7 +495,7 @@ class BufferedRef:
|
||||
if self.memory_space == VMEM: return
|
||||
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
|
||||
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
|
||||
current_slot = self.current_slot[0]
|
||||
current_slot = self.current_slot_index
|
||||
tpu_primitives.make_async_copy(
|
||||
src_ref.at[src_slice], # nb: doesn't matter
|
||||
self.window_ref.at[current_slot].at[
|
||||
@ -484,7 +508,8 @@ class BufferedRef:
|
||||
"""Waits for output copy to finish."""
|
||||
assert self.is_output
|
||||
if self.memory_space == VMEM: return
|
||||
prev_slot = lax.rem(self.current_slot[0] + 1, 2)
|
||||
# In a double buffer, previous slot is the same as next slot.
|
||||
prev_slot = self.next_slot_index
|
||||
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
|
||||
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
|
||||
tpu_primitives.make_async_copy(
|
||||
@ -671,10 +696,7 @@ class Scheduler:
|
||||
def _start():
|
||||
if buffered_ref.is_input:
|
||||
buffered_ref.copy_in(src_ref, self.indices)
|
||||
|
||||
# In the prologue this makes it so we wait on the prologue copy to finish.
|
||||
# In other iterations this is the regular swap.
|
||||
buffered_ref.swap_slots()
|
||||
buffered_ref.swap_slots()
|
||||
|
||||
def wait_in(self, buffered_ref, src_ref, schedule=None):
|
||||
if schedule is None:
|
||||
@ -780,9 +802,32 @@ class Scheduler:
|
||||
@self._named_scope("ep_finalize")
|
||||
def _end():
|
||||
if buffered_ref.is_output:
|
||||
buffered_ref.swap_slots() # formally correct, not actually necessary.
|
||||
buffered_ref.wait_out(dst_ref, self.indices)
|
||||
|
||||
def swap_slots(self, buffered_ref, hbm_ref, schedule=None):
|
||||
if buffered_ref.swap is not None:
|
||||
swap = buffered_ref.swap[0]
|
||||
else:
|
||||
# If we are not using an SMEM `swap` tensor to keep track of
|
||||
# swaps needed, then all the copies into and out of BufferedRefs
|
||||
# are done by direct calls to the `copy_in` and `copy_out`
|
||||
# methods in the pipeline loop. To determine if the BufferedRef
|
||||
# needs a swap of slots, we recalculate the copy-in/copy-out
|
||||
# conditions.
|
||||
if schedule is None:
|
||||
schedule = _default_schedule
|
||||
pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref)
|
||||
pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref)
|
||||
|
||||
copied_in = pred_in & buffered_ref.is_input & ~self.last_step
|
||||
copied_out = pred_out & buffered_ref.is_output
|
||||
swap = copied_in | copied_out
|
||||
|
||||
@pl.when(swap)
|
||||
@self._named_scope("ep_swap")
|
||||
def _swap():
|
||||
buffered_ref.swap_slots()
|
||||
|
||||
# END SCHEDULE --------------------------------------------------------------
|
||||
|
||||
|
||||
@ -875,6 +920,7 @@ def make_pipeline_allocations(
|
||||
in_specs=None,
|
||||
out_specs=None,
|
||||
should_accumulate_out=False,
|
||||
needs_swap_ref=True,
|
||||
):
|
||||
"""Create BufferedRefs for the pipeline.
|
||||
|
||||
@ -887,6 +933,7 @@ def make_pipeline_allocations(
|
||||
out_specs: output pallas block specs
|
||||
should_accumulate_out: booleans to indicate which outputs should be treated
|
||||
as accumulators.
|
||||
needs_swap_ref: whether a swap slots tracker needs to be allocated.
|
||||
|
||||
Returns:
|
||||
A list of BufferedRefs, one corresponding to each ref specified in the
|
||||
@ -905,12 +952,12 @@ def make_pipeline_allocations(
|
||||
in_refs = refs[:num_in_specs]
|
||||
out_refs = refs[num_in_specs:]
|
||||
def make_input_bref(in_spec, in_ref):
|
||||
return BufferedRef.input(in_spec, in_ref.dtype)
|
||||
return BufferedRef.input(in_spec, in_ref.dtype, needs_swap_ref)
|
||||
in_brefs = jax.tree.map(make_input_bref, in_specs, in_refs)
|
||||
def make_output_bref(out_spec, out_ref, accumulate):
|
||||
if accumulate:
|
||||
return BufferedRef.accumulator(out_spec, out_ref.dtype)
|
||||
return BufferedRef.output(out_spec, out_ref.dtype)
|
||||
return BufferedRef.accumulator(out_spec, out_ref.dtype, needs_swap_ref)
|
||||
return BufferedRef.output(out_spec, out_ref.dtype, needs_swap_ref)
|
||||
out_brefs = jax.tree.map(
|
||||
make_output_bref, out_specs, out_refs, should_accumulate_out)
|
||||
return (*in_brefs, *out_brefs)
|
||||
@ -1109,6 +1156,14 @@ def emit_pipeline(
|
||||
scratches = ()
|
||||
if allocations is None:
|
||||
# run with inline scoped allocations
|
||||
|
||||
# Prefetch and postyeet are arbitrary functions that can copy
|
||||
# into or out of any of the BufferedRefs. Thus, we need a ref
|
||||
# for the scheduler to mark when the prefetch or postyeet
|
||||
# functions perform a copy and the slots need to be
|
||||
# swapped. Without prefetch and postyeet, the swapping logic can
|
||||
# be performed without the need for state.
|
||||
needs_swap_ref = prefetch is not None or postyeet is not None
|
||||
return primitives.run_scoped(
|
||||
lambda allocations: pipeline(
|
||||
*refs,
|
||||
@ -1125,7 +1180,9 @@ def emit_pipeline(
|
||||
*refs,
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
should_accumulate_out=should_accumulate_out),
|
||||
should_accumulate_out=should_accumulate_out,
|
||||
needs_swap_ref=needs_swap_ref,
|
||||
),
|
||||
)
|
||||
if isinstance(allocations, list):
|
||||
allocations = tuple(allocations)
|
||||
@ -1184,6 +1241,8 @@ def emit_pipeline(
|
||||
lax.cond(step == 0,
|
||||
lambda: postyeet(*brefs, scheduler),
|
||||
lambda: None)
|
||||
|
||||
map_brefs(scheduler.swap_slots, brefs, refs, schedule)
|
||||
map_brefs(scheduler.finalize, brefs, refs, schedule)
|
||||
|
||||
return _next_index(indices, grid)
|
||||
|
Loading…
x
Reference in New Issue
Block a user