[pallas:mosaic_gpu] Added WG lowering rules for TMA primitives and run_scoped_p

PiperOrigin-RevId: 730780335
This commit is contained in:
Sergei Lebedev 2025-02-25 01:31:43 -08:00 committed by jax authors
parent 80848ad859
commit 7eadc64b5a
4 changed files with 178 additions and 37 deletions

View File

@ -79,6 +79,11 @@ def _align_to(x: int, alignment: int):
return x
@dataclasses.dataclass(frozen=True)
class ResourceEstimatorContext:
arrival_multiplier: int
@dataclasses.dataclass(kw_only=True, frozen=True)
class Resources:
smem_scratch_bytes: int = 0
@ -118,7 +123,9 @@ class Resources:
class ResourceEstimator(Protocol):
def __call__(self, *args: Any, **params: Any) -> Resources:
def __call__(
self, ctx: ResourceEstimatorContext, *args: Any, **params: Any
) -> Resources:
...
@ -133,7 +140,9 @@ def _register_resource_estimator(primitive: jax_core.Primitive):
return deco
def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources:
def _estimate_resources(
ctx: ResourceEstimatorContext, jaxpr: jax_core.Jaxpr
) -> Resources:
"""Estimates the resources required by the kernel."""
rs = Resources(smem_scratch_bytes=0)
for eqn in jaxpr.eqns:
@ -142,34 +151,48 @@ def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources:
if rule is None:
# Assume that unsupported primitives are neutral wrt resource usage.
continue
rs |= rule(*(invar.aval for invar in eqn.invars), **eqn.params)
rs |= rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params)
return rs
@_register_resource_estimator(lax.cond_p)
def _cond_resource_estimator(*args, branches) -> int:
def _cond_resource_estimator(
ctx: ResourceEstimatorContext, *args, branches
) -> int:
del args # Unused.
return functools.reduce(
lambda a, b: a | b,
(_estimate_resources(branch.jaxpr) for branch in branches),
(_estimate_resources(ctx, branch.jaxpr) for branch in branches),
)
@_register_resource_estimator(lax.scan_p)
def _scan_resource_estimator(*args, jaxpr: jax_core.ClosedJaxpr, **params) -> int:
def _scan_resource_estimator(
ctx: ResourceEstimatorContext, *args, jaxpr: jax_core.ClosedJaxpr, **params
) -> int:
del args, params # Unused.
return _estimate_resources(jaxpr)
return _estimate_resources(ctx, jaxpr)
@_register_resource_estimator(lax.while_p)
def _while_resource_estimator(*args, cond_jaxpr: jax_core.ClosedJaxpr, body_jaxpr: jax_core.ClosedJaxpr, **params) -> int:
def _while_resource_estimator(
ctx: ResourceEstimatorContext,
*args,
cond_jaxpr: jax_core.ClosedJaxpr,
body_jaxpr: jax_core.ClosedJaxpr,
**params,
) -> int:
del args, params # Unused.
return _estimate_resources(cond_jaxpr) | _estimate_resources(body_jaxpr)
return _estimate_resources(ctx, cond_jaxpr) | _estimate_resources(
ctx, body_jaxpr
)
@_register_resource_estimator(primitives.run_scoped_p)
def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
def _run_scoped_resource_estimator(
ctx: ResourceEstimatorContext, *consts, jaxpr: jax_core.Jaxpr
) -> int:
del consts # Unused.
rs = Resources()
for v in jaxpr.invars:
@ -178,7 +201,7 @@ def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
rs += Resources(
barrier_counts=collections.Counter([
mgpu.Barrier(
aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape
aval.dtype.num_arrivals * ctx.arrival_multiplier, *aval.shape
)
])
)
@ -186,11 +209,14 @@ def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
rs += Resources(
smem_scratch_bytes=math.prod(aval.shape) * aval.dtype.itemsize
)
return rs + _estimate_resources(jaxpr)
return rs + _estimate_resources(ctx, jaxpr)
@_register_resource_estimator(lax.reduce_sum_p)
def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
def _reduce_sum_resource_estimator(
ctx: ResourceEstimatorContext, x_aval: jax_core.ShapedArray, *, axes
) -> int:
del ctx, axes # Unused.
# We don't need shmem for some reductons, but it depends on the layout, so we
# conservatively request some scratch space.
return Resources(smem_scratch_bytes=4 * x_aval.dtype.itemsize)
@ -786,9 +812,12 @@ def lower_jaxpr_to_module(
"All scratch operands must be SMEM references or accumulators (ACC),"
f" but got: {scratch_avals}"
)
rs = _estimate_resources(jaxpr)
arrival_multiplier = (
WARPGROUP_SIZE if thread_semantics == mgpu.ThreadSemantics.Lane else 1
)
rs = _estimate_resources(ResourceEstimatorContext(arrival_multiplier), jaxpr)
extra_barriers = [
mgpu.Barrier(aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape)
mgpu.Barrier(aval.dtype.num_arrivals * arrival_multiplier, *aval.shape)
for aval in scratch_avals
if isinstance(aval.dtype, gpu_core.BarrierType)
]
@ -1562,15 +1591,24 @@ def _debug_print_lowering_rule(
@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Warpgroup)
def _run_scoped_lowering_rule(
ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr
):
input_refs = []
should_discharge = []
alloc_stack = contextlib.ExitStack()
arrival_multiplier = (
WARPGROUP_SIZE if ctx.thread_semantics == mgpu.ThreadSemantics.Lane else 1
)
for v in jaxpr.invars:
aval = v.aval
if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef):
if ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup:
# TODO(bchetioui): Fix this and remove the NotImplementedError.
raise NotImplementedError(
"WGMMA accumulators are not supported with Warpgroup semantics."
)
mlir_dtype = mlir.dtype_to_ir_type(aval.dtype)
input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, mlir_dtype))
should_discharge.append(True)
@ -1578,7 +1616,7 @@ def _run_scoped_lowering_rule(
input_refs.append(
ctx.module_ctx.reserve_barrier(
mgpu.Barrier(
aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape
aval.dtype.num_arrivals * arrival_multiplier, *aval.shape
)
)
)
@ -1604,13 +1642,23 @@ def _run_scoped_lowering_rule(
discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge)
new_input_vals = consts + tuple(input_refs)
outs = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, discharged_jaxpr, new_input_vals, ()
ctx.module_ctx,
ctx.launch_ctx,
discharged_jaxpr,
new_input_vals,
(),
thread_semantics=ctx.thread_semantics,
)
# Discharge appends to the output the refs that got discharged.
outs = outs[:-sum(should_discharge)]
else:
outs = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts
ctx.module_ctx,
ctx.launch_ctx,
jaxpr,
input_refs,
consts,
thread_semantics=ctx.thread_semantics,
)
assert len(outs) == len(jaxpr.outvars), (jaxpr, outs)

