[pallas:mosaic_gpu] pl.run_scoped now supports scoped barriers

PiperOrigin-RevId: 684449776
This commit is contained in:
Sergei Lebedev 2024-10-10 08:15:37 -07:00 committed by jax authors
parent 94abaf430e
commit 70ee8e1161
2 changed files with 146 additions and 36 deletions

View File

@ -16,12 +16,13 @@
from __future__ import annotations
from collections.abc import Sequence
import collections
from collections.abc import MutableMapping, MutableSequence, Sequence
import dataclasses
import functools
import itertools as it
import math
from typing import Any, cast
from typing import Any, Protocol, cast
import jax
from jax import lax
@ -59,44 +60,101 @@ zip, unsafe_zip = util.safe_zip, zip
partial = functools.partial
SMEM = gpu_core.SMEM
_smem_estimators = {}
@dataclasses.dataclass(kw_only=True, frozen=True)
class Resources:
smem_scratch_bytes: int
barriers: collections.Counter[mgpu.Barrier] = dataclasses.field(
default_factory=collections.Counter
)
def __add__(self, other: Resources) -> Resources:
# TODO(slebedev): Optimize this.
#
# At the moment, if we have run_scoped(b1) followed by run_scoped(b2)
# we will allocate two barriers, even though one would be enough.
return Resources(
smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes,
barriers=self.barriers + other.barriers,
)
def __or__(self, other: Resources) -> Resources:
return Resources(
smem_scratch_bytes=max(
self.smem_scratch_bytes, other.smem_scratch_bytes
),
barriers=self.barriers | other.barriers,
)
def _regiter_smem_estimator(primitive: jax_core.Primitive):
class ResourceEstimator(Protocol):
def __call__(self, *args: Any, **params: Any) -> Resources:
...
_resource_estimators: dict[jax_core.Primitive, ResourceEstimator] = {}
def _register_resource_estimator(primitive: jax_core.Primitive):
def deco(fn):
_smem_estimators[primitive] = fn
_resource_estimators[primitive] = fn
return fn
return deco
def _estimate_smem_scratch_bytes(jaxpr: jax_core.Jaxpr) -> int:
"""Estimates the amount of SMEM scratch bytes required by the kernel."""
max_used = 0
def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources:
"""Estimates the resources required by the kernel."""
rs = Resources(smem_scratch_bytes=0)
for eqn in jaxpr.eqns:
# TODO(slebedev): Add support for other primitives, notably control flow.
rule = _smem_estimators.get(eqn.primitive)
rule = _resource_estimators.get(eqn.primitive)
if rule is None:
# Assume that unsupported primitives are neutral wrt SMEM usage.
# Assume that unsupported primitives are neutral wrt resource usage.
continue
max_used = max(
max_used, rule(*(invar.aval for invar in eqn.invars), **eqn.params)
)
return max_used
rs |= rule(*(invar.aval for invar in eqn.invars), **eqn.params)
return rs
@_regiter_smem_estimator(primitives.run_scoped_p)
def _run_scoped_smem_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
@_register_resource_estimator(lax.cond_p)
def _cond_resource_estimator(*args, branches) -> int:
del args # Unused.
return functools.reduce(
lambda a, b: a | b,
(_estimate_resources(branch.jaxpr) for branch in branches),
)
@_register_resource_estimator(lax.scan_p)
def _scan_resource_estimator(*args, jaxpr: jax_core.ClosedJaxpr, **params) -> int:
del args, params # Unused.
return _estimate_resources(jaxpr)
@_register_resource_estimator(primitives.run_scoped_p)
def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
del consts # Unused.
in_avals = (v.aval.inner_aval for v in jaxpr.invars)
return sum(math.prod(aval.shape) * aval.dtype.itemsize for aval in in_avals)
smem_scratch_bytes = 0
barriers = []
for v in jaxpr.invars:
aval = v.aval
if isinstance(aval.dtype, gpu_core.BarrierType):
barriers.append(mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape))
else:
smem_scratch_bytes += math.prod(aval.shape) * aval.dtype.itemsize
rs = Resources(
smem_scratch_bytes=smem_scratch_bytes,
barriers=collections.Counter(barriers),
)
return rs + _estimate_resources(jaxpr)
@_regiter_smem_estimator(lax.reduce_sum_p)
def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
@_register_resource_estimator(lax.reduce_sum_p)
def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
if axes != (0,):
raise NotImplementedError("No support for axes other than 0 yet")
return 4 * x_aval.dtype.itemsize
return Resources(smem_scratch_bytes=4 * x_aval.dtype.itemsize)
@dataclasses.dataclass
@ -106,7 +164,21 @@ class ModuleContext:
program_ids: Sequence[ir.Value] | None
approx_math: bool
runtime_smem: ir.Value # ir.MemRefType
smem_used_bytes: int = 0
smem_used_bytes: int
runtime_barriers: MutableMapping[
mgpu.Barrier, MutableSequence[mgpu.BarrierRef]
]
def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef:
"""Reserves a barrier.
Raises:
RuntimeError: If the barrier is already reserved.
"""
available = self.runtime_barriers.get(barrier, [])
if not available:
raise RuntimeError(f"Barrier {barrier} is already reserved")
return available.pop()
# TODO(cperivol): Only return the shapes and figure out the sizes when freeing.
def scratch_view(
@ -352,7 +424,7 @@ def lower_jaxpr_to_module(
in_buffers_smem, out_buffers_smem = util.split_list(
buffers_smem, [grid_mapping.num_inputs]
)
barriers, *extra_barriers = barriers
barriers, runtime_barriers, extra_barriers = barriers
parallel_count = it.count()
program_ids_template = [
@ -367,9 +439,21 @@ def lower_jaxpr_to_module(
step = arith_dialect.index_cast(ir.IntegerType.get_signless(32), step)
return [step if pid is None else pid for pid in program_ids_template]
grouped_barriers = collections.defaultdict(list)
for barrier, barrier_ref in zip(
sorted(rs.barriers.elements()), runtime_barriers
):
grouped_barriers[barrier].append(barrier_ref)
module_ctx = ModuleContext(
name_and_src_info.name, grid_mapping, None, approx_math, runtime_smem
name_and_src_info.name,
grid_mapping,
None,
approx_math,
runtime_smem,
smem_used_bytes=0,
runtime_barriers=grouped_barriers,
)
del runtime_smem, grouped_barriers, runtime_barriers
smem_scratch_it = iter(scratch_buffers_smem)
scratch_buffers_template = []
@ -611,6 +695,7 @@ def lower_jaxpr_to_module(
"All scratch operands must be SMEM references or accumulators (ACC),"
f" but got: {scratch_avals}"
)
rs = _estimate_resources(jaxpr)
extra_barriers = [
mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)
for aval in scratch_avals
@ -624,7 +709,7 @@ def lower_jaxpr_to_module(
]
smem_scratch_bytes = compiler_params.get("smem_scratch_bytes")
if smem_scratch_bytes is None:
smem_scratch_bytes = _estimate_smem_scratch_bytes(jaxpr)
smem_scratch_bytes = rs.smem_scratch_bytes
extra_smem_scratch.append(
jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8)
)
@ -641,7 +726,8 @@ def lower_jaxpr_to_module(
*extra_smem_scratch,
(
mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps),
*extra_barriers,
[*sorted(rs.barriers.elements())],
extra_barriers,
),
),
module_name=name_and_src_info.name,
@ -979,21 +1065,28 @@ def _run_scoped_lowering_rule(
input_refs = []
bytes_allocated = 0
should_discharge = []
for a in jaxpr.invars:
a = a.aval
if isinstance(a, gpu_core.WGMMAAbstractAccumulatorRef):
mlir_dtype = mlir.dtype_to_ir_type(a.dtype)
input_refs.append(mgpu.WGMMAAccumulator.zero(*a.shape, mlir_dtype))
for v in jaxpr.invars:
aval = v.aval
if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef):
mlir_dtype = mlir.dtype_to_ir_type(aval.dtype)
input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, mlir_dtype))
should_discharge.append(True)
elif a.memory_space == gpu_core.SMEM:
elif isinstance(aval.dtype, gpu_core.BarrierType):
input_refs.append(
ctx.module_ctx.reserve_barrier(
mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)
)
)
should_discharge.append(False)
elif aval.memory_space == gpu_core.SMEM:
ref_bytes, [input_ref] = ctx.module_ctx.scratch_view(
[jax.ShapeDtypeStruct(shape=a.shape, dtype=a.dtype)]
[jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)]
)
bytes_allocated += ref_bytes
input_refs.append(input_ref)
should_discharge.append(False)
else:
raise ValueError(f"Can't convert to ref: {a}")
raise ValueError(f"Can't convert to ref: {aval}")
if any(should_discharge):
# We convert consts to args, because we only have ir.Values and

View File

@ -253,6 +253,24 @@ class PallasCallTest(PallasTest):
x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
def test_copy_gmem_to_smem_in_run_scoped(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
)
def kernel(x_ref_gmem, o_ref):
def body(barrier_ref):
def inner_body(scratch_ref):
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier=barrier_ref)
plgpu.wait_barrier(barrier_ref)
o_ref[...] = scratch_ref[...] + 1
pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32))
pl.run_scoped(body, plgpu.Barrier(num_arrivals=1))
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
def test_add_doubled_sum(self):
@functools.partial(
pl.pallas_call,
@ -375,7 +393,7 @@ class PallasCallTest(PallasTest):
self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output())
def test_scoped_allocation(self):
def test_run_scoped(self):
def kernel(x_ref, o_ref):
def body(tmp_ref):
self.assertEqual(tmp_ref.shape, (8, 128))
@ -611,7 +629,6 @@ class PallasCallTest(PallasTest):
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
def test_input_output_aliases(self):
# Note that we're writing to the input pointer, which should alias b_ptr.
def kernel(a_ref, b_ref):