1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-27 06:16:06 +00:00
Sergei Lebedev 928caf83ee [pallas:mosaic_gpu] copy_smem_to_gmem now allows skipping cp.async.commit_group
This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.

PiperOrigin-RevId: 734553668
2025-03-07 07:43:54 -08:00

687 lines
24 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for emitting custom GPU pipelines within a Pallas kernel."""
from __future__ import annotations
from collections.abc import Callable, Sequence
import dataclasses
import functools
import itertools as it
import math
from typing import Any
import jax
from jax import api_util
from jax import lax
from jax._src import core
from jax._src import linear_util as lu
from jax._src import util
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives
from jax.experimental import pallas as pl
import jax.numpy as jnp
map = util.safe_map
zip = util.safe_zip
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class BufferedRef:
spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True})
is_index_invariant: bool = dataclasses.field(metadata={"static": True})
gmem_ref: pallas_core.AbstractMemoryRef
# ``None`` if the ref is pinned to GMEM; otherwise, has shape
# [num_slots, *spec.block_shape].
smem_ref: pallas_core.AbstractMemoryRef | None
def get_ref_for_slot(
self, slot: int | jax.Array
) -> pallas_core.AbstractMemoryRef:
if self.smem_ref is None:
return self.gmem_ref
return self.smem_ref.at[slot]
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
index_map = self.spec.index_map
assert index_map is not None
# We don't allow Python scalars here, because they are interpreted
# differently depending on the x32/x64 mode.
assert all(i.dtype == jnp.dtype(jnp.int32) for i in grid_indices)
return tuple(
pl.Slice(idx * size, size) # type: ignore[arg-type]
for idx, size in zip(
index_map(*grid_indices), self.spec.block_shape # type: ignore[arg-type]
)
)
def copy_in(self, slot, grid_indices, barrier_ref):
if not _in_smem(self.spec):
return
assert self.smem_ref is not None
gmem_slices = self.compute_gmem_slice(grid_indices)
gpu_primitives.copy_gmem_to_smem(
self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands
self.smem_ref.at[slot], # pytype: disable=unsupported-operands
barrier_ref.at[slot],
)
def copy_out(self, slot, grid_indices, predicate=None):
if not _in_smem(self.spec):
return
assert self.smem_ref is not None
gmem_slices = self.compute_gmem_slice(grid_indices)
gpu_primitives.copy_smem_to_gmem(
self.smem_ref.at[slot], # pytype: disable=unsupported-operands
self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands
predicate=predicate,
commit_group=False,
)
def _uses_arguments(
index_map: Callable[..., Any], num_args: int
) -> Sequence[bool]:
if not num_args:
return ()
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(
index_map,
debug_info=api_util.debug_info("pallas index_map",
index_map,
(0,) * num_args, {})),
(core.ShapedArray((), jnp.int32),) * num_args
)
_, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))
return used_inputs
def _is_index_invariant(
spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid
) -> bool:
if (index_map := spec.index_map) is None:
return True
return not any(_uses_arguments(index_map, len(grid)))
def _inc_grid_by_1(
indices: tuple[jax.Array, ...], grid: Sequence[int]
) -> tuple[jax.Array, ...]:
next_indices = []
carry: bool | jax.Array = True
for idx, size in reversed(list(zip(indices, grid))):
next_idx = lax.select(carry, idx + 1, idx)
carry = next_idx == size
next_indices.append(
lax.select(carry, jnp.asarray(0, dtype=idx.dtype), next_idx)
)
return tuple(reversed(next_indices))
def _in_smem(spec: pallas_core.BlockSpec) -> bool:
return spec.memory_space in (None, gpu_core.SMEM)
# ``pl.Slice`` uses a different pytree encoding, depending on whether the
# start/size are static or dynamic. This leads to pytree structure mismatch
# in the pipeline body. So, we define a different ``Slice`` class below.
@dataclasses.dataclass(frozen=True)
class _Slice:
start: int | jax.Array
size: int | jax.Array
def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
return lax.bitwise_and(self.start == other.start, self.size == other.size)
jax.tree_util.register_dataclass(
_Slice, data_fields=["start", "size"], meta_fields=[]
)
def emit_pipeline(
body: Callable[..., None],
*,
grid: pallas_core.StaticGrid,
in_specs: Sequence[pallas_core.BlockSpec] = (),
out_specs: Sequence[pallas_core.BlockSpec] = (),
max_concurrent_steps: int = 1,
delay_release: int = 0,
):
"""Creates a function to emit a manual pipeline within a Pallas kernel.
Args:
body: The pipeline body.
grid: The grid to use for the pipeline.
in_specs: The block specs for the inputs.
out_specs: The block specs for the outputs.
max_concurrent_steps: The maximum number of sequential stages that are
active concurrently. Defaults to 1.
delay_release: The number of steps to wait before reusing the input/output
references. Defaults to 0, and must be strictly smaller than
``max_concurrent_steps``. Generally, you'll want to set it to 1 if you
don't await the WGMMA in the body.
"""
num_steps = math.prod(grid)
if max_concurrent_steps <= delay_release:
raise ValueError(
"max_concurrent_steps must be greater than delay_release, but"
f" {max_concurrent_steps=}, {delay_release=}"
)
# Shrink ``max_concurrent_steps`` if the total number of steps is lower to
# reduce the size of the refs allocated in SMEM.
if max_concurrent_steps > num_steps:
max_concurrent_steps = num_steps
delay_release = 0 # No need to delay anything.
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
in_smem_refs, out_smem_refs = util.split_list(
[
gpu_core.SMEM(
(max_concurrent_steps, *spec.block_shape), # type: ignore
ref.dtype,
transforms=tuple(
t.batch(1) for t in getattr(spec, "transforms", ())
),
)
if _in_smem(spec)
else None
for spec, ref in zip(it.chain(in_specs, out_specs), gmem_refs)
],
[len(in_specs)],
)
return pl.run_scoped(
functools.partial(
scoped_pipeline,
in_gmem_refs=in_gmem_refs,
out_gmem_refs=out_gmem_refs,
),
in_smem_refs=in_smem_refs,
out_smem_refs=out_smem_refs,
barrier_ref=gpu_core.Barrier(
# TODO(slebedev): Change this to arrive only once.
sum(map(_in_smem, in_specs)),
num_barriers=max_concurrent_steps,
),
)
def scoped_pipeline(
*, in_gmem_refs, out_gmem_refs, in_smem_refs, out_smem_refs, barrier_ref
):
in_brefs: Sequence[BufferedRef] = [
BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref)
for spec, gmem_ref, smem_ref in zip(
in_specs, in_gmem_refs, in_smem_refs
)
]
out_brefs: Sequence[BufferedRef] = [
BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref)
for spec, gmem_ref, smem_ref in zip(
out_specs, out_gmem_refs, out_smem_refs
)
]
for step, indices in enumerate(
it.islice(it.product(*map(range, grid)), max_concurrent_steps)
):
indices = tuple(map(lambda i: jnp.asarray(i, dtype=jnp.int32), indices))
map(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs)
# This is true if any of the outputs need to be transferred inside the loop.
copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs)
def loop_body(step, carry):
slot = lax.rem(step, max_concurrent_steps)
indices, fetch_indices, last_store_slices = carry
if in_specs:
# Wait for the current GMEM->SMEM copy to complete.
gpu_primitives.barrier_wait(barrier_ref.at[slot])
# Wait for the previous output SMEM->GMEM copy to complete.
if copies_out_in_loop:
gpu_primitives.wait_smem_to_gmem(
max_concurrent_steps - (1 + delay_release), wait_read_only=True
)
with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)):
body(*(
bref.get_ref_for_slot(slot)
for bref in it.chain(in_brefs, out_brefs)
))
if copies_out_in_loop:
gpu_primitives.commit_smem()
# Copy the output from SMEM to GMEM.
new_store_slices = last_store_slices[:]
for idx, bref in enumerate(out_brefs):
if bref.is_index_invariant:
assert last_store_slices[idx] is None
continue
assert last_store_slices[idx] is not None
new_store_slices[idx] = tuple(
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
)
are_same_slices = map(
lambda old, new: old == new,
last_store_slices[idx],
new_store_slices[idx],
)
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
is_last_step = step == num_steps - 1
# TODO(apaszke,slebedev): This still diverges significantly from the
# TPU semantics in that it will move on to the next SMEM output slice
# even if it's not storing the previous one.
bref.copy_out(
slot,
indices,
predicate=lax.bitwise_or(slices_changed, is_last_step),
)
gpu_primitives.commit_smem_to_gmem_group()
fetch_step = step + (max_concurrent_steps - delay_release)
fetch_slot = lax.rem(fetch_step, max_concurrent_steps)
def do_fetch():
for bref in in_brefs:
bref.copy_in(fetch_slot, fetch_indices, barrier_ref)
jax.lax.cond(
lax.bitwise_and(step >= delay_release, fetch_step < num_steps),
do_fetch,
lambda: None,
)
return (
_inc_grid_by_1(indices, grid),
_inc_grid_by_1(fetch_indices, grid),
new_store_slices,
)
# Invariant: ``indices`` and ``fetch_indices`` are always
# ``max_concurrent_steps-delay_release`` apart.
indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid)
fetch_indices = indices
for _ in range(max_concurrent_steps-delay_release):
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
# TODO(justinfu): Only store base pointer instead of all indices.
last_store_slices = [
None
if bref.is_index_invariant
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
for bref in out_brefs
]
last_indices, _, _ = lax.fori_loop(
0, num_steps, loop_body, (indices, fetch_indices, last_store_slices)
)
# Outputs invariant to the sequential axis are never written from inside the
# loop. This is the only place where we store them.
if not copies_out_in_loop:
gpu_primitives.commit_smem()
last_slot = lax.rem(num_steps - 1, max_concurrent_steps)
for bref in out_brefs:
if bref.is_index_invariant:
bref.copy_out(last_slot, last_indices, predicate=None)
gpu_primitives.commit_smem_to_gmem_group()
# Finalize the pipeline.
gpu_primitives.wait_smem_to_gmem(0)
return pipeline
def emit_pipeline_warp_specialized(
body: Callable[..., None],
*,
grid: pallas_core.StaticGrid,
memory_registers: int,
in_specs: Sequence[gpu_core.GPUBlockSpec] = (),
out_specs: Sequence[gpu_core.GPUBlockSpec] = (),
max_concurrent_steps: int = 2,
wg_axis: str,
num_compute_wgs: int,
manual_consumed_barriers: bool = False,
carry_coroutine: Any | None = None,
memory_thread_idx: int | None = None,
):
"""Creates a function to emit a warp-specialized pipeline.
The ``body`` function should have the following signature (without carry).
``consumed_barriers`` is an optional argument that is only passed if the
``manual_consumed_barriers`` argument is True.
```
def body(*input_refs, *output_refs, [consumed_barriers]) -> None:
```
or with a carries enabled (enabled via the ``carry_coroutine`` argument),
where the body returns the next carry:
```
def body(*input_refs, *output_refs, [consumed_barriers], carry) -> Carry:
```
Args:
body: The pipeline body.
grid: The grid to use for the pipeline.
memory_registers: The number of registers to reserve for the memory thread.
For H100 GPUs, 40 is a reasonable value.
in_specs: The block specs for the inputs.
out_specs: The block specs for the outputs.
max_concurrent_steps: The maximum number of sequential stages that are
active concurrently. Defaults to 2.
wg_axis: The axis name for the warp group axis.
num_compute_wgs: The number of compute warpgroups
manual_consumed_barriers: If True, consumed barriers will be
passed into the body function after the output refs. There will be one
barrier per input and will be passed in the same order.
carry_coroutine: If specified, enables carries in the pipeline.
The signature of the body function will be modified such that the last
argument will be the current carry and it must return the next carry.
The coroutine itself should yield the initial carry, and the
yield statement will return the final value of the carry.
memory_thread_idx: The index of the memory thread. If not specified,
defaults to the last thread.
"""
# TODO(justinfu): Factor out common code between warp-specialized and
# normal pipelines.
if memory_thread_idx is None:
memory_thread_idx = num_compute_wgs
if memory_thread_idx != num_compute_wgs:
# TODO(justinfu): Indexing calculations for buffers assume the memory
# thread is the last thread.
raise NotImplementedError("Memory thread must be the last thread.")
has_carry = carry_coroutine is not None
# Trace the index maps to determine if they depend on the grid.
# Grid-independent values will not be multiple-buffered.
in_spec_has_seq_axis = [
~_is_index_invariant(spec, grid) for spec in in_specs]
out_spec_has_seq_axis = [
~_is_index_invariant(spec, grid) for spec in out_specs]
spec_has_seq_axis = [*in_spec_has_seq_axis, *out_spec_has_seq_axis]
num_pipeline_steps = math.prod(grid)
def _get_slot(step, has_seq_dim):
"""Returns the buffer slot given the pipeline step."""
if has_seq_dim:
return step
else:
return 0
# Shrink ``max_concurrent_steps`` if the total number of steps is lower to
# reduce the size of the refs allocated in SMEM.
if max_concurrent_steps > num_pipeline_steps:
max_concurrent_steps = num_pipeline_steps
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
if len(out_gmem_refs) != len(out_specs):
raise ValueError(
"Number of output refs does not match number of output specs."
)
smem_allocs = []
for spec, has_seq_dim, gmem_ref in zip(
it.chain(in_specs, out_specs),
spec_has_seq_axis,
gmem_refs):
slots = max_concurrent_steps if has_seq_dim else 1
smem_allocs.append(
gpu_core.SMEM(
(slots, *spec.block_shape), # type: ignore
gmem_ref.dtype,
transforms=spec.transforms,
)
)
in_smem_refs, out_smem_refs = util.split_list(
smem_allocs, [len(in_specs)])
in_smem_barriers = []
consumed_barriers = []
for has_seq_dim in in_spec_has_seq_axis:
num_barriers = max_concurrent_steps if has_seq_dim else 1
in_smem_barriers.append(
gpu_core.Barrier(
num_arrivals=1,
num_barriers=num_barriers))
if manual_consumed_barriers:
consumed_barriers.append(
gpu_core.Barrier(
num_arrivals=num_compute_wgs,
num_barriers=max_concurrent_steps,
)
)
if not manual_consumed_barriers:
# We only allocated one consumed barrier for all inputs when using
# automatic consumed barriers.
consumed_barriers = [
gpu_core.Barrier(
num_arrivals=num_compute_wgs,
num_barriers=max_concurrent_steps,
)
]
return pl.run_scoped(
functools.partial(
scoped_pipeline,
in_gmem_refs=in_gmem_refs,
out_gmem_refs=out_gmem_refs,
),
in_smem_refs=in_smem_refs,
out_smem_refs=out_smem_refs,
in_smem_barrier_refs=in_smem_barriers,
consumed_barrier_refs=consumed_barriers,
)
def scoped_pipeline(
*,
in_gmem_refs,
out_gmem_refs,
in_smem_refs,
out_smem_refs,
in_smem_barrier_refs,
consumed_barrier_refs,
):
in_brefs: Sequence[BufferedRef] = [
BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref)
for spec, has_seq_axis, gmem_ref, smem_ref in zip(
in_specs, in_spec_has_seq_axis, in_gmem_refs, in_smem_refs
)
]
out_brefs: Sequence[BufferedRef] = [
BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref)
for spec, has_seq_axis, gmem_ref, smem_ref in zip(
out_specs, out_spec_has_seq_axis, out_gmem_refs, out_smem_refs
)
]
def compute_block():
gpu_primitives.set_max_registers(
_compute_registers(memory_registers, num_compute_wgs),
action="increase")
# This is true if any of the outputs need to be transferred inside the loop.
copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs)
def compute_loop_body(step, carry):
indices, last_store_slices, prev_body_carry = carry
slot = lax.rem(step, max_concurrent_steps)
# Wait for the current GMEM->SMEM copies to complete.
for in_barrier, has_seq_dim in zip(
in_smem_barrier_refs, in_spec_has_seq_axis):
# TODO(justinfu): Use a single barrier with
# num_arrivals=len(in_smem_barrier_refs)
gpu_primitives.barrier_wait(
in_barrier.at[_get_slot(slot, has_seq_dim)])
# Wait for the previous output SMEM->GMEM copy to complete.
if copies_out_in_loop:
gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1)
with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)):
body_refs = []
for bref in it.chain(in_brefs, out_brefs):
buf_slot = _get_slot(slot, ~bref.is_index_invariant)
body_refs.append(bref.get_ref_for_slot(buf_slot))
body_args = body_refs
if manual_consumed_barriers:
body_args += [consumed_barrier_ref.at[slot] for consumed_barrier_ref in consumed_barrier_refs]
if has_carry:
body_args += [prev_body_carry]
next_body_carry = body(*body_args)
if not manual_consumed_barriers:
[consumed_barrier_ref] = consumed_barrier_refs
gpu_primitives.barrier_arrive(consumed_barrier_ref.at[slot])
# TODO(justinfu,apaszke): This should probably be done by the memory WG.
# Copy the output from SMEM to GMEM.
if copies_out_in_loop:
gpu_primitives.commit_smem()
new_store_slices = last_store_slices[:]
for idx, bref in enumerate(out_brefs):
if bref.is_index_invariant:
assert last_store_slices[idx] is None
continue
assert last_store_slices[idx] is not None
new_store_slices[idx] = tuple(
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
)
are_same_slices = map(
lambda old, new: old == new,
last_store_slices[idx],
new_store_slices[idx],
)
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
bref.copy_out(_get_slot(slot, ~bref.is_index_invariant),
indices,
predicate=slices_changed)
gpu_primitives.commit_smem_to_gmem_group()
next_indices = _inc_grid_by_1(indices, grid)
return (next_indices, new_store_slices, next_body_carry)
init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid)
# TODO(justinfu): Only store base pointer instead of all indices.
last_store_slices = [
None
if bref.is_index_invariant
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
for bref in out_brefs
]
if has_carry:
_carry = carry_coroutine()
try:
carry_init = next(_carry)
except StopIteration:
raise ValueError("carry_coroutine must yield the initial carry.") # pylint: disable=raise-missing-from
else:
_carry = None
carry_init = None
init_loop_carry = (init_indices, last_store_slices, carry_init)
last_indices, _, final_body_carry = lax.fori_loop(0,
num_pipeline_steps,
compute_loop_body,
init_loop_carry)
if has_carry:
try:
_carry.send(final_body_carry) # pytype: disable=attribute-error
raise ValueError("carry_coroutine must only yield once.")
except StopIteration:
pass
# Handle index_invariant outputs after the loop. They are not
# written in the main pipeline loop.
if not copies_out_in_loop:
gpu_primitives.commit_smem()
last_slot = lax.rem(num_pipeline_steps - 1, max_concurrent_steps)
for bref in out_brefs:
if bref.is_index_invariant:
bref.copy_out(last_slot, last_indices, predicate=None)
gpu_primitives.commit_smem_to_gmem_group()
# Finalize the pipeline.
gpu_primitives.wait_smem_to_gmem(0)
# The memory thread executes this block which issues all pipelined DMAs.
def memory_block():
gpu_primitives.set_max_registers(memory_registers, action="decrease")
indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid)
# Begin initial copies.
for step in range(max_concurrent_steps):
for bref, barrier in zip(in_brefs, in_smem_barrier_refs):
buf_slot = _get_slot(step, ~bref.is_index_invariant)
bref.copy_in(buf_slot, indices, barrier)
indices = _inc_grid_by_1(indices, grid)
def memory_loop_body(step, carry):
indices, = carry
slot = lax.rem(step, max_concurrent_steps)
fetch_slot = slot # (x + y) % y == x % y
if not manual_consumed_barriers:
# We only have one consumed barrier when using automatic consumed
# barrier management.
[consumed_barrier_ref] = consumed_barrier_refs
gpu_primitives.barrier_wait(consumed_barrier_ref.at[slot])
consumed_barrier_it = [None] * len(in_brefs)
else:
consumed_barrier_it = consumed_barrier_refs
for bref, barrier, consumed_barrier in zip(
in_brefs, in_smem_barrier_refs, consumed_barrier_it):
if manual_consumed_barriers:
gpu_primitives.barrier_wait(consumed_barrier.at[slot]) # pytype: disable=attribute-error
bref.copy_in(
_get_slot(fetch_slot, ~bref.is_index_invariant), indices, barrier)
next_indices = _inc_grid_by_1(indices, grid)
return (next_indices,)
lax.fori_loop(0, num_pipeline_steps - max_concurrent_steps,
memory_loop_body, (indices,))
wg_idx = lax.axis_index(wg_axis)
lax.cond(
wg_idx != memory_thread_idx,
compute_block,
memory_block
)
return pipeline
def _compute_registers(
memory_registers: int,
num_compute_wgs: int,
) -> int:
"""Returns the number of registers to use for the compute thread."""
# TODO(justinfu): Configure this per-platform.
n_registers = (512 - memory_registers) / num_compute_wgs
# Round down to the nearest multiple of 8.
return int((n_registers // 8) * 8)