mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[Mosaic GPU] Prototype of a warp-specialized pipeline emitter for Mosaic GPU.
PiperOrigin-RevId: 708010809
This commit is contained in:
parent
9f42b99a76
commit
d129438548
@ -301,6 +301,7 @@ def emit_pipeline(
|
||||
fetch_indices = indices
|
||||
for _ in range(max_concurrent_steps):
|
||||
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
|
||||
# TODO(justinfu): Only store base pointer instead of all indices.
|
||||
last_store_slices = [
|
||||
None
|
||||
if bref.is_index_invariant
|
||||
@ -324,3 +325,251 @@ def emit_pipeline(
|
||||
gpu_primitives.wait_smem_to_gmem(0)
|
||||
|
||||
return pipeline
|
||||
|
||||
def emit_pipeline_warp_specialized(
|
||||
body: Callable[..., None],
|
||||
*,
|
||||
grid: pallas_core.StaticGrid,
|
||||
memory_registers: int,
|
||||
in_specs: Sequence[gpu_core.GPUBlockSpec] = (),
|
||||
out_specs: Sequence[gpu_core.GPUBlockSpec] = (),
|
||||
max_concurrent_steps: int = 2,
|
||||
wg_axis: str,
|
||||
num_compute_wgs: int,
|
||||
memory_thread_idx: int | None = None,
|
||||
):
|
||||
"""Creates a function to emit a warp-specialized pipeline.
|
||||
|
||||
Args:
|
||||
body: The pipeline body.
|
||||
grid: The grid to use for the pipeline.
|
||||
memory_registers: The number of registers to reserve for the memory thread.
|
||||
For H100 GPUs, 40 is a reasonable value.
|
||||
in_specs: The block specs for the inputs.
|
||||
out_specs: The block specs for the outputs.
|
||||
max_concurrent_steps: The maximum number of sequential stages that are
|
||||
active concurrently. Defaults to 2.
|
||||
wg_axis: The axis name for the warp group axis.
|
||||
num_compute_wgs: The number of compute warpgroups
|
||||
memory_thread_idx: The index of the memory thread. If not specified,
|
||||
defaults to the last thread.
|
||||
"""
|
||||
# TODO(justinfu): Factor out common code between warp-specialized and
|
||||
# normal pipelines.
|
||||
# TODO(justinfu): Allow body to return carries.
|
||||
# TODO(justinfu): Allow passing consumed_barrier into body.
|
||||
|
||||
if memory_thread_idx is None:
|
||||
memory_thread_idx = num_compute_wgs
|
||||
if memory_thread_idx != num_compute_wgs:
|
||||
# TODO(justinfu): Indexing calculations for buffers assume the memory
|
||||
# thread is the last thread.
|
||||
raise NotImplementedError("Memory thread must be the last thread.")
|
||||
|
||||
# Trace the index maps to determine if they depend on the grid.
|
||||
# Grid-independent values will not be multiple-buffered.
|
||||
in_spec_has_seq_axis = [
|
||||
~_is_index_invariant(spec, grid) for spec in in_specs]
|
||||
out_spec_has_seq_axis = [
|
||||
~_is_index_invariant(spec, grid) for spec in out_specs]
|
||||
spec_has_seq_axis = [*in_spec_has_seq_axis, *out_spec_has_seq_axis]
|
||||
|
||||
num_pipeline_steps = math.prod(grid)
|
||||
|
||||
def _get_slot(step, has_seq_dim):
|
||||
"""Returns the buffer slot given the pipeline step."""
|
||||
if has_seq_dim:
|
||||
return step
|
||||
else:
|
||||
return 0
|
||||
|
||||
# Shrink ``max_concurrent_steps`` if the total number of steps is lower to
|
||||
# reduce the size of the refs allocated in SMEM.
|
||||
if max_concurrent_steps > num_pipeline_steps:
|
||||
max_concurrent_steps = num_pipeline_steps
|
||||
|
||||
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
|
||||
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
|
||||
if len(out_gmem_refs) != len(out_specs):
|
||||
raise ValueError(
|
||||
"Number of output refs does not match number of output specs."
|
||||
)
|
||||
smem_allocs = []
|
||||
for spec, has_seq_dim, gmem_ref in zip(
|
||||
it.chain(in_specs, out_specs),
|
||||
spec_has_seq_axis,
|
||||
gmem_refs):
|
||||
slots = max_concurrent_steps if has_seq_dim else 1
|
||||
smem_allocs.append(
|
||||
gpu_core.SMEM(
|
||||
(slots, *spec.block_shape), # type: ignore
|
||||
gmem_ref.dtype,
|
||||
transforms=spec.transforms,
|
||||
)
|
||||
)
|
||||
in_smem_refs, out_smem_refs = util.split_list(
|
||||
smem_allocs, [len(in_specs)])
|
||||
|
||||
in_smem_barriers = []
|
||||
for has_seq_dim in in_spec_has_seq_axis:
|
||||
num_barriers = max_concurrent_steps if has_seq_dim else 1
|
||||
in_smem_barriers.append(
|
||||
gpu_core.Barrier(
|
||||
num_arrivals=1,
|
||||
num_barriers=num_barriers))
|
||||
|
||||
return pl.run_scoped(
|
||||
functools.partial(
|
||||
scoped_pipeline,
|
||||
in_gmem_refs=in_gmem_refs,
|
||||
out_gmem_refs=out_gmem_refs,
|
||||
),
|
||||
in_smem_refs=in_smem_refs,
|
||||
out_smem_refs=out_smem_refs,
|
||||
in_smem_barrier_refs=in_smem_barriers,
|
||||
consumed_barrier_ref=gpu_core.Barrier(
|
||||
num_arrivals=num_compute_wgs,
|
||||
num_barriers=max_concurrent_steps,
|
||||
),
|
||||
)
|
||||
|
||||
def scoped_pipeline(
|
||||
*,
|
||||
in_gmem_refs,
|
||||
out_gmem_refs,
|
||||
in_smem_refs,
|
||||
out_smem_refs,
|
||||
in_smem_barrier_refs,
|
||||
consumed_barrier_ref,
|
||||
):
|
||||
in_brefs: Sequence[BufferedRef] = [
|
||||
BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref)
|
||||
for spec, has_seq_axis, gmem_ref, smem_ref in zip(
|
||||
in_specs, in_spec_has_seq_axis, in_gmem_refs, in_smem_refs
|
||||
)
|
||||
]
|
||||
out_brefs: Sequence[BufferedRef] = [
|
||||
BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref)
|
||||
for spec, has_seq_axis, gmem_ref, smem_ref in zip(
|
||||
out_specs, out_spec_has_seq_axis, out_gmem_refs, out_smem_refs
|
||||
)
|
||||
]
|
||||
|
||||
def compute_block():
|
||||
gpu_primitives.set_max_registers(
|
||||
_compute_registers(memory_registers, num_compute_wgs),
|
||||
action="increase")
|
||||
|
||||
def compute_loop_body(step, carry):
|
||||
indices, last_store_slices = carry
|
||||
slot = step % max_concurrent_steps
|
||||
# Wait for the current GMEM->SMEM copies to complete.
|
||||
for in_barrier, has_seq_dim in zip(
|
||||
in_smem_barrier_refs, in_spec_has_seq_axis):
|
||||
# TODO(justinfu): Use a single barrier with
|
||||
# num_arrivals=len(in_smem_barrier_refs)
|
||||
gpu_primitives.barrier_wait(
|
||||
in_barrier.at[_get_slot(slot, has_seq_dim)])
|
||||
|
||||
# Wait for the previous output SMEM->GMEM copy to complete.
|
||||
gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1)
|
||||
|
||||
with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)):
|
||||
body_refs = []
|
||||
for bref in it.chain(in_brefs, out_brefs):
|
||||
buf_slot = _get_slot(slot, ~bref.is_index_invariant)
|
||||
body_refs.append(bref.get_ref_for_slot(buf_slot))
|
||||
body(*body_refs)
|
||||
gpu_primitives.barrier_arrive(consumed_barrier_ref.at[slot])
|
||||
# Copy the output from SMEM to GMEM.
|
||||
if not all(bref.is_index_invariant for bref in out_brefs):
|
||||
gpu_primitives.commit_smem()
|
||||
|
||||
new_store_slices = last_store_slices[:]
|
||||
for idx, bref in enumerate(out_brefs):
|
||||
if bref.is_index_invariant:
|
||||
assert last_store_slices[idx] is None
|
||||
continue
|
||||
assert last_store_slices[idx] is not None
|
||||
new_store_slices[idx] = tuple(
|
||||
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
|
||||
)
|
||||
are_same_slices = map(
|
||||
lambda old, new: old == new,
|
||||
last_store_slices[idx],
|
||||
new_store_slices[idx],
|
||||
)
|
||||
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
|
||||
bref.copy_out(_get_slot(slot, ~bref.is_index_invariant),
|
||||
indices,
|
||||
predicate=slices_changed)
|
||||
next_indices = _inc_grid_by_1(indices, grid)
|
||||
return (next_indices, new_store_slices)
|
||||
init_indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid)
|
||||
# TODO(justinfu): Only store base pointer instead of all indices.
|
||||
last_store_slices = [
|
||||
None
|
||||
if bref.is_index_invariant
|
||||
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
|
||||
for bref in out_brefs
|
||||
]
|
||||
last_indices, _ = lax.fori_loop(0,
|
||||
num_pipeline_steps,
|
||||
compute_loop_body,
|
||||
(init_indices, last_store_slices))
|
||||
|
||||
# Handle index_invariant outputs after the loop. They are not
|
||||
# written in the main pipeline loop.
|
||||
if all(bref.is_index_invariant for bref in out_brefs):
|
||||
gpu_primitives.commit_smem()
|
||||
last_slot = (num_pipeline_steps - 1) % max_concurrent_steps
|
||||
for bref in out_brefs:
|
||||
if bref.is_index_invariant:
|
||||
bref.copy_out(last_slot, last_indices, predicate=None)
|
||||
|
||||
# Finalize the pipeline.
|
||||
gpu_primitives.wait_smem_to_gmem(0)
|
||||
return
|
||||
|
||||
# The memory thread executes this block which issues all pipelined DMAs.
|
||||
def memory_block():
|
||||
gpu_primitives.set_max_registers(memory_registers, action="decrease")
|
||||
indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid)
|
||||
|
||||
# Begin initial copies.
|
||||
for step in range(max_concurrent_steps):
|
||||
for bref, barrier in zip(in_brefs, in_smem_barrier_refs):
|
||||
buf_slot = _get_slot(step, ~bref.is_index_invariant)
|
||||
bref.copy_in(buf_slot, indices, barrier)
|
||||
indices = _inc_grid_by_1(indices, grid)
|
||||
|
||||
def memory_loop_body(step, carry):
|
||||
indices, = carry
|
||||
slot = step % max_concurrent_steps
|
||||
fetch_slot = slot # (x + y) % y == x % y
|
||||
gpu_primitives.barrier_wait(consumed_barrier_ref.at[slot])
|
||||
for bref, barrier in zip(in_brefs, in_smem_barrier_refs):
|
||||
bref.copy_in(
|
||||
_get_slot(fetch_slot, ~bref.is_index_invariant), indices, barrier)
|
||||
next_indices = _inc_grid_by_1(indices, grid)
|
||||
return (next_indices,)
|
||||
lax.fori_loop(0, num_pipeline_steps - max_concurrent_steps,
|
||||
memory_loop_body, (indices,))
|
||||
|
||||
wg_idx = lax.axis_index(wg_axis)
|
||||
lax.cond(
|
||||
wg_idx != memory_thread_idx,
|
||||
compute_block,
|
||||
memory_block
|
||||
)
|
||||
return pipeline
|
||||
|
||||
def _compute_registers(
|
||||
memory_registers: int,
|
||||
num_compute_wgs: int,
|
||||
) -> int:
|
||||
"""Returns the number of registers to use for the compute thread."""
|
||||
# TODO(justinfu): Configure this per-platform.
|
||||
n_registers = (512 - memory_registers) / num_compute_wgs
|
||||
# Round down to the nearest multiple of 8.
|
||||
return int((n_registers // 8) * 8)
|
||||
|
@ -25,6 +25,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import mosaic_gpu as plgpu
|
||||
import jax.numpy as jnp
|
||||
@ -1434,6 +1435,111 @@ class PipelineTest(PallasTest):
|
||||
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
|
||||
|
||||
|
||||
class WarpSpecializedPipelineTest(PallasTest):
|
||||
|
||||
def test_pipelined_copy(self, m=512, n=512):
|
||||
x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16)
|
||||
o = jnp.zeros((m, n), dtype=jnp.float16)
|
||||
blk_m = blk_n = 64
|
||||
o_last_block = jnp.zeros((blk_m, blk_n), dtype=jnp.float16)
|
||||
|
||||
def copy_kernel(x_smem, o_smem, o_last_block_smem):
|
||||
# TODO(justinfu): Have each wg compute a separate slice
|
||||
# after multiple-indexers are supported.
|
||||
# This is currently a race, but the values written are the same.
|
||||
o_smem[...] = x_smem[...]
|
||||
o_last_block_smem[...] = x_smem[...]
|
||||
block_spec = plgpu.GPUBlockSpec(
|
||||
block_shape=(blk_m, blk_n),
|
||||
index_map=lambda i, j: (i, j),
|
||||
transforms=[],
|
||||
)
|
||||
pipeline = mgpu_pipeline.emit_pipeline_warp_specialized(
|
||||
copy_kernel,
|
||||
grid=(m // blk_m, n // blk_n),
|
||||
memory_registers=40,
|
||||
max_concurrent_steps=2,
|
||||
num_compute_wgs=2,
|
||||
wg_axis="wg",
|
||||
in_specs=[block_spec],
|
||||
out_specs=[block_spec,
|
||||
# Create an index-invariant output.
|
||||
plgpu.GPUBlockSpec(block_shape=(blk_m, blk_n),
|
||||
index_map=lambda i, j: (0, 0))
|
||||
],
|
||||
)
|
||||
mesh = plgpu.GPUMesh(
|
||||
grid=(1,),
|
||||
num_threads=3,
|
||||
axis_names=("_", "wg",),
|
||||
approx_math=True,
|
||||
)
|
||||
def run(refs):
|
||||
@pl.core_map(mesh)
|
||||
def _kernel_entry():
|
||||
pipeline(*refs)
|
||||
@jax.jit
|
||||
def run_function(x, o, o_last_block):
|
||||
_, out, out_last = pl.run_state(run)((x, o, o_last_block))
|
||||
return (out, out_last)
|
||||
out, out_last_block = run_function(x, o, o_last_block)
|
||||
np.testing.assert_array_equal(out, x)
|
||||
np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:])
|
||||
|
||||
def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2):
|
||||
blk_m = blk_n = 64
|
||||
x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32)
|
||||
y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32)
|
||||
o = jnp.zeros((m, n), dtype=jnp.float32)
|
||||
|
||||
def tiled_add_kernel(x_smem, y_smem, o_smem):
|
||||
# TODO(justinfu): Have each wg compute a separate slice
|
||||
# after multiple-indexers are supported.
|
||||
# This is currently a race, but the values written are the same.
|
||||
o_smem[...] = x_smem[...] + y_smem[...]
|
||||
|
||||
pipeline = mgpu_pipeline.emit_pipeline_warp_specialized(
|
||||
tiled_add_kernel,
|
||||
grid=(m // blk_m, n // blk_n),
|
||||
max_concurrent_steps=2,
|
||||
num_compute_wgs=num_compute_wgs,
|
||||
memory_registers=40,
|
||||
wg_axis="wg",
|
||||
in_specs=[
|
||||
plgpu.GPUBlockSpec(
|
||||
block_shape=(blk_m, blk_n),
|
||||
index_map=lambda i, j: (i, j),
|
||||
transforms=[]),
|
||||
plgpu.GPUBlockSpec(
|
||||
block_shape=(blk_m, blk_n),
|
||||
index_map=lambda i, j: (i, j),
|
||||
transforms=[]),
|
||||
],
|
||||
out_specs=[
|
||||
plgpu.GPUBlockSpec(
|
||||
block_shape=(blk_m, blk_n),
|
||||
index_map=lambda i, j: (i, j),
|
||||
transforms=[])],
|
||||
)
|
||||
mesh = plgpu.GPUMesh(
|
||||
grid=(1,),
|
||||
num_threads=num_compute_wgs + 1,
|
||||
axis_names=("_", "wg",),
|
||||
approx_math=True,
|
||||
)
|
||||
def run(refs):
|
||||
@pl.core_map(mesh)
|
||||
def _kernel_entry():
|
||||
pipeline(*refs)
|
||||
@jax.jit
|
||||
def run_function(x, y, o):
|
||||
_, _, out = pl.run_state(run)((x, y, o))
|
||||
return out
|
||||
out = run_function(x, y, o)
|
||||
reference = x + y
|
||||
np.testing.assert_allclose(out, reference, atol=1e-4)
|
||||
|
||||
|
||||
class CoreMapTest(PallasTest):
|
||||
|
||||
def test_multiple_wg(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user