rocm_jax/tests/pallas/mosaic_gpu_test.py
Adam Paszke 3d87a01bea [Pallas:MGPU] Adjust warpgroup lowering to the recent emit_pipeline changes
The Pallas-level pipelining generates a number of ops we haven't had to deal with before
like conditionals, scans, etc.

PiperOrigin-RevId: 730899808
2025-02-25 08:39:44 -08:00

2158 lines
72 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import math
import operator
import os
import re
import tempfile
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax._src import test_util as jtu
from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline
from jax.experimental import pallas as pl
from jax.experimental.pallas import mosaic_gpu as plgpu
import jax.numpy as jnp
import numpy as np
try:
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
except ImportError:
mosaic_gpu_lib = None
jax.config.parse_flags_with_absl()
def _fori_loop(force_while: bool, lb, ub, body, init):
if force_while:
# using jnp.asarray make the matcher for while or scan to think
# that the bounds are dynamic and forces the use of the while
# primitive.
lb, ub = jnp.asarray(lb), jnp.asarray(ub)
return jax.lax.fori_loop(lb, ub, body, init)
def _sum_same_dtype(x):
# TODO(slebedev): Remove this once ``FragmentedArray`` supports
# ``reduce_sum`` for non-32-bit types.
return jnp.sum(x, dtype=x.dtype)
class PallasTest(jtu.JaxTestCase):
def setUp(self):
if not jtu.is_cuda_compute_capability_at_least("9.0"):
self.skipTest("Only works on a GPU with capability >= sm90")
super().setUp()
@contextlib.contextmanager
def capture_stdout(self):
if mosaic_gpu_lib is None:
raise ValueError("Running tests but missing Mosaic GPU extension")
with jtu.capture_stdout() as stdout:
yield stdout
# We need to cudaDeviceSynchronize to make sure printfs are flushed.
mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices()
class PallasSm90ATest(PallasTest, jtu.CudaArchSpecificTest):
def setUp(self):
self.skip_unless_sm90a()
super().setUp()
class PallasCallTest(PallasTest):
@parameterized.named_parameters(
("add_one", lambda x: x + 1.),
("logistic", jax.lax.logistic),
("exp", jax.lax.exp),
("square", lambda x: x ** 2),
("rsqrt", jax.lax.rsqrt),
("tanh", jax.lax.tanh, 1e-6),
)
def test_unary_op(self, unary, rtol=1e-7):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = unary(x_ref[...])
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_allclose(kernel(x), unary(x), rtol=rtol)
@parameterized.product(
op=[
operator.add,
lambda x, _: x + 1, # for int->vector conversion
operator.sub,
operator.mul,
lax.div,
jnp.minimum,
jnp.maximum,
],
dtype=[jnp.float32, jnp.int32, jnp.uint32],
thread_semantics=[*plgpu.ThreadSemantics],
)
def test_binary_op(self, op, dtype, thread_semantics):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], dtype),
compiler_params=plgpu.GPUCompilerParams(
thread_semantics=thread_semantics
),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = op(x_ref[...], y_ref[...])
key0, key1 = jax.random.split(jax.random.key(0), 2)
x = (jax.random.uniform(key0, [256]) * 42).astype(dtype)
y = (jax.random.uniform(key1, [256]) * 42).astype(dtype)
np.testing.assert_array_equal(kernel(x, y), op(x, y))
@parameterized.product(
op=[
lax.eq,
operator.ne,
operator.lt,
operator.le,
operator.gt,
operator.ge,
],
# TODO(slebedev): Support integral types.
dtype=[jnp.float32, jnp.int32, jnp.uint32],
thread_semantics=[*plgpu.ThreadSemantics],
)
def test_comparison_op(self, op, dtype, thread_semantics):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], dtype),
compiler_params=plgpu.GPUCompilerParams(
thread_semantics=thread_semantics
),
)
def kernel(o_ref):
o_ref[...] = jnp.broadcast_to(
op(dtype(42), dtype(24)).astype(dtype), o_ref.shape
)
np.testing.assert_array_equal(kernel(), jnp.full([256], op(42, 24), dtype))
def test_add_first(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[0]
x = jnp.arange(256).astype(jnp.float32)
y = jnp.flip(x).reshape(1, 256)
np.testing.assert_array_equal(kernel(x, y), x + y[0])
def test_reshape(self):
shape1, shape2 = (128,), (2, 16, 4)
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32),
)
def kernel(x_ref, out_ref):
x_ref_reshaped = x_ref.reshape(shape2)
self.assertEqual(x_ref.shape, shape1)
self.assertEqual(x_ref_reshaped.shape, shape2)
out_ref[...] = x_ref_reshaped[...]
x = jnp.arange(math.prod(shape1)).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x.reshape(shape2))
def test_add_xy_indexed(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
idx = _sum_same_dtype(y_ref[...])
o_ref[...] = x_ref[idx]
x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
y = jnp.zeros(128, dtype=jnp.int32)
np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)])
def test_add_one_grid(self):
@functools.partial(
pl.pallas_call,
in_specs=[pl.BlockSpec((128,), lambda *i: i)],
out_specs=pl.BlockSpec((128,), lambda *i: i),
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32),
grid=2,
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1.0
x = jnp.arange(128 * 2).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
def test_add_one_grid_with_scratch(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32),
in_specs=[pl.BlockSpec((128,), lambda *i: i)],
out_specs=pl.BlockSpec((128,), lambda *i: i),
scratch_shapes=[plgpu.SMEM((128,), jnp.float32)],
grid=2,
)
def kernel(x_ref, o_ref, scratch_ref):
scratch_ref[...] = x_ref[...] + 1
o_ref[...] = scratch_ref[...]
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
@parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16])
def test_add_one_grid_pipelined(self, max_concurrent_steps):
@functools.partial(
pl.pallas_call,
in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))],
out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)),
out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
max_concurrent_steps=max_concurrent_steps,
),
grid=(2, 4),
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1.0
x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
def test_add_one_grid_pipelined_program_id(self):
@functools.partial(
pl.pallas_call,
out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)),
out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
max_concurrent_steps=2,
),
grid=(4, 4),
)
def kernel(o_ref):
o_ref[...] = jnp.broadcast_to(pl.program_id(1), o_ref.shape)
np.testing.assert_array_equal(
kernel(),
jnp.repeat(jnp.repeat(jnp.arange(4), 16)[None], 16, axis=0),
)
def test_add_one_grid_pipelined_sequential_invariant_output(self):
@functools.partial(
pl.pallas_call,
in_specs=[pl.BlockSpec((32, 16), lambda i, j: (i, j))],
out_specs=pl.BlockSpec((32, 16), lambda i, j: (i, 0)),
out_shape=jax.ShapeDtypeStruct([32 * 2, 64], jnp.float32),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
max_concurrent_steps=2,
),
grid=(2, 4),
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1.0
x = jnp.arange(32 * 2 * 64).reshape((32 * 2, 64)).astype(jnp.float32)
y = jnp.empty_like(x)
for i in range(2):
i_slice = slice(32 * i, 32 * (i + 1))
for j in range(4):
j_slice = slice(16 * j, 16 * (j + 1))
y = y.at[i_slice, :16].set(x[i_slice, j_slice] + 1)
# We only compare the elements in the first 16 columns, because the rest
# are never written to.
np.testing.assert_array_equal(kernel(x)[:, :16], y[:, :16])
@parameterized.parameters(jnp.float32, jnp.int32, jnp.uint32)
def test_iota(self, dtype):
dimension = 1
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct((128, 128), dtype),
)
def kernel(o_ref):
o_ref[...] = plgpu.broadcasted_iota(dtype, (128, 128), dimension, layout=plgpu.Layout.WGMMA)
np.testing.assert_array_equal(kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension))
@parameterized.product(
indexer=[..., slice(128), slice(None, 128)],
thread_semantics=[*plgpu.ThreadSemantics],
)
def test_copy_smem_to_gmem(self, indexer, thread_semantics):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
scratch_shapes=[plgpu.SMEM((256,), jnp.float32)],
compiler_params=plgpu.GPUCompilerParams(
thread_semantics=thread_semantics
),
)
def kernel(x_ref, o_ref_gmem, scratch_ref):
scratch_ref[...] = x_ref[...] + 1
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref_gmem.at[indexer])
plgpu.wait_smem_to_gmem(0)
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)
@parameterized.named_parameters(
{"testcase_name": "1d_none",
"shape": (256,), "indexers": (slice(0, 128), slice(None, 32))},
{"testcase_name": "1d_offset",
"shape": (256,), "indexers": (slice(32, 96), slice(0, 32))},
{"testcase_name": "2d_extract",
"shape": (64, 64), "indexers": (4, slice(0, 64))},
)
def test_copy_smem_to_gmem_with_multiple_gmem_indexers(self, shape, indexers):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
scratch_shapes=[plgpu.SMEM(shape, jnp.float32)],
)
def kernel(x_ref, o_ref_gmem, scratch_ref):
scratch_ref[...] = x_ref[...] + 1
plgpu.commit_smem()
for indexer in indexers:
scratch_ref = scratch_ref.at[indexer]
o_ref_gmem = o_ref_gmem.at[indexer]
plgpu.copy_smem_to_gmem(scratch_ref, o_ref_gmem)
plgpu.wait_smem_to_gmem(0)
x = jnp.arange(np.prod(shape)).astype(jnp.float32).reshape(*shape)
result = kernel(x)
ref = x + 1.0
for indexer in indexers:
result = result[indexer]
ref = ref[indexer]
np.testing.assert_array_equal(result, ref)
@parameterized.product(indexer=[..., slice(128), slice(None, 128)])
def test_copy_gmem_to_smem(self, indexer):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
scratch_shapes=[
plgpu.SMEM((256,), jnp.float32),
plgpu.Barrier(num_arrivals=1),
],
)
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
plgpu.copy_gmem_to_smem(
x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier_ref
)
plgpu.barrier_wait(barrier_ref)
o_ref[...] = scratch_ref[...] + 1
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)
@parameterized.named_parameters(
{
"testcase_name": "1d_none",
"shape": (256,),
"indexers": (slice(0, 128), slice(None, 32)),
},
{
"testcase_name": "1d_offset",
"shape": (256,),
"indexers": (slice(32, 96), slice(0, 32)),
},
{
"testcase_name": "2d_extract_static",
"shape": (64, 64),
"indexers": (4, slice(0, 64)),
},
{
"testcase_name": "2d_extract_dyn",
"shape": (64, 64),
"indexers": lambda in_dev: (
pl.program_id(0) + 4 if in_dev else jnp.array(4),
slice(0, 64),
),
},
)
def test_copy_gmem_to_smem_with_multiple_gmem_indexers(self, shape, indexers):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
scratch_shapes=[plgpu.SMEM(shape, jnp.float32),
plgpu.Barrier(num_arrivals=1),
],
grid=(1,),
)
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
scratch_ref_sliced = scratch_ref
for indexer in indexers(True) if callable(indexers) else indexers:
scratch_ref_sliced = scratch_ref_sliced.at[indexer]
x_ref_gmem = x_ref_gmem.at[indexer]
plgpu.copy_gmem_to_smem(
x_ref_gmem, scratch_ref_sliced, barrier_ref
)
plgpu.barrier_wait(barrier_ref)
o_ref[...] = scratch_ref[...] + 1
x = jnp.arange(np.prod(shape)).astype(jnp.float32).reshape(*shape)
result = kernel(x)
ref = x + 1.0
for indexer in indexers(False) if callable(indexers) else indexers:
result = result[indexer]
ref = ref[indexer]
np.testing.assert_array_equal(result, ref)
def test_gmem_to_smem_with_multiple_smem_indexers(self):
x = jax.random.uniform(jax.random.key(0), (2, 64, 64), dtype=jnp.float32)
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
scratch_shapes=[
plgpu.SMEM(x.shape, jnp.float32),
plgpu.Barrier(num_arrivals=1),
],
)
def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier_ref)
plgpu.barrier_wait(barrier_ref)
x_sliced = scratch_ref.at[0, :, :] # shape=(64, 64)
o_ref[pl.ds(0, 32), :] = x_sliced[pl.ds(0, 32), :]
o_ref[pl.ds(32, 32), :] = x_sliced[pl.ds(32, 32), :]
np.testing.assert_array_equal(extract_x0(x), x[0])
def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self):
x = jnp.arange(512 * 512, dtype=jnp.int32).reshape(512, 512)
@functools.partial(
pl.pallas_call,
grid=(4, 4),
out_shape=jax.ShapeDtypeStruct((256, 128), jnp.int32),
in_specs=(plgpu.GPUBlockSpec(
block_shape=(128, 128),
index_map=lambda i, j: (i, j),
memory_space=plgpu.SMEM,
transforms=(plgpu.TilingTransform((64, 32)),
plgpu.SwizzleTransform(128))),),
out_specs=(plgpu.GPUBlockSpec(
block_shape=(64, 32),
index_map=lambda i, j: (i, j),
memory_space=plgpu.SMEM,)),
)
def kernel(x_ref, o_ref):
x_sliced = x_ref.at[0:64, 32:96].at[:, 0:32] # get x_ref[0:64, 32:64]
o_ref[...] = x_sliced[...]
ref = jnp.concatenate([x[blk:blk+64, :] for blk in range(0, 512, 128)])
ref = jnp.concatenate(
[ref[:, blk+32:blk+64] for blk in range(0, 512, 128)], axis=1)
np.testing.assert_array_equal(kernel(x), ref)
@parameterized.product(indexer=[0, 1, 2, 3])
def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
scratch_shapes=[
plgpu.SMEM((128,), jnp.float32),
plgpu.Barrier(num_arrivals=1, num_barriers=4),
],
)
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
plgpu.copy_gmem_to_smem(
x_ref_gmem, scratch_ref, barrier_ref.at[indexer]
)
plgpu.barrier_wait(barrier_ref.at[indexer])
o_ref[...] = scratch_ref[...] + 1
x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
@parameterized.named_parameters(("_g2s", False), ("_s2g", True))
def test_copy_with_transforms(self, to_smem):
def kernel(x_ref, o_ref, barrier_ref):
if to_smem:
plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier_ref)
plgpu.barrier_wait(barrier_ref)
else:
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(x_ref, o_ref)
plgpu.wait_smem_to_gmem(0)
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
out_spec = plgpu.GPUBlockSpec(
(128, 128),
lambda: (0, 0),
transforms=(
plgpu.TilingTransform((64, 32)),
plgpu.SwizzleTransform(128),
),
memory_space=plgpu.SMEM,
)
if not to_smem:
in_spec, out_spec = out_spec, in_spec
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
in_specs=(in_spec,),
out_specs=out_spec,
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
)
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), x)
def test_scoped_copy_with_transforms(self):
ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))
def kernel(x_ref, o_ref, barrier_ref):
def body(tmp_ref):
plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref)
plgpu.barrier_wait(barrier_ref)
o_ref[...] = tmp_ref[...] * 2
pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts))
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
out_spec = plgpu.GPUBlockSpec(
(128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM,
)
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
in_specs=(in_spec,),
out_specs=out_spec,
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
)
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), x * 2)
def test_copy_with_transforms_and_indexing(self):
def kernel(x_ref, o_ref, barrier_ref):
for i in range(2):
plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier_ref)
plgpu.barrier_wait(barrier_ref)
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
out_spec = plgpu.GPUBlockSpec(
(2, 128, 128),
lambda: (0, 0, 0),
transforms=(
plgpu.TilingTransform((64, 32)),
plgpu.TransposeTransform((0, 2, 1, 3, 4)),
plgpu.SwizzleTransform(128),
),
memory_space=plgpu.SMEM,
)
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct([2, 128, 128], jnp.float32),
in_specs=(in_spec,),
out_specs=out_spec,
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
)
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0))
def test_indexing_before_transpose(self):
def kernel(x_ref, o_ref, barrier_ref):
for i in range(2):
plgpu.copy_gmem_to_smem(
x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier_ref
)
plgpu.barrier_wait(barrier_ref)
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
out_spec = plgpu.GPUBlockSpec(
(2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM,
)
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32),
in_specs=(in_spec,),
out_specs=out_spec,
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
)
x = jnp.arange(2 * 64 * 128, dtype=jnp.float32).reshape(2, 64, 128)
xt = x.transpose((1, 0, 2))
np.testing.assert_array_equal(f(x), np.stack([xt, xt], axis=0))
def test_copy_gmem_to_smem_in_run_scoped(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
)
def kernel(x_ref_gmem, o_ref):
def body(barrier_ref):
def inner_body(scratch_ref):
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier_ref)
plgpu.barrier_wait(barrier_ref)
o_ref[...] = scratch_ref[...] + 1
pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32))
pl.run_scoped(body, plgpu.Barrier(num_arrivals=1))
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
def test_add_doubled_sum(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + jnp.sum(x_ref[...]) + jnp.sum(x_ref[...])
x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + x.sum()*2)
@parameterized.parameters(False, True)
def test_rsqrt(self, approx_math):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math),
)
def kernel(x_ref, o_ref):
o_ref[...] = jax.lax.rsqrt(x_ref[...])
x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_allclose(kernel(x), jax.lax.rsqrt(x))
@parameterized.product(input_factor=[0.001, 1, 10, 100, 100])
def test_layer_norm(self, input_factor):
eps = 1e-5
gamma = 1.0
beta = 1.0
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def layer_norm(x_ref, o_ref):
x_mean = jnp.mean(x_ref[...])
x_centered = x_ref[...] - x_mean
o_ref[...] = (
x_centered * jax.lax.rsqrt(jnp.mean(x_centered**2) + eps) * gamma
+ beta
)
def layer_norm_np(x):
x_mean = np.mean(x)
x_centered = x - x_mean
return (x_centered / np.sqrt(np.mean(x_centered**2) + eps) * gamma) + beta
# Ones are always fully precise
x = jnp.ones((256,)).astype(jnp.float32) * input_factor
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x))
# random (and anything else is not)
x = (
jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32)
* input_factor
)
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=5e-5)
def test_print(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, o_ref):
del x_ref, o_ref
pl.debug_print("It works!")
x = jnp.arange(256).astype(jnp.float32)
with self.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertEqual(output(), "It works!\n")
def test_print_wgmma_tiled_layout(self):
shape = (128, 64)
size = math.prod(shape)
def kernel(x_ref, o_ref):
pl.debug_print("{}", x_ref[...])
spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)))
x = jnp.arange(size, dtype=jnp.float32).reshape(shape)
f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec)
with self.capture_stdout() as get_output:
jax.block_until_ready(f(x))
output = get_output()
results = re.findall(r"\[(\d+), (\d+)\]/\[128, 64\]: (\d+)", output)
self.assertLen(results, size)
for i, j, v in results:
i, j, v = map(int, (i, j, v))
self.assertEqual(v, i * shape[1] + j)
def test_print_scalar(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
)
def kernel(x_ref, o_ref):
del o_ref
pl.debug_print("x.sum() = {}", _sum_same_dtype(x_ref[...]))
x = jnp.arange(256, dtype=jnp.int32)
with self.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertIn(f"x.sum() = {x.sum()}", output())
def test_print_scalar_array(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
)
def kernel(x_ref, o_ref):
del o_ref
pl.debug_print("x.sum() = {}", _sum_same_dtype(x_ref[...]) + 1)
x = jnp.arange(256, dtype=jnp.int32)
with self.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertIn(f"x.sum() = {x.sum() + 1}", output())
def test_print_array(self):
in_shape = [2, 1, 64, 64]
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(in_shape, jnp.int32),
)
def kernel(x_ref, o_ref):
del o_ref
pl.debug_print("x: {}", x_ref[...])
x = jnp.arange(math.prod(in_shape), dtype=jnp.int32).reshape(in_shape)
with self.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output())
def test_load_scalar(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct((128,), jnp.int32),
in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)],
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.broadcast_to(x_ref[10], (128,))
np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)),
jnp.full((128,), 10, dtype=jnp.int32))
@parameterized.product(thread_semantics=[*plgpu.ThreadSemantics])
def test_run_scoped(self, thread_semantics):
def kernel(x_ref, o_ref):
def body(tmp_ref):
self.assertEqual(tmp_ref.shape, (8, 128))
tmp_ref[...] = x_ref[...] + 1.0
return tmp_ref[...]
tmp = pl.run_scoped(body, plgpu.SMEM((8, 128), jnp.float32))
self.assertEqual(tmp.shape, (8, 128))
o_ref[...] = tmp
inp = np.ones((8, 128), jnp.float32)
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
compiler_params=plgpu.GPUCompilerParams(
thread_semantics=thread_semantics
),
)
o = f(inp)
np.testing.assert_array_equal(o, inp + 1.0)
def test_program_id(self):
@functools.partial(
pl.pallas_call,
in_specs=(),
out_specs=pl.BlockSpec((128,), lambda *i: i),
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
grid=2,
)
def kernel(o_ref):
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
np.testing.assert_array_equal(
kernel(),
jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32),
)
def test_program_id_in_squashed_grid(self):
# Tests whether a grid with >3 logical dimensions is correctly squashed to
# 3 CUDA grid dimensions.
grid = (2, 3, 4, 5)
@functools.partial(
pl.pallas_call,
in_specs=(),
out_specs=pl.BlockSpec((1,) * len(grid) + (128,), lambda *i: (*i, 0)),
out_shape=jax.ShapeDtypeStruct([*grid, 128], jnp.int32),
grid=grid,
)
def kernel(o_ref):
mult = 1
idx = 0
for axis in range(len(grid)-1, -1, -1):
idx += pl.program_id(axis) * mult
mult *= pl.num_programs(axis)
o_ref[...] = jnp.full(o_ref.shape, idx)
np.testing.assert_array_equal(
kernel()[:, :, :, :, 0],
jnp.arange(math.prod(grid), dtype=jnp.int32).reshape(*grid)
)
def test_program_id_in_block_spec(self):
@functools.partial(
pl.pallas_call,
in_specs=(pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)),),
out_specs=pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)),
out_shape=jax.ShapeDtypeStruct([2, 128], jnp.int32),
grid=2,
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...]
x = jnp.arange(2 * 128, dtype=jnp.int32).reshape([2, 128])
np.testing.assert_array_equal(kernel(x), x)
def test_num_programs(self):
@functools.partial(
pl.pallas_call,
in_specs=(),
out_specs=pl.BlockSpec((128,), lambda *i: i),
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
grid=2,
)
def kernel(o_ref):
o_ref[...] = jnp.full(o_ref.shape, pl.num_programs(0), o_ref.dtype)
np.testing.assert_array_equal(
kernel(),
jnp.full([256], 2, dtype=jnp.int32),
)
def test_swizzled_blockspec_shapes(self):
spec = plgpu.GPUBlockSpec(
(128, 64),
lambda *i: i,
transforms=(
plgpu.TilingTransform((64, 64)),
plgpu.SwizzleTransform(128),
),
)
@functools.partial(
pl.pallas_call,
in_specs=[spec],
out_specs=spec,
out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16),
grid=(2, 2),
)
def kernel(x_ref, o_ref):
assert x_ref.shape == (128, 64), x_ref.shape
o_ref[...] = x_ref[...]
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
np.testing.assert_array_equal(kernel(x), x)
@parameterized.product(
force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics]
)
def test_fori_loop_array(self, force_while, thread_semantics):
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
# TODO(apaszke,bchetioui,slebedev): Support while + array carries.
self.skipTest("WG semantics unsupported")
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics),
)
def kernel(x_ref, o_ref):
# Equivalent to x_ref[...] + 2 + 3.
o_ref[...] = _fori_loop(force_while, 2, 4, lambda i, x: x + i, x_ref[...])
x = jnp.arange(256, dtype=jnp.int32)
np.testing.assert_array_equal(kernel(x), x + 2 + 3)
@parameterized.product(
force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics]
)
def test_fori_loop_scalar(self, force_while, thread_semantics):
if force_while and thread_semantics == plgpu.ThreadSemantics.Warpgroup:
self.skipTest("WG semantics does not support force_while.")
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics),
)
def kernel(o_ref):
# Equivalent to 2 + 3.
o_ref[...] = jax.lax.broadcast(
_fori_loop(force_while, 2, 4, lambda i, x: x + i, jnp.int32(0)),
o_ref.shape,
)
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, jnp.int32))
def test_fori_loop_dynamic_bounds(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
grid=(1,)
)
def kernel(o_ref):
zero = pl.program_id(0)
# Equivalent to 2 + 3.
o_ref[...] = jax.lax.broadcast(
jax.lax.fori_loop(2 + zero, 4 + zero, lambda i, x: x + i, 0), o_ref.shape
)
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
@parameterized.parameters(False, True)
def test_fori_loop_tuple(self, force_while):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
)
def kernel(o_ref):
def body(step, xs):
return tuple(
jax.lax.cond(step % 2 == 0, lambda x: x + 1, lambda x: x, x)
for x in xs
)
# Equivalent to 3 * (0 + 1).
o_ref[...] = jax.lax.broadcast(
sum(_fori_loop(force_while, 2, 4, body, (jnp.int32(0),) * 3)),
o_ref.shape,
)
np.testing.assert_array_equal(
kernel(), jnp.full([256], 3 * (0 + 1), jnp.int32)
)
@parameterized.parameters(False, True)
def test_fori_loop_indexed_store(self, force_while):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
def body(idx, _):
o_ref[idx] = x_ref[idx] + y_ref[idx]
return ()
_fori_loop(force_while, 0, 4, body, ())
x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
y = x + 1
np.testing.assert_array_equal(kernel(x, y), x + y)
def test_while_loop(self):
@functools.partial(
pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32)
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.zeros(o_ref.shape, dtype=jnp.int32)
def cond(acc):
_, last_o = acc
return _sum_same_dtype(last_o) < 128*10
def body(acc):
i, _ = acc
o_ref[...] += x_ref[i]
return i+1, o_ref[...]
_ = jax.lax.while_loop(cond, body, (0, o_ref[...]))
np.testing.assert_array_equal(
kernel(jnp.ones([128, 128], jnp.int32)), jnp.full([128], 10, jnp.int32)
)
def test_while_loop_layout_mismatch(self):
@functools.partial(
pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32)
)
def kernel(o_ref):
def cond(acc):
return _sum_same_dtype(acc) < 128
def body(acc):
del acc # Unused.
# We deliberately do a cast here to trigger a layout mismatch.
return plgpu.layout_cast(
jnp.zeros(o_ref.shape, o_ref.dtype), plgpu.Layout.WGMMA_ROW
)
_ = jax.lax.while_loop(cond, body, o_ref[...])
with self.assertRaisesRegex(ValueError, "has layout .*, when it should be"):
kernel()
@parameterized.parameters([*plgpu.ThreadSemantics])
def test_cond(self, thread_semantics):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics),
)
def kernel(x_ref, o_ref):
jax.lax.cond(
x_ref[0] % 2 == 0,
lambda: pl.debug_print("acc % 2"),
lambda: pl.debug_print("acc"),
)
o_ref[...] = jnp.broadcast_to(jnp.asarray(0, dtype=o_ref.dtype), o_ref.shape)
x = jnp.full((256,), 1234, dtype=jnp.int32)
with self.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertIn("acc % 2", output())
def test_cond_returning_array(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
)
def kernel(x_ref, o_ref):
acc = _sum_same_dtype(x_ref[...])
acc2, acc = jax.lax.cond(
acc % 2 == 0,
lambda: (acc * 2, acc),
lambda: (acc, acc * 2),
)
o_ref[...] = jnp.broadcast_to(acc + acc2, o_ref.shape)
x = jnp.arange(256, dtype=jnp.int32)
np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256]))
def test_input_output_aliases(self):
# Note that we're writing to the input pointer, which should alias b_ptr.
def kernel(a_ref, b_ref):
del b_ref
a_ref[...] = jnp.ones_like(a_ref)
a = np.zeros((64, 64), dtype=jnp.float32)
b = pl.pallas_call(
kernel,
in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)],
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM),
input_output_aliases={0: 0},
out_shape=a,
)(a)
np.testing.assert_array_equal(b, np.ones_like(a))
def test_slicing(self):
left = upper = slice(None, 64)
right = lower = slice(64, None)
# We rotate the four quadrants of the input clockwise.
def rotate(src, dst):
dst[upper, left] = src[lower, left]
dst[upper, right] = src[upper, left]
dst[lower, right] = src[upper, right]
dst[lower, left] = src[lower, right]
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
spec = plgpu.GPUBlockSpec(
(128, 128),
lambda: (0, 0),
transforms=(
plgpu.TilingTransform((64, 64)),
plgpu.SwizzleTransform(128),
),
)
f = pl.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec)
expected = np.empty_like(x)
rotate(x, expected)
np.testing.assert_array_equal(f(x), expected)
def test_layout_cast(self, shape=(256, 64)):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
)
def kernel(o_ref):
o_ref[...] = plgpu.layout_cast(jnp.full(shape, 42.0, jnp.float32), plgpu.Layout.WGMMA)
x = jnp.full(shape, 42.0, jnp.float32)
np.testing.assert_array_equal(kernel(), x)
def test_profiler(self):
def kernel(x_ref, o_ref):
with jax.named_scope("add"):
with jax.named_scope("load"):
x = x_ref[...]
o = x + x
with jax.named_scope("store"):
o_ref[...] = o
with tempfile.TemporaryDirectory() as tmpdir:
x = jnp.arange(256).astype(jnp.float32)
y = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
compiler_params=plgpu.GPUCompilerParams(
profile_space=16, profile_dir=tmpdir
),
)(x)
jax.block_until_ready(y)
jax.effects_barrier()
[name] = os.listdir(tmpdir)
with open(os.path.join(tmpdir, name), "r") as f:
data = f.read()
self.assertEqual(data.count('"name": "add"'), 2)
self.assertEqual(data.count('"name": "load"'), 2)
self.assertEqual(data.count('"name": "store"'), 2)
np.testing.assert_array_equal(y, x + x)
@parameterized.parameters(
(jnp.float16, jnp.float16), # Noop
(jnp.int16, jnp.bfloat16),
(jnp.int16, jnp.float16),
(jnp.uint16, jnp.float16),
(jnp.float32, jnp.int32),
(jnp.float32, jnp.uint32),
(jnp.uint32, jnp.int32),
(jnp.int32, jnp.uint32),
)
def test_bitcast_convert_type(self, in_dtype, out_dtype):
m, n = 16, 8
out_shape = jax.ShapeDtypeStruct((m, n), out_dtype)
grid = ()
@functools.partial(pl.pallas_call, out_shape=out_shape, grid=grid)
def convert(x_ref, y_ref):
y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape)
x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n))
y = convert(x)
y_ref = jax.lax.bitcast_convert_type(x, out_dtype)
np.testing.assert_array_equal(y, y_ref)
class PallasCallSm90ATest(PallasSm90ATest):
@parameterized.parameters(False, True)
def test_fori_loop_accumulator(self, force_while):
transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128))
@functools.partial(
pl.pallas_call,
in_specs=[plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms)],
out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16),
out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)),
)
def kernel(i_ref, o_ref):
def scope(acc_ref):
return _fori_loop(force_while, 0, 4, lambda _, v: v + acc_ref[...], acc_ref[...])
o_ref[...] = pl.run_state(scope)(plgpu.ACC.init(i_ref[...]))
acc_ini = jnp.ones((64, 64), dtype=jnp.float16)
np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16))
def test_realistic_matmul(self):
dtype = jnp.float16
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
grid_m, grid_k, grid_n = 132, 10, 4
tile_m = tile_n = 128
tile_k = elems_128b
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
def kernel(a_ref, b_ref, o_ref, acc_ref):
# Make sure tiling does not alter the shape of references
assert a_ref.shape == (tile_m, tile_k)
assert b_ref.shape == (tile_k, tile_n)
assert o_ref.shape == acc_ref.shape == (tile_m, tile_n)
plgpu.wgmma(acc_ref, a_ref, b_ref)
is_last_step = pl.program_id(2) == grid_k - 1
@pl.when(is_last_step)
def _epilogue():
o_ref[...] = acc_ref[...].astype(dtype)
plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=dtype)
b = jax.random.uniform(key2, shape=(k, n), dtype=dtype)
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(tile_m, tile_k),
lambda m, n, k: (m, k),
transforms=(
plgpu.TilingTransform((64, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
plgpu.GPUBlockSpec(
(tile_k, tile_n),
lambda m, n, k: (k, n),
transforms=(
plgpu.TilingTransform((elems_128b, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
],
out_specs=plgpu.GPUBlockSpec(
(tile_m, tile_n),
lambda m, n, k: (m, n),
transforms=(
plgpu.TilingTransform((64, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)],
grid=(grid_m, grid_n, grid_k),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "parallel", "sequential"],
max_concurrent_steps=2,
delay_release=1,
),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
@parameterized.parameters(jnp.float16, jnp.float32)
def test_wgmma(self, dtype):
# TensorCores can only fuse transposes of 16-bit values, and RHS
# is expected to be column major by default.
rhs_transpose = jnp.dtype(dtype).itemsize != 2
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
def kernel(a_ref, b_ref, o_ref):
if rhs_transpose:
b_ref = plgpu.transpose_ref(b_ref, (1, 0))
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref)
return acc_ref[...]
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype)
b_shape = (128, 192)
if rhs_transpose:
b_shape = b_shape[::-1]
b = jax.random.uniform(key2, shape=b_shape, dtype=dtype)
rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),)
if rhs_transpose:
rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),)
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(64, 128),
lambda i, j: (i, j),
transforms=(
plgpu.TilingTransform((64, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
plgpu.GPUBlockSpec(
b_shape,
lambda *i: i,
transforms=(*rhs_transforms, plgpu.SwizzleTransform(128)),
),
],
out_specs=plgpu.GPUBlockSpec((64, 192), lambda *i: i),
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
grid=(1, 1),
)(a, b)
np.testing.assert_allclose(
res, a @ (b.T if rhs_transpose else b), rtol=1e-3
)
def test_wgmma_registers(self):
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
return acc_ref[...]
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16)
transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128))
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms),
plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms),
],
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
def test_wgmma_registers_init(self):
def kernel(a_ref, b_ref, i_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
o_ref[...] = pl.run_state(scope)(plgpu.ACC.init(i_ref[...]))
key1, key2, key3 = jax.random.split(jax.random.key(42), 3)
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16)
i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10
transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128))
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms),
plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms),
plgpu.GPUBlockSpec((64, 192), lambda: (0, 0), transforms=transforms),
],
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16),
)(a, b, i)
np.testing.assert_allclose(res, i + a @ b, rtol=2e-3)
def test_wgmma_sliced_ref(self):
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0])
return acc_ref[...]
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(2, 64, 128), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16)
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(2, 64, 128), lambda: (0, 0, 0),
transforms=(
plgpu.TilingTransform((64, 64)),
plgpu.SwizzleTransform(128),
),
),
plgpu.GPUBlockSpec(
(2, 128, 192), lambda: (0, 0, 0),
transforms=(
plgpu.TilingTransform((64, 64)),
plgpu.SwizzleTransform(128),
),
),
],
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
)(a, b)
np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3)
def test_wgmma_sliced_acc(self):
swizzle = 128
elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref)
return acc_ref[:, :64], acc_ref[:, 64:]
o_ref[:, :64], o_ref[:, 64:] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32))
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16)
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(64, 128),
lambda i, j: (i, j),
transforms=(
plgpu.TilingTransform((64, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
plgpu.GPUBlockSpec(
(128, 128),
lambda *i: i,
transforms=(
plgpu.TilingTransform((elems_128b, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
],
out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i),
out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32),
grid=(1, 1),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
class PipelineTest(PallasTest):
def test_pipeline_mode(self):
def body(x_ref, y_ref, o_ref):
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y
data_size = 64 * 256
block_size = 256
x = jnp.arange(data_size, dtype=jnp.float32)
y = jnp.arange(data_size, dtype=jnp.float32)
in_specs = [
pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(2)),
pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(1))
]
out_specs = pl.BlockSpec((block_size,), lambda *i: i)
@jax.jit
def vadd(x, y):
return pl.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32),
in_specs=in_specs,
out_specs=out_specs,
grid=data_size // block_size,
)(x, y)
with self.assertRaisesRegex(Exception, "Pipeline mode is not supported"):
vadd(x, y)
def test_manual(self):
max_concurrent_steps = 2
num_steps = 4
def kernel(x_gmem, o_gmem):
return pl.run_scoped(
functools.partial(scoped_kernel, x_gmem, o_gmem),
plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32),
plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32),
plgpu.Barrier(1, num_barriers=max_concurrent_steps),
)
def scoped_kernel(x_gmem, o_gmem, x_smem, o_smem, barrier):
gmem_slice = pl.ds(pl.program_id(0) * 32, 32)
def body(step, _):
slot = step % max_concurrent_steps
# Wait for the current GMEM->SMEM copy to complete.
plgpu.barrier_wait(barrier.at[slot])
# Wait for the previous output SMEM->GMEM copy to complete.
plgpu.wait_smem_to_gmem(max_concurrent_steps - 1)
o_smem.at[slot][...] = x_smem.at[slot][...] + 1.0
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)]
)
fetch_step = step + max_concurrent_steps
fetch_slot = slot # (x + y) % y == x % y
jax.lax.cond(
fetch_step < num_steps,
lambda: plgpu.copy_gmem_to_smem(
x_gmem.at[gmem_slice, pl.ds(fetch_step * 16, 16)],
x_smem.at[fetch_slot],
barrier.at[fetch_slot],
),
lambda: None,
)
return ()
# Initialize the pipeline.
for slot in range(min(max_concurrent_steps, num_steps)):
plgpu.copy_gmem_to_smem(
x_gmem.at[gmem_slice, pl.ds(slot * 16, 16)],
x_smem.at[slot],
barrier.at[slot],
)
jax.lax.fori_loop(0, num_steps, body, ())
# Finalize the pipeline.
plgpu.wait_smem_to_gmem(0)
x = jnp.arange(32 * 4 * 64).reshape(32 * 4, 64).astype(jnp.float32)
kernel_fn = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid=(4, 1),
)
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
@parameterized.parameters(
((),),
((plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)),),
)
def test_emit(self, transforms):
num_steps = 4
def kernel(x_gmem, o_gmem):
plgpu.emit_pipeline(
kernel_body,
in_specs=[
plgpu.GPUBlockSpec(
(64, 64), lambda i: (0, i), transforms=transforms
)
],
out_specs=[
plgpu.GPUBlockSpec(
(64, 64), lambda i: (0, i), transforms=transforms
)
],
grid=(num_steps,),
max_concurrent_steps=2,
)(x_gmem, o_gmem)
def kernel_body(x_smem, o_smem):
# +1 for the indexing done by ``emit_pipeline`.
self.assertLen(x_smem.transforms, len(transforms) + 1)
o_smem[...] = x_smem[...] + 1.0
x = jnp.arange(64 * num_steps * 64)
x = x.reshape(-1, num_steps * 64).astype(jnp.float32)
kernel_fn = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
)
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
def test_nested_emit(self):
num_steps = 4
def kernel(x_gmem, o_gmem):
plgpu.emit_pipeline(
nested_kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
grid=(),
)(x_gmem, o_gmem)
def nested_kernel(x_gmem, o_gmem):
plgpu.emit_pipeline(
nested_kernel_body,
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
grid=(num_steps,),
max_concurrent_steps=2,
)(x_gmem, o_gmem)
def nested_kernel_body(x_smem, o_smem):
o_smem[...] = x_smem[...] + 1.0
x = jnp.arange(32 * num_steps * 16)
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
kernel_fn = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
)
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
def test_emit_with_grid_invariant_output(self):
num_steps = 4
def kernel(x_gmem, o_gmem):
plgpu.emit_pipeline(
kernel_body,
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, 0))],
grid=(num_steps,),
max_concurrent_steps=2,
)(x_gmem, o_gmem)
def kernel_body(x_smem, o_smem):
o_smem[...] = x_smem[...] + 1.0
x = jnp.arange(32 * num_steps * 16)
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
kernel_fn = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
)
y = jnp.empty_like(x)
for i in range(num_steps):
i_slice = slice(16 * i, 16 * (i + 1))
y = y.at[:, :16].set(x[:, i_slice] + 1)
# We only compare the elements in the first 16 columns, because the rest
# are never written to.
np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16])
def test_emit_with_parallel_grid(self):
num_steps1 = 4
num_steps2 = 5
def kernel(x_gmem, o_gmem):
pid = pl.program_id(0)
plgpu.emit_pipeline(
kernel_body,
in_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))],
out_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))],
grid=(num_steps2,),
max_concurrent_steps=2,
)(x_gmem, o_gmem)
def kernel_body(x_smem, o_smem):
o_smem[...] = x_smem[...] + 1.0
x = jnp.arange(num_steps1 * 32 * num_steps2 * 16)
x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32)
kernel_fn = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid=(num_steps1,),
)
y = x + 1.0
np.testing.assert_array_equal(kernel_fn(x), y)
def test_emit_with_2d_grid(self):
num_steps1 = 4
num_steps2 = 5
def kernel(x_gmem, o_gmem):
plgpu.emit_pipeline(
kernel_body,
in_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))],
out_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))],
grid=(num_steps1, num_steps2),
max_concurrent_steps=2,
)(x_gmem, o_gmem)
def kernel_body(x_smem, o_smem):
o_smem[...] = x_smem[...] + 1.0
x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8)
x = x.reshape(-1, num_steps1 * 16, num_steps2 * 8).astype(jnp.float32)
kernel_fn = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
)
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
class PipelineSm90ATest(PallasSm90ATest):
def test_realistic_matmul(self):
dtype = jnp.float16
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
grid_m, grid_k, grid_n = 132, 10, 4
tile_m = tile_n = 128
assert tile_m % elems_128b == 0
tile_k = elems_128b
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
def kernel(a_gmem, b_gmem, o_smem, acc):
def kernel_body(a_smem, b_smem):
assert a_smem.shape == (tile_m, tile_k)
assert b_smem.shape == (tile_k, tile_n)
plgpu.wgmma(acc, a_smem, b_smem)
plgpu.wgmma_wait(1)
pid_m = pl.program_id(0)
pid_n = pl.program_id(1)
plgpu.emit_pipeline(
kernel_body,
in_specs=[
plgpu.GPUBlockSpec(
(tile_m, tile_k),
lambda k: (pid_m, k),
transforms=(
plgpu.TilingTransform((64, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
plgpu.GPUBlockSpec(
(tile_k, tile_n),
lambda k: (k, pid_n),
transforms=(
plgpu.TilingTransform((elems_128b, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
],
grid=(grid_k,),
max_concurrent_steps=2,
delay_release=1,
)(a_gmem, b_gmem)
o_smem[...] = acc[...].astype(dtype)
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=dtype)
b = jax.random.uniform(key2, shape=(k, n), dtype=dtype)
res = pl.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=plgpu.GMEM),
pl.BlockSpec(memory_space=plgpu.GMEM)
],
out_specs=plgpu.GPUBlockSpec(
(tile_m, tile_n),
lambda m, n: (m, n),
transforms=(
plgpu.TilingTransform((64, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)],
grid=(grid_m, grid_n),
)(a, b)
np.testing.assert_array_equal(res, a @ b)
class WarpSpecializedPipelineTest(PallasTest):
@parameterized.product(m=[512], n=[512],
manual_consumed_barriers=[False, True])
def test_pipelined_copy(self, m, n, manual_consumed_barriers):
x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16)
o = jnp.zeros((m, n), dtype=jnp.float16)
blk_m = blk_n = 64
o_last_block = jnp.zeros((blk_m, blk_n), dtype=jnp.float16)
def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers):
# TODO(justinfu): Have each wg compute a separate slice
# after multiple-indexers are supported.
# This is currently a race, but the values written are the same.
o_smem[...] = x_smem[...]
o_last_block_smem[...] = x_smem[...]
if manual_consumed_barriers:
[x_barrier] = consumed_barriers
plgpu.barrier_arrive(x_barrier)
block_spec = plgpu.GPUBlockSpec(
block_shape=(blk_m, blk_n),
index_map=lambda i, j: (i, j),
transforms=[],
)
pipeline = mgpu_pipeline.emit_pipeline_warp_specialized(
copy_kernel,
grid=(m // blk_m, n // blk_n),
memory_registers=40,
max_concurrent_steps=2,
num_compute_wgs=2,
wg_axis="wg",
manual_consumed_barriers=manual_consumed_barriers,
in_specs=[block_spec],
out_specs=[block_spec,
# Create an index-invariant output.
plgpu.GPUBlockSpec(block_shape=(blk_m, blk_n),
index_map=lambda i, j: (0, 0))
],
)
mesh = plgpu.GPUMesh(grid=(1,), num_threads=3, axis_names=("_", "wg"))
def run(refs):
@pl.core_map(
mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True)
)
def _kernel_entry():
pipeline(*refs)
@jax.jit
def run_function(x, o, o_last_block):
_, out, out_last = pl.run_state(run)((x, o, o_last_block))
return (out, out_last)
out, out_last_block = run_function(x, o, o_last_block)
np.testing.assert_array_equal(out, x)
np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:])
def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2):
blk_m = blk_n = 64
x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32)
y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32)
o = jnp.zeros((m, n), dtype=jnp.float32)
def tiled_add_kernel(x_smem, y_smem, o_smem):
# TODO(justinfu): Have each wg compute a separate slice
# after multiple-indexers are supported.
# This is currently a race, but the values written are the same.
o_smem[...] = x_smem[...] + y_smem[...]
pipeline = mgpu_pipeline.emit_pipeline_warp_specialized(
tiled_add_kernel,
grid=(m // blk_m, n // blk_n),
max_concurrent_steps=2,
num_compute_wgs=num_compute_wgs,
memory_registers=40,
wg_axis="wg",
in_specs=[
plgpu.GPUBlockSpec(
block_shape=(blk_m, blk_n),
index_map=lambda i, j: (i, j),
transforms=[]),
plgpu.GPUBlockSpec(
block_shape=(blk_m, blk_n),
index_map=lambda i, j: (i, j),
transforms=[]),
],
out_specs=[
plgpu.GPUBlockSpec(
block_shape=(blk_m, blk_n),
index_map=lambda i, j: (i, j),
transforms=[])],
)
mesh = plgpu.GPUMesh(
grid=(1,), num_threads=num_compute_wgs + 1, axis_names=("_", "wg")
)
def run(refs):
@pl.core_map(
mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True)
)
def _kernel_entry():
pipeline(*refs)
@jax.jit
def run_function(x, y, o):
_, _, out = pl.run_state(run)((x, y, o))
return out
out = run_function(x, y, o)
reference = x + y
np.testing.assert_allclose(out, reference, atol=1e-4)
def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2):
blk_m = blk_n = 64
x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32)
acc_init = jnp.zeros((blk_m, blk_n), dtype=jnp.float32)
def _scoped(acc_smem, x_gmem, acc_gmem):
def _compute_thread():
# Cast the init value to the same layout as x_smem, so the pipeline loop
# carry has a constant signature.
o_acc = plgpu.layout_cast(
jnp.full((blk_m, blk_n,), 0, dtype=jnp.float32),
plgpu.Layout.WG_STRIDED((blk_m, blk_n), vec_size=2))
carry_init = (o_acc,)
# Pass control to the pipeline emitter and return the final carry.
final_carry = (yield carry_init)
o_final, = final_carry
# Note that both compute WGs are doing identical work so the potential
# race condition on the store here won't affect the result.
acc_smem[...] = o_final
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(acc_smem, acc_gmem)
plgpu.wait_smem_to_gmem(0)
def tiled_acc_kernel(x_smem, carry):
o_carry, = carry
new_carry = x_smem[...] + o_carry
return (new_carry,)
pipeline = mgpu_pipeline.emit_pipeline_warp_specialized(
tiled_acc_kernel,
grid=(m // blk_m, n // blk_n),
max_concurrent_steps=2,
num_compute_wgs=num_compute_wgs,
memory_registers=40,
wg_axis="wg",
carry_coroutine=_compute_thread,
in_specs=[
plgpu.GPUBlockSpec(
block_shape=(blk_m, blk_n),
index_map=lambda i, j: (i, j),
transforms=[]),
],
out_specs=[],
)
pipeline(x_gmem)
mesh = plgpu.GPUMesh(
grid=(1,),
num_threads=num_compute_wgs + 1,
axis_names=("_", "wg",),
)
def run(refs):
x_ref, acc_ref = refs
@pl.core_map(mesh)
def _kernel_entry():
pl.run_scoped(
functools.partial(_scoped, x_gmem=x_ref, acc_gmem=acc_ref),
plgpu.SMEM((blk_m, blk_n), jnp.float32)
)
@jax.jit
def run_function(x, acc):
_, out_acc = pl.run_state(run)((x, acc))
return out_acc
out_acc = run_function(x, acc_init)
ref = jnp.sum(jnp.stack(np.split(x, m // blk_m, axis=0)), axis=0)
ref = jnp.sum(jnp.stack(np.split(ref, n // blk_n, axis=1)), axis=0)
np.testing.assert_allclose(out_acc, ref, atol=1e-4)
class CoreMapTest(PallasTest):
def test_multiple_wg(self):
mesh = plgpu.GPUMesh(num_threads=2, axis_names=("y",))
@jax.jit
def f():
@pl.run_state
def inner(y_ref):
@pl.core_map(mesh)
def kernel():
wg_idx = jax.lax.axis_index("y")
y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,))
y_init = jnp.zeros((2, 128), np.int32)
return inner(y_init)
np.testing.assert_array_equal(
f(), np.repeat(np.arange(2), 128).reshape(2, 128)
)
def test_multiple_wg_with_grid(self):
mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg"))
@jax.jit
def f():
@pl.run_state
def inner(y_ref):
@pl.core_map(mesh)
def kernel():
xy_idx = jax.lax.axis_index(("x", "y"))
yx_idx = jax.lax.axis_index(("y", "x"))
wg_idx = jax.lax.axis_index("wg")
num_wgs = jax.lax.psum(1, "wg")
y_ref[xy_idx, wg_idx] = jnp.broadcast_to(
yx_idx * num_wgs + wg_idx, (128,)
)
y_init = jnp.zeros((4, 2, 128), np.int32)
return inner(y_init)
np.testing.assert_array_equal(
f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128)
)
def test_multiple_wg_with_squashed_grid(self):
# Tests whether a grid with >3 logical dimensions is correctly squashed to
# 3 CUDA grid dimensions.
b = 4
x_dim = 3
y_dim = 5
z_dim = 7
num_threads = 2
mesh = plgpu.GPUMesh(grid=(b, x_dim, y_dim, z_dim),
num_threads=num_threads,
axis_names=("b", "x", "y", "z", "wg"))
@jax.jit
def f():
@pl.run_state
def inner(y_ref):
@pl.core_map(mesh)
def _():
b_idx = jax.lax.axis_index("b")
x_idx = jax.lax.axis_index("x")
y_idx = jax.lax.axis_index("y")
z_idx = jax.lax.axis_index("z")
wg_idx = jax.lax.axis_index("wg")
bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg"))
y_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to(
bxyzw_idx, (128,)
)
y_init = jnp.zeros((b, x_dim, y_dim, z_dim, num_threads, 128), np.int32)
return inner(y_init)
result = f()[:, :, :, :, :, 0]
ref = np.arange(b * x_dim * y_dim * z_dim * num_threads).reshape(
result.shape)
np.testing.assert_array_equal(result, ref)
def test_cross_wg_barrier(self):
mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",))
@jax.jit
def f():
@pl.run_state
def inner(y_ref):
@pl.core_map(mesh)
def kernel():
def scoped(barrier):
plgpu.barrier_arrive(barrier)
plgpu.barrier_wait(barrier)
wg_idx = jax.lax.axis_index("wg")
y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,))
# Each warpgroup is a single logical thread!
pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2))
y_init = jnp.zeros((2, 128), np.int32)
return inner(y_init)
np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128))
class ExamplesTest(PallasTest):
# Basic
def test_stage0(self):
def body(l_ref, r_ref, o_ref):
o_ref[...] = l_ref[...] + r_ref[...]
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x)(x, x)
np.testing.assert_allclose(out, x + x)
# Multi-block kernels
def test_stage1(self):
row_block = 64
def body(l_ref, r_ref, o_ref):
my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block)
o_ref[my_slice] = l_ref[my_slice] + r_ref[my_slice]
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x)
np.testing.assert_allclose(out, x + x)
# Async copies
def test_stage3(self):
row_block, col_block = 64, 128
def body(l_ref, r_ref, o_ref):
my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block)
def scoped(l_smem, r_smem, o_smem, barrier):
plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier)
plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier)
plgpu.barrier_wait(barrier)
o_smem[...] = l_smem[...] + r_smem[...]
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice])
plgpu.wait_smem_to_gmem(0)
pl.run_scoped(
scoped,
*([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3),
plgpu.Barrier(num_arrivals=2),
)
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x)
np.testing.assert_allclose(out, x + x)
# Pipelining
def test_stage4(self):
row_block, col_block = 64, 32
def body(l_ref, r_ref, o_ref):
def compute(l_smem, r_smem, o_smem):
o_smem[...] = l_smem[...] + r_smem[...]
r = lax.axis_index("rows")
block = pl.BlockSpec((row_block, col_block), lambda c: (r, c))
plgpu.emit_pipeline(
compute,
grid=(l_ref.shape[1] // col_block,),
in_specs=[block] * 2,
out_specs=[block],
)(l_ref, r_ref, o_ref)
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x)
np.testing.assert_allclose(out, x + x)
# Transforms
def test_stage5(self):
row_block, col_block = 64, 32
def body(l_ref, r_ref, o_ref):
def compute(l_smem, r_smem, o_smem):
o_smem[...] = l_smem[...] + r_smem[...]
r = lax.axis_index("rows")
block = plgpu.GPUBlockSpec(
(row_block, col_block), lambda c: (r, c),
transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)),
)
plgpu.emit_pipeline(
compute,
grid=(l_ref.shape[1] // col_block,),
in_specs=[block] * 2,
out_specs=[block],
)(l_ref, r_ref, o_ref)
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x)
np.testing.assert_allclose(out, x + x)
class ExamplesSm90ATest(PallasSm90ATest):
# WGMMA
def test_stage6(self):
m_block = n_block = 64
k_block = 32
def body(l_ref, r_ref, o_ref):
def compute(l_smem, r_smem, o_smem):
def do_wgmma(acc_ref):
plgpu.wgmma(acc_ref, l_smem, r_smem)
return acc_ref[...]
o_smem[...] += pl.run_scoped(do_wgmma, plgpu.ACC((m_block, n_block), jnp.float16))
m, n = lax.axis_index("m"), lax.axis_index("n")
lo_transforms = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64))
r_transforms = (plgpu.TilingTransform((32, 32)), plgpu.SwizzleTransform(64))
plgpu.emit_pipeline(
compute,
grid=(l_ref.shape[1] // k_block,),
in_specs=[plgpu.GPUBlockSpec((m_block, k_block), lambda k: (m, k), transforms=lo_transforms),
plgpu.GPUBlockSpec((k_block, n_block), lambda k: (k, n), transforms=r_transforms)],
out_specs=[plgpu.GPUBlockSpec((m_block, n_block), lambda k: (m, n), transforms=lo_transforms)],
)(l_ref, r_ref, o_ref)
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x, grid=(2, 2), axis_names=("m", "n"))(x, x)
np.testing.assert_allclose(out, x @ x)
# TODO(apaszke): Clusters and multicast
if __name__ == "__main__":
absltest.main()