[pallas:Mosaic GPU] Configurable smem scratch and a small bug fix in Mosaic GPU

PiperOrigin-RevId: 634813241
This commit is contained in:
jax authors 2024-05-17 10:09:32 -07:00 committed by jax authors
parent e93f36aa7c
commit 815256687f
4 changed files with 60 additions and 20 deletions

View File

@ -56,6 +56,41 @@ class ModuleContext:
grid_mapping: pl_core.GridMapping
runtime_smem: ir.Value # ir.MemRefType
def scratch_view(self, shapes: list[jax.ShapeDtypeStruct]) -> list[ir.Value]:
"""Return memref views into the runtime scrath based on the shapes."""
smem_scratch_bytes = math.prod(ir.MemRefType(self.runtime_smem.type).shape)
required_scratch_bytes = sum(
math.prod(sh.shape) * jnp.dtype(sh.dtype).itemsize for sh in shapes
)
if smem_scratch_bytes < required_scratch_bytes:
raise ValueError(
f"Too few {smem_scratch_bytes=} provided (pass via compiler_params), we"
f" need {required_scratch_bytes} ({shapes=})"
)
views = []
off = 0
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
for sh in shapes:
sh_bytes = math.prod(sh.shape) * jnp.dtype(sh.dtype).itemsize
strides = (*np.cumprod(sh.shape)[:-1:-1], 1)
# We need scratch to be able to store 128 items of x.
scratch = memref_dialect.subview(
self.runtime_smem,
offsets=[_index(off)],
sizes=[_index(sh_bytes)],
strides=[_index(i) for i in strides],
)
scratch_ty = ir.MemRefType.get(
[np.prod(sh.shape)], mlir.dtype_to_ir_type(sh.dtype), memory_space=smem
)
off += sh_bytes
views.append(memref_dialect.view(scratch_ty, scratch, _index(off), []))
return views
@dataclasses.dataclass
class LoweringRuleContext:
@ -92,6 +127,7 @@ def lower_jaxpr_to_module(
out_structs: tuple[jax.ShapeDtypeStruct, ...],
jaxpr: jax_core.Jaxpr,
name: str,
compiler_params: dict[str, Any],
) -> LoweringResult:
assert len(jaxpr.outvars) == 0
assert not grid_mapping.mapped_dims
@ -145,8 +181,11 @@ def lower_jaxpr_to_module(
launch_ctx.await_async_copy(0)
# TODO(cperivol): Allow the user to provide the size of the runtime shared memory.
extra_smem_scratch = [jax.ShapeDtypeStruct(shape=[4 * 4], dtype=np.int8)]
extra_smem_scratch = [
jax.ShapeDtypeStruct(
shape=[compiler_params.get("smem_scratch_bytes", 0)], dtype=np.int8
)
]
module, out_structs, gmem_scratch_bytes, _ = mosaic_gpu._lower_as_gpu_kernel(
body,
grid,
@ -304,19 +343,9 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
if axes != (0,):
raise NotImplementedError("No support for axes other than 0 yet")
[x_aval] = ctx.avals_in
# We need scratch to be able to store 128 items of x.
# TODO(cperivol): check that enough scratch size was provided
scratch = memref_dialect.subview(
ctx.module_context.runtime_smem,
offsets=[_index(0)],
sizes=[_index(4 * jnp.dtype(x_aval.dtype).itemsize)],
strides=[_index(1)],
[scratch] = ctx.module_context.scratch_view(
[jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype)]
)
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
scratch_ty = ir.MemRefType.get(
[4], mlir.dtype_to_ir_type(x_aval.dtype), memory_space=smem
)
scratch = memref_dialect.view(scratch_ty, scratch, _index(0), [])
return mgpu.FragmentedArray.splat(x.reduce_sum(scratch), ())
@ -372,4 +401,4 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value:
def _index(i: int) -> ir.Value:
return arith_dialect.constant(ir.IndexType.get(), i)
return arith_dialect.constant(ir.IndexType.get(), int(i))

View File

@ -77,6 +77,7 @@ def pallas_call_lowering(
out_shapes,
jaxpr,
name,
compiler_params,
)
if debug:
print(lowering_result.module.operation)

View File

@ -410,8 +410,8 @@ class FragmentedArray:
vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg),
)
scratch_ty = ir.MemRefType(scratch.type)
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [WARPGROUP_SIZE]:
raise ValueError(f"Expected sheape={(WARPGROUP_SIZE,)}, {self.mlir_dtype} (got {scratch_ty})")
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]:
raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})")
if ir.FloatType.isinstance(self.mlir_dtype):
op = arith.addf
@ -421,7 +421,7 @@ class FragmentedArray:
raise NotImplementedError(self.mlir_dtype)
warp_result = utils.warp_tree_reduce(result, op, 32)
warp_id = arith.remui(gpu.thread_id(gpu.Dimension.x), c(32, index))
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index))
memref.store(warp_result, scratch, [warp_id])
utils.commit_shared()
zero_index = c(0, index)

View File

@ -15,6 +15,7 @@
import functools
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax.experimental import pallas as pl
@ -47,7 +48,8 @@ class PallasCallTest(PallasTest):
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(add_one(x), x + 1.0)
def test_layer_norm(self):
@parameterized.product(input_factor=[0.001, 1, 10, 100, 100])
def test_layer_norm(self, input_factor):
eps = 1e-5
gamma = 1.0
beta = 1.0
@ -55,6 +57,7 @@ class PallasCallTest(PallasTest):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
compiler_params={"smem_scratch_bytes": 4 * 4}
)
def layer_norm(x_ref, o_ref):
o_ref[...] = (x_ref[...] - jnp.mean(x_ref[...], keepdims=True)) * jax.lax.rsqrt(
@ -66,9 +69,16 @@ class PallasCallTest(PallasTest):
np.var(x, keepdims=True) + eps
) * gamma + beta
x = jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32)
# Ones are always fully precise
x = jnp.ones((256,)).astype(jnp.float32) * input_factor
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x))
# random (and anything else is not)
x = jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32) * input_factor
# TODO(cperivol): find out why in this particular case we have a small-ish error.
rtol = 1e-07 if input_factor > 10 else 5e-5
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=rtol)
if __name__ == "__main__":
absltest.main()