mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
2768 lines
89 KiB
Python
2768 lines
89 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import math
|
|
import operator
|
|
import os
|
|
import re
|
|
import tempfile
|
|
from typing import ClassVar
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import test_util as jtu
|
|
from jax._src.pallas import pallas_call
|
|
from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering
|
|
from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline
|
|
from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives
|
|
from jax._src.state import discharge
|
|
from jax.experimental import pallas as pl
|
|
import jax.experimental.mosaic.gpu as mgpu
|
|
from jax.experimental.pallas import mosaic_gpu as plgpu
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
try:
|
|
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
|
|
except ImportError:
|
|
mosaic_gpu_lib = None
|
|
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
def _fori_loop(force_while: bool, lb, ub, body, init):
|
|
if force_while:
|
|
# using jnp.asarray make the matcher for while or scan to think
|
|
# that the bounds are dynamic and forces the use of the while
|
|
# primitive.
|
|
lb, ub = jnp.asarray(lb), jnp.asarray(ub)
|
|
return jax.lax.fori_loop(lb, ub, body, init)
|
|
|
|
|
|
def _sum_same_dtype(x):
|
|
# TODO(slebedev): Remove this once ``FragmentedArray`` supports
|
|
# ``reduce_sum`` for non-32-bit types.
|
|
return jnp.sum(x, dtype=x.dtype)
|
|
|
|
|
|
class PallasTestMetaclass(parameterized.TestGeneratorMetaclass):
|
|
|
|
def __new__(mcs, *args, thread_semantics=plgpu.ThreadSemantics.Lane):
|
|
cls = super().__new__(mcs, *args)
|
|
cls.THREAD_SEMANTICS = thread_semantics
|
|
return cls
|
|
|
|
|
|
class PallasTest(jtu.JaxTestCase, metaclass=PallasTestMetaclass):
|
|
THREAD_SEMANTICS: ClassVar[plgpu.ThreadSemantics]
|
|
|
|
def setUp(self):
|
|
if not jtu.is_cuda_compute_capability_at_least("9.0"):
|
|
self.skipTest("Only works on a GPU with capability >= sm90")
|
|
context_stack = contextlib.ExitStack()
|
|
context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True))
|
|
self.addCleanup(context_stack.close)
|
|
|
|
super().setUp()
|
|
|
|
def skip_if_wg_semantics(self):
|
|
if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Warpgroup:
|
|
self.skipTest("Not supported under WG semantics")
|
|
|
|
def kernel(self, *args, **kwargs):
|
|
compiler_params = dataclasses.replace(
|
|
kwargs.pop("compiler_params", plgpu.GPUCompilerParams()),
|
|
thread_semantics=self.THREAD_SEMANTICS,
|
|
)
|
|
return plgpu.kernel(*args, compiler_params=compiler_params, **kwargs)
|
|
|
|
def pallas_call(self, *args, **kwargs):
|
|
compiler_params = dataclasses.replace(
|
|
kwargs.pop("compiler_params", plgpu.GPUCompilerParams()),
|
|
thread_semantics=self.THREAD_SEMANTICS,
|
|
)
|
|
return pl.pallas_call(*args, compiler_params=compiler_params, **kwargs)
|
|
|
|
@contextlib.contextmanager
|
|
def capture_stdout(self):
|
|
if mosaic_gpu_lib is None:
|
|
raise ValueError("Running tests but missing Mosaic GPU extension")
|
|
with jtu.capture_stdout() as stdout:
|
|
yield stdout
|
|
# We need to cudaDeviceSynchronize to make sure printfs are flushed.
|
|
mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices()
|
|
|
|
|
|
class PallasSm90ATest(PallasTest, jtu.CudaArchSpecificTest):
|
|
|
|
def setUp(self):
|
|
self.skip_unless_sm90a()
|
|
super().setUp()
|
|
|
|
|
|
class PallasSm100ATest(PallasTest, jtu.CudaArchSpecificTest):
|
|
|
|
def setUp(self):
|
|
self.skip_unless_sm100a()
|
|
super().setUp()
|
|
|
|
|
|
class PallasCallTest(PallasTest):
|
|
|
|
@parameterized.product(
|
|
op=[
|
|
lax.neg,
|
|
lax.bitwise_not,
|
|
lax.logistic,
|
|
lax.exp,
|
|
lambda x: x**2,
|
|
lax.rsqrt,
|
|
lax.tanh,
|
|
lax.log,
|
|
],
|
|
approx_math=[True, False],
|
|
)
|
|
def test_unary_op(self, op, approx_math):
|
|
dtype = jnp.int32 if op is lax.bitwise_not else jnp.float32
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], dtype),
|
|
compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = op(x_ref[...])
|
|
|
|
x = jnp.arange(256).astype(dtype)
|
|
np.testing.assert_allclose(
|
|
kernel(x), op(x), rtol=1e-5 if approx_math else 3e-7
|
|
)
|
|
|
|
@parameterized.product(
|
|
op=[
|
|
operator.add,
|
|
lambda x, _: x + 1, # for int->vector conversion
|
|
operator.sub,
|
|
operator.mul,
|
|
lax.div,
|
|
jnp.minimum,
|
|
jnp.maximum,
|
|
],
|
|
dtype=[jnp.float32, jnp.int32, jnp.uint32],
|
|
)
|
|
def test_binary_op(self, op, dtype):
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype)
|
|
)
|
|
def kernel(x_ref, y_ref, o_ref):
|
|
o_ref[...] = op(x_ref[...], y_ref[...])
|
|
|
|
key0, key1 = jax.random.split(jax.random.key(0), 2)
|
|
x = (jax.random.uniform(key0, [256]) * 42).astype(dtype)
|
|
y = (jax.random.uniform(key1, [256]) * 42).astype(dtype)
|
|
np.testing.assert_array_equal(kernel(x, y), op(x, y))
|
|
|
|
@parameterized.product(
|
|
op=[
|
|
lax.eq,
|
|
operator.ne,
|
|
operator.lt,
|
|
operator.le,
|
|
operator.gt,
|
|
operator.ge,
|
|
],
|
|
# TODO(slebedev): Support integral types.
|
|
dtype=[jnp.float32, jnp.int32, jnp.uint32],
|
|
)
|
|
def test_comparison_op(self, op, dtype):
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype)
|
|
)
|
|
def kernel(o_ref):
|
|
o_ref[...] = jnp.broadcast_to(
|
|
op(dtype(42), dtype(24)).astype(dtype), o_ref.shape
|
|
)
|
|
|
|
np.testing.assert_array_equal(kernel(), jnp.full([256], op(42, 24), dtype))
|
|
|
|
def test_add_first(self):
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
)
|
|
def kernel(x_ref, y_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + y_ref[0]
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
y = jnp.flip(x).reshape(1, 256)
|
|
np.testing.assert_array_equal(kernel(x, y), x + y[0])
|
|
|
|
@parameterized.product(shape=[(128,), (128, 128)])
|
|
def test_reduce_sum(self, shape):
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32)
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = jnp.broadcast_to(_sum_same_dtype(x_ref[...]), o_ref.shape)
|
|
|
|
x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), jnp.sum(x))
|
|
|
|
def test_reshape(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
shape1, shape2 = (128,), (2, 16, 4)
|
|
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32)
|
|
)
|
|
def kernel(x_ref, out_ref):
|
|
x_ref_reshaped = x_ref.reshape(shape2)
|
|
self.assertEqual(x_ref.shape, shape1)
|
|
self.assertEqual(x_ref_reshaped.shape, shape2)
|
|
out_ref[...] = x_ref_reshaped[...]
|
|
|
|
x = jnp.arange(math.prod(shape1)).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x.reshape(shape2))
|
|
|
|
def test_add_xy_indexed(self):
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32)
|
|
)
|
|
def kernel(x_ref, y_ref, o_ref):
|
|
idx = _sum_same_dtype(y_ref[...])
|
|
o_ref[...] = x_ref[idx]
|
|
|
|
x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
|
|
y = jnp.zeros(128, dtype=jnp.int32)
|
|
np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)])
|
|
|
|
def test_add_one_grid(self):
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
in_specs=[pl.BlockSpec((128,), lambda *i: i)],
|
|
out_specs=pl.BlockSpec((128,), lambda *i: i),
|
|
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32),
|
|
grid=2,
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + 1.0
|
|
|
|
x = jnp.arange(128 * 2).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
|
|
|
def test_add_one_grid_with_scratch(self):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32),
|
|
in_specs=[pl.BlockSpec((128,), lambda *i: i)],
|
|
out_specs=pl.BlockSpec((128,), lambda *i: i),
|
|
scratch_shapes=[plgpu.SMEM((128,), jnp.float32)],
|
|
grid=2,
|
|
)
|
|
def kernel(x_ref, o_ref, scratch_ref):
|
|
scratch_ref[...] = x_ref[...] + 1
|
|
o_ref[...] = scratch_ref[...]
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
|
|
|
@parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16])
|
|
def test_add_one_grid_pipelined(self, max_concurrent_steps):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))],
|
|
out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)),
|
|
out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32),
|
|
compiler_params=plgpu.GPUCompilerParams(
|
|
dimension_semantics=["parallel", "sequential"],
|
|
max_concurrent_steps=max_concurrent_steps,
|
|
),
|
|
grid=(2, 4),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + 1.0
|
|
|
|
x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
|
|
|
def test_add_one_grid_pipelined_program_id(self):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)),
|
|
out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32),
|
|
compiler_params=plgpu.GPUCompilerParams(
|
|
dimension_semantics=["parallel", "sequential"],
|
|
max_concurrent_steps=2,
|
|
),
|
|
grid=(4, 4),
|
|
)
|
|
def kernel(o_ref):
|
|
o_ref[...] = jnp.broadcast_to(pl.program_id(1), o_ref.shape)
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(),
|
|
jnp.repeat(jnp.repeat(jnp.arange(4), 16)[None], 16, axis=0),
|
|
)
|
|
|
|
def test_add_one_grid_pipelined_sequential_invariant_output(self):
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
in_specs=[pl.BlockSpec((32, 16), lambda i, j: (i, j))],
|
|
out_specs=pl.BlockSpec((32, 16), lambda i, j: (i, 0)),
|
|
out_shape=jax.ShapeDtypeStruct([32 * 2, 64], jnp.float32),
|
|
compiler_params=plgpu.GPUCompilerParams(
|
|
dimension_semantics=["parallel", "sequential"],
|
|
max_concurrent_steps=2,
|
|
),
|
|
grid=(2, 4),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + 1.0
|
|
|
|
x = jnp.arange(32 * 2 * 64).reshape((32 * 2, 64)).astype(jnp.float32)
|
|
y = jnp.empty_like(x)
|
|
for i in range(2):
|
|
i_slice = slice(32 * i, 32 * (i + 1))
|
|
for j in range(4):
|
|
j_slice = slice(16 * j, 16 * (j + 1))
|
|
y = y.at[i_slice, :16].set(x[i_slice, j_slice] + 1)
|
|
|
|
# We only compare the elements in the first 16 columns, because the rest
|
|
# are never written to.
|
|
np.testing.assert_array_equal(kernel(x)[:, :16], y[:, :16])
|
|
|
|
@parameterized.parameters(jnp.float32, jnp.int32, jnp.uint32)
|
|
def test_iota(self, dtype):
|
|
self.skip_if_wg_semantics()
|
|
|
|
dimension = 1
|
|
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct((128, 128), dtype)
|
|
)
|
|
def kernel(o_ref):
|
|
o_ref[...] = plgpu.broadcasted_iota(
|
|
dtype, o_ref.shape, dimension, layout=plgpu.Layout.WGMMA
|
|
)
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension)
|
|
)
|
|
|
|
def test_inline_mgpu(self):
|
|
dtype = jnp.bfloat16
|
|
self.skip_if_wg_semantics()
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct((128, 128), dtype),
|
|
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
|
scratch_shapes=[
|
|
plgpu.SMEM((128, 128), dtype),
|
|
plgpu.Barrier(num_arrivals=1),
|
|
],
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
)
|
|
def kernel(x_ref, o_ref, smem_ref, barrier):
|
|
plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier)
|
|
plgpu.barrier_wait(barrier)
|
|
arr = jnp.ones_like(x_ref)
|
|
@plgpu.inline_mgpu(
|
|
smem_ref,
|
|
o_ref,
|
|
arr,
|
|
arg_types=[plgpu.RefType(), plgpu.RefType(), plgpu.Layout.WG_SPLAT(x_ref.shape)],
|
|
)
|
|
def _(ctx, smem_ref, o_ref, y):
|
|
del ctx
|
|
x = mgpu.FragmentedArray.load_strided(smem_ref)
|
|
(x + y).store_untiled(o_ref)
|
|
|
|
key = jax.random.key(0)
|
|
x = (jax.random.uniform(key, (128, 128)) * 42).astype(dtype)
|
|
np.testing.assert_array_equal(kernel(x), x + 1)
|
|
|
|
@parameterized.product(indexer=[..., slice(128), slice(None, 128)])
|
|
def test_copy_smem_to_gmem(self, indexer):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
scratch_shapes=[plgpu.SMEM((256,), jnp.float32)],
|
|
)
|
|
def kernel(x_ref, o_ref_gmem, scratch_ref):
|
|
scratch_ref[...] = x_ref[...] + 1
|
|
plgpu.commit_smem()
|
|
plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref_gmem.at[indexer])
|
|
plgpu.wait_smem_to_gmem(0)
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)
|
|
|
|
@parameterized.parameters(jnp.bfloat16, jnp.float16, jnp.float32)
|
|
def test_copy_smem_to_gmem_reduction(self, dtype):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
grid=(200,),
|
|
in_specs=[pl.BlockSpec((128,), lambda *i: i), pl.BlockSpec(memory_space=plgpu.GMEM)],
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
out_shape=jax.ShapeDtypeStruct([128], dtype),
|
|
scratch_shapes=[plgpu.SMEM((128,), dtype)],
|
|
input_output_aliases={1:0}
|
|
)
|
|
def kernel(x_ref, o_ref_gmem, o_ref_gmem_alias, scratch_ref):
|
|
del o_ref_gmem_alias
|
|
scratch_ref[...] = x_ref[...]
|
|
plgpu.commit_smem()
|
|
plgpu.copy_smem_to_gmem(scratch_ref.at[...], o_ref_gmem.at[...], reduction_op="add")
|
|
plgpu.wait_smem_to_gmem(0)
|
|
x = jnp.ones(200 * 128).astype(dtype) # 200 blocks
|
|
output = jnp.zeros(128).astype(dtype)
|
|
output = kernel(x, output)
|
|
output_val = x.reshape(-1, 128).sum(axis=0)
|
|
np.testing.assert_array_equal(output, output_val)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "1d_none",
|
|
"shape": (256,), "indexers": (slice(0, 128), slice(None, 32))},
|
|
{"testcase_name": "1d_offset",
|
|
"shape": (256,), "indexers": (slice(32, 96), slice(0, 32))},
|
|
{"testcase_name": "2d_extract",
|
|
"shape": (64, 64), "indexers": (4, slice(0, 64))},
|
|
)
|
|
def test_copy_smem_to_gmem_with_multiple_gmem_indexers(self, shape, indexers):
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
scratch_shapes=[plgpu.SMEM(shape, jnp.float32)],
|
|
)
|
|
def kernel(x_ref, o_ref_gmem, scratch_ref):
|
|
scratch_ref[...] = x_ref[...] + 1
|
|
plgpu.commit_smem()
|
|
for indexer in indexers:
|
|
scratch_ref = scratch_ref.at[indexer]
|
|
o_ref_gmem = o_ref_gmem.at[indexer]
|
|
plgpu.copy_smem_to_gmem(scratch_ref, o_ref_gmem)
|
|
plgpu.wait_smem_to_gmem(0)
|
|
|
|
x = jnp.arange(np.prod(shape)).astype(jnp.float32).reshape(*shape)
|
|
result = kernel(x)
|
|
ref = x + 1.0
|
|
for indexer in indexers:
|
|
result = result[indexer]
|
|
ref = ref[indexer]
|
|
np.testing.assert_array_equal(result, ref)
|
|
|
|
@parameterized.product(indexer=[..., slice(128), slice(None, 128)])
|
|
def test_copy_gmem_to_smem(self, indexer):
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
|
scratch_shapes=[
|
|
plgpu.SMEM((256,), jnp.float32),
|
|
plgpu.Barrier(num_arrivals=1),
|
|
],
|
|
)
|
|
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
|
|
plgpu.copy_gmem_to_smem(
|
|
x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier_ref
|
|
)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
o_ref[...] = scratch_ref[...] + 1
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)
|
|
|
|
@parameterized.named_parameters(
|
|
{
|
|
"testcase_name": "1d_none",
|
|
"shape": (256,),
|
|
"indexers": (slice(0, 128), slice(None, 32)),
|
|
},
|
|
{
|
|
"testcase_name": "1d_offset",
|
|
"shape": (256,),
|
|
"indexers": (slice(32, 96), slice(0, 32)),
|
|
},
|
|
{
|
|
"testcase_name": "2d_extract_static",
|
|
"shape": (64, 64),
|
|
"indexers": (4, slice(0, 64)),
|
|
},
|
|
{
|
|
"testcase_name": "2d_extract_dyn",
|
|
"shape": (64, 64),
|
|
"indexers": lambda in_dev: (
|
|
pl.program_id(0) + 4 if in_dev else jnp.array(4),
|
|
slice(0, 64),
|
|
),
|
|
},
|
|
)
|
|
def test_copy_gmem_to_smem_with_multiple_gmem_indexers(self, shape, indexers):
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
|
|
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
|
scratch_shapes=[
|
|
plgpu.SMEM(shape, jnp.float32),
|
|
plgpu.Barrier(num_arrivals=1),
|
|
],
|
|
grid=(1,),
|
|
)
|
|
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
|
|
scratch_ref_sliced = scratch_ref
|
|
for indexer in indexers(True) if callable(indexers) else indexers:
|
|
scratch_ref_sliced = scratch_ref_sliced.at[indexer]
|
|
x_ref_gmem = x_ref_gmem.at[indexer]
|
|
plgpu.copy_gmem_to_smem(
|
|
x_ref_gmem, scratch_ref_sliced, barrier_ref
|
|
)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
o_ref[...] = scratch_ref[...] + 1
|
|
|
|
x = jnp.arange(np.prod(shape)).astype(jnp.float32).reshape(*shape)
|
|
result = kernel(x)
|
|
ref = x + 1.0
|
|
for indexer in indexers(False) if callable(indexers) else indexers:
|
|
result = result[indexer]
|
|
ref = ref[indexer]
|
|
np.testing.assert_array_equal(result, ref)
|
|
|
|
def test_gmem_to_smem_with_multiple_smem_indexers(self):
|
|
x = jax.random.uniform(jax.random.key(0), (2, 64, 64), dtype=jnp.float32)
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32),
|
|
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
|
scratch_shapes=[
|
|
plgpu.SMEM(x.shape, jnp.float32),
|
|
plgpu.Barrier(num_arrivals=1),
|
|
],
|
|
)
|
|
def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
|
|
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier_ref)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
x_sliced = scratch_ref.at[0, :, :] # shape=(64, 64)
|
|
o_ref[pl.ds(0, 32), :] = x_sliced[pl.ds(0, 32), :]
|
|
o_ref[pl.ds(32, 32), :] = x_sliced[pl.ds(32, 32), :]
|
|
np.testing.assert_array_equal(extract_x0(x), x[0])
|
|
|
|
def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
x = jnp.arange(512 * 512, dtype=jnp.int32).reshape(512, 512)
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
grid=(4, 4),
|
|
out_shape=jax.ShapeDtypeStruct((256, 128), jnp.int32),
|
|
in_specs=(
|
|
plgpu.GPUBlockSpec(
|
|
block_shape=(128, 128),
|
|
index_map=lambda i, j: (i, j),
|
|
memory_space=plgpu.SMEM,
|
|
transforms=(
|
|
plgpu.TilingTransform((8, 32)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
),
|
|
),
|
|
out_specs=(
|
|
plgpu.GPUBlockSpec(
|
|
block_shape=(64, 32),
|
|
index_map=lambda i, j: (i, j),
|
|
memory_space=plgpu.SMEM,
|
|
)
|
|
),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
x_sliced = x_ref.at[0:64, 32:96].at[:, 0:32] # get x_ref[0:64, 32:64]
|
|
o_ref[...] = x_sliced[...]
|
|
ref = jnp.concatenate([x[blk:blk+64, :] for blk in range(0, 512, 128)])
|
|
ref = jnp.concatenate(
|
|
[ref[:, blk+32:blk+64] for blk in range(0, 512, 128)], axis=1)
|
|
np.testing.assert_array_equal(kernel(x), ref)
|
|
|
|
@parameterized.product(indexer=[0, 1, 2, 3])
|
|
def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer):
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
|
|
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
|
scratch_shapes=[
|
|
plgpu.SMEM((128,), jnp.float32),
|
|
plgpu.Barrier(num_arrivals=1, num_barriers=4),
|
|
],
|
|
)
|
|
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
|
|
plgpu.copy_gmem_to_smem(
|
|
x_ref_gmem, scratch_ref, barrier_ref.at[indexer]
|
|
)
|
|
plgpu.barrier_wait(barrier_ref.at[indexer])
|
|
o_ref[...] = scratch_ref[...] + 1
|
|
|
|
x = jnp.arange(128).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
|
|
|
@parameterized.named_parameters(("_g2s", False), ("_s2g", True))
|
|
def test_copy_with_transforms(self, to_smem):
|
|
self.skip_if_wg_semantics()
|
|
|
|
def kernel(x_ref, o_ref, barrier_ref):
|
|
if to_smem:
|
|
plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier_ref)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
else:
|
|
plgpu.commit_smem()
|
|
plgpu.copy_smem_to_gmem(x_ref, o_ref)
|
|
plgpu.wait_smem_to_gmem(0)
|
|
|
|
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
|
out_spec = plgpu.GPUBlockSpec(
|
|
transforms=(
|
|
plgpu.TilingTransform((8, 32)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
memory_space=plgpu.SMEM,
|
|
)
|
|
if not to_smem:
|
|
in_spec, out_spec = out_spec, in_spec
|
|
f = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
|
|
in_specs=(in_spec,),
|
|
out_specs=out_spec,
|
|
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
|
|
)
|
|
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
|
|
np.testing.assert_array_equal(f(x), x)
|
|
|
|
def test_scoped_copy_with_transforms(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
ts = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128))
|
|
def kernel(x_ref, o_ref, barrier_ref):
|
|
def body(tmp_ref):
|
|
plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
o_ref[...] = tmp_ref[...] * 2
|
|
pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts))
|
|
|
|
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
|
out_spec = plgpu.GPUBlockSpec(transforms=ts, memory_space=plgpu.SMEM)
|
|
f = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
|
|
in_specs=(in_spec,),
|
|
out_specs=out_spec,
|
|
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
|
|
)
|
|
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
|
|
np.testing.assert_array_equal(f(x), x * 2)
|
|
|
|
def test_scoped_copy_with_user_transforms(self):
|
|
def kernel(x_ref, o_ref, barrier_ref):
|
|
def body(tmp_ref):
|
|
tmp_ref = plgpu.unswizzle_ref(tmp_ref, 128)
|
|
tmp_ref = plgpu.untile_ref(tmp_ref, (8, 32))
|
|
plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
o_ref[...] = tmp_ref[...] * 2
|
|
pl.run_scoped(body, plgpu.SMEM((16, 4, 8, 32), jnp.float32))
|
|
|
|
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
|
f = pl.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
|
|
in_specs=(in_spec,),
|
|
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
|
|
)
|
|
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
|
|
np.testing.assert_array_equal(f(x), x * 2)
|
|
|
|
def test_copy_with_transforms_and_indexing(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
def kernel(x_ref, o_ref, barrier_ref):
|
|
for i in range(2):
|
|
plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier_ref)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
|
|
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
|
out_spec = plgpu.GPUBlockSpec(
|
|
transforms=(
|
|
plgpu.TilingTransform((8, 32)),
|
|
plgpu.TransposeTransform((0, 2, 1, 3, 4)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
memory_space=plgpu.SMEM,
|
|
)
|
|
f = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct([2, 128, 128], jnp.float32),
|
|
in_specs=(in_spec,),
|
|
out_specs=out_spec,
|
|
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
|
|
)
|
|
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
|
|
np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0))
|
|
|
|
@parameterized.product(
|
|
src_memory_space=[plgpu.SMEM, plgpu.GMEM],
|
|
layout=[plgpu.Layout.WG_STRIDED((128,), vec_size=1), None,
|
|
]
|
|
)
|
|
def test_load_to_strided_layout_with_indexing(self, src_memory_space, layout):
|
|
self.skip_if_wg_semantics()
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32),
|
|
in_specs=[pl.BlockSpec(memory_space=src_memory_space)],
|
|
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
for i in range(2):
|
|
x = plgpu.load(x_ref, (i,), layout=layout)
|
|
o_ref[i, ...] = x
|
|
|
|
x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128)
|
|
np.testing.assert_array_equal(kernel(x), x)
|
|
|
|
@parameterized.product(
|
|
src_memory_space=[plgpu.SMEM, plgpu.GMEM],
|
|
layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL],
|
|
m=[64, 128, 192],
|
|
)
|
|
def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layout, m):
|
|
self.skip_if_wg_semantics()
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32),
|
|
in_specs=[pl.BlockSpec(memory_space=src_memory_space)],
|
|
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
for i in range(2):
|
|
x = plgpu.load(x_ref, (i,), layout=layout)
|
|
o_ref[i, ...] = x
|
|
|
|
x = jnp.arange(2 * m, dtype=jnp.float32).reshape(2, m)
|
|
np.testing.assert_array_equal(kernel(x), x)
|
|
|
|
@parameterized.product(
|
|
src_memory_space=[plgpu.SMEM],
|
|
layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL],
|
|
)
|
|
def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space, layout):
|
|
self.skip_if_wg_semantics()
|
|
|
|
m, k, n = 64, 128, 192
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
if layout == plgpu.Layout.WGMMA_ROW:
|
|
input_shape = (m,)
|
|
broadcast_dim = 0
|
|
expand_dim = 1
|
|
else:
|
|
input_shape = (k,)
|
|
broadcast_dim = 1
|
|
expand_dim = 0
|
|
a = jax.random.uniform(key1, shape=input_shape, dtype=jnp.float16)
|
|
b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)
|
|
def kernel(x_ref, y_ref, o_ref):
|
|
x = plgpu.load(x_ref, (), layout=layout)
|
|
x = lax.broadcast_in_dim(x, (m, k), [broadcast_dim])
|
|
|
|
def compute(acc_ref):
|
|
plgpu.wgmma(acc_ref, x, y_ref)
|
|
return acc_ref[...]
|
|
|
|
out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32))
|
|
o_ref[...] = out
|
|
f = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32),
|
|
in_specs=(
|
|
pl.BlockSpec(memory_space=src_memory_space),
|
|
plgpu.GPUBlockSpec(
|
|
transforms=(
|
|
plgpu.TilingTransform((8, 64)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
),
|
|
),
|
|
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM),
|
|
)
|
|
|
|
out_ref = (
|
|
jnp.broadcast_to(jnp.expand_dims(a, axis=expand_dim), (m, k)) @ b
|
|
)
|
|
np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3)
|
|
|
|
def test_indexing_before_transpose(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
def kernel(x_ref, o_ref, barrier_ref):
|
|
for i in range(2):
|
|
plgpu.copy_gmem_to_smem(
|
|
x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier_ref
|
|
)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
|
|
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
|
out_spec = plgpu.GPUBlockSpec(memory_space=plgpu.SMEM)
|
|
f = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32),
|
|
in_specs=(in_spec,),
|
|
out_specs=out_spec,
|
|
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
|
|
)
|
|
x = jnp.arange(2 * 64 * 128, dtype=jnp.float32).reshape(2, 64, 128)
|
|
xt = x.transpose((1, 0, 2))
|
|
np.testing.assert_array_equal(f(x), np.stack([xt, xt], axis=0))
|
|
|
|
def test_copy_gmem_to_smem_in_run_scoped(self):
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
|
)
|
|
def kernel(x_ref_gmem, o_ref):
|
|
def body(barrier_ref):
|
|
def inner_body(scratch_ref):
|
|
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier_ref)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
o_ref[...] = scratch_ref[...] + 1
|
|
pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32))
|
|
pl.run_scoped(body, plgpu.Barrier(num_arrivals=1))
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
|
|
|
def test_add_doubled_sum(self):
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + jnp.sum(x_ref[...]) + jnp.sum(x_ref[...])
|
|
|
|
x = jnp.arange(128).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + x.sum()*2)
|
|
|
|
@parameterized.product(input_factor=[0.001, 1, 10, 100, 100])
|
|
def test_layer_norm(self, input_factor):
|
|
eps = 1e-5
|
|
gamma = 1.0
|
|
beta = 1.0
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
)
|
|
def layer_norm(x_ref, o_ref):
|
|
x_mean = jnp.mean(x_ref[...])
|
|
x_centered = x_ref[...] - x_mean
|
|
o_ref[...] = (
|
|
x_centered * jax.lax.rsqrt(jnp.mean(x_centered**2) + eps) * gamma
|
|
+ beta
|
|
)
|
|
|
|
def layer_norm_np(x):
|
|
x_mean = np.mean(x)
|
|
x_centered = x - x_mean
|
|
return (x_centered / np.sqrt(np.mean(x_centered**2) + eps) * gamma) + beta
|
|
|
|
# Ones are always fully precise
|
|
x = jnp.ones((256,)).astype(jnp.float32) * input_factor
|
|
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x))
|
|
|
|
# random (and anything else is not)
|
|
x = (
|
|
jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32)
|
|
* input_factor
|
|
)
|
|
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=5e-5)
|
|
|
|
def test_print(self):
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
del x_ref, o_ref
|
|
pl.debug_print("It works!")
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
with self.capture_stdout() as output:
|
|
jax.block_until_ready(kernel(x))
|
|
self.assertEqual(output(), "It works!\n")
|
|
|
|
def test_print_wgmma_tiled_layout(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
shape = (128, 64)
|
|
size = math.prod(shape)
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(
|
|
transforms=(
|
|
plgpu.TilingTransform((8, 32)),
|
|
plgpu.SwizzleTransform(128),
|
|
)
|
|
)
|
|
],
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
del o_ref # Unused.
|
|
pl.debug_print("prefix {}", x_ref[...])
|
|
|
|
x = jnp.arange(size, dtype=jnp.float32).reshape(shape)
|
|
with self.capture_stdout() as get_output:
|
|
jax.block_until_ready(kernel(x))
|
|
|
|
output = get_output()
|
|
results = re.findall(r"prefix \[(\d+), (\d+)\]: (\d+).?\d*", output)
|
|
self.assertLen(results, size, output)
|
|
for i, j, v in results:
|
|
i, j, v = map(int, (i, j, v))
|
|
self.assertEqual(v, i * shape[1] + j)
|
|
|
|
def test_print_scalar(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
del o_ref
|
|
pl.debug_print("x.sum() = {}", _sum_same_dtype(x_ref[...]))
|
|
|
|
x = jnp.arange(256, dtype=jnp.int32)
|
|
with self.capture_stdout() as output:
|
|
jax.block_until_ready(kernel(x))
|
|
|
|
self.assertIn(f"x.sum() = {x.sum()}", output())
|
|
|
|
def test_print_scalar_array(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
del o_ref
|
|
pl.debug_print("x.sum() = {}", _sum_same_dtype(x_ref[...]) + 1)
|
|
|
|
x = jnp.arange(256, dtype=jnp.int32)
|
|
with self.capture_stdout() as output:
|
|
jax.block_until_ready(kernel(x))
|
|
|
|
self.assertIn(f"x.sum() = {x.sum() + 1}", output())
|
|
|
|
def test_print_array(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
in_shape = [2, 1, 64, 64]
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct(in_shape, jnp.int32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
del o_ref
|
|
pl.debug_print("x: {}", x_ref[...])
|
|
|
|
x = jnp.arange(math.prod(in_shape), dtype=jnp.int32).reshape(in_shape)
|
|
with self.capture_stdout() as output:
|
|
jax.block_until_ready(kernel(x))
|
|
|
|
self.assertIn("x: [1, 0, 43, 23]: 6871\n", output())
|
|
|
|
def test_load_scalar(self):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct((128,), jnp.int32),
|
|
in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)],
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = jnp.broadcast_to(x_ref[10], (128,))
|
|
|
|
np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)),
|
|
jnp.full((128,), 10, dtype=jnp.int32))
|
|
|
|
def test_run_scoped(self):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
def body(tmp_ref):
|
|
self.assertEqual(tmp_ref.shape, (8, 128))
|
|
tmp_ref[...] = x_ref[...] + 1.0
|
|
return tmp_ref[...]
|
|
|
|
tmp = pl.run_scoped(body, plgpu.SMEM((8, 128), jnp.float32))
|
|
self.assertEqual(tmp.shape, (8, 128))
|
|
o_ref[...] = tmp
|
|
|
|
x = np.ones((8, 128), jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
|
|
|
def test_program_id(self):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
in_specs=(),
|
|
out_specs=pl.BlockSpec((128,), lambda *i: i),
|
|
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
|
|
grid=2,
|
|
)
|
|
def kernel(o_ref):
|
|
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(),
|
|
jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32),
|
|
)
|
|
|
|
def test_program_id_in_squashed_grid(self):
|
|
# Tests whether a grid with >3 logical dimensions is correctly squashed to
|
|
# 3 CUDA grid dimensions.
|
|
grid = (2, 3, 4, 5)
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
in_specs=(),
|
|
out_specs=pl.BlockSpec((1,) * len(grid) + (128,), lambda *i: (*i, 0)),
|
|
out_shape=jax.ShapeDtypeStruct([*grid, 128], jnp.int32),
|
|
grid=grid,
|
|
)
|
|
def kernel(o_ref):
|
|
mult = 1
|
|
idx = 0
|
|
for axis in range(len(grid)-1, -1, -1):
|
|
idx += pl.program_id(axis) * mult
|
|
mult *= pl.num_programs(axis)
|
|
o_ref[...] = jnp.full(o_ref.shape, idx)
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel()[:, :, :, :, 0],
|
|
jnp.arange(math.prod(grid), dtype=jnp.int32).reshape(*grid)
|
|
)
|
|
|
|
def test_program_id_in_block_spec(self):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
in_specs=(pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)),),
|
|
out_specs=pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)),
|
|
out_shape=jax.ShapeDtypeStruct([2, 128], jnp.int32),
|
|
grid=2,
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = x_ref[...]
|
|
|
|
x = jnp.arange(2 * 128, dtype=jnp.int32).reshape([2, 128])
|
|
np.testing.assert_array_equal(kernel(x), x)
|
|
|
|
def test_num_programs(self):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
in_specs=(),
|
|
out_specs=pl.BlockSpec((128,), lambda *i: i),
|
|
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
|
|
grid=2,
|
|
)
|
|
def kernel(o_ref):
|
|
o_ref[...] = jnp.full(o_ref.shape, pl.num_programs(0), o_ref.dtype)
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(),
|
|
jnp.full([256], 2, dtype=jnp.int32),
|
|
)
|
|
|
|
def test_swizzled_blockspec_shapes(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
spec = plgpu.GPUBlockSpec(
|
|
(128, 64),
|
|
lambda *i: i,
|
|
transforms=(
|
|
plgpu.TilingTransform((8, 64)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
)
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
in_specs=[spec],
|
|
out_specs=spec,
|
|
out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16),
|
|
grid=(2, 2),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
assert x_ref.shape == (128, 64), x_ref.shape
|
|
o_ref[...] = x_ref[...]
|
|
|
|
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
|
|
np.testing.assert_array_equal(kernel(x), x)
|
|
|
|
@parameterized.product(force_while=[False, True])
|
|
def test_fori_loop_array(self, force_while):
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32)
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
# Equivalent to x_ref[...] + 2 + 3.
|
|
o_ref[...] = _fori_loop(force_while, 2, 4, lambda i, x: x + i, x_ref[...])
|
|
|
|
x = jnp.arange(256, dtype=jnp.int32)
|
|
np.testing.assert_array_equal(kernel(x), x + 2 + 3)
|
|
|
|
@parameterized.product(force_while=[False, True])
|
|
def test_fori_loop_scalar(self, force_while):
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32)
|
|
)
|
|
def kernel(o_ref):
|
|
# Equivalent to 2 + 3.
|
|
o_ref[...] = jax.lax.broadcast(
|
|
_fori_loop(force_while, 2, 4, lambda i, x: x + i, jnp.int32(0)),
|
|
o_ref.shape,
|
|
)
|
|
|
|
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, jnp.int32))
|
|
|
|
def test_fori_loop_dynamic_bounds(self):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
|
grid=(1,)
|
|
)
|
|
def kernel(o_ref):
|
|
zero = pl.program_id(0)
|
|
# Equivalent to 2 + 3.
|
|
o_ref[...] = jax.lax.broadcast(
|
|
jax.lax.fori_loop(2 + zero, 4 + zero, lambda i, x: x + i, 0), o_ref.shape
|
|
)
|
|
|
|
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
|
|
|
|
@parameterized.product(force_while=[False, True])
|
|
def test_fori_loop_tuple(self, force_while):
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32)
|
|
)
|
|
def kernel(o_ref):
|
|
def body(step, xs):
|
|
return tuple(
|
|
jax.lax.cond(step % 2 == 0, lambda x: x + 1, lambda x: x, x)
|
|
for x in xs
|
|
)
|
|
|
|
# Equivalent to 3 * (0 + 1).
|
|
o_ref[...] = jax.lax.broadcast(
|
|
sum(_fori_loop(force_while, 2, 4, body, (jnp.int32(0),) * 3)),
|
|
o_ref.shape,
|
|
)
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(), jnp.full([256], 3 * (0 + 1), jnp.int32)
|
|
)
|
|
|
|
@parameterized.product(force_while=[False, True])
|
|
def test_fori_loop_indexed_store(self, force_while):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32),
|
|
)
|
|
def kernel(x_ref, y_ref, o_ref):
|
|
def body(idx, _):
|
|
o_ref[idx] = x_ref[idx] + y_ref[idx]
|
|
return ()
|
|
|
|
_fori_loop(force_while, 0, 4, body, ())
|
|
|
|
x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
|
|
y = x + 1
|
|
np.testing.assert_array_equal(kernel(x, y), x + y)
|
|
|
|
def test_while_loop(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32)
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = jnp.zeros(o_ref.shape, dtype=jnp.int32)
|
|
|
|
def cond(acc):
|
|
_, last_o = acc
|
|
return _sum_same_dtype(last_o) < 128*10
|
|
|
|
def body(acc):
|
|
i, _ = acc
|
|
o_ref[...] += x_ref[i]
|
|
return i+1, o_ref[...]
|
|
|
|
_ = jax.lax.while_loop(cond, body, (0, o_ref[...]))
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(jnp.ones([128, 128], jnp.int32)), jnp.full([128], 10, jnp.int32)
|
|
)
|
|
|
|
def test_while_loop_layout_mismatch(self):
|
|
self.skip_if_wg_semantics() # `plgpu.layout_cast` is not supported.
|
|
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32)
|
|
)
|
|
def kernel(o_ref):
|
|
def cond(acc):
|
|
return _sum_same_dtype(acc) < 128
|
|
|
|
def body(acc):
|
|
del acc # Unused.
|
|
|
|
# We deliberately do a cast here to trigger a layout mismatch.
|
|
return plgpu.layout_cast(
|
|
jnp.zeros(o_ref.shape, o_ref.dtype), plgpu.Layout.WGMMA_ROW
|
|
)
|
|
|
|
_ = jax.lax.while_loop(cond, body, o_ref[...])
|
|
|
|
with self.assertRaisesRegex(ValueError, "has layout .*, when it should be"):
|
|
kernel()
|
|
|
|
def test_cond(self):
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32)
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
jax.lax.cond(
|
|
x_ref[0] % 2 == 0,
|
|
lambda: pl.debug_print("acc % 2"),
|
|
lambda: pl.debug_print("acc"),
|
|
)
|
|
o_ref[...] = jnp.broadcast_to(jnp.asarray(0, dtype=o_ref.dtype), o_ref.shape)
|
|
|
|
x = jnp.full((256,), 1234, dtype=jnp.int32)
|
|
with self.capture_stdout() as output:
|
|
jax.block_until_ready(kernel(x))
|
|
|
|
self.assertIn("acc % 2", output())
|
|
|
|
def test_cond_returning_array(self):
|
|
@functools.partial(
|
|
self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32)
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
acc_sum = _sum_same_dtype(x_ref[...])
|
|
acc2, acc = jax.lax.cond(
|
|
acc_sum % 2 == 0,
|
|
lambda: (acc_sum * 2, x_ref[...]),
|
|
lambda: (acc_sum, x_ref[...]),
|
|
)
|
|
o_ref[...] = jnp.broadcast_to(_sum_same_dtype(acc) + acc2, o_ref.shape)
|
|
|
|
x = jnp.arange(256, dtype=jnp.int32)
|
|
np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256]))
|
|
|
|
def test_tile_slicing(self):
|
|
# Not testing with warpgroup semantics, because we want to enforce a layout.
|
|
self.skip_if_wg_semantics()
|
|
|
|
shape = (256, 128)
|
|
block_spec = plgpu.GPUBlockSpec(
|
|
transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
|
|
)
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
in_specs=[block_spec],
|
|
out_specs=block_spec,
|
|
out_shape=jax.ShapeDtypeStruct((64, 64), jnp.uint16),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
def sum_tiles(row, acc):
|
|
row_slice = pl.ds(row * 64, 64)
|
|
for col in range(128 // 64):
|
|
acc += x_ref[row_slice, pl.ds(col * 64, 64)]
|
|
return acc
|
|
acc = plgpu.layout_cast(jnp.zeros((64, 64), jnp.uint16), plgpu.Layout.WGMMA)
|
|
o_ref[...] = _fori_loop(False, 0, 256 // 64, sum_tiles, acc)
|
|
|
|
x = jnp.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape)
|
|
y = x.reshape(256 // 64, 64, 128 // 64, 64).sum(axis=(0, 2), dtype=jnp.uint16)
|
|
np.testing.assert_array_equal(kernel(x), y)
|
|
|
|
def test_input_output_aliases(self):
|
|
# Note that we're writing to the input pointer, which should alias b_ptr.
|
|
def kernel(a_ref, b_ref):
|
|
del b_ref
|
|
a_ref[...] = jnp.ones_like(a_ref)
|
|
|
|
a = np.zeros((64, 64), dtype=jnp.float32)
|
|
b = self.pallas_call(
|
|
kernel,
|
|
in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)],
|
|
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM),
|
|
input_output_aliases={0: 0},
|
|
out_shape=a,
|
|
)(a)
|
|
np.testing.assert_array_equal(b, np.ones_like(a))
|
|
|
|
def test_slicing(self):
|
|
self.skip_if_wg_semantics()
|
|
|
|
left = upper = slice(None, 64)
|
|
right = lower = slice(64, None)
|
|
# We rotate the four quadrants of the input clockwise.
|
|
def rotate(src, dst):
|
|
dst[upper, left] = src[lower, left]
|
|
dst[upper, right] = src[upper, left]
|
|
dst[lower, right] = src[upper, right]
|
|
dst[lower, left] = src[lower, right]
|
|
|
|
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
|
|
spec = plgpu.GPUBlockSpec(
|
|
transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
|
|
)
|
|
f = self.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec)
|
|
expected = np.empty_like(x)
|
|
rotate(x, expected)
|
|
np.testing.assert_array_equal(f(x), expected)
|
|
|
|
def test_layout_cast(self, shape=(256, 64)):
|
|
self.skip_if_wg_semantics() # `plgpu.layout_cast` is not supported.
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
|
|
)
|
|
def kernel(o_ref):
|
|
o_ref[...] = plgpu.layout_cast(jnp.full(shape, 42.0, jnp.float32), plgpu.Layout.WGMMA)
|
|
|
|
x = jnp.full(shape, 42.0, jnp.float32)
|
|
np.testing.assert_array_equal(kernel(), x)
|
|
|
|
def test_wgmma_transposed_layout(self):
|
|
"""Tests that the result of wgmma can be store transposed using
|
|
the WGMMA_TRNASPOSED layout.
|
|
"""
|
|
|
|
dtype = jnp.dtype(jnp.float16)
|
|
swizzle_elems = 128 // dtype.itemsize
|
|
shape = (128, 128)
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct(shape, dtype),
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
scratch_shapes=[
|
|
plgpu.SMEM(
|
|
shape, dtype,
|
|
transforms=(
|
|
plgpu.TilingTransform((8, swizzle_elems)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
)
|
|
]
|
|
)
|
|
def kernel(o_ref, smem):
|
|
iota = plgpu.broadcasted_iota(
|
|
dtype, o_ref.shape, 0, layout=plgpu.Layout.WGMMA
|
|
) * o_ref.shape[0]
|
|
iota += plgpu.broadcasted_iota(
|
|
dtype, o_ref.shape, 1, layout=plgpu.Layout.WGMMA
|
|
)
|
|
|
|
smem_trns = plgpu.transpose_ref(smem, (1, 0))
|
|
smem_trns[...] = plgpu.layout_cast(iota, plgpu.Layout.WGMMA_TRANSPOSED)
|
|
plgpu.commit_smem()
|
|
plgpu.copy_smem_to_gmem(smem, o_ref)
|
|
|
|
x = jnp.arange(128 * 128, dtype=dtype).reshape((128, 128)).T
|
|
np.testing.assert_array_equal(kernel(), x)
|
|
|
|
def test_profiler(self):
|
|
self.skip_if_wg_semantics() # Transform inference fails.
|
|
|
|
def kernel(x_ref, o_ref):
|
|
with jax.named_scope("add"):
|
|
with jax.named_scope("load"):
|
|
x = x_ref[...]
|
|
o = x + x
|
|
with jax.named_scope("store"):
|
|
o_ref[...] = o
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
y = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
compiler_params=plgpu.GPUCompilerParams(
|
|
profile_space=16, profile_dir=tmpdir
|
|
),
|
|
)(x)
|
|
jax.block_until_ready(y)
|
|
jax.effects_barrier()
|
|
[name] = os.listdir(tmpdir)
|
|
with open(os.path.join(tmpdir, name), "r") as f:
|
|
data = f.read()
|
|
self.assertEqual(data.count('"name": "add"'), 2)
|
|
self.assertEqual(data.count('"name": "load"'), 2)
|
|
self.assertEqual(data.count('"name": "store"'), 2)
|
|
np.testing.assert_array_equal(y, x + x)
|
|
|
|
@parameterized.product(
|
|
dtypes=[
|
|
(jnp.float16, jnp.float16), # Noop
|
|
(jnp.int16, jnp.bfloat16),
|
|
(jnp.int16, jnp.float16),
|
|
(jnp.uint16, jnp.float16),
|
|
(jnp.float32, jnp.int32),
|
|
(jnp.float32, jnp.uint32),
|
|
(jnp.uint32, jnp.int32),
|
|
(jnp.int32, jnp.uint32),
|
|
],
|
|
)
|
|
def test_bitcast_convert_type(self, dtypes):
|
|
in_dtype, out_dtype = dtypes
|
|
m, n = 16, 8
|
|
out_shape = jax.ShapeDtypeStruct((m, n), out_dtype)
|
|
|
|
@functools.partial(self.pallas_call, out_shape=out_shape)
|
|
def convert(x_ref, y_ref):
|
|
y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape)
|
|
|
|
x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n))
|
|
np.testing.assert_array_equal(
|
|
convert(x), jax.lax.bitcast_convert_type(x, out_dtype)
|
|
)
|
|
|
|
def test_optimization_barrier(self):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = lax.optimization_barrier(x_ref[...])
|
|
|
|
x = jax.lax.iota(jnp.float32, 128)
|
|
np.testing.assert_array_equal(kernel(x), x)
|
|
|
|
def test_optimization_barrier_multiple_inputs(self):
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
|
|
)
|
|
def kernel(x_ref, y_ref, o_ref):
|
|
x, y = lax.optimization_barrier([x_ref[...], y_ref[...]])
|
|
o_ref[...] = x + y
|
|
|
|
x = jax.lax.iota(jnp.float32, 128)
|
|
y = jax.lax.iota(jnp.float32, 128) * 3
|
|
np.testing.assert_array_equal(kernel(x, y), x + y)
|
|
|
|
|
|
class PallasCallWGTest(
|
|
PallasCallTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup
|
|
):
|
|
...
|
|
|
|
def test_missing_primitive_lowerings_are_tracked(self):
|
|
# This test is a way to keep track of which primitives need to be adapted
|
|
# to using warpgroup semantics. Once the set is empty, we should be able to
|
|
# enable warpgroup semantics by default (assuming we haven't overspecialized
|
|
# lowerings).
|
|
rules = mgpu_lowering.mosaic_lowering_rules
|
|
wg_lowered_primitives = set(rules[plgpu.ThreadSemantics.Warpgroup])
|
|
lane_lowered_primitives = set(rules[plgpu.ThreadSemantics.Lane])
|
|
|
|
actual_missing_primitives = lane_lowered_primitives - wg_lowered_primitives
|
|
expected_missing_primitives = {
|
|
mgpu_primitives.inline_mgpu_p,
|
|
mgpu_primitives.broadcasted_iota_p,
|
|
mgpu_primitives.layout_cast_p,
|
|
mgpu_primitives.load_p,
|
|
lax.slice_p,
|
|
discharge.run_state_p,
|
|
}
|
|
|
|
self.assertSetEqual(actual_missing_primitives, expected_missing_primitives)
|
|
|
|
|
|
class PallasCallSm90ATest(PallasSm90ATest):
|
|
|
|
@parameterized.parameters(False, True)
|
|
def test_fori_loop_accumulator(self, force_while):
|
|
# ``pl.run_state`` is not supported in WG semantics.
|
|
self.skip_if_wg_semantics()
|
|
|
|
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
in_specs=[plgpu.GPUBlockSpec((64, 64), transforms=transforms)],
|
|
out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16),
|
|
out_specs=plgpu.GPUBlockSpec((64, 64)),
|
|
)
|
|
def kernel(i_ref, o_ref):
|
|
def scope(acc_ref):
|
|
return _fori_loop(force_while, 0, 4, lambda _, v: v + acc_ref[...], acc_ref[...])
|
|
o_ref[...] = pl.run_state(scope)(plgpu.ACC.init(i_ref[...]))
|
|
|
|
acc_ini = jnp.ones((64, 64), dtype=jnp.float16)
|
|
np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16))
|
|
|
|
def test_realistic_matmul(self):
|
|
dtype = jnp.float16
|
|
swizzle = 128
|
|
elems_128b = swizzle // jnp.dtype(dtype).itemsize
|
|
grid_m, grid_k, grid_n = 132, 10, 4
|
|
tile_m = tile_n = 128
|
|
tile_k = elems_128b
|
|
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
|
|
def kernel(a_ref, b_ref, o_ref, acc_ref):
|
|
# Make sure tiling does not alter the shape of references
|
|
assert a_ref.shape == (tile_m, tile_k)
|
|
assert b_ref.shape == (tile_k, tile_n)
|
|
assert o_ref.shape == acc_ref.shape == (tile_m, tile_n)
|
|
plgpu.wgmma(acc_ref, a_ref, b_ref)
|
|
is_last_step = pl.program_id(2) == grid_k - 1
|
|
@pl.when(is_last_step)
|
|
def _epilogue():
|
|
o_ref[...] = acc_ref[...].astype(dtype)
|
|
plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(m, k), dtype=dtype)
|
|
b = jax.random.uniform(key2, shape=(k, n), dtype=dtype)
|
|
|
|
lhs_spec = pl.BlockSpec(
|
|
(tile_m, tile_k),
|
|
lambda m, n, k: (m, k),
|
|
)
|
|
rhs_spec = pl.BlockSpec(
|
|
(tile_k, tile_n),
|
|
lambda m, n, k: (k, n),
|
|
)
|
|
out_spec = pl.BlockSpec(
|
|
(tile_m, tile_n),
|
|
lambda m, n, k: (m, n),
|
|
)
|
|
|
|
if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane:
|
|
lhs_spec = plgpu.GPUBlockSpec(
|
|
lhs_spec.block_shape,
|
|
lhs_spec.index_map,
|
|
transforms=(
|
|
plgpu.TilingTransform((8, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
)
|
|
rhs_spec = plgpu.GPUBlockSpec(
|
|
rhs_spec.block_shape,
|
|
rhs_spec.index_map,
|
|
transforms=(
|
|
plgpu.TilingTransform((8, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
)
|
|
out_spec = plgpu.GPUBlockSpec(
|
|
out_spec.block_shape,
|
|
out_spec.index_map,
|
|
transforms=(
|
|
plgpu.TilingTransform((8, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
)
|
|
|
|
res = self.pallas_call(
|
|
kernel,
|
|
in_specs=[lhs_spec, rhs_spec],
|
|
out_specs=out_spec,
|
|
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
|
|
scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)],
|
|
grid=(grid_m, grid_n, grid_k),
|
|
compiler_params=plgpu.GPUCompilerParams(
|
|
dimension_semantics=["parallel", "parallel", "sequential"],
|
|
max_concurrent_steps=2,
|
|
delay_release=1,
|
|
),
|
|
)(a, b)
|
|
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
|
|
|
|
@parameterized.parameters(jnp.float16, jnp.float32)
|
|
def test_wgmma(self, dtype):
|
|
self.skip_if_wg_semantics()
|
|
|
|
# TensorCores can only fuse transposes of 16-bit values, and RHS
|
|
# is expected to be column major by default.
|
|
rhs_transpose = jnp.dtype(dtype).itemsize != 2
|
|
swizzle = 128
|
|
elems_128b = swizzle // jnp.dtype(dtype).itemsize
|
|
def kernel(a_ref, b_ref, o_ref):
|
|
if rhs_transpose:
|
|
b_ref = plgpu.transpose_ref(b_ref, (1, 0))
|
|
def scope(acc_ref):
|
|
plgpu.wgmma(acc_ref, a_ref, b_ref)
|
|
return acc_ref[...]
|
|
|
|
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype)
|
|
b_shape = (128, 192)
|
|
if rhs_transpose:
|
|
b_shape = b_shape[::-1]
|
|
b = jax.random.uniform(key2, shape=b_shape, dtype=dtype)
|
|
|
|
rhs_transforms = (plgpu.TilingTransform((8, elems_128b)),)
|
|
res = self.pallas_call(
|
|
kernel,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(
|
|
(64, 128),
|
|
lambda i, j: (i, j),
|
|
transforms=(
|
|
plgpu.TilingTransform((8, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
),
|
|
plgpu.GPUBlockSpec(
|
|
b_shape,
|
|
lambda *i: i,
|
|
transforms=(*rhs_transforms, plgpu.SwizzleTransform(128)),
|
|
),
|
|
],
|
|
out_specs=plgpu.GPUBlockSpec((64, 192), lambda *i: i),
|
|
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
|
|
grid=(1, 1),
|
|
)(a, b)
|
|
np.testing.assert_allclose(
|
|
res, a @ (b.T if rhs_transpose else b), rtol=1e-3
|
|
)
|
|
|
|
def test_wgmma_registers(self):
|
|
def kernel(a_ref, b_ref, o_ref):
|
|
def scope(acc_ref):
|
|
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
|
|
return acc_ref[...]
|
|
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
|
|
b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16)
|
|
|
|
transforms = ()
|
|
if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane:
|
|
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
|
|
res = self.pallas_call(
|
|
kernel,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(transforms=transforms),
|
|
plgpu.GPUBlockSpec(transforms=transforms),
|
|
],
|
|
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
|
|
)(a, b)
|
|
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
|
|
|
|
def test_wgmma_registers_init(self):
|
|
# ``pl.run_state`` is not supported in WG semantics.
|
|
self.skip_if_wg_semantics()
|
|
|
|
def kernel(a_ref, b_ref, i_ref, o_ref):
|
|
def scope(acc_ref):
|
|
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
|
|
o_ref[...] = pl.run_state(scope)(plgpu.ACC.init(i_ref[...]))
|
|
|
|
key1, key2, key3 = jax.random.split(jax.random.key(42), 3)
|
|
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
|
|
b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16)
|
|
i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10
|
|
|
|
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
|
|
res = self.pallas_call(
|
|
kernel,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(transforms=transforms),
|
|
plgpu.GPUBlockSpec(transforms=transforms),
|
|
plgpu.GPUBlockSpec(transforms=transforms),
|
|
],
|
|
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16),
|
|
)(a, b, i)
|
|
np.testing.assert_allclose(res, i + a @ b, rtol=2e-3)
|
|
|
|
def test_wgmma_sliced_ref(self):
|
|
self.skip_if_wg_semantics() # Needs WGMMA to support slices.
|
|
|
|
def kernel(a_ref, b_ref, o_ref):
|
|
def scope(acc_ref):
|
|
plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0])
|
|
return acc_ref[...]
|
|
|
|
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(2, 64, 128), dtype=jnp.float16)
|
|
b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16)
|
|
|
|
transforms = ()
|
|
if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane:
|
|
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
|
|
|
|
res = self.pallas_call(
|
|
kernel,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(transforms=transforms),
|
|
plgpu.GPUBlockSpec(transforms=transforms),
|
|
],
|
|
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
|
|
)(a, b)
|
|
np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3)
|
|
|
|
def test_wgmma_sliced_acc(self):
|
|
self.skip_if_wg_semantics() # Needs WGMMA to support slices.
|
|
|
|
swizzle = 128
|
|
elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize
|
|
def kernel(a_ref, b_ref, o_ref):
|
|
def scope(acc_ref):
|
|
plgpu.wgmma(acc_ref, a_ref, b_ref)
|
|
return acc_ref[:, :64], acc_ref[:, 64:]
|
|
|
|
o_ref[:, :64], o_ref[:, 64:] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32))
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
|
|
b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16)
|
|
transforms = ()
|
|
if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane:
|
|
transforms = (
|
|
plgpu.TilingTransform((8, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
)
|
|
res = self.pallas_call(
|
|
kernel,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(
|
|
(64, 128), lambda *ij: ij, transforms=transforms
|
|
),
|
|
plgpu.GPUBlockSpec(
|
|
(128, 128), lambda *ij: ij, transforms=transforms
|
|
),
|
|
],
|
|
out_specs=plgpu.GPUBlockSpec((64, 128), lambda *ij: ij),
|
|
out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32),
|
|
grid=(1, 1),
|
|
)(a, b)
|
|
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
|
|
|
|
|
|
class PallasCallSm90AWGTest(
|
|
PallasCallSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup
|
|
):
|
|
...
|
|
|
|
|
|
class PallasCallSm100ATest(PallasSm100ATest):
|
|
|
|
def test_tmem_alloc(self):
|
|
|
|
@functools.partial(
|
|
self.kernel,
|
|
out_shape=jnp.zeros((128, 128), jnp.float32),
|
|
scratch_shapes=[
|
|
plgpu.TMEM((128, 128), jnp.float32),
|
|
plgpu.SMEM((128, 128), jnp.float32),
|
|
],
|
|
num_threads=1,
|
|
thread_name="x",
|
|
)
|
|
def kernel(y_ref, tmem_ref, smem_ref):
|
|
# Issue a write so the TMEM load is not DCE'd.
|
|
smem_ref[...] = tmem_ref[...]
|
|
plgpu.commit_smem()
|
|
plgpu.copy_smem_to_gmem(smem_ref, y_ref)
|
|
plgpu.wait_smem_to_gmem(0)
|
|
|
|
# Test that this runs without errors.
|
|
jax.block_until_ready(kernel())
|
|
|
|
|
|
class PallasCallSm100AWGTest(
|
|
PallasCallSm100ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup
|
|
):
|
|
...
|
|
|
|
|
|
class PipelineTest(PallasTest):
|
|
|
|
def test_pipeline_mode(self):
|
|
def body(x_ref, y_ref, o_ref):
|
|
x = x_ref[:]
|
|
y = y_ref[:]
|
|
o_ref[:] = x + y
|
|
|
|
data_size = 64 * 256
|
|
block_size = 256
|
|
|
|
x = jnp.arange(data_size, dtype=jnp.float32)
|
|
y = jnp.arange(data_size, dtype=jnp.float32)
|
|
in_specs = [
|
|
pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(2)),
|
|
pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(1))
|
|
]
|
|
out_specs = pl.BlockSpec((block_size,), lambda *i: i)
|
|
|
|
@jax.jit
|
|
def vadd(x, y):
|
|
return self.pallas_call(
|
|
body,
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32),
|
|
in_specs=in_specs,
|
|
out_specs=out_specs,
|
|
grid=data_size // block_size,
|
|
)(x, y)
|
|
|
|
with self.assertRaisesRegex(Exception, "Pipeline mode is not supported"):
|
|
vadd(x, y)
|
|
|
|
def test_manual(self):
|
|
max_concurrent_steps = 2
|
|
num_steps = 4
|
|
|
|
def kernel(x_gmem, o_gmem):
|
|
return pl.run_scoped(
|
|
functools.partial(scoped_kernel, x_gmem, o_gmem),
|
|
plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32),
|
|
plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32),
|
|
plgpu.Barrier(1, num_barriers=max_concurrent_steps),
|
|
)
|
|
|
|
def scoped_kernel(x_gmem, o_gmem, x_smem, o_smem, barrier):
|
|
gmem_slice = pl.ds(pl.program_id(0) * 32, 32)
|
|
|
|
def body(step, _):
|
|
slot = step % max_concurrent_steps
|
|
|
|
# Wait for the current GMEM->SMEM copy to complete.
|
|
plgpu.barrier_wait(barrier.at[slot])
|
|
# Wait for the previous output SMEM->GMEM copy to complete.
|
|
plgpu.wait_smem_to_gmem(max_concurrent_steps - 1)
|
|
|
|
o_smem.at[slot][...] = x_smem.at[slot][...] + 1.0
|
|
|
|
plgpu.commit_smem()
|
|
plgpu.copy_smem_to_gmem(
|
|
o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)]
|
|
)
|
|
|
|
fetch_step = step + max_concurrent_steps
|
|
fetch_slot = slot # (x + y) % y == x % y
|
|
jax.lax.cond(
|
|
fetch_step < num_steps,
|
|
lambda: plgpu.copy_gmem_to_smem(
|
|
x_gmem.at[gmem_slice, pl.ds(fetch_step * 16, 16)],
|
|
x_smem.at[fetch_slot],
|
|
barrier.at[fetch_slot],
|
|
),
|
|
lambda: None,
|
|
)
|
|
return ()
|
|
|
|
# Initialize the pipeline.
|
|
for slot in range(min(max_concurrent_steps, num_steps)):
|
|
plgpu.copy_gmem_to_smem(
|
|
x_gmem.at[gmem_slice, pl.ds(slot * 16, 16)],
|
|
x_smem.at[slot],
|
|
barrier.at[slot],
|
|
)
|
|
|
|
jax.lax.fori_loop(0, num_steps, body, ())
|
|
|
|
# Finalize the pipeline.
|
|
plgpu.wait_smem_to_gmem(0)
|
|
|
|
x = jnp.arange(32 * 4 * 64).reshape(32 * 4, 64).astype(jnp.float32)
|
|
kernel_fn = self.pallas_call(
|
|
kernel,
|
|
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
grid=(4, 1),
|
|
)
|
|
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
|
|
|
|
@parameterized.parameters(
|
|
((),),
|
|
((plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)),),
|
|
)
|
|
def test_emit(self, transforms):
|
|
if transforms:
|
|
self.skip_if_wg_semantics()
|
|
|
|
num_steps = 4
|
|
|
|
def kernel(x_gmem, o_gmem):
|
|
plgpu.emit_pipeline(
|
|
kernel_body,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(
|
|
(64, 64), lambda i: (0, i), transforms=transforms
|
|
)
|
|
],
|
|
out_specs=[
|
|
plgpu.GPUBlockSpec(
|
|
(64, 64), lambda i: (0, i), transforms=transforms
|
|
)
|
|
],
|
|
grid=(num_steps,),
|
|
max_concurrent_steps=2,
|
|
)(x_gmem, o_gmem)
|
|
|
|
def kernel_body(_, x_smem, o_smem):
|
|
# +1 for the indexing done by ``emit_pipeline`.
|
|
self.assertLen(x_smem.transforms, len(transforms) + 1)
|
|
o_smem[...] = x_smem[...] + 1.0
|
|
|
|
x = jnp.arange(64 * num_steps * 64)
|
|
x = x.reshape(-1, num_steps * 64).astype(jnp.float32)
|
|
kernel_fn = self.pallas_call(
|
|
kernel,
|
|
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
)
|
|
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
|
|
|
|
def test_nested_emit(self):
|
|
num_steps = 4
|
|
|
|
def kernel(x_gmem, o_gmem):
|
|
plgpu.emit_pipeline(
|
|
nested_kernel,
|
|
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
|
out_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
|
grid=(),
|
|
)(x_gmem, o_gmem)
|
|
|
|
def nested_kernel(_, x_gmem, o_gmem):
|
|
plgpu.emit_pipeline(
|
|
nested_kernel_body,
|
|
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
|
|
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
|
|
grid=(num_steps,),
|
|
max_concurrent_steps=2,
|
|
)(x_gmem, o_gmem)
|
|
|
|
def nested_kernel_body(_, x_smem, o_smem):
|
|
o_smem[...] = x_smem[...] + 1.0
|
|
|
|
x = jnp.arange(32 * num_steps * 16)
|
|
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
|
|
kernel_fn = self.pallas_call(
|
|
kernel,
|
|
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
)
|
|
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
|
|
|
|
def test_emit_with_grid_invariant_output(self):
|
|
num_steps = 4
|
|
|
|
def kernel(x_gmem, o_gmem):
|
|
plgpu.emit_pipeline(
|
|
kernel_body,
|
|
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
|
|
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, 0))],
|
|
grid=(num_steps,),
|
|
max_concurrent_steps=2,
|
|
)(x_gmem, o_gmem)
|
|
|
|
def kernel_body(_, x_smem, o_smem):
|
|
o_smem[...] = x_smem[...] + 1.0
|
|
|
|
x = jnp.arange(32 * num_steps * 16)
|
|
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
|
|
kernel_fn = self.pallas_call(
|
|
kernel,
|
|
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
)
|
|
y = jnp.empty_like(x)
|
|
for i in range(num_steps):
|
|
i_slice = slice(16 * i, 16 * (i + 1))
|
|
y = y.at[:, :16].set(x[:, i_slice] + 1)
|
|
# We only compare the elements in the first 16 columns, because the rest
|
|
# are never written to.
|
|
np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16])
|
|
|
|
def test_emit_with_parallel_grid(self):
|
|
num_steps1 = 4
|
|
num_steps2 = 5
|
|
|
|
def kernel(x_gmem, o_gmem):
|
|
pid = pl.program_id(0)
|
|
plgpu.emit_pipeline(
|
|
kernel_body,
|
|
in_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))],
|
|
out_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))],
|
|
grid=(num_steps2,),
|
|
max_concurrent_steps=2,
|
|
)(x_gmem, o_gmem)
|
|
|
|
def kernel_body(_, x_smem, o_smem):
|
|
o_smem[...] = x_smem[...] + 1.0
|
|
|
|
x = jnp.arange(num_steps1 * 32 * num_steps2 * 16)
|
|
x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32)
|
|
kernel_fn = self.pallas_call(
|
|
kernel,
|
|
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
grid=(num_steps1,),
|
|
)
|
|
y = x + 1.0
|
|
np.testing.assert_array_equal(kernel_fn(x), y)
|
|
|
|
def test_emit_with_2d_grid(self):
|
|
num_steps1 = 4
|
|
num_steps2 = 5
|
|
|
|
def kernel(x_gmem, o_gmem):
|
|
plgpu.emit_pipeline(
|
|
kernel_body,
|
|
in_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))],
|
|
out_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))],
|
|
grid=(num_steps1, num_steps2),
|
|
max_concurrent_steps=2,
|
|
)(x_gmem, o_gmem)
|
|
|
|
def kernel_body(_, x_smem, o_smem):
|
|
o_smem[...] = x_smem[...] + 1.0
|
|
|
|
x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8)
|
|
x = x.reshape(-1, num_steps1 * 16, num_steps2 * 8).astype(jnp.float32)
|
|
kernel_fn = self.pallas_call(
|
|
kernel,
|
|
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
)
|
|
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
|
|
|
|
|
|
class PipelineWGTest(
|
|
PipelineTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup
|
|
):
|
|
...
|
|
|
|
|
|
class PipelineSm90ATest(PallasSm90ATest):
|
|
|
|
def test_realistic_matmul(self):
|
|
self.skip_if_wg_semantics() # Needs WGMMA to support slices.
|
|
|
|
dtype = jnp.float16
|
|
swizzle = 128
|
|
elems_128b = swizzle // jnp.dtype(dtype).itemsize
|
|
grid_m, grid_k, grid_n = 132, 10, 4
|
|
tile_m = tile_n = 128
|
|
assert tile_m % elems_128b == 0
|
|
tile_k = elems_128b
|
|
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
|
|
|
|
transforms = ()
|
|
if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane:
|
|
transforms = (
|
|
plgpu.TilingTransform((8, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
)
|
|
|
|
def kernel(a_gmem, b_gmem, o_smem, acc):
|
|
def kernel_body(_, a_smem, b_smem):
|
|
assert a_smem.shape == (tile_m, tile_k)
|
|
assert b_smem.shape == (tile_k, tile_n)
|
|
plgpu.wgmma(acc, a_smem, b_smem)
|
|
plgpu.wgmma_wait(1)
|
|
|
|
pid_m = pl.program_id(0)
|
|
pid_n = pl.program_id(1)
|
|
plgpu.emit_pipeline(
|
|
kernel_body,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(
|
|
(tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms
|
|
),
|
|
plgpu.GPUBlockSpec(
|
|
(tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms
|
|
),
|
|
],
|
|
grid=(grid_k,),
|
|
max_concurrent_steps=2,
|
|
delay_release=1,
|
|
)(a_gmem, b_gmem)
|
|
|
|
o_smem[...] = acc[...].astype(dtype)
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(m, k), dtype=dtype)
|
|
b = jax.random.uniform(key2, shape=(k, n), dtype=dtype)
|
|
|
|
res = self.pallas_call(
|
|
kernel,
|
|
in_specs=[
|
|
pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
],
|
|
out_specs=plgpu.GPUBlockSpec(
|
|
(tile_m, tile_n), lambda m, n: (m, n), transforms=transforms
|
|
),
|
|
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
|
|
scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)],
|
|
grid=(grid_m, grid_n),
|
|
)(a, b)
|
|
np.testing.assert_array_equal(res, a @ b)
|
|
|
|
|
|
class PipelineSm90AWGTest(
|
|
PipelineSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup
|
|
):
|
|
...
|
|
|
|
|
|
class WarpSpecializedPipelineTest(PallasTest):
|
|
|
|
@parameterized.product(m=[512], n=[512],
|
|
manual_consumed_barriers=[False, True])
|
|
def test_pipelined_copy(self, m, n, manual_consumed_barriers):
|
|
self.skip_if_wg_semantics() # Times out!
|
|
|
|
x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16)
|
|
blk_m = blk_n = 64
|
|
|
|
def copy_kernel(_, x_smem, o_smem, o_last_block_smem, *consumed_barriers):
|
|
# TODO(justinfu): Have each wg compute a separate slice
|
|
# after multiple-indexers are supported.
|
|
# This is currently a race, but the values written are the same.
|
|
o_smem[...] = x_smem[...]
|
|
o_last_block_smem[...] = x_smem[...]
|
|
if manual_consumed_barriers:
|
|
[x_barrier] = consumed_barriers
|
|
plgpu.barrier_arrive(x_barrier)
|
|
|
|
spec = pl.BlockSpec(
|
|
block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j)
|
|
)
|
|
pipeline = mgpu_pipeline.emit_pipeline_warp_specialized(
|
|
copy_kernel,
|
|
grid=(m // blk_m, n // blk_n),
|
|
memory_registers=40,
|
|
max_concurrent_steps=2,
|
|
num_compute_wgs=2,
|
|
wg_axis="wg",
|
|
manual_consumed_barriers=manual_consumed_barriers,
|
|
in_specs=[spec],
|
|
out_specs=[
|
|
spec,
|
|
# Create an index-invariant output.
|
|
pl.BlockSpec(
|
|
block_shape=(blk_m, blk_n), index_map=lambda i, j: (0, 0)
|
|
),
|
|
],
|
|
)
|
|
kernel = self.kernel(
|
|
pipeline,
|
|
out_shape=(
|
|
jax.ShapeDtypeStruct((m, n), jnp.float16),
|
|
jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float16),
|
|
),
|
|
compiler_params=plgpu.GPUCompilerParams(approx_math=True),
|
|
grid=(1,),
|
|
grid_names=("_",),
|
|
num_threads=3,
|
|
thread_name="wg",
|
|
)
|
|
out, out_last_block = kernel(x)
|
|
np.testing.assert_array_equal(out, x)
|
|
np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:])
|
|
|
|
def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2):
|
|
self.skip_if_wg_semantics() # Crashes!
|
|
|
|
blk_m = blk_n = 64
|
|
spec = pl.BlockSpec(
|
|
block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j)
|
|
)
|
|
|
|
def tiled_add_kernel(_, x_smem, y_smem, o_smem):
|
|
# TODO(justinfu): Have each wg compute a separate slice
|
|
# after multiple-indexers are supported.
|
|
# This is currently a race, but the values written are the same.
|
|
o_smem[...] = x_smem[...] + y_smem[...]
|
|
|
|
pipeline = mgpu_pipeline.emit_pipeline_warp_specialized(
|
|
tiled_add_kernel,
|
|
grid=(m // blk_m, n // blk_n),
|
|
max_concurrent_steps=2,
|
|
num_compute_wgs=num_compute_wgs,
|
|
memory_registers=40,
|
|
wg_axis="wg",
|
|
in_specs=[spec, spec],
|
|
out_specs=[spec],
|
|
)
|
|
kernel = self.kernel(
|
|
pipeline,
|
|
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
|
|
compiler_params=plgpu.GPUCompilerParams(approx_math=True),
|
|
grid=(1,),
|
|
grid_names=("_",),
|
|
num_threads=num_compute_wgs + 1,
|
|
thread_name="wg",
|
|
)
|
|
x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32)
|
|
y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32)
|
|
np.testing.assert_allclose(kernel(x, y), x + y, atol=1e-4)
|
|
|
|
def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2):
|
|
self.skip_if_wg_semantics() # `plgpu.layout_cast` is not supported.
|
|
|
|
blk_m = blk_n = 64
|
|
|
|
@functools.partial(
|
|
self.kernel,
|
|
out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32),
|
|
scratch_shapes=[
|
|
plgpu.SMEM((blk_m, blk_n), jnp.float32),
|
|
],
|
|
compiler_params=plgpu.GPUCompilerParams(approx_math=True),
|
|
grid=(1,),
|
|
grid_names=("_",),
|
|
num_threads=num_compute_wgs + 1,
|
|
thread_name="wg",
|
|
)
|
|
def kernel(x_gmem, acc_gmem, acc_smem):
|
|
def _compute_thread():
|
|
# Cast the init value to the same layout as x_smem, so the pipeline loop
|
|
# carry has a constant signature.
|
|
o_acc = plgpu.layout_cast(
|
|
jnp.full((blk_m, blk_n,), 0, dtype=jnp.float32),
|
|
plgpu.Layout.WG_STRIDED((blk_m, blk_n), vec_size=2))
|
|
carry_init = (o_acc,)
|
|
# Pass control to the pipeline emitter and return the final carry.
|
|
final_carry = (yield carry_init)
|
|
o_final, = final_carry
|
|
# Note that both compute WGs are doing identical work so the potential
|
|
# race condition on the store here won't affect the result.
|
|
acc_smem[...] = o_final
|
|
plgpu.commit_smem()
|
|
plgpu.copy_smem_to_gmem(acc_smem, acc_gmem)
|
|
plgpu.wait_smem_to_gmem(0)
|
|
|
|
def tiled_acc_kernel(_, x_smem, carry):
|
|
o_carry, = carry
|
|
new_carry = x_smem[...] + o_carry
|
|
return (new_carry,)
|
|
|
|
pipeline = mgpu_pipeline.emit_pipeline_warp_specialized(
|
|
tiled_acc_kernel,
|
|
grid=(m // blk_m, n // blk_n),
|
|
max_concurrent_steps=2,
|
|
num_compute_wgs=num_compute_wgs,
|
|
memory_registers=40,
|
|
wg_axis="wg",
|
|
carry_coroutine=_compute_thread,
|
|
in_specs=[
|
|
pl.BlockSpec(
|
|
block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j)
|
|
)
|
|
],
|
|
out_specs=[],
|
|
)
|
|
pipeline(x_gmem)
|
|
|
|
x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32)
|
|
ref = jnp.sum(jnp.stack(np.split(x, m // blk_m, axis=0)), axis=0)
|
|
ref = jnp.sum(jnp.stack(np.split(ref, n // blk_n, axis=1)), axis=0)
|
|
np.testing.assert_allclose(kernel(x), ref, atol=1e-4)
|
|
|
|
|
|
class WarpSpecializedPipelineWGTest(
|
|
WarpSpecializedPipelineTest,
|
|
thread_semantics=plgpu.ThreadSemantics.Warpgroup,
|
|
):
|
|
...
|
|
|
|
|
|
class CoreMapTest(PallasTest):
|
|
|
|
def test_multiple_wg(self):
|
|
|
|
@functools.partial(
|
|
self.kernel,
|
|
out_shape=jnp.zeros((2, 128), np.int32),
|
|
num_threads=2,
|
|
thread_name="wg",
|
|
)
|
|
def kernel(o_ref):
|
|
wg_idx = jax.lax.axis_index("wg")
|
|
o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,))
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(), np.repeat(np.arange(2), 128).reshape(2, 128)
|
|
)
|
|
|
|
def test_multiple_wg_with_grid(self):
|
|
|
|
@functools.partial(
|
|
self.kernel,
|
|
out_shape=jnp.zeros((4, 2, 128), np.int32),
|
|
grid=(2, 2),
|
|
grid_names=("x", "y"),
|
|
num_threads=2,
|
|
thread_name="wg",
|
|
)
|
|
def kernel(o_ref):
|
|
xy_idx = jax.lax.axis_index(("x", "y"))
|
|
yx_idx = jax.lax.axis_index(("y", "x"))
|
|
wg_idx = jax.lax.axis_index("wg")
|
|
num_wgs = jax.lax.psum(1, "wg")
|
|
o_ref[xy_idx, wg_idx] = jnp.broadcast_to(
|
|
yx_idx * num_wgs + wg_idx, (128,)
|
|
)
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128)
|
|
)
|
|
|
|
def test_multiple_wg_with_squashed_grid(self):
|
|
# Tests whether a grid with >3 logical dimensions is correctly squashed to
|
|
# 3 CUDA grid dimensions.
|
|
b = 4
|
|
x_dim = 3
|
|
y_dim = 5
|
|
z_dim = 7
|
|
num_threads = 2
|
|
|
|
@functools.partial(
|
|
self.kernel,
|
|
out_shape=jnp.zeros(
|
|
(b, x_dim, y_dim, z_dim, num_threads, 128), np.int32
|
|
),
|
|
grid=(b, x_dim, y_dim, z_dim),
|
|
grid_names=("b", "x", "y", "z"),
|
|
num_threads=num_threads,
|
|
thread_name="wg",
|
|
)
|
|
def kernel(o_ref):
|
|
b_idx = jax.lax.axis_index("b")
|
|
x_idx = jax.lax.axis_index("x")
|
|
y_idx = jax.lax.axis_index("y")
|
|
z_idx = jax.lax.axis_index("z")
|
|
wg_idx = jax.lax.axis_index("wg")
|
|
bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg"))
|
|
o_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to(
|
|
bxyzw_idx, (128,)
|
|
)
|
|
|
|
result = kernel()[:, :, :, :, :, 0]
|
|
ref = np.arange(b * x_dim * y_dim * z_dim * num_threads).reshape(
|
|
result.shape
|
|
)
|
|
np.testing.assert_array_equal(result, ref)
|
|
|
|
def test_cross_wg_barrier(self):
|
|
self.skip_if_wg_semantics() # Times out!
|
|
|
|
@functools.partial(
|
|
self.kernel,
|
|
out_shape=jnp.zeros((2, 128), np.int32),
|
|
# Each warpgroup is a single logical thread!
|
|
scratch_shapes=[plgpu.Barrier(num_arrivals=2)],
|
|
num_threads=2,
|
|
thread_name="wg",
|
|
)
|
|
def kernel(o_ref, barrier):
|
|
plgpu.barrier_arrive(barrier)
|
|
plgpu.barrier_wait(barrier)
|
|
wg_idx = jax.lax.axis_index("wg")
|
|
o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,))
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(), np.repeat([0, 1], 128).reshape(2, 128)
|
|
)
|
|
|
|
def test_cluster(self):
|
|
self.skip_if_wg_semantics() # Needs debug_print in the MGPU dialect.
|
|
|
|
@functools.partial(
|
|
self.kernel,
|
|
out_shape=jnp.zeros(128, np.int32),
|
|
grid=(2,),
|
|
grid_names=("x",),
|
|
cluster=(2,),
|
|
cluster_names=("cluster",),
|
|
)
|
|
def kernel(ref):
|
|
block_idx = jax.lax.axis_index("x")
|
|
cluster_idx = jax.lax.axis_index("cluster")
|
|
pl.debug_print("block: {} cluster: {}", block_idx, cluster_idx)
|
|
|
|
ref[...] = ref[...]
|
|
|
|
with self.capture_stdout() as output:
|
|
jax.block_until_ready(kernel())
|
|
self.assertEqual(
|
|
set(output().splitlines()),
|
|
{
|
|
"block: 0 cluster: 0",
|
|
"block: 1 cluster: 0",
|
|
"block: 0 cluster: 1",
|
|
"block: 1 cluster: 1",
|
|
},
|
|
)
|
|
|
|
def test_realistic_matmul_with_cluster(self):
|
|
self.skip_if_wg_semantics() # Needs WGMMA to support slices.
|
|
|
|
dtype = jnp.float16
|
|
swizzle = 128
|
|
elems_128b = swizzle // jnp.dtype(dtype).itemsize
|
|
grid_m, grid_k, grid_n = 132, 10, 32
|
|
# TODO(slebedev): Remove ``grid_tile_n`` to simplify the test.
|
|
grid_tile_n = 4
|
|
assert grid_n % grid_tile_n == 0
|
|
cluster_m = 2
|
|
cluster_n = 2
|
|
cluster_tile_n = min(cluster_n, grid_tile_n)
|
|
tile_m = tile_n = 128
|
|
assert tile_m % elems_128b == 0
|
|
tile_k = elems_128b
|
|
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
|
|
|
|
transforms = (
|
|
plgpu.TilingTransform((8, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
)
|
|
|
|
max_concurrent_steps = 2
|
|
delay_release = 1
|
|
|
|
@functools.partial(
|
|
self.kernel,
|
|
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
|
|
scratch_shapes=[
|
|
plgpu.SMEM(
|
|
(max_concurrent_steps, tile_m, tile_k),
|
|
dtype,
|
|
transforms=transforms,
|
|
),
|
|
plgpu.SMEM(
|
|
(max_concurrent_steps, tile_k, tile_n),
|
|
dtype,
|
|
transforms=transforms,
|
|
),
|
|
plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms),
|
|
plgpu.ACC((tile_m, tile_n), jnp.float32),
|
|
plgpu.Barrier(num_arrivals=2, num_barriers=max_concurrent_steps),
|
|
plgpu.ClusterBarrier(
|
|
collective_axes=(("x", "z"), "y"),
|
|
num_barriers=max_concurrent_steps,
|
|
),
|
|
],
|
|
grid=(grid_tile_n, grid_m, grid_n // grid_tile_n),
|
|
grid_names=("tile_n", "m", "n"),
|
|
cluster=(cluster_tile_n, cluster_m, cluster_n // cluster_tile_n),
|
|
cluster_names=("x", "y", "z"),
|
|
)
|
|
def kernel(
|
|
a_gmem,
|
|
b_gmem,
|
|
o_gmem,
|
|
a_smem,
|
|
b_smem,
|
|
o_smem,
|
|
acc,
|
|
barrier,
|
|
cluster_barrier,
|
|
):
|
|
m_slice = pl.ds(lax.axis_index("m") * tile_m, tile_m)
|
|
n_slice = pl.ds(
|
|
(lax.axis_index("tile_n") + lax.axis_index("n") * grid_tile_n)
|
|
* tile_n,
|
|
tile_n,
|
|
)
|
|
|
|
def fetch(step, slot):
|
|
if not isinstance(slot, int): # Skip in initialization.
|
|
plgpu.barrier_arrive(cluster_barrier.at[slot])
|
|
plgpu.barrier_wait(cluster_barrier.at[slot])
|
|
|
|
k_slice = pl.ds(step * tile_k, tile_k)
|
|
plgpu.copy_gmem_to_smem(
|
|
a_gmem.at[m_slice, k_slice],
|
|
a_smem.at[slot],
|
|
barrier.at[slot],
|
|
collective_axes=("x", "z"),
|
|
)
|
|
plgpu.copy_gmem_to_smem(
|
|
b_gmem.at[k_slice, n_slice],
|
|
b_smem.at[slot],
|
|
barrier.at[slot],
|
|
collective_axes="y",
|
|
)
|
|
|
|
# Initialize the pipeline.
|
|
for slot in range(min(max_concurrent_steps, grid_k)):
|
|
fetch(slot, slot)
|
|
|
|
def body(step, _):
|
|
slot = step % max_concurrent_steps
|
|
plgpu.barrier_wait(barrier.at[slot])
|
|
|
|
plgpu.wgmma(acc, a_smem.at[slot], b_smem.at[slot])
|
|
plgpu.wgmma_wait(delay_release)
|
|
|
|
fetch_step = step + (max_concurrent_steps - delay_release)
|
|
fetch_slot = lax.rem(fetch_step, max_concurrent_steps)
|
|
jax.lax.cond(
|
|
lax.bitwise_and(step >= delay_release, fetch_step < grid_k),
|
|
lambda: fetch(fetch_step, fetch_slot),
|
|
lambda: None,
|
|
)
|
|
return ()
|
|
|
|
jax.lax.fori_loop(0, grid_k, body, ())
|
|
|
|
# Finalize the pipeline.
|
|
o_smem[...] = acc[...].astype(dtype)
|
|
plgpu.commit_smem()
|
|
plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])
|
|
plgpu.wait_smem_to_gmem(0)
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(m, k), dtype=dtype)
|
|
b = jax.random.uniform(key2, shape=(k, n), dtype=dtype)
|
|
np.testing.assert_array_equal(kernel(a, b), a @ b)
|
|
|
|
|
|
class CoreMapWGTest(
|
|
CoreMapTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup
|
|
):
|
|
...
|
|
|
|
|
|
class ExamplesTest(PallasTest):
|
|
|
|
# Basic
|
|
def test_stage0(self):
|
|
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
|
|
|
|
@functools.partial(self.kernel, out_shape=x)
|
|
def kernel(l_ref, r_ref, o_ref):
|
|
o_ref[...] = l_ref[...] + r_ref[...]
|
|
|
|
np.testing.assert_allclose(kernel(x, x), x + x)
|
|
|
|
# Multi-block kernels
|
|
def test_stage1(self):
|
|
row_block = 64
|
|
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
|
|
|
|
@functools.partial(
|
|
self.kernel, out_shape=x, grid=(2,), grid_names=("rows",)
|
|
)
|
|
def kernel(l_ref, r_ref, o_ref):
|
|
my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block)
|
|
o_ref[my_slice] = l_ref[my_slice] + r_ref[my_slice]
|
|
|
|
np.testing.assert_allclose(kernel(x, x), x + x)
|
|
|
|
# Async copies
|
|
def test_stage3(self):
|
|
row_block, col_block = 64, 128
|
|
|
|
@functools.partial(
|
|
self.kernel,
|
|
out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16),
|
|
scratch_shapes=[
|
|
*([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3),
|
|
plgpu.Barrier(num_arrivals=2),
|
|
],
|
|
grid=(2,),
|
|
grid_names=("rows",),
|
|
)
|
|
def kernel(l_ref, r_ref, o_ref, l_smem, r_smem, o_smem, barrier):
|
|
my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block)
|
|
plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier)
|
|
plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier)
|
|
plgpu.barrier_wait(barrier)
|
|
o_smem[...] = l_smem[...] + r_smem[...]
|
|
plgpu.commit_smem()
|
|
plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice])
|
|
plgpu.wait_smem_to_gmem(0)
|
|
|
|
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
|
|
np.testing.assert_allclose(kernel(x, x), x + x)
|
|
|
|
# Pipelining
|
|
def test_stage4(self):
|
|
row_block, col_block = 64, 32
|
|
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
|
|
|
|
@functools.partial(
|
|
self.kernel, out_shape=x, grid=(2,), grid_names=("rows",)
|
|
)
|
|
def kernel(l_ref, r_ref, o_ref):
|
|
def compute(_, l_smem, r_smem, o_smem):
|
|
o_smem[...] = l_smem[...] + r_smem[...]
|
|
r = lax.axis_index("rows")
|
|
block = pl.BlockSpec((row_block, col_block), lambda c: (r, c))
|
|
plgpu.emit_pipeline(
|
|
compute,
|
|
grid=(l_ref.shape[1] // col_block,),
|
|
in_specs=[block] * 2,
|
|
out_specs=[block],
|
|
)(l_ref, r_ref, o_ref)
|
|
|
|
np.testing.assert_allclose(kernel(x, x), x + x)
|
|
|
|
# Transforms
|
|
def test_stage5(self):
|
|
self.skip_if_wg_semantics() # Needs WGMMA to support slices.
|
|
|
|
row_block, col_block = 64, 32
|
|
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
|
|
|
|
@functools.partial(
|
|
self.kernel, out_shape=x, grid=(2,), grid_names=("rows",)
|
|
)
|
|
def kernel(l_ref, r_ref, o_ref):
|
|
def compute(_, l_smem, r_smem, o_smem):
|
|
o_smem[...] = l_smem[...] + r_smem[...]
|
|
r = lax.axis_index("rows")
|
|
block = plgpu.GPUBlockSpec(
|
|
(row_block, col_block), lambda c: (r, c),
|
|
transforms=(plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)),
|
|
)
|
|
plgpu.emit_pipeline(
|
|
compute,
|
|
grid=(l_ref.shape[1] // col_block,),
|
|
in_specs=[block] * 2,
|
|
out_specs=[block],
|
|
)(l_ref, r_ref, o_ref)
|
|
|
|
np.testing.assert_allclose(kernel(x, x), x + x)
|
|
|
|
def test_semaphore_lowering(self):
|
|
# This is a smoke test until we add support for lowering of semaphore ops.
|
|
def body(i_ref1, i_ref2, o_ref, sem_ref):
|
|
del i_ref2 # Only here to have a different number of inputs and outputs.
|
|
assert sem_ref.shape == (4,)
|
|
assert jnp.issubdtype(sem_ref.dtype, pl.semaphore)
|
|
o_ref[...] = i_ref1[...]
|
|
x = jnp.arange(128, dtype=jnp.float32).reshape((128,))
|
|
kernel = self.pallas_call(
|
|
body,
|
|
out_shape=x,
|
|
scratch_shapes=[plgpu.SemaphoreType.REGULAR((4,))],
|
|
)
|
|
text = jax.jit(kernel).lower(x, x).as_text()
|
|
self.assertIn(
|
|
r"output_operand_aliases ="
|
|
r" [#stablehlo.output_operand_alias<output_tuple_indices = [1],"
|
|
r" operand_index = 2, operand_tuple_indices = []>]",
|
|
text,
|
|
)
|
|
self.assertIn(
|
|
r"(tensor<128xf32>, tensor<128xf32>, tensor<4xi32>) ->"
|
|
r" (tensor<128xf32>, tensor<4xi32>)",
|
|
text,
|
|
)
|
|
|
|
|
|
class ExamplesWGTest(
|
|
ExamplesTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup
|
|
):
|
|
...
|
|
|
|
|
|
class ExamplesSm90ATest(PallasSm90ATest):
|
|
|
|
# WGMMA
|
|
def test_stage6(self):
|
|
self.skip_if_wg_semantics() # Needs WGMMA to support slices.
|
|
|
|
m_block = n_block = 64
|
|
k_block = 32
|
|
x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
|
|
|
|
@functools.partial(
|
|
self.kernel, out_shape=x, grid=(2, 2), grid_names=("m", "n")
|
|
)
|
|
def kernel(l_ref, r_ref, o_ref):
|
|
def compute(_, l_smem, r_smem, o_smem):
|
|
def do_wgmma(acc_ref):
|
|
plgpu.wgmma(acc_ref, l_smem, r_smem)
|
|
return acc_ref[...]
|
|
o_smem[...] += pl.run_scoped(do_wgmma, plgpu.ACC((m_block, n_block), jnp.float16))
|
|
m = lax.axis_index("m")
|
|
n = lax.axis_index("n")
|
|
lo_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64))
|
|
r_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64))
|
|
plgpu.emit_pipeline(
|
|
compute,
|
|
grid=(l_ref.shape[1] // k_block,),
|
|
in_specs=[plgpu.GPUBlockSpec((m_block, k_block), lambda k: (m, k), transforms=lo_transforms),
|
|
plgpu.GPUBlockSpec((k_block, n_block), lambda k: (k, n), transforms=r_transforms)],
|
|
out_specs=[plgpu.GPUBlockSpec((m_block, n_block), lambda k: (m, n), transforms=lo_transforms)],
|
|
)(l_ref, r_ref, o_ref)
|
|
|
|
np.testing.assert_allclose(kernel(x, x), x @ x)
|
|
|
|
# TODO(apaszke): Clusters and multicast
|
|
|
|
|
|
class ExamplesSm90AWGTest(
|
|
ExamplesSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup
|
|
):
|
|
...
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main()
|