mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[pallas:Mosaic GPU] Configurable smem scratch and a small bug fix in Mosaic GPU
PiperOrigin-RevId: 634813241
This commit is contained in:
parent
e93f36aa7c
commit
815256687f
@ -56,6 +56,41 @@ class ModuleContext:
|
||||
grid_mapping: pl_core.GridMapping
|
||||
runtime_smem: ir.Value # ir.MemRefType
|
||||
|
||||
def scratch_view(self, shapes: list[jax.ShapeDtypeStruct]) -> list[ir.Value]:
|
||||
"""Return memref views into the runtime scrath based on the shapes."""
|
||||
|
||||
smem_scratch_bytes = math.prod(ir.MemRefType(self.runtime_smem.type).shape)
|
||||
required_scratch_bytes = sum(
|
||||
math.prod(sh.shape) * jnp.dtype(sh.dtype).itemsize for sh in shapes
|
||||
)
|
||||
if smem_scratch_bytes < required_scratch_bytes:
|
||||
raise ValueError(
|
||||
f"Too few {smem_scratch_bytes=} provided (pass via compiler_params), we"
|
||||
f" need {required_scratch_bytes} ({shapes=})"
|
||||
)
|
||||
|
||||
views = []
|
||||
off = 0
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
for sh in shapes:
|
||||
sh_bytes = math.prod(sh.shape) * jnp.dtype(sh.dtype).itemsize
|
||||
strides = (*np.cumprod(sh.shape)[:-1:-1], 1)
|
||||
|
||||
# We need scratch to be able to store 128 items of x.
|
||||
scratch = memref_dialect.subview(
|
||||
self.runtime_smem,
|
||||
offsets=[_index(off)],
|
||||
sizes=[_index(sh_bytes)],
|
||||
strides=[_index(i) for i in strides],
|
||||
)
|
||||
scratch_ty = ir.MemRefType.get(
|
||||
[np.prod(sh.shape)], mlir.dtype_to_ir_type(sh.dtype), memory_space=smem
|
||||
)
|
||||
off += sh_bytes
|
||||
views.append(memref_dialect.view(scratch_ty, scratch, _index(off), []))
|
||||
|
||||
return views
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoweringRuleContext:
|
||||
@ -92,6 +127,7 @@ def lower_jaxpr_to_module(
|
||||
out_structs: tuple[jax.ShapeDtypeStruct, ...],
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
compiler_params: dict[str, Any],
|
||||
) -> LoweringResult:
|
||||
assert len(jaxpr.outvars) == 0
|
||||
assert not grid_mapping.mapped_dims
|
||||
@ -145,8 +181,11 @@ def lower_jaxpr_to_module(
|
||||
|
||||
launch_ctx.await_async_copy(0)
|
||||
|
||||
# TODO(cperivol): Allow the user to provide the size of the runtime shared memory.
|
||||
extra_smem_scratch = [jax.ShapeDtypeStruct(shape=[4 * 4], dtype=np.int8)]
|
||||
extra_smem_scratch = [
|
||||
jax.ShapeDtypeStruct(
|
||||
shape=[compiler_params.get("smem_scratch_bytes", 0)], dtype=np.int8
|
||||
)
|
||||
]
|
||||
module, out_structs, gmem_scratch_bytes, _ = mosaic_gpu._lower_as_gpu_kernel(
|
||||
body,
|
||||
grid,
|
||||
@ -304,19 +343,9 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
||||
if axes != (0,):
|
||||
raise NotImplementedError("No support for axes other than 0 yet")
|
||||
[x_aval] = ctx.avals_in
|
||||
# We need scratch to be able to store 128 items of x.
|
||||
# TODO(cperivol): check that enough scratch size was provided
|
||||
scratch = memref_dialect.subview(
|
||||
ctx.module_context.runtime_smem,
|
||||
offsets=[_index(0)],
|
||||
sizes=[_index(4 * jnp.dtype(x_aval.dtype).itemsize)],
|
||||
strides=[_index(1)],
|
||||
[scratch] = ctx.module_context.scratch_view(
|
||||
[jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype)]
|
||||
)
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
scratch_ty = ir.MemRefType.get(
|
||||
[4], mlir.dtype_to_ir_type(x_aval.dtype), memory_space=smem
|
||||
)
|
||||
scratch = memref_dialect.view(scratch_ty, scratch, _index(0), [])
|
||||
return mgpu.FragmentedArray.splat(x.reduce_sum(scratch), ())
|
||||
|
||||
|
||||
@ -372,4 +401,4 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value:
|
||||
|
||||
|
||||
def _index(i: int) -> ir.Value:
|
||||
return arith_dialect.constant(ir.IndexType.get(), i)
|
||||
return arith_dialect.constant(ir.IndexType.get(), int(i))
|
||||
|
@ -77,6 +77,7 @@ def pallas_call_lowering(
|
||||
out_shapes,
|
||||
jaxpr,
|
||||
name,
|
||||
compiler_params,
|
||||
)
|
||||
if debug:
|
||||
print(lowering_result.module.operation)
|
||||
|
@ -410,8 +410,8 @@ class FragmentedArray:
|
||||
vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg),
|
||||
)
|
||||
scratch_ty = ir.MemRefType(scratch.type)
|
||||
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [WARPGROUP_SIZE]:
|
||||
raise ValueError(f"Expected sheape={(WARPGROUP_SIZE,)}, {self.mlir_dtype} (got {scratch_ty})")
|
||||
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]:
|
||||
raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})")
|
||||
|
||||
if ir.FloatType.isinstance(self.mlir_dtype):
|
||||
op = arith.addf
|
||||
@ -421,7 +421,7 @@ class FragmentedArray:
|
||||
raise NotImplementedError(self.mlir_dtype)
|
||||
|
||||
warp_result = utils.warp_tree_reduce(result, op, 32)
|
||||
warp_id = arith.remui(gpu.thread_id(gpu.Dimension.x), c(32, index))
|
||||
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index))
|
||||
memref.store(warp_result, scratch, [warp_id])
|
||||
utils.commit_shared()
|
||||
zero_index = c(0, index)
|
||||
|
@ -15,6 +15,7 @@
|
||||
import functools
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import pallas as pl
|
||||
@ -47,7 +48,8 @@ class PallasCallTest(PallasTest):
|
||||
x = jnp.arange(256).astype(jnp.float32)
|
||||
np.testing.assert_array_equal(add_one(x), x + 1.0)
|
||||
|
||||
def test_layer_norm(self):
|
||||
@parameterized.product(input_factor=[0.001, 1, 10, 100, 100])
|
||||
def test_layer_norm(self, input_factor):
|
||||
eps = 1e-5
|
||||
gamma = 1.0
|
||||
beta = 1.0
|
||||
@ -55,6 +57,7 @@ class PallasCallTest(PallasTest):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
||||
compiler_params={"smem_scratch_bytes": 4 * 4}
|
||||
)
|
||||
def layer_norm(x_ref, o_ref):
|
||||
o_ref[...] = (x_ref[...] - jnp.mean(x_ref[...], keepdims=True)) * jax.lax.rsqrt(
|
||||
@ -66,9 +69,16 @@ class PallasCallTest(PallasTest):
|
||||
np.var(x, keepdims=True) + eps
|
||||
) * gamma + beta
|
||||
|
||||
x = jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32)
|
||||
# Ones are always fully precise
|
||||
x = jnp.ones((256,)).astype(jnp.float32) * input_factor
|
||||
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x))
|
||||
|
||||
# random (and anything else is not)
|
||||
x = jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32) * input_factor
|
||||
# TODO(cperivol): find out why in this particular case we have a small-ish error.
|
||||
rtol = 1e-07 if input_factor > 10 else 5e-5
|
||||
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=rtol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user