[pallas_mgpu] Add a test for emit_pipeline with wgmma.

PiperOrigin-RevId: 723012611
This commit is contained in:
Peter Buchlovsky 2025-02-04 03:24:29 -08:00 committed by jax authors
parent 124e123946
commit c7d535d3c9

View File

@ -1585,6 +1585,74 @@ class PipelineTest(PallasTest):
)
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
def test_emit_pipeline_with_wgmma(self):
self.skip_unless_sm90a()
m, n, k = 256, 256, 256
dtype = jnp.float16
key = jax.random.key(42)
x = jax.random.uniform(key, shape=(m, k), dtype=dtype)
y = jax.random.uniform(key, shape=(k, n), dtype=dtype)
swizzle = 128
swizzle_elems = swizzle // jnp.dtype(x.dtype).itemsize
tile_m = 64
tile_n = 64
tile_k = swizzle_elems
grid_m = m // tile_m
grid_n = n // tile_n
grid_k = k // tile_k
def kernel(a_gmem, b_gmem, c_smem, acc_reg):
def pipeline_body(a_smem, b_smem):
plgpu.wgmma(acc_reg, a_smem, b_smem)
plgpu.emit_pipeline(
pipeline_body,
in_specs=[
plgpu.GPUBlockSpec(
(tile_m, tile_k),
lambda i: (0, i),
transforms=(
plgpu.TilingTransform((64, swizzle_elems)),
plgpu.SwizzleTransform(swizzle),
),
),
plgpu.GPUBlockSpec(
(tile_k, tile_n),
lambda i: (i, 0),
transforms=(
plgpu.TilingTransform((swizzle_elems, swizzle_elems)),
plgpu.SwizzleTransform(swizzle),
),
),
],
grid=(grid_k,),
max_concurrent_steps=2,
delay_release=1,
)(a_gmem, b_gmem)
c_smem[...] = acc_reg[...]
@jax.jit
def matmul(a: jax.Array, b: jax.Array) -> jax.Array:
return pl.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=plgpu.GMEM),
pl.BlockSpec(memory_space=plgpu.GMEM),
],
out_specs=pl.BlockSpec((tile_m, tile_n), lambda m, n: (m, n)),
grid=(grid_m, grid_n),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
scratch_shapes=[plgpu.ACC((tile_m, tile_n), dtype)],
)(a, b)
res = matmul(x, y)
np.testing.assert_allclose(res, x @ y, rtol=0.4)
class WarpSpecializedPipelineTest(PallasTest):