mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[pallas:mosaic_gpu] Added WG lowering rules for TMA primitives and run_scoped_p
PiperOrigin-RevId: 730780335
This commit is contained in:
parent
80848ad859
commit
7eadc64b5a
@ -79,6 +79,11 @@ def _align_to(x: int, alignment: int):
|
||||
return x
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ResourceEstimatorContext:
|
||||
arrival_multiplier: int
|
||||
|
||||
|
||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||
class Resources:
|
||||
smem_scratch_bytes: int = 0
|
||||
@ -118,7 +123,9 @@ class Resources:
|
||||
|
||||
class ResourceEstimator(Protocol):
|
||||
|
||||
def __call__(self, *args: Any, **params: Any) -> Resources:
|
||||
def __call__(
|
||||
self, ctx: ResourceEstimatorContext, *args: Any, **params: Any
|
||||
) -> Resources:
|
||||
...
|
||||
|
||||
|
||||
@ -133,7 +140,9 @@ def _register_resource_estimator(primitive: jax_core.Primitive):
|
||||
return deco
|
||||
|
||||
|
||||
def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources:
|
||||
def _estimate_resources(
|
||||
ctx: ResourceEstimatorContext, jaxpr: jax_core.Jaxpr
|
||||
) -> Resources:
|
||||
"""Estimates the resources required by the kernel."""
|
||||
rs = Resources(smem_scratch_bytes=0)
|
||||
for eqn in jaxpr.eqns:
|
||||
@ -142,34 +151,48 @@ def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources:
|
||||
if rule is None:
|
||||
# Assume that unsupported primitives are neutral wrt resource usage.
|
||||
continue
|
||||
rs |= rule(*(invar.aval for invar in eqn.invars), **eqn.params)
|
||||
rs |= rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params)
|
||||
|
||||
return rs
|
||||
|
||||
|
||||
@_register_resource_estimator(lax.cond_p)
|
||||
def _cond_resource_estimator(*args, branches) -> int:
|
||||
def _cond_resource_estimator(
|
||||
ctx: ResourceEstimatorContext, *args, branches
|
||||
) -> int:
|
||||
del args # Unused.
|
||||
return functools.reduce(
|
||||
lambda a, b: a | b,
|
||||
(_estimate_resources(branch.jaxpr) for branch in branches),
|
||||
(_estimate_resources(ctx, branch.jaxpr) for branch in branches),
|
||||
)
|
||||
|
||||
|
||||
@_register_resource_estimator(lax.scan_p)
|
||||
def _scan_resource_estimator(*args, jaxpr: jax_core.ClosedJaxpr, **params) -> int:
|
||||
def _scan_resource_estimator(
|
||||
ctx: ResourceEstimatorContext, *args, jaxpr: jax_core.ClosedJaxpr, **params
|
||||
) -> int:
|
||||
del args, params # Unused.
|
||||
return _estimate_resources(jaxpr)
|
||||
return _estimate_resources(ctx, jaxpr)
|
||||
|
||||
|
||||
@_register_resource_estimator(lax.while_p)
|
||||
def _while_resource_estimator(*args, cond_jaxpr: jax_core.ClosedJaxpr, body_jaxpr: jax_core.ClosedJaxpr, **params) -> int:
|
||||
def _while_resource_estimator(
|
||||
ctx: ResourceEstimatorContext,
|
||||
*args,
|
||||
cond_jaxpr: jax_core.ClosedJaxpr,
|
||||
body_jaxpr: jax_core.ClosedJaxpr,
|
||||
**params,
|
||||
) -> int:
|
||||
del args, params # Unused.
|
||||
return _estimate_resources(cond_jaxpr) | _estimate_resources(body_jaxpr)
|
||||
return _estimate_resources(ctx, cond_jaxpr) | _estimate_resources(
|
||||
ctx, body_jaxpr
|
||||
)
|
||||
|
||||
|
||||
@_register_resource_estimator(primitives.run_scoped_p)
|
||||
def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
|
||||
def _run_scoped_resource_estimator(
|
||||
ctx: ResourceEstimatorContext, *consts, jaxpr: jax_core.Jaxpr
|
||||
) -> int:
|
||||
del consts # Unused.
|
||||
rs = Resources()
|
||||
for v in jaxpr.invars:
|
||||
@ -178,7 +201,7 @@ def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
|
||||
rs += Resources(
|
||||
barrier_counts=collections.Counter([
|
||||
mgpu.Barrier(
|
||||
aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape
|
||||
aval.dtype.num_arrivals * ctx.arrival_multiplier, *aval.shape
|
||||
)
|
||||
])
|
||||
)
|
||||
@ -186,11 +209,14 @@ def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
|
||||
rs += Resources(
|
||||
smem_scratch_bytes=math.prod(aval.shape) * aval.dtype.itemsize
|
||||
)
|
||||
return rs + _estimate_resources(jaxpr)
|
||||
return rs + _estimate_resources(ctx, jaxpr)
|
||||
|
||||
|
||||
@_register_resource_estimator(lax.reduce_sum_p)
|
||||
def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
|
||||
def _reduce_sum_resource_estimator(
|
||||
ctx: ResourceEstimatorContext, x_aval: jax_core.ShapedArray, *, axes
|
||||
) -> int:
|
||||
del ctx, axes # Unused.
|
||||
# We don't need shmem for some reductons, but it depends on the layout, so we
|
||||
# conservatively request some scratch space.
|
||||
return Resources(smem_scratch_bytes=4 * x_aval.dtype.itemsize)
|
||||
@ -786,9 +812,12 @@ def lower_jaxpr_to_module(
|
||||
"All scratch operands must be SMEM references or accumulators (ACC),"
|
||||
f" but got: {scratch_avals}"
|
||||
)
|
||||
rs = _estimate_resources(jaxpr)
|
||||
arrival_multiplier = (
|
||||
WARPGROUP_SIZE if thread_semantics == mgpu.ThreadSemantics.Lane else 1
|
||||
)
|
||||
rs = _estimate_resources(ResourceEstimatorContext(arrival_multiplier), jaxpr)
|
||||
extra_barriers = [
|
||||
mgpu.Barrier(aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape)
|
||||
mgpu.Barrier(aval.dtype.num_arrivals * arrival_multiplier, *aval.shape)
|
||||
for aval in scratch_avals
|
||||
if isinstance(aval.dtype, gpu_core.BarrierType)
|
||||
]
|
||||
@ -1562,15 +1591,24 @@ def _debug_print_lowering_rule(
|
||||
|
||||
|
||||
@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _run_scoped_lowering_rule(
|
||||
ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr
|
||||
):
|
||||
input_refs = []
|
||||
should_discharge = []
|
||||
alloc_stack = contextlib.ExitStack()
|
||||
arrival_multiplier = (
|
||||
WARPGROUP_SIZE if ctx.thread_semantics == mgpu.ThreadSemantics.Lane else 1
|
||||
)
|
||||
for v in jaxpr.invars:
|
||||
aval = v.aval
|
||||
if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef):
|
||||
if ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup:
|
||||
# TODO(bchetioui): Fix this and remove the NotImplementedError.
|
||||
raise NotImplementedError(
|
||||
"WGMMA accumulators are not supported with Warpgroup semantics."
|
||||
)
|
||||
mlir_dtype = mlir.dtype_to_ir_type(aval.dtype)
|
||||
input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, mlir_dtype))
|
||||
should_discharge.append(True)
|
||||
@ -1578,7 +1616,7 @@ def _run_scoped_lowering_rule(
|
||||
input_refs.append(
|
||||
ctx.module_ctx.reserve_barrier(
|
||||
mgpu.Barrier(
|
||||
aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape
|
||||
aval.dtype.num_arrivals * arrival_multiplier, *aval.shape
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -1604,13 +1642,23 @@ def _run_scoped_lowering_rule(
|
||||
discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge)
|
||||
new_input_vals = consts + tuple(input_refs)
|
||||
outs = lower_jaxpr_to_mosaic_gpu(
|
||||
ctx.module_ctx, ctx.launch_ctx, discharged_jaxpr, new_input_vals, ()
|
||||
ctx.module_ctx,
|
||||
ctx.launch_ctx,
|
||||
discharged_jaxpr,
|
||||
new_input_vals,
|
||||
(),
|
||||
thread_semantics=ctx.thread_semantics,
|
||||
)
|
||||
# Discharge appends to the output the refs that got discharged.
|
||||
outs = outs[:-sum(should_discharge)]
|
||||
else:
|
||||
outs = lower_jaxpr_to_mosaic_gpu(
|
||||
ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts
|
||||
ctx.module_ctx,
|
||||
ctx.launch_ctx,
|
||||
jaxpr,
|
||||
input_refs,
|
||||
consts,
|
||||
thread_semantics=ctx.thread_semantics,
|
||||
)
|
||||
|
||||
assert len(outs) == len(jaxpr.outvars), (jaxpr, outs)
|
||||
|
@ -74,6 +74,9 @@ def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params):
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane)
|
||||
@lowering.register_lowering_rule(
|
||||
copy_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup
|
||||
)
|
||||
def _copy_smem_to_gmem_lowering(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
src,
|
||||
@ -97,15 +100,49 @@ def _copy_smem_to_gmem_lowering(
|
||||
dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms)
|
||||
src, src_transforms = lowering._handle_indexing(src, src_transforms)
|
||||
copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms)
|
||||
ctx.launch_ctx.async_copy(
|
||||
src_ref=src,
|
||||
dst_ref=dst,
|
||||
predicate=predicate,
|
||||
**copy_params,
|
||||
if ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
ctx.launch_ctx.async_copy(
|
||||
src_ref=src,
|
||||
dst_ref=dst,
|
||||
predicate=predicate,
|
||||
**copy_params,
|
||||
)
|
||||
return ()
|
||||
|
||||
if "gmem_slice" not in copy_params:
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
slice_lengths = ir.MemRefType(src.type).shape
|
||||
indices = [mgpu.utils.c(0, i32)] * len(slice_lengths)
|
||||
else:
|
||||
indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"])
|
||||
assert copy_params.get("swizzle") is None
|
||||
assert not copy_params.get("gmem_transform")
|
||||
mgpu.dialect.async_store(
|
||||
src, dst, indices, slice_lengths, predicate=predicate
|
||||
)
|
||||
return ()
|
||||
|
||||
|
||||
def _split_gmem_slice(gmem_slice):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
indices = []
|
||||
slice_lengths = []
|
||||
for idx in gmem_slice:
|
||||
match idx:
|
||||
case slice():
|
||||
indices.append(mgpu_utils.c(idx.start, i32))
|
||||
slice_lengths.append(idx.stop - idx.start)
|
||||
case mgpu.DynamicSlice():
|
||||
indices.append(arith_dialect.index_cast(i32, idx.base))
|
||||
slice_lengths.append(idx.length)
|
||||
case ir.Value():
|
||||
indices.append(arith_dialect.index_cast(i32, idx))
|
||||
slice_lengths.append(-1)
|
||||
case _:
|
||||
raise NotImplementedError(f"Unsupported GMEM slice: {idx}")
|
||||
return indices, slice_lengths
|
||||
|
||||
|
||||
def _extract_gmem_copy_params(transforms):
|
||||
if not transforms:
|
||||
return {}
|
||||
@ -118,6 +155,7 @@ def _extract_gmem_copy_params(transforms):
|
||||
gmem_slice=lowering._ndindexer_indices(indexer),
|
||||
)
|
||||
|
||||
|
||||
def _extract_smem_copy_params(transforms):
|
||||
if not transforms:
|
||||
return {}
|
||||
@ -188,6 +226,9 @@ def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params):
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane)
|
||||
@lowering.register_lowering_rule(
|
||||
copy_gmem_to_smem_p, mgpu.ThreadSemantics.Warpgroup
|
||||
)
|
||||
def _copy_gmem_to_smem_lowering(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
src,
|
||||
@ -220,17 +261,38 @@ def _copy_gmem_to_smem_lowering(
|
||||
)
|
||||
dst_ty = ir.MemRefType(dst.type)
|
||||
bytes = math.prod(dst_ty.shape) * mgpu.bytewidth(dst_ty.element_type)
|
||||
if bytes % WARPGROUP_SIZE:
|
||||
raise NotImplementedError("Only aligned copies are supported")
|
||||
# We arrive uniformly from each thread in the WG, so we need to divide the
|
||||
# number of bytes by the number of threads in the WG.
|
||||
# TODO: apaszke - Relax this. We can just select the WG leader and have it
|
||||
# arrive with the whole transfer size, while everyone else arrives with 0.
|
||||
# But we should continue using this scheme as it's likely to be faster.
|
||||
bytes //= WARPGROUP_SIZE
|
||||
barrier.arrive_expect_tx(bytes)
|
||||
ctx.launch_ctx.async_copy(
|
||||
src_ref=src, dst_ref=dst, barrier=barrier, arrive=False, **copy_params
|
||||
if ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
if bytes % WARPGROUP_SIZE:
|
||||
raise NotImplementedError("Only aligned copies are supported")
|
||||
# We arrive uniformly from each thread in the WG, so we need to divide the
|
||||
# number of bytes by the number of threads in the WG.
|
||||
# TODO: apaszke - Relax this. We can just select the WG leader and have it
|
||||
# arrive with the whole transfer size, while everyone else arrives with 0.
|
||||
# But we should continue using this scheme as it's likely to be faster.
|
||||
bytes //= WARPGROUP_SIZE
|
||||
barrier.arrive_expect_tx(bytes)
|
||||
ctx.launch_ctx.async_copy(
|
||||
src_ref=src, dst_ref=dst, barrier=barrier, arrive=False, **copy_params
|
||||
)
|
||||
return ()
|
||||
|
||||
if "gmem_slice" not in copy_params:
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
slice_lengths = ir.MemRefType(src.type).shape
|
||||
indices = [mgpu.utils.c(0, i32)] * len(slice_lengths)
|
||||
else:
|
||||
indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"])
|
||||
assert copy_params.get("swizzle") is None
|
||||
assert not copy_params.get("gmem_transform")
|
||||
barrier_ref = barrier.as_dialect_barrier_memref()
|
||||
mgpu.dialect.arrive_expect_tx(barrier_ref, bytes)
|
||||
mgpu.dialect.async_load(
|
||||
src,
|
||||
dst,
|
||||
barrier_ref,
|
||||
indices,
|
||||
slice_lengths,
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
return ()
|
||||
|
||||
@ -346,6 +408,7 @@ def _barrier_wait_abstract_eval(barrier, *args, **params):
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane)
|
||||
@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _barrier_wait_lowering(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
barrier,
|
||||
@ -383,6 +446,9 @@ def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only):
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(wait_smem_to_gmem_p, mgpu.ThreadSemantics.Lane)
|
||||
@lowering.register_lowering_rule(
|
||||
wait_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup
|
||||
)
|
||||
def _wait_smem_to_gmem_lowering(
|
||||
ctx: lowering.LoweringRuleContext, n, *, wait_read_only
|
||||
):
|
||||
@ -715,6 +781,7 @@ def _commit_smem_abstract_eval():
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane)
|
||||
@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _commit_smem_lowering(ctx: lowering.LoweringRuleContext):
|
||||
mgpu.commit_shared()
|
||||
return ()
|
||||
|
@ -320,10 +320,18 @@ def _mgpu_async_load_op_lowering_rule(
|
||||
|
||||
dst_layout = ir.MemRefType(load_op.destination.type).layout
|
||||
swizzle, transforms = memref_layout_to_swizzle_and_transforms(dst_layout)
|
||||
|
||||
gmem_slice = []
|
||||
for idx, size in zip(load_op.indices, load_op.slice_lengths):
|
||||
idx = arith.index_cast(ir.IndexType.get(), idx)
|
||||
v = idx if size < 0 else utils.DynamicSlice(idx, size)
|
||||
gmem_slice.append(v)
|
||||
|
||||
# TODO(dasenov): Add support for the remaining op properties.
|
||||
ctx.launch_context.async_copy(
|
||||
src_ref=load_op.source,
|
||||
dst_ref=transform_memref(load_op.destination, transforms),
|
||||
gmem_slice=tuple(gmem_slice),
|
||||
barrier=barrier,
|
||||
arrive=False,
|
||||
uniform=True,
|
||||
@ -342,10 +350,18 @@ def _mgpu_async_store_op_lowering_rule(
|
||||
|
||||
src_layout = ir.MemRefType(store_op.source.type).layout
|
||||
swizzle, transforms = memref_layout_to_swizzle_and_transforms(src_layout)
|
||||
|
||||
gmem_slice = []
|
||||
for idx, size in zip(store_op.indices, store_op.slice_lengths):
|
||||
idx = arith.index_cast(ir.IndexType.get(), idx)
|
||||
v = idx if size < 0 else utils.DynamicSlice(idx, size)
|
||||
gmem_slice.append(v)
|
||||
|
||||
# TODO(dasenov): Add support for the remaining op properties.
|
||||
ctx.launch_context.async_copy(
|
||||
src_ref=transform_memref(store_op.source, transforms),
|
||||
dst_ref=store_op.destination,
|
||||
gmem_slice=tuple(gmem_slice),
|
||||
swizzle=swizzle,
|
||||
gmem_transform=transforms,
|
||||
uniform=True,
|
||||
|
@ -312,13 +312,19 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
np.testing.assert_array_equal(kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension))
|
||||
|
||||
@parameterized.product(indexer=[..., slice(128), slice(None, 128)])
|
||||
def test_copy_smem_to_gmem(self, indexer):
|
||||
@parameterized.product(
|
||||
indexer=[..., slice(128), slice(None, 128)],
|
||||
thread_semantics=[*plgpu.ThreadSemantics],
|
||||
)
|
||||
def test_copy_smem_to_gmem(self, indexer, thread_semantics):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
||||
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
||||
scratch_shapes=[plgpu.SMEM((256,), jnp.float32)],
|
||||
compiler_params=plgpu.GPUCompilerParams(
|
||||
thread_semantics=thread_semantics
|
||||
),
|
||||
)
|
||||
def kernel(x_ref, o_ref_gmem, scratch_ref):
|
||||
scratch_ref[...] = x_ref[...] + 1
|
||||
@ -775,7 +781,8 @@ class PallasCallTest(PallasTest):
|
||||
np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)),
|
||||
jnp.full((128,), 10, dtype=jnp.int32))
|
||||
|
||||
def test_run_scoped(self):
|
||||
@parameterized.product(thread_semantics=[*plgpu.ThreadSemantics])
|
||||
def test_run_scoped(self, thread_semantics):
|
||||
def kernel(x_ref, o_ref):
|
||||
def body(tmp_ref):
|
||||
self.assertEqual(tmp_ref.shape, (8, 128))
|
||||
@ -790,6 +797,9 @@ class PallasCallTest(PallasTest):
|
||||
f = pl.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
compiler_params=plgpu.GPUCompilerParams(
|
||||
thread_semantics=thread_semantics
|
||||
),
|
||||
)
|
||||
o = f(inp)
|
||||
np.testing.assert_array_equal(o, inp + 1.0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user