rocm_jax/jax/_src/pallas/mosaic/pipeline.py
Peter Hawkins 7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00

1066 lines
38 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for emitting custom TPU pipelines within a Pallas call."""
from __future__ import annotations
from collections.abc import Sequence
import dataclasses
import enum
import functools
import itertools
import operator
from typing import Union, Any
import jax
from jax import lax
from jax import tree_util
from jax._src import util as jax_util
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import primitives as tpu_primitives
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
SMEM = tpu_core.TPUMemorySpace.SMEM
VMEM = tpu_core.TPUMemorySpace.VMEM
DMA = tpu_core.SemaphoreType.DMA
REF = tpu_core.MemoryRef
SemaphoreType = tpu_core.SemaphoreType
ArrayRef = Union[REF, jax.Array]
GridIndices = tuple[jax.Array, ...]
CondVal = Union[jax.Array, bool]
PipelineBlockSpecs = Union[Sequence[pallas_core.BlockSpec], Any]
PipelineRefs = Union[Sequence[REF], Any]
# TODO(sharadmv): make this a parameter and make it queryable from the Device.
_TILING = (8, 128)
def _broadcast_pytree_to(from_pytree, to_pytree):
"""Broadcast a prefix pytree to a given full tree."""
proxy = object()
treedef = tree_util.tree_structure(to_pytree)
broadcast_leaves = []
def add_leaves(i, x):
broadcast_leaves.extend(
[i] * tree_util.tree_structure(x).num_leaves)
try:
tree_util.tree_map(add_leaves, from_pytree, to_pytree,
is_leaf=lambda x: x is None)
except ValueError:
raise ValueError(f"Cannot broadcast tree {from_pytree} "
f"to full tree structure {treedef}.") from None
broadcast_leaves = [None if a is proxy else a for a in broadcast_leaves]
assert len(broadcast_leaves) == treedef.num_leaves
return tree_util.tree_unflatten(treedef, broadcast_leaves)
def _get_tpu_generation() -> int:
kind = jax.devices()[0].device_kind
if kind.endswith(' lite'):
kind = kind[:-len(' lite')]
assert kind[:5] == "TPU v", kind
return int(kind[5])
def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]:
# For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
# and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
# (2, 3, 128, 128) -> (1, 1, 8, 128).
if len(shape) < 2:
raise ValueError(f"Shape must have at least 2 dimensions: {shape=}")
leading_dims, final_dims = shape[:-2], shape[-2:]
# We want to find the minimum power of 2 that fits the second-minor dimension
# of shape, with maximum value 8.
second_minor, _ = final_dims
packing = 4 // dtype.itemsize
max_tiling = _TILING[0]
second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing
while second_minor_tiling < min(second_minor, max_tiling):
second_minor_tiling *= 2
return (*(1,) * len(leading_dims), second_minor_tiling, _TILING[1])
def _mod(a, n):
""""Calculates a mod n for positive and negative a with |a| <= n."""
return lax.rem(a + n, n)
def _round_up_to_nearest_multiple(s: int, multiple: int) -> int:
if s % multiple == 0:
return s
# Subtract off the remainder, then add multiple
return s - s % multiple + multiple
def _make_ds(
idx: jax.Array | int, size: jax.Array | int
) -> pl.Slice:
"""Make a DMA slice with mosaic size hints."""
out = pl.ds(idx * size, size)
assert isinstance(out, pl.Slice)
return out
def _make_block_slice(
block_index: jax.Array, block_size: int, size: int, tiling: int
) -> pl.Slice | slice:
# Computes a slice given a block index and block size. In the default case,
# we return slice(block_index * block_size, (block_index + 1) * block_size).
# However, if the total size of the ref does not divide block size and we are
# selecting the last block, we need to pick the lowest tiling size multiple
# that contains the block.
if size % block_size == 0:
return _make_ds(block_index, block_size)
if block_size % tiling != 0:
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
num_blocks = pl.cdiv(size, block_size)
is_last = block_index == num_blocks - 1
rounded_size = jnp.where(
is_last,
_round_up_to_nearest_multiple(size % block_size, tiling),
block_size,
)
rounded_size = pl.multiple_of(rounded_size, tiling)
return pl.ds(block_index * block_size, rounded_size)
def _tuples_differ(xs, ys):
"""Dynamic index-tuple comparison calculation."""
differences = jax.tree.map(lambda x, y: x != y, xs, ys)
return functools.reduce(lambda x, y: x | y, differences, False)
def _grid_size(grid):
"""Dynamic grid size calculation."""
size = jnp.array(1, jnp.int32)
for dim in grid:
size *= dim
return size
def _get_indices(step, grid, offsets):
"""Get indices for a given step and grid."""
extended_grid = grid + (1,)
strides = tuple(
itertools.accumulate(extended_grid[::-1], func=operator.mul))[::-1]
indices = tuple(
lax.div(lax.rem(step, a), b)
for a, b in zip(strides[:-1], strides[1:])
)
return tuple(a + b for a, b in zip(indices, offsets, strict=True))
class BufferType(enum.Enum):
"""Buffer type for the arguments to an emitted pipeline."""
INPUT = 1
OUTPUT = 2
ACCUMULATOR = 3
@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class BufferedRef:
"""A helper class to automate VMEM double buffering in pallas pipelines.
Attributes:
spec: pallas blockspec.
dtype: dtype for buffers.
buffer_type: enum indicating whether this is an input, output, or in/out
accumulator buffered reference.
vmem_ref: a double-buffer to hold a working buffer and a dirty buffer used
to copy into and out of. In the case of a BufferedRef targeting a VMEM
reference, this simply points to the existing ref.
accum_ref: accumulating buffer used by accumulator BufferedRefs.
current_slot: current slot index to the working buffer.
next_slot: slot that will point to the working buffer in the next iteration.
sem_recv: semaphore for input DMAs.
sem_send: semaphore for output DMAs.
block_shape: passthrough property for the BlockSpec's block_shape.
compute_index: passthrough property for the BlockSpec's compute_index.
memory_space: passthrough property for the BlockSpec's memory_space.
current_ref: points to the current working slice of the double-buffer.
is_input: whether this BufferedRef acts as a pipeline input.
is_output: whether this BufferedRef acts as a pipeline output.
is_accumulator: whether this BufferedRef is an accumulator.
"""
spec: pl.BlockSpec # static metadata
dtype: Any # static metadata
buffer_type: BufferType # static metadata
vmem_ref: REF | None
accum_ref: REF | None
current_slot: ArrayRef | None
next_slot: ArrayRef | None
sem_recv: SemaphoreType | None
sem_send: SemaphoreType | None
def tree_flatten(self):
return ((self.vmem_ref, self.accum_ref, self.current_slot,
self.next_slot, self.sem_recv, self.sem_send),
(self.spec, self.dtype, self.buffer_type))
@classmethod
def tree_unflatten(cls, meta, data):
return cls(*meta, *data)
@classmethod
def create(cls, spec, dtype, buffer_type) -> BufferedRef:
"""Create a BufferedRef.
Args:
spec: pallas blockspec.
dtype: dtype for buffers.
buffer_type: enum indicating whether this is an input, output, or in/out
accumulator buffered reference.
Returns:
Initialized BufferedRef
"""
block_shape = tuple([1 if x is None else x for x in spec.block_shape])
if spec.memory_space == VMEM:
# We don't need to do any double-buffering in the case that our pipeline
# reference is already in VMEM, we just need allocate the accumulation
# buffer and we will refer to the original reference slices directly.
return cls(
spec=spec, dtype=dtype,
buffer_type=buffer_type,
vmem_ref=None, # to be bound to existing ref by the pipeline routine
accum_ref=(VMEM(block_shape, dtype)
if buffer_type is BufferType.ACCUMULATOR else None),
current_slot=None, next_slot=None, sem_recv=None, sem_send=None)
else:
return cls(
spec=spec, dtype=dtype,
buffer_type=buffer_type,
vmem_ref=VMEM((2,) + block_shape, dtype),
accum_ref=(VMEM(block_shape, dtype)
if buffer_type is BufferType.ACCUMULATOR else None),
current_slot=SMEM((1,), jnp.int32),
next_slot=SMEM((1,), jnp.int32),
sem_recv=(None if buffer_type is BufferType.OUTPUT
else SemaphoreType.DMA),
sem_send=(None if buffer_type is BufferType.INPUT
else SemaphoreType.DMA),)
@classmethod
def input(cls, spec, dtype):
return cls.create(spec, dtype, BufferType.INPUT)
@classmethod
def output(cls, spec, dtype):
return cls.create(spec, dtype, BufferType.OUTPUT)
@classmethod
def accumulator(cls, spec, dtype):
return cls.create(spec, dtype, BufferType.ACCUMULATOR)
@property
def block_shape(self):
return self.spec.block_shape
@property
def compute_index(self):
return self.spec.compute_index
@property
def memory_space(self):
return self.spec.memory_space
@property
def current_ref(self):
buffer_slice = tuple(
[0 if x is None else slice(None) for x in self.block_shape])
if self.memory_space == VMEM:
return self.vmem_ref.at[buffer_slice]
else:
return self.vmem_ref.at[(self.current_slot[0], *buffer_slice)]
@property
def is_input(self):
return self.buffer_type in [BufferType.INPUT, BufferType.ACCUMULATOR]
@property
def is_output(self):
return self.buffer_type in [BufferType.OUTPUT, BufferType.ACCUMULATOR]
@property
def is_accumulator(self):
return self.buffer_type == BufferType.ACCUMULATOR
def bind_existing_ref(self, vmem_ref, indices):
"""For handling VMEM references, the pipeline aliases the existing ref."""
if self.memory_space == VMEM:
return dataclasses.replace(
self, vmem_ref=vmem_ref.at[self.compute_slice(indices)])
return self
def compute_slice(self, grid_indices):
"""Compute DMA slice from grid indices."""
block_shape = tuple([1 if x is None else x for x in self.block_shape])
indices = self.compute_index(*grid_indices)
return jax.tree.map(_make_ds, indices, block_shape)
def init_slots(self):
"""Initialize slot indices."""
if self.memory_space == VMEM: return
self.current_slot[0] = 0
self.next_slot[0] = 0
def swap_slots(self):
"""Switch to the next slot."""
if self.memory_space == VMEM: return
self.current_slot[0] = self.next_slot[0]
def get_dma_slice(self, src_shape, src_dtype, grid_indices):
# We need to handle blocks that might go OOB in the src array. An in bounds
# block looks like this (for array shape (600, 600) and block shape
# (256, 256)):
#
# +--------------+------------------|
# | Block (0,0) | |
# | (256, 256) | |
# +--------------+ |
# | A (600, 600) |
# | |
# +---------------------------------+
#
# For in-bounds blocks, we don't need to do anything special.
# An out-of-bounds block looks like this:
#
# +--------------+------------------|
# | |
# | |
# + |
# | A (600, 600) |
# +--------------+ |
# | Block (2,0) | |
# + --------------------------------|
# | XXXXXXXXXX |
# +--------------+
# where the X's indicate where the block is out of bounds.
#
# When we have an out of bounds block like this, we need to truncate it to
# a tile boundary (tiles are (8, 128) along the two minormost dimensions).
# In this case, we'll have a block that is indexing the
# 512:768 elements of A along the first dimension. We need to convert 768
# into 600 (600 % 8 == 0), so our indexing will look like this:
# +--------------+------------------|
# | |
# | |
# + |
# | A (600, 600) |
# +--------------+ |
# | Block (2,0) | |
# + --------------------------------|
# where it is now a (88, 256) sized block.
#
# Suppose A is now (601, 600), instead of picking a (88, 256)-sized block
# for the last iteration on that dimension, we will pick the next highest
# tile multiple, i.e. (96, 256).
if len(src_shape) < 2:
raise NotImplementedError("Must use >1D values.")
tiling = _make_tiling(src_shape, src_dtype)
block_shape = tuple(1 if b is None else b for b in self.block_shape)
block_indices = self.compute_index(*grid_indices)
return jax.tree.map(
_make_block_slice, block_indices, block_shape, src_shape, tiling
)
def copy_in(self, src_ref, grid_indices):
"""Starts copy of HBM dma slice into the current slot."""
assert self.is_input
if self.memory_space == VMEM: return
next_slot = lax.rem(self.current_slot[0] + 1, 2)
self.next_slot[0] = next_slot
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
tpu_primitives.make_async_copy(
src_ref.at[src_slice],
self.vmem_ref.at[next_slot].at[dst_slice],
self.sem_recv).start()
def copy_out(self, dst_ref, grid_indices):
"""Starts copy of HBM dma slice from the current slot."""
assert self.is_output
if self.memory_space == VMEM: return
slot = self.current_slot[0]
self.next_slot[0] = lax.rem(slot + 1, 2)
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
self.vmem_ref.at[slot].at[src_slice],
dst_ref.at[dst_slice],
self.sem_send).start()
def wait_in(self, src_ref, grid_indices):
"""Waits for input copy to finish."""
assert self.is_input
if self.memory_space == VMEM: return
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
tpu_primitives.make_async_copy(
src_ref.at[src_slice], # nb: doesn't matter
self.vmem_ref.at[self.current_slot[0]].at[dst_slice], # only dst shape is important
self.sem_recv).wait()
def wait_out(self, dst_ref, grid_indices):
"""Waits for output copy to finish."""
assert self.is_output
if self.memory_space == VMEM: return
prev_slot = lax.rem(self.current_slot[0] + 1, 2)
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
tpu_primitives.make_async_copy(
self.vmem_ref.at[prev_slot].at[src_slice], # nb: doesn't matter
dst_ref.at[dst_slice], # only dst shape is important
self.sem_send).wait()
# Accumulator methods
#
# Accumulating inline in VMEM saves half the HBM<->VMEM bandwidth cost of
# doing another full loop around HBM to do a reduction, at the current cost
# of allocating another VMEM buffer.
#
# NB: there's no actual need to have an additional accumulation buffer, if
# we just rewrote inner kernels to handle the initial-zero-init and output
# reduction, we don't need to waste VMEM. Consider removing this magic
# init and reduce support.
def set_accumulator(self, init=False):
"""Set accumulator or zero it out to initialize."""
assert self.is_accumulator
if self.accum_ref is not None:
def _init():
self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...])
def _set():
self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref)
lax.cond(init, _init, _set)
def accumulate(self):
"""Add into the current slot."""
assert self.is_accumulator
if self.accum_ref is not None:
accum_dtype = jnp.float32
if self.vmem_ref.dtype == jnp.int32:
accum_dtype = jnp.int32
# TODO(levskaya): we could generalize init and reduction functions,
# could it ever be useful to support more generic monoids?
self.current_ref[...] = (
self.current_ref[...].astype(accum_dtype) +
self.accum_ref[...].astype(accum_dtype)
).astype(self.vmem_ref.dtype)
# Helper to tree map over BufferedRefs as leaves.
map_brefs = functools.partial(
jax.tree.map,
is_leaf=lambda x: isinstance(x, BufferedRef))
class Scheduler:
"""Sequences input and output copies and waits for a pipeline."""
def __init__(self,
step: jax.Array,
grid: tuple[int | jax.Array, ...],
grid_offsets: tuple[int | jax.Array, ...],
first_cycle=None,
last_cycle=None,
init_accumulators=None,
):
"""Initializes scheduler.
Args:
step: inner step number.
grid: pallas grid for BufferedRefs.
grid_offsets: offsets for grid indices (used for megacore).
first_cycle: whether this is the first invocation of the pipeline.
last_cycle: whether this is the last invocation of the pipeline.
init_accumulators: do we zero-initialize accumulator state for this
invocation of the pipeline.
"""
self.step = step
self.grid = grid
self.first_cycle = first_cycle
self.last_cycle = last_cycle
self.init_accumulators = init_accumulators
# Total number of linear steps.
self.num_steps = _grid_size(grid)
# First and last inner step conditionals.
self.first_step = step == 0
self.last_step = step == self.num_steps - 1
# First and last total step conditionals.
self.first_step_ever = first_cycle & self.first_step
self.last_step_ever = last_cycle & self.last_step
# Cyclic steps
self.prev_step = _mod(step - 1, self.num_steps)
self.next_step = _mod(step + 1, self.num_steps)
# Derived grid indices for present, previous, and next steps.
self.indices = _get_indices(step, grid, grid_offsets)
self.prev_indices = _get_indices(
self.prev_step, grid, grid_offsets
)
self.next_indices = _get_indices(
self.next_step, grid, grid_offsets
)
def grid_env(self):
return pallas_core.grid_env(
list(map(pallas_core.GridAxis, self.indices, self.grid)))
def has_changed(self, buffered_ref):
indices = buffered_ref.compute_index(*self.indices)
prev_indices = buffered_ref.compute_index(*self.prev_indices)
return _tuples_differ(indices, prev_indices)
def will_change(self, buffered_ref):
indices = buffered_ref.compute_index(*self.indices)
next_indices = buffered_ref.compute_index(*self.next_indices)
return _tuples_differ(indices, next_indices)
def alias_local_refs(self, buffered_ref, ref):
return buffered_ref.bind_existing_ref(ref, self.indices)
# SCHEDULE ----------------------------------------------------------------
# Below is the sequence of conditional waits and copies used for inputs,
# outputs, and in-out accumulators.
def initialize(self, buffered_ref, src_ref, schedule=None):
pred = self.first_step_ever
if schedule is not None:
pred = schedule['prologue_copy_in'](self, buffered_ref, src_ref)
with jax.named_scope("ep_initialize"):
@pl.when(self.first_step_ever)
def _init_slots():
buffered_ref.init_slots()
@pl.when(pred)
def _start():
if buffered_ref.is_input:
buffered_ref.copy_in(src_ref, self.indices)
buffered_ref.swap_slots()
def wait_in(self, buffered_ref, src_ref, schedule=None):
pred = self.has_changed(buffered_ref) | self.first_step
if schedule is not None:
pred = schedule['wait_in'](self, buffered_ref, src_ref)
@jax.named_scope("ep_wait_in")
def _wait():
if buffered_ref.is_input:
buffered_ref.wait_in(src_ref, self.indices)
if buffered_ref.is_accumulator:
buffered_ref.set_accumulator(self.init_accumulators)
@jax.named_scope("ep_set_accum")
def _no_wait():
if buffered_ref.is_accumulator:
@pl.when(self.first_step)
def _set_accumulator():
buffered_ref.set_accumulator(self.init_accumulators)
lax.cond(pred, _wait, _no_wait)
def copy_in(self, buffered_ref, src_ref, schedule=None):
pred = self.will_change(buffered_ref) & ~self.last_step_ever
if schedule is not None:
pred = schedule['copy_in'](self, buffered_ref, src_ref)
@pl.when(pred)
@jax.named_scope("ep_copy_in")
def _send():
if buffered_ref.is_input:
@pl.when(~self.last_step)
def _copy_in():
buffered_ref.copy_in(src_ref, self.next_indices)
# --> Call prefetch here to grab the first inputs of next cycle.
# convenience method for prefetch callbacks.
def prefetch(self, buffered_ref, src_ref, schedule=None):
pred = ((self.will_change(buffered_ref) | self.last_step) &
~self.last_step_ever)
if schedule is not None:
pred = schedule['prefetch'](self, buffered_ref, src_ref)
@pl.when(pred)
@jax.named_scope("ep_prefetch")
def _send():
if buffered_ref.is_input:
@pl.when(self.last_step)
def _prefetch_in():
buffered_ref.copy_in(src_ref, self.next_indices)
def wait_out(self, buffered_ref, dst_ref, schedule=None):
pred = ((self.has_changed(buffered_ref) | self.first_step) &
~self.first_step_ever)
if schedule is not None:
pred = schedule['wait_out'](self, buffered_ref, dst_ref)
@pl.when(pred)
@jax.named_scope("ep_wait_out")
def _wait():
if buffered_ref.is_output:
buffered_ref.wait_out(dst_ref, self.prev_indices)
# --> Call "postyeet" here, after last output copy is finished from previous
# cycle
def copy_out(self, buffered_ref, dst_ref, schedule=None):
pred = self.will_change(buffered_ref) | self.last_step
if schedule is not None:
pred = schedule['copy_out'](self, buffered_ref, dst_ref)
@jax.named_scope("ep_copy_out")
def _copy_out_and_accumulate():
if buffered_ref.is_accumulator:
buffered_ref.accumulate()
if buffered_ref.is_output:
buffered_ref.copy_out(dst_ref, self.indices)
@jax.named_scope("ep_accum")
def _just_accumulate():
if buffered_ref.is_accumulator:
@pl.when(self.last_step)
def _accumulate():
buffered_ref.accumulate()
lax.cond(pred, _copy_out_and_accumulate, _just_accumulate)
def finalize(self, buffered_ref, dst_ref, schedule=None):
pred = self.last_step_ever
if schedule is not None:
pred = schedule['epilogue_wait_out'](self, buffered_ref, dst_ref)
@pl.when(pred)
@jax.named_scope("ep_finalize")
def _end():
if buffered_ref.is_output:
buffered_ref.swap_slots() # formally correct, not actually necessary.
buffered_ref.wait_out(dst_ref, self.indices)
# END SCHEDULE --------------------------------------------------------------
# Scheduling overrides.
# When trying to fuse across pipelines that use accumulator arguments, we
# sometimes need to mess with the default scheduling above to avoid data-races
# or to maximize performance. A schedule is simply a set of functions that
# calculate predicates for whether or not the pipeline input and output
# BufferedRefs should do copies and waits.
# Copy of the default pipeline schedule. The default schedule tacitly assumes
# that the source and target HBM Refs change with each cycle.
_default_schedule = dict(
prologue_copy_in=lambda s, bref, _: s.first_step_ever,
wait_in=lambda s, bref, _: s.has_changed(bref) | s.first_step,
copy_in=lambda s, bref, _: s.will_change(bref) & ~s.last_step_ever,
prefetch=lambda s, bref, _: (
(s.will_change(bref) | s.last_step) & ~s.last_step_ever),
wait_out=lambda s, bref, _: (
(s.has_changed(bref) | s.first_step) & ~s.first_step_ever),
copy_out=lambda s, bref, _: s.will_change(bref) | s.last_step,
epilogue_wait_out=lambda s, bref, _: s.last_step_ever,
)
# Alternative schedule needed for accumulators reading and writing to a fixed
# HBM reference to avoid HBM data races for trivially small grids: only
# read/write when tiles change or at the very beginning or end of a fused
# pipeline schedule.
_fixed_schedule = dict(
prologue_copy_in=lambda s, bref, _: s.first_step_ever,
wait_in=lambda s, bref, _: s.has_changed(bref) | s.first_step_ever,
copy_in=lambda s, bref, _: s.will_change(bref) & ~s.last_step_ever,
prefetch=lambda s, bref, _: s.will_change(bref) & ~s.last_step_ever,
wait_out=lambda s, bref, _: s.has_changed(bref) & ~s.first_step_ever,
copy_out=lambda s, bref, _: s.will_change(bref) | s.last_step_ever,
epilogue_wait_out=lambda s, bref, _: s.last_step_ever,
)
def get_pipeline_schedule(schedule) -> Any:
"""Retrieve a named pipeline schedule or pass through fully specified one."""
predefined_schedules = {
'default': _default_schedule,
'fixed': _fixed_schedule
}
if isinstance(schedule, str):
return predefined_schedules[schedule].copy()
return schedule
# Main pipeline methods
def make_pipeline_allocations(
*refs,
in_specs=None,
out_specs=None,
should_accumulate_out=False,
):
"""Create BufferedRefs for the pipeline.
This function creates buffered refs for an inner pipeline that can be
created at the top-level of a pallas call such that they may be reused across
multiple invocations of the inner pipeline.
Args:
in_specs: input pallas block specs
out_specs: output pallas block specs
should_accumulate_out: booleans to indicate which outputs should be treated
as accumulators.
Returns:
A list of BufferedRefs, one corresponding to each ref specified in the
in_specs and out_specs.
"""
# TODO(levskaya): generalize argument tree handling here and in emit_pipeline.
num_in_specs = len(in_specs)
if not isinstance(in_specs, (list, tuple)):
in_specs = (in_specs,)
if not isinstance(out_specs, (list, tuple)):
out_specs = (out_specs,)
if isinstance(in_specs, list):
in_specs = tuple(in_specs)
if isinstance(out_specs, list):
out_specs = tuple(out_specs)
in_refs = refs[:num_in_specs]
out_refs = refs[num_in_specs:]
def make_input_bref(in_spec, in_ref):
return BufferedRef.input(in_spec, in_ref.dtype)
in_brefs = jax.tree.map(make_input_bref, in_specs, in_refs)
def make_output_bref(out_spec, out_ref, accumulate):
if accumulate:
return BufferedRef.accumulator(out_spec, out_ref.dtype)
return BufferedRef.output(out_spec, out_ref.dtype)
out_brefs = jax.tree.map(
make_output_bref, out_specs, out_refs, should_accumulate_out)
return (*in_brefs, *out_brefs)
class GridDimensionSemantics:
pass
PARALLEL = GridDimensionSemantics()
ARBITRARY = GridDimensionSemantics()
def _partition_grid(
grid: tuple[int | jax.Array, ...],
core_axis: int | None,
dimension_semantics: tuple[GridDimensionSemantics, ...] | None,
) -> tuple[tuple[int | jax.Array, ...], tuple[int | jax.Array, ...]]:
if core_axis is None:
# We aren't partitioning the grid
return grid, (0,) * len(grid)
num_cores = pl.num_programs(core_axis)
# Check that num_cores is statically known
if not isinstance(num_cores, int):
raise NotImplementedError(
f"Cannot partition grid over dynamic number of cores: {core_axis=}"
)
if num_cores == 1:
# We aren't partitioning the grid
return grid, (0,) * len(grid)
# If dimension_semantics aren't provided, we assume it is all arbitrary.
if dimension_semantics is None:
dimension_semantics = (ARBITRARY,) * len(grid)
if len(dimension_semantics) != len(grid):
raise ValueError("dimension_semantics must be the same length as grid.")
parallel_dimensions = {i for i, d in enumerate(dimension_semantics)
if d == PARALLEL}
# If there are no parallel dimensions, we can't partition the grid
if not parallel_dimensions:
# TODO(sharadmv): enable running kernel on just one core
raise NotImplementedError(
"Cannot partition over cores without parallel grid dimensions:"
f" {dimension_semantics=}"
)
if all(not isinstance(grid[i], int) for i in parallel_dimensions):
raise NotImplementedError(
f"Cannot partition cores over only dynamic grid dimensions: {grid=}"
)
# Try to find a divisible dimension to partition the grid on
divisible_dimensions = {
i for i in parallel_dimensions
if isinstance(grid[i], int) and grid[i] % num_cores == 0
}
if divisible_dimensions:
first_divisible_dimension, *_ = (
i for i in range(len(dimension_semantics)) if i in divisible_dimensions
)
partitioned_dim_size = grid[first_divisible_dimension] // num_cores
partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size
new_grid = jax_util.tuple_update(
grid, first_divisible_dimension, partitioned_dim_size
)
offsets = jax_util.tuple_update(
(0,) * len(grid), first_divisible_dimension, partitioned_dim_offset
)
else:
# No divisible dimensions, so we can't evenly partition the grid. Let's pick
# the largest dimension and try to divide it as evenly as possible.
# TODO(sharadmv): take the product of many nondivisible dimensions to
# potentially divide it more evenly
largest_parallel_dimension = max(grid[i] for i in parallel_dimensions
if isinstance(grid[i], int)) # type: ignore
partition_dimension, *_ = (
i
for i, d in enumerate(grid)
if isinstance(d, int) and d == largest_parallel_dimension
)
base_num_iters, rem = divmod(grid[partition_dimension], num_cores)
assert rem > 0, rem
# We have some remainder iterations that we need to assign somewhere. We
# know that rem < num_cores, so we can assign one extra iteration to each
# core except for the last (num_cores - rem).
core_index = pl.program_id(core_axis)
num_iters = jnp.where(core_index < rem, base_num_iters + 1,
base_num_iters)
new_grid = jax_util.tuple_update(grid, partition_dimension, num_iters)
# Ordinarily, we would compute the offset as:
# grid_offset = pl.program_id(core_axis) * num_iters
# However, since we have some cores that don't have an extra iteration, we
# need to adjust the offset by `rem`.
grid_offset = jnp.where(
core_index < rem,
core_index * num_iters,
core_index * base_num_iters + rem,
)
offsets = jax_util.tuple_update(
(0,) * len(grid), partition_dimension, grid_offset
)
return new_grid, offsets
def emit_pipeline(
body,
*,
grid: tuple[int | jax.Array, ...],
in_specs=None,
out_specs=None,
should_accumulate_out=False,
core_axis: int | None = None,
dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None
):
"""Creates a function to emit a manual pallas pipeline.
This has the same semantics as pallas_call but is meant to be called inside
pallas_call for nesting grids. This is useful when you need to have separate
windowing strategies for communication and computation.
The new argument `should_accumulate_out` can be used to specify which outputs
we should accumulate into automatically within and across pipeline
invocations.
Args:
body: pallas kernel to set up pipeline for.
grid: a pallas grid definition.
in_specs: input pallas block specs
out_specs: output pallas block specs
should_accumulate_out: booleans to indicate which outputs should be treated
as accumulators.
core_axis: optional int, indicates whether or not to partition the grid
along the core axis.
dimension_semantics: optional tuple of GridDimensionSemantics (e.g. PARALLEL
or ARBITRARY).
"""
if any(not isinstance(d, (int, jax.Array)) for d in grid):
grid_types = tuple(type(d) for d in grid)
raise ValueError(
f"Grid must consist of Python integers and JAX Arrays: {grid_types}"
)
grid, grid_offsets = _partition_grid(grid, core_axis, dimension_semantics)
num_steps = _grid_size(grid)
if not isinstance(in_specs, (list, tuple)):
in_specs = (in_specs,)
if not isinstance(out_specs, (list, tuple)):
out_specs = (out_specs,)
if isinstance(in_specs, list):
in_specs = tuple(in_specs)
if isinstance(out_specs, list):
out_specs = tuple(out_specs)
should_accumulate_out = _broadcast_pytree_to(should_accumulate_out, out_specs)
def pipeline(
*refs: Any,
scratches=None,
allocations=None,
first_cycle: CondVal = True,
last_cycle: CondVal = True,
init_accumulators: CondVal = False,
prefetch=None,
postyeet=None,
schedule=None,
):
"""
Run the pipeline.
Args:
*ref_args: a list of pallas refs (or more generally a list of pytrees of
pallas refs)
scratches: scratch buffers for the inner kernel
allocations: a list of BufferedRefs, one corresponding to each ref
first_cycle: boolean indicating if this is the first invocation of the
inner pipeline cycle.
last_cycle: boolean indicating if this is the last invocation of the
inner pipeline cycle.
init_accumulators: whether to zero-init accumulators during this cycle.
prefetch: callback called as fn(*brefs, scheduler) that is used to fetch
the next cycle invocations first inputs. Called during the inputs phase
in the final inner step.
postyeet: callback called as fn(*brefs, scheduler) that is used to finish
any writes or transfers from the last output of the previous cycle.
Called during the outputs phase in the first inner step.
schedule: manually specified pipeline schedules for brefs, None indicates
default schedule.
"""
if scratches is None:
scratches = ()
if allocations is None:
# run with inline scoped allocations
return tpu_primitives.run_scoped(
lambda allocations: pipeline(
*refs,
scratches=scratches,
allocations=allocations,
first_cycle=first_cycle,
last_cycle=last_cycle,
init_accumulators=init_accumulators,
prefetch=prefetch,
postyeet=postyeet,
schedule=schedule,
),
make_pipeline_allocations(
*refs,
in_specs=in_specs,
out_specs=out_specs,
should_accumulate_out=should_accumulate_out),
)
if isinstance(allocations, list):
allocations = tuple(allocations)
# Normalize custom schedule arguments.
if schedule is None:
schedule = map_brefs(lambda x: None, allocations)
if not isinstance(schedule, (list, tuple)):
schedule = map_brefs(lambda x: schedule, allocations)
if isinstance(schedule, list):
schedule = tuple(schedule)
schedule = map_brefs(
lambda _, x: get_pipeline_schedule(x), allocations, schedule)
def loop_body(step, _):
nonlocal allocations
scheduler = Scheduler(
step,
grid,
grid_offsets=grid_offsets,
first_cycle=first_cycle,
last_cycle=last_cycle,
init_accumulators=init_accumulators)
# prepare any local VMEM aliases
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
# loop input handling phase
map_brefs(scheduler.initialize, brefs, refs, schedule)
map_brefs(scheduler.wait_in, brefs, refs, schedule)
map_brefs(scheduler.copy_in, brefs, refs, schedule)
# prefetch inputs for the *next* invocation of this pipeline
with jax.named_scope("ep_prefetch"):
if prefetch is not None:
lax.cond(step == num_steps - 1,
lambda: prefetch(*brefs, scheduler),
lambda: None)
# run the kernel!
current_refs = map_brefs(lambda x: x.current_ref, brefs)
with jax.named_scope("ep_run_kernel"):
with scheduler.grid_env():
body(*current_refs, *scratches)
# loop output handling phase
map_brefs(scheduler.wait_out, brefs, refs, schedule)
# handle writes for the *last* invocation of this pipeline's outputs
with jax.named_scope("ep_postyeet"):
if postyeet is not None:
lax.cond(step == 0,
lambda: postyeet(*brefs, scheduler),
lambda: None)
map_brefs(scheduler.copy_out, brefs, refs, schedule)
map_brefs(scheduler.finalize, brefs, refs, schedule)
return ()
# run pipeline
lax.fori_loop(0, num_steps, loop_body, ())
return pipeline
def emit_pipeline_with_allocations(
body,
*,
grid,
in_specs=None,
out_specs=None,
should_accumulate_out=False,
):
"""Creates pallas pipeline and top-level allocation preparation functions.
Args:
body: pallas kernel to set up pipeline for.
grid: a pallas grid definition.
in_specs: input pallas block specs
out_specs: output pallas block specs
should_accumulate_out: booleans to indicate which outputs should be treated
as accumulators.
Returns:
(emit_pipeline, make_allocations) function pair, where:
emit_pipeline is the pallas pipeline function.
make_allocations is a function to create buffered refs for the inner
pipeline that can be created at the top-level of a pallas call to be
reused across multiple invocations of the inner pipeline.
"""
make_allocations = functools.partial(make_pipeline_allocations,
in_specs=in_specs,
out_specs=out_specs,
should_accumulate_out=should_accumulate_out)
pipeline = emit_pipeline(
body,
grid=grid,
in_specs=in_specs,
out_specs=out_specs,
should_accumulate_out=should_accumulate_out)
return pipeline, make_allocations