View File

@ -74,6 +74,9 @@ def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params):
@lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane)
@lowering.register_lowering_rule(
copy_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup
)
def _copy_smem_to_gmem_lowering(
ctx: lowering.LoweringRuleContext,
src,
@ -97,15 +100,49 @@ def _copy_smem_to_gmem_lowering(
dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms)
src, src_transforms = lowering._handle_indexing(src, src_transforms)
copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms)
ctx.launch_ctx.async_copy(
src_ref=src,
dst_ref=dst,
predicate=predicate,
**copy_params,
if ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
ctx.launch_ctx.async_copy(
src_ref=src,
dst_ref=dst,
predicate=predicate,
**copy_params,
)
return ()
if "gmem_slice" not in copy_params:
i32 = ir.IntegerType.get_signless(32)
slice_lengths = ir.MemRefType(src.type).shape
indices = [mgpu.utils.c(0, i32)] * len(slice_lengths)
else:
indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"])
assert copy_params.get("swizzle") is None
assert not copy_params.get("gmem_transform")
mgpu.dialect.async_store(
src, dst, indices, slice_lengths, predicate=predicate
)
return ()
def _split_gmem_slice(gmem_slice):
i32 = ir.IntegerType.get_signless(32)
indices = []
slice_lengths = []
for idx in gmem_slice:
match idx:
case slice():
indices.append(mgpu_utils.c(idx.start, i32))
slice_lengths.append(idx.stop - idx.start)
case mgpu.DynamicSlice():
indices.append(arith_dialect.index_cast(i32, idx.base))
slice_lengths.append(idx.length)
case ir.Value():
indices.append(arith_dialect.index_cast(i32, idx))
slice_lengths.append(-1)
case _:
raise NotImplementedError(f"Unsupported GMEM slice: {idx}")
return indices, slice_lengths
def _extract_gmem_copy_params(transforms):
if not transforms:
return {}
@ -118,6 +155,7 @@ def _extract_gmem_copy_params(transforms):
gmem_slice=lowering._ndindexer_indices(indexer),
)
def _extract_smem_copy_params(transforms):
if not transforms:
return {}
@ -188,6 +226,9 @@ def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params):
@lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane)
@lowering.register_lowering_rule(
copy_gmem_to_smem_p, mgpu.ThreadSemantics.Warpgroup
)
def _copy_gmem_to_smem_lowering(
ctx: lowering.LoweringRuleContext,
src,
@ -220,17 +261,38 @@ def _copy_gmem_to_smem_lowering(
)
dst_ty = ir.MemRefType(dst.type)
bytes = math.prod(dst_ty.shape) * mgpu.bytewidth(dst_ty.element_type)
if bytes % WARPGROUP_SIZE:
raise NotImplementedError("Only aligned copies are supported")
# We arrive uniformly from each thread in the WG, so we need to divide the
# number of bytes by the number of threads in the WG.
# TODO: apaszke - Relax this. We can just select the WG leader and have it
# arrive with the whole transfer size, while everyone else arrives with 0.
# But we should continue using this scheme as it's likely to be faster.
bytes //= WARPGROUP_SIZE
barrier.arrive_expect_tx(bytes)
ctx.launch_ctx.async_copy(
src_ref=src, dst_ref=dst, barrier=barrier, arrive=False, **copy_params
if ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
if bytes % WARPGROUP_SIZE:
raise NotImplementedError("Only aligned copies are supported")
# We arrive uniformly from each thread in the WG, so we need to divide the
# number of bytes by the number of threads in the WG.
# TODO: apaszke - Relax this. We can just select the WG leader and have it
# arrive with the whole transfer size, while everyone else arrives with 0.
# But we should continue using this scheme as it's likely to be faster.
bytes //= WARPGROUP_SIZE
barrier.arrive_expect_tx(bytes)
ctx.launch_ctx.async_copy(
src_ref=src, dst_ref=dst, barrier=barrier, arrive=False, **copy_params
)
return ()
if "gmem_slice" not in copy_params:
i32 = ir.IntegerType.get_signless(32)
slice_lengths = ir.MemRefType(src.type).shape
indices = [mgpu.utils.c(0, i32)] * len(slice_lengths)
else:
indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"])
assert copy_params.get("swizzle") is None
assert not copy_params.get("gmem_transform")
barrier_ref = barrier.as_dialect_barrier_memref()
mgpu.dialect.arrive_expect_tx(barrier_ref, bytes)
mgpu.dialect.async_load(
src,
dst,
barrier_ref,
indices,
slice_lengths,
collective=ir.ArrayAttr.get([]),
)
return ()
@ -346,6 +408,7 @@ def _barrier_wait_abstract_eval(barrier, *args, **params):
@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane)
@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Warpgroup)
def _barrier_wait_lowering(
ctx: lowering.LoweringRuleContext,
barrier,
@ -383,6 +446,9 @@ def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only):
@lowering.register_lowering_rule(wait_smem_to_gmem_p, mgpu.ThreadSemantics.Lane)
@lowering.register_lowering_rule(
wait_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup
)
def _wait_smem_to_gmem_lowering(
ctx: lowering.LoweringRuleContext, n, *, wait_read_only
):
@ -715,6 +781,7 @@ def _commit_smem_abstract_eval():
@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane)
@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Warpgroup)
def _commit_smem_lowering(ctx: lowering.LoweringRuleContext):
mgpu.commit_shared()
return ()

