[Mosaic GPU] Prototype of a warp-specialized pipeline emitter for Mosaic GPU.

PiperOrigin-RevId: 708010809
This commit is contained in:
Justin Fu 2024-12-19 13:28:20 -08:00 committed by jax authors
parent 9f42b99a76
commit d129438548
2 changed files with 355 additions and 0 deletions

View File

@ -301,6 +301,7 @@ def emit_pipeline(
fetch_indices = indices
for _ in range(max_concurrent_steps):
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
# TODO(justinfu): Only store base pointer instead of all indices.
last_store_slices = [
None
if bref.is_index_invariant
@ -324,3 +325,251 @@ def emit_pipeline(
gpu_primitives.wait_smem_to_gmem(0)
return pipeline
def emit_pipeline_warp_specialized(
body: Callable[..., None],
*,
grid: pallas_core.StaticGrid,
memory_registers: int,
in_specs: Sequence[gpu_core.GPUBlockSpec] = (),
out_specs: Sequence[gpu_core.GPUBlockSpec] = (),
max_concurrent_steps: int = 2,
wg_axis: str,
num_compute_wgs: int,
memory_thread_idx: int | None = None,
):
"""Creates a function to emit a warp-specialized pipeline.
Args:
body: The pipeline body.
grid: The grid to use for the pipeline.
memory_registers: The number of registers to reserve for the memory thread.
For H100 GPUs, 40 is a reasonable value.
in_specs: The block specs for the inputs.
out_specs: The block specs for the outputs.
max_concurrent_steps: The maximum number of sequential stages that are
active concurrently. Defaults to 2.
wg_axis: The axis name for the warp group axis.
num_compute_wgs: The number of compute warpgroups
memory_thread_idx: The index of the memory thread. If not specified,
defaults to the last thread.
"""
# TODO(justinfu): Factor out common code between warp-specialized and
# normal pipelines.
# TODO(justinfu): Allow body to return carries.
# TODO(justinfu): Allow passing consumed_barrier into body.
if memory_thread_idx is None:
memory_thread_idx = num_compute_wgs
if memory_thread_idx != num_compute_wgs:
# TODO(justinfu): Indexing calculations for buffers assume the memory
# thread is the last thread.
raise NotImplementedError("Memory thread must be the last thread.")
# Trace the index maps to determine if they depend on the grid.
# Grid-independent values will not be multiple-buffered.
in_spec_has_seq_axis = [
~_is_index_invariant(spec, grid) for spec in in_specs]
out_spec_has_seq_axis = [
~_is_index_invariant(spec, grid) for spec in out_specs]
spec_has_seq_axis = [*in_spec_has_seq_axis, *out_spec_has_seq_axis]
num_pipeline_steps = math.prod(grid)
def _get_slot(step, has_seq_dim):
"""Returns the buffer slot given the pipeline step."""
if has_seq_dim:
return step
else:
return 0
# Shrink ``max_concurrent_steps`` if the total number of steps is lower to
# reduce the size of the refs allocated in SMEM.
if max_concurrent_steps > num_pipeline_steps:
max_concurrent_steps = num_pipeline_steps
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
if len(out_gmem_refs) != len(out_specs):
raise ValueError(
"Number of output refs does not match number of output specs."
)
smem_allocs = []
for spec, has_seq_dim, gmem_ref in zip(
it.chain(in_specs, out_specs),
spec_has_seq_axis,
gmem_refs):
slots = max_concurrent_steps if has_seq_dim else 1
smem_allocs.append(
gpu_core.SMEM(
(slots, *spec.block_shape), # type: ignore
gmem_ref.dtype,
transforms=spec.transforms,
)
)
in_smem_refs, out_smem_refs = util.split_list(
smem_allocs, [len(in_specs)])
in_smem_barriers = []
for has_seq_dim in in_spec_has_seq_axis:
num_barriers = max_concurrent_steps if has_seq_dim else 1
in_smem_barriers.append(
gpu_core.Barrier(
num_arrivals=1,
num_barriers=num_barriers))
return pl.run_scoped(
functools.partial(
scoped_pipeline,
in_gmem_refs=in_gmem_refs,
out_gmem_refs=out_gmem_refs,
),
in_smem_refs=in_smem_refs,
out_smem_refs=out_smem_refs,
in_smem_barrier_refs=in_smem_barriers,
consumed_barrier_ref=gpu_core.Barrier(
num_arrivals=num_compute_wgs,
num_barriers=max_concurrent_steps,
),
)
def scoped_pipeline(
*,
in_gmem_refs,
out_gmem_refs,
in_smem_refs,
out_smem_refs,
in_smem_barrier_refs,
consumed_barrier_ref,
):
in_brefs: Sequence[BufferedRef] = [
BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref)
for spec, has_seq_axis, gmem_ref, smem_ref in zip(
in_specs, in_spec_has_seq_axis, in_gmem_refs, in_smem_refs
)
]
out_brefs: Sequence[BufferedRef] = [
BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref)
for spec, has_seq_axis, gmem_ref, smem_ref in zip(
out_specs, out_spec_has_seq_axis, out_gmem_refs, out_smem_refs
)
]
def compute_block():
gpu_primitives.set_max_registers(
_compute_registers(memory_registers, num_compute_wgs),
action="increase")
def compute_loop_body(step, carry):
indices, last_store_slices = carry
slot = step % max_concurrent_steps
# Wait for the current GMEM->SMEM copies to complete.
for in_barrier, has_seq_dim in zip(
in_smem_barrier_refs, in_spec_has_seq_axis):
# TODO(justinfu): Use a single barrier with
# num_arrivals=len(in_smem_barrier_refs)
gpu_primitives.barrier_wait(
in_barrier.at[_get_slot(slot, has_seq_dim)])
# Wait for the previous output SMEM->GMEM copy to complete.
gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1)
with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)):
body_refs = []
for bref in it.chain(in_brefs, out_brefs):
buf_slot = _get_slot(slot, ~bref.is_index_invariant)
body_refs.append(bref.get_ref_for_slot(buf_slot))
body(*body_refs)
gpu_primitives.barrier_arrive(consumed_barrier_ref.at[slot])
# Copy the output from SMEM to GMEM.
if not all(bref.is_index_invariant for bref in out_brefs):
gpu_primitives.commit_smem()
new_store_slices = last_store_slices[:]
for idx, bref in enumerate(out_brefs):
if bref.is_index_invariant:
assert last_store_slices[idx] is None
continue
assert last_store_slices[idx] is not None
new_store_slices[idx] = tuple(
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
)
are_same_slices = map(
lambda old, new: old == new,
last_store_slices[idx],
new_store_slices[idx],
)
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
bref.copy_out(_get_slot(slot, ~bref.is_index_invariant),
indices,
predicate=slices_changed)
next_indices = _inc_grid_by_1(indices, grid)
return (next_indices, new_store_slices)
init_indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid)
# TODO(justinfu): Only store base pointer instead of all indices.
last_store_slices = [
None
if bref.is_index_invariant
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
for bref in out_brefs
]
last_indices, _ = lax.fori_loop(0,
num_pipeline_steps,
compute_loop_body,
(init_indices, last_store_slices))
# Handle index_invariant outputs after the loop. They are not
# written in the main pipeline loop.
if all(bref.is_index_invariant for bref in out_brefs):
gpu_primitives.commit_smem()
last_slot = (num_pipeline_steps - 1) % max_concurrent_steps
for bref in out_brefs:
if bref.is_index_invariant:
bref.copy_out(last_slot, last_indices, predicate=None)
# Finalize the pipeline.
gpu_primitives.wait_smem_to_gmem(0)
return
# The memory thread executes this block which issues all pipelined DMAs.
def memory_block():
gpu_primitives.set_max_registers(memory_registers, action="decrease")
indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid)
# Begin initial copies.
for step in range(max_concurrent_steps):
for bref, barrier in zip(in_brefs, in_smem_barrier_refs):
buf_slot = _get_slot(step, ~bref.is_index_invariant)
bref.copy_in(buf_slot, indices, barrier)
indices = _inc_grid_by_1(indices, grid)
def memory_loop_body(step, carry):
indices, = carry
slot = step % max_concurrent_steps
fetch_slot = slot # (x + y) % y == x % y
gpu_primitives.barrier_wait(consumed_barrier_ref.at[slot])
for bref, barrier in zip(in_brefs, in_smem_barrier_refs):
bref.copy_in(
_get_slot(fetch_slot, ~bref.is_index_invariant), indices, barrier)
next_indices = _inc_grid_by_1(indices, grid)
return (next_indices,)
lax.fori_loop(0, num_pipeline_steps - max_concurrent_steps,
memory_loop_body, (indices,))
wg_idx = lax.axis_index(wg_axis)
lax.cond(
wg_idx != memory_thread_idx,
compute_block,
memory_block
)
return pipeline
def _compute_registers(
memory_registers: int,
num_compute_wgs: int,
) -> int:
"""Returns the number of registers to use for the compute thread."""
# TODO(justinfu): Configure this per-platform.
n_registers = (512 - memory_registers) / num_compute_wgs
# Round down to the nearest multiple of 8.
return int((n_registers // 8) * 8)

View File

@ -25,6 +25,7 @@ from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline
from jax.experimental import pallas as pl
from jax.experimental.pallas import mosaic_gpu as plgpu
import jax.numpy as jnp
@ -1434,6 +1435,111 @@ class PipelineTest(PallasTest):
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
class WarpSpecializedPipelineTest(PallasTest):
def test_pipelined_copy(self, m=512, n=512):
x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16)
o = jnp.zeros((m, n), dtype=jnp.float16)
blk_m = blk_n = 64
o_last_block = jnp.zeros((blk_m, blk_n), dtype=jnp.float16)
def copy_kernel(x_smem, o_smem, o_last_block_smem):
# TODO(justinfu): Have each wg compute a separate slice
# after multiple-indexers are supported.
# This is currently a race, but the values written are the same.
o_smem[...] = x_smem[...]
o_last_block_smem[...] = x_smem[...]
block_spec = plgpu.GPUBlockSpec(
block_shape=(blk_m, blk_n),
index_map=lambda i, j: (i, j),
transforms=[],
)
pipeline = mgpu_pipeline.emit_pipeline_warp_specialized(
copy_kernel,
grid=(m // blk_m, n // blk_n),
memory_registers=40,
max_concurrent_steps=2,
num_compute_wgs=2,
wg_axis="wg",
in_specs=[block_spec],
out_specs=[block_spec,
# Create an index-invariant output.
plgpu.GPUBlockSpec(block_shape=(blk_m, blk_n),
index_map=lambda i, j: (0, 0))
],
)
mesh = plgpu.GPUMesh(
grid=(1,),
num_threads=3,
axis_names=("_", "wg",),
approx_math=True,
)
def run(refs):
@pl.core_map(mesh)
def _kernel_entry():
pipeline(*refs)
@jax.jit
def run_function(x, o, o_last_block):
_, out, out_last = pl.run_state(run)((x, o, o_last_block))
return (out, out_last)
out, out_last_block = run_function(x, o, o_last_block)
np.testing.assert_array_equal(out, x)
np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:])
def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2):
blk_m = blk_n = 64
x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32)
y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32)
o = jnp.zeros((m, n), dtype=jnp.float32)
def tiled_add_kernel(x_smem, y_smem, o_smem):
# TODO(justinfu): Have each wg compute a separate slice
# after multiple-indexers are supported.
# This is currently a race, but the values written are the same.
o_smem[...] = x_smem[...] + y_smem[...]
pipeline = mgpu_pipeline.emit_pipeline_warp_specialized(
tiled_add_kernel,
grid=(m // blk_m, n // blk_n),
max_concurrent_steps=2,
num_compute_wgs=num_compute_wgs,
memory_registers=40,
wg_axis="wg",
in_specs=[
plgpu.GPUBlockSpec(
block_shape=(blk_m, blk_n),
index_map=lambda i, j: (i, j),
transforms=[]),
plgpu.GPUBlockSpec(
block_shape=(blk_m, blk_n),
index_map=lambda i, j: (i, j),
transforms=[]),
],
out_specs=[
plgpu.GPUBlockSpec(
block_shape=(blk_m, blk_n),
index_map=lambda i, j: (i, j),
transforms=[])],
)
mesh = plgpu.GPUMesh(
grid=(1,),
num_threads=num_compute_wgs + 1,
axis_names=("_", "wg",),
approx_math=True,
)
def run(refs):
@pl.core_map(mesh)
def _kernel_entry():
pipeline(*refs)
@jax.jit
def run_function(x, y, o):
_, _, out = pl.run_state(run)((x, y, o))
return out
out = run_function(x, y, o)
reference = x + y
np.testing.assert_allclose(out, reference, atol=1e-4)
class CoreMapTest(PallasTest):
def test_multiple_wg(self):