mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[pallas_mgpu] Add a test for emit_pipeline with wgmma.
PiperOrigin-RevId: 723012611
This commit is contained in:
parent
124e123946
commit
c7d535d3c9
@ -1585,6 +1585,74 @@ class PipelineTest(PallasTest):
|
||||
)
|
||||
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
|
||||
|
||||
def test_emit_pipeline_with_wgmma(self):
|
||||
self.skip_unless_sm90a()
|
||||
|
||||
m, n, k = 256, 256, 256
|
||||
dtype = jnp.float16
|
||||
key = jax.random.key(42)
|
||||
x = jax.random.uniform(key, shape=(m, k), dtype=dtype)
|
||||
y = jax.random.uniform(key, shape=(k, n), dtype=dtype)
|
||||
|
||||
swizzle = 128
|
||||
swizzle_elems = swizzle // jnp.dtype(x.dtype).itemsize
|
||||
|
||||
tile_m = 64
|
||||
tile_n = 64
|
||||
tile_k = swizzle_elems
|
||||
|
||||
grid_m = m // tile_m
|
||||
grid_n = n // tile_n
|
||||
grid_k = k // tile_k
|
||||
|
||||
def kernel(a_gmem, b_gmem, c_smem, acc_reg):
|
||||
def pipeline_body(a_smem, b_smem):
|
||||
plgpu.wgmma(acc_reg, a_smem, b_smem)
|
||||
|
||||
plgpu.emit_pipeline(
|
||||
pipeline_body,
|
||||
in_specs=[
|
||||
plgpu.GPUBlockSpec(
|
||||
(tile_m, tile_k),
|
||||
lambda i: (0, i),
|
||||
transforms=(
|
||||
plgpu.TilingTransform((64, swizzle_elems)),
|
||||
plgpu.SwizzleTransform(swizzle),
|
||||
),
|
||||
),
|
||||
plgpu.GPUBlockSpec(
|
||||
(tile_k, tile_n),
|
||||
lambda i: (i, 0),
|
||||
transforms=(
|
||||
plgpu.TilingTransform((swizzle_elems, swizzle_elems)),
|
||||
plgpu.SwizzleTransform(swizzle),
|
||||
),
|
||||
),
|
||||
],
|
||||
grid=(grid_k,),
|
||||
max_concurrent_steps=2,
|
||||
delay_release=1,
|
||||
)(a_gmem, b_gmem)
|
||||
|
||||
c_smem[...] = acc_reg[...]
|
||||
|
||||
@jax.jit
|
||||
def matmul(a: jax.Array, b: jax.Array) -> jax.Array:
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=plgpu.GMEM),
|
||||
pl.BlockSpec(memory_space=plgpu.GMEM),
|
||||
],
|
||||
out_specs=pl.BlockSpec((tile_m, tile_n), lambda m, n: (m, n)),
|
||||
grid=(grid_m, grid_n),
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
|
||||
scratch_shapes=[plgpu.ACC((tile_m, tile_n), dtype)],
|
||||
)(a, b)
|
||||
|
||||
res = matmul(x, y)
|
||||
np.testing.assert_allclose(res, x @ y, rtol=0.4)
|
||||
|
||||
|
||||
class WarpSpecializedPipelineTest(PallasTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user