View File

@ -320,10 +320,18 @@ def _mgpu_async_load_op_lowering_rule(
dst_layout = ir.MemRefType(load_op.destination.type).layout
swizzle, transforms = memref_layout_to_swizzle_and_transforms(dst_layout)
gmem_slice = []
for idx, size in zip(load_op.indices, load_op.slice_lengths):
idx = arith.index_cast(ir.IndexType.get(), idx)
v = idx if size < 0 else utils.DynamicSlice(idx, size)
gmem_slice.append(v)
# TODO(dasenov): Add support for the remaining op properties.
ctx.launch_context.async_copy(
src_ref=load_op.source,
dst_ref=transform_memref(load_op.destination, transforms),
gmem_slice=tuple(gmem_slice),
barrier=barrier,
arrive=False,
uniform=True,
@ -342,10 +350,18 @@ def _mgpu_async_store_op_lowering_rule(
src_layout = ir.MemRefType(store_op.source.type).layout
swizzle, transforms = memref_layout_to_swizzle_and_transforms(src_layout)
gmem_slice = []
for idx, size in zip(store_op.indices, store_op.slice_lengths):
idx = arith.index_cast(ir.IndexType.get(), idx)
v = idx if size < 0 else utils.DynamicSlice(idx, size)
gmem_slice.append(v)
# TODO(dasenov): Add support for the remaining op properties.
ctx.launch_context.async_copy(
src_ref=transform_memref(store_op.source, transforms),
dst_ref=store_op.destination,
gmem_slice=tuple(gmem_slice),
swizzle=swizzle,
gmem_transform=transforms,
uniform=True,

View File

@ -312,13 +312,19 @@ class PallasCallTest(PallasTest):
np.testing.assert_array_equal(kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension))
@parameterized.product(indexer=[..., slice(128), slice(None, 128)])
def test_copy_smem_to_gmem(self, indexer):
@parameterized.product(
indexer=[..., slice(128), slice(None, 128)],
thread_semantics=[*plgpu.ThreadSemantics],
)
def test_copy_smem_to_gmem(self, indexer, thread_semantics):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
scratch_shapes=[plgpu.SMEM((256,), jnp.float32)],
compiler_params=plgpu.GPUCompilerParams(
thread_semantics=thread_semantics
),
)
def kernel(x_ref, o_ref_gmem, scratch_ref):
scratch_ref[...] = x_ref[...] + 1
@ -775,7 +781,8 @@ class PallasCallTest(PallasTest):
np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)),
jnp.full((128,), 10, dtype=jnp.int32))
def test_run_scoped(self):
@parameterized.product(thread_semantics=[*plgpu.ThreadSemantics])
def test_run_scoped(self, thread_semantics):
def kernel(x_ref, o_ref):
def body(tmp_ref):
self.assertEqual(tmp_ref.shape, (8, 128))
@ -790,6 +797,9 @@ class PallasCallTest(PallasTest):
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
compiler_params=plgpu.GPUCompilerParams(
thread_semantics=thread_semantics
),
)
o = f(inp)
np.testing.assert_array_equal(o, inp + 1.0)