mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas:mosaic_gpu] pl.run_scoped
now supports scoped barriers
PiperOrigin-RevId: 684449776
This commit is contained in:
parent
94abaf430e
commit
70ee8e1161
@ -16,12 +16,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import collections
|
||||
from collections.abc import MutableMapping, MutableSequence, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools as it
|
||||
import math
|
||||
from typing import Any, cast
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -59,44 +60,101 @@ zip, unsafe_zip = util.safe_zip, zip
|
||||
partial = functools.partial
|
||||
SMEM = gpu_core.SMEM
|
||||
|
||||
_smem_estimators = {}
|
||||
|
||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||
class Resources:
|
||||
smem_scratch_bytes: int
|
||||
barriers: collections.Counter[mgpu.Barrier] = dataclasses.field(
|
||||
default_factory=collections.Counter
|
||||
)
|
||||
|
||||
def __add__(self, other: Resources) -> Resources:
|
||||
# TODO(slebedev): Optimize this.
|
||||
#
|
||||
# At the moment, if we have run_scoped(b1) followed by run_scoped(b2)
|
||||
# we will allocate two barriers, even though one would be enough.
|
||||
return Resources(
|
||||
smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes,
|
||||
barriers=self.barriers + other.barriers,
|
||||
)
|
||||
|
||||
def __or__(self, other: Resources) -> Resources:
|
||||
return Resources(
|
||||
smem_scratch_bytes=max(
|
||||
self.smem_scratch_bytes, other.smem_scratch_bytes
|
||||
),
|
||||
barriers=self.barriers | other.barriers,
|
||||
)
|
||||
|
||||
|
||||
def _regiter_smem_estimator(primitive: jax_core.Primitive):
|
||||
class ResourceEstimator(Protocol):
|
||||
|
||||
def __call__(self, *args: Any, **params: Any) -> Resources:
|
||||
...
|
||||
|
||||
|
||||
_resource_estimators: dict[jax_core.Primitive, ResourceEstimator] = {}
|
||||
|
||||
|
||||
def _register_resource_estimator(primitive: jax_core.Primitive):
|
||||
def deco(fn):
|
||||
_smem_estimators[primitive] = fn
|
||||
_resource_estimators[primitive] = fn
|
||||
return fn
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
def _estimate_smem_scratch_bytes(jaxpr: jax_core.Jaxpr) -> int:
|
||||
"""Estimates the amount of SMEM scratch bytes required by the kernel."""
|
||||
max_used = 0
|
||||
def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources:
|
||||
"""Estimates the resources required by the kernel."""
|
||||
rs = Resources(smem_scratch_bytes=0)
|
||||
for eqn in jaxpr.eqns:
|
||||
# TODO(slebedev): Add support for other primitives, notably control flow.
|
||||
rule = _smem_estimators.get(eqn.primitive)
|
||||
rule = _resource_estimators.get(eqn.primitive)
|
||||
if rule is None:
|
||||
# Assume that unsupported primitives are neutral wrt SMEM usage.
|
||||
# Assume that unsupported primitives are neutral wrt resource usage.
|
||||
continue
|
||||
max_used = max(
|
||||
max_used, rule(*(invar.aval for invar in eqn.invars), **eqn.params)
|
||||
)
|
||||
return max_used
|
||||
rs |= rule(*(invar.aval for invar in eqn.invars), **eqn.params)
|
||||
return rs
|
||||
|
||||
|
||||
@_regiter_smem_estimator(primitives.run_scoped_p)
|
||||
def _run_scoped_smem_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
|
||||
@_register_resource_estimator(lax.cond_p)
|
||||
def _cond_resource_estimator(*args, branches) -> int:
|
||||
del args # Unused.
|
||||
return functools.reduce(
|
||||
lambda a, b: a | b,
|
||||
(_estimate_resources(branch.jaxpr) for branch in branches),
|
||||
)
|
||||
|
||||
|
||||
@_register_resource_estimator(lax.scan_p)
|
||||
def _scan_resource_estimator(*args, jaxpr: jax_core.ClosedJaxpr, **params) -> int:
|
||||
del args, params # Unused.
|
||||
return _estimate_resources(jaxpr)
|
||||
|
||||
|
||||
@_register_resource_estimator(primitives.run_scoped_p)
|
||||
def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
|
||||
del consts # Unused.
|
||||
in_avals = (v.aval.inner_aval for v in jaxpr.invars)
|
||||
return sum(math.prod(aval.shape) * aval.dtype.itemsize for aval in in_avals)
|
||||
smem_scratch_bytes = 0
|
||||
barriers = []
|
||||
for v in jaxpr.invars:
|
||||
aval = v.aval
|
||||
if isinstance(aval.dtype, gpu_core.BarrierType):
|
||||
barriers.append(mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape))
|
||||
else:
|
||||
smem_scratch_bytes += math.prod(aval.shape) * aval.dtype.itemsize
|
||||
rs = Resources(
|
||||
smem_scratch_bytes=smem_scratch_bytes,
|
||||
barriers=collections.Counter(barriers),
|
||||
)
|
||||
return rs + _estimate_resources(jaxpr)
|
||||
|
||||
|
||||
@_regiter_smem_estimator(lax.reduce_sum_p)
|
||||
def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
|
||||
@_register_resource_estimator(lax.reduce_sum_p)
|
||||
def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
|
||||
if axes != (0,):
|
||||
raise NotImplementedError("No support for axes other than 0 yet")
|
||||
return 4 * x_aval.dtype.itemsize
|
||||
return Resources(smem_scratch_bytes=4 * x_aval.dtype.itemsize)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -106,7 +164,21 @@ class ModuleContext:
|
||||
program_ids: Sequence[ir.Value] | None
|
||||
approx_math: bool
|
||||
runtime_smem: ir.Value # ir.MemRefType
|
||||
smem_used_bytes: int = 0
|
||||
smem_used_bytes: int
|
||||
runtime_barriers: MutableMapping[
|
||||
mgpu.Barrier, MutableSequence[mgpu.BarrierRef]
|
||||
]
|
||||
|
||||
def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef:
|
||||
"""Reserves a barrier.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the barrier is already reserved.
|
||||
"""
|
||||
available = self.runtime_barriers.get(barrier, [])
|
||||
if not available:
|
||||
raise RuntimeError(f"Barrier {barrier} is already reserved")
|
||||
return available.pop()
|
||||
|
||||
# TODO(cperivol): Only return the shapes and figure out the sizes when freeing.
|
||||
def scratch_view(
|
||||
@ -352,7 +424,7 @@ def lower_jaxpr_to_module(
|
||||
in_buffers_smem, out_buffers_smem = util.split_list(
|
||||
buffers_smem, [grid_mapping.num_inputs]
|
||||
)
|
||||
barriers, *extra_barriers = barriers
|
||||
barriers, runtime_barriers, extra_barriers = barriers
|
||||
|
||||
parallel_count = it.count()
|
||||
program_ids_template = [
|
||||
@ -367,9 +439,21 @@ def lower_jaxpr_to_module(
|
||||
step = arith_dialect.index_cast(ir.IntegerType.get_signless(32), step)
|
||||
return [step if pid is None else pid for pid in program_ids_template]
|
||||
|
||||
grouped_barriers = collections.defaultdict(list)
|
||||
for barrier, barrier_ref in zip(
|
||||
sorted(rs.barriers.elements()), runtime_barriers
|
||||
):
|
||||
grouped_barriers[barrier].append(barrier_ref)
|
||||
module_ctx = ModuleContext(
|
||||
name_and_src_info.name, grid_mapping, None, approx_math, runtime_smem
|
||||
name_and_src_info.name,
|
||||
grid_mapping,
|
||||
None,
|
||||
approx_math,
|
||||
runtime_smem,
|
||||
smem_used_bytes=0,
|
||||
runtime_barriers=grouped_barriers,
|
||||
)
|
||||
del runtime_smem, grouped_barriers, runtime_barriers
|
||||
|
||||
smem_scratch_it = iter(scratch_buffers_smem)
|
||||
scratch_buffers_template = []
|
||||
@ -611,6 +695,7 @@ def lower_jaxpr_to_module(
|
||||
"All scratch operands must be SMEM references or accumulators (ACC),"
|
||||
f" but got: {scratch_avals}"
|
||||
)
|
||||
rs = _estimate_resources(jaxpr)
|
||||
extra_barriers = [
|
||||
mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)
|
||||
for aval in scratch_avals
|
||||
@ -624,7 +709,7 @@ def lower_jaxpr_to_module(
|
||||
]
|
||||
smem_scratch_bytes = compiler_params.get("smem_scratch_bytes")
|
||||
if smem_scratch_bytes is None:
|
||||
smem_scratch_bytes = _estimate_smem_scratch_bytes(jaxpr)
|
||||
smem_scratch_bytes = rs.smem_scratch_bytes
|
||||
extra_smem_scratch.append(
|
||||
jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8)
|
||||
)
|
||||
@ -641,7 +726,8 @@ def lower_jaxpr_to_module(
|
||||
*extra_smem_scratch,
|
||||
(
|
||||
mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps),
|
||||
*extra_barriers,
|
||||
[*sorted(rs.barriers.elements())],
|
||||
extra_barriers,
|
||||
),
|
||||
),
|
||||
module_name=name_and_src_info.name,
|
||||
@ -979,21 +1065,28 @@ def _run_scoped_lowering_rule(
|
||||
input_refs = []
|
||||
bytes_allocated = 0
|
||||
should_discharge = []
|
||||
for a in jaxpr.invars:
|
||||
a = a.aval
|
||||
if isinstance(a, gpu_core.WGMMAAbstractAccumulatorRef):
|
||||
mlir_dtype = mlir.dtype_to_ir_type(a.dtype)
|
||||
input_refs.append(mgpu.WGMMAAccumulator.zero(*a.shape, mlir_dtype))
|
||||
for v in jaxpr.invars:
|
||||
aval = v.aval
|
||||
if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef):
|
||||
mlir_dtype = mlir.dtype_to_ir_type(aval.dtype)
|
||||
input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, mlir_dtype))
|
||||
should_discharge.append(True)
|
||||
elif a.memory_space == gpu_core.SMEM:
|
||||
elif isinstance(aval.dtype, gpu_core.BarrierType):
|
||||
input_refs.append(
|
||||
ctx.module_ctx.reserve_barrier(
|
||||
mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)
|
||||
)
|
||||
)
|
||||
should_discharge.append(False)
|
||||
elif aval.memory_space == gpu_core.SMEM:
|
||||
ref_bytes, [input_ref] = ctx.module_ctx.scratch_view(
|
||||
[jax.ShapeDtypeStruct(shape=a.shape, dtype=a.dtype)]
|
||||
[jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)]
|
||||
)
|
||||
bytes_allocated += ref_bytes
|
||||
input_refs.append(input_ref)
|
||||
should_discharge.append(False)
|
||||
else:
|
||||
raise ValueError(f"Can't convert to ref: {a}")
|
||||
raise ValueError(f"Can't convert to ref: {aval}")
|
||||
|
||||
if any(should_discharge):
|
||||
# We convert consts to args, because we only have ir.Values and
|
||||
|
@ -253,6 +253,24 @@ class PallasCallTest(PallasTest):
|
||||
x = jnp.arange(128).astype(jnp.float32)
|
||||
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
||||
|
||||
def test_copy_gmem_to_smem_in_run_scoped(self):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
||||
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
||||
)
|
||||
def kernel(x_ref_gmem, o_ref):
|
||||
def body(barrier_ref):
|
||||
def inner_body(scratch_ref):
|
||||
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier=barrier_ref)
|
||||
plgpu.wait_barrier(barrier_ref)
|
||||
o_ref[...] = scratch_ref[...] + 1
|
||||
pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32))
|
||||
pl.run_scoped(body, plgpu.Barrier(num_arrivals=1))
|
||||
|
||||
x = jnp.arange(256).astype(jnp.float32)
|
||||
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
||||
|
||||
def test_add_doubled_sum(self):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
@ -375,7 +393,7 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output())
|
||||
|
||||
def test_scoped_allocation(self):
|
||||
def test_run_scoped(self):
|
||||
def kernel(x_ref, o_ref):
|
||||
def body(tmp_ref):
|
||||
self.assertEqual(tmp_ref.shape, (8, 128))
|
||||
@ -611,7 +629,6 @@ class PallasCallTest(PallasTest):
|
||||
)(a, b)
|
||||
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
|
||||
|
||||
|
||||
def test_input_output_aliases(self):
|
||||
# Note that we're writing to the input pointer, which should alias b_ptr.
|
||||
def kernel(a_ref, b_ref):
|
||||
|
Loading…
x
Reference in New Issue
Block a user