rocm_jax/tests/pallas/mosaic_gpu_test.py
Adam Paszke 2db03ba54b [Pallas:MGPU] Add support for grid dims in GPUMesh
Of course no communication can happen across grid dimensions (unlike over the WG dim),
but we need to be able to launch multiple blocks somehow.

PiperOrigin-RevId: 688488660
2024-10-22 04:10:46 -07:00

1053 lines
34 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
import re
import traceback
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax.experimental import pallas as pl
from jax.experimental.pallas import mosaic_gpu as plgpu
import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
class PallasTest(jtu.JaxTestCase):
def setUp(self):
if config.enable_x64.value:
self.skipTest("Only works on x32 at the moment")
if not jtu.is_cuda_compute_capability_at_least("9.0"):
self.skipTest("Only works on a GPU with capability >= sm90")
super().setUp()
class PallasCallTest(PallasTest):
@parameterized.named_parameters(
("add_one", lambda x: x + 1.),
("logistic", jax.lax.logistic),
("exp", jax.lax.exp),
("square", lambda x: x ** 2),
("rsqrt", jax.lax.rsqrt),
)
def test_unary_ops(self, unary):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = unary(x_ref[...])
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), unary(x))
def test_add_first(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[0]
x = jnp.arange(256).astype(jnp.float32)
y = jnp.flip(x).reshape(1, 256)
np.testing.assert_array_equal(kernel(x, y), x + y[0])
def test_add_xy(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[...]
x = jnp.arange(256).astype(jnp.float32)
y = x + 1
np.testing.assert_array_equal(kernel(x, y), x + y)
def test_add_xy_indexed(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
idx = jnp.sum(y_ref[...])
o_ref[...] = x_ref[idx]
x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
y = jnp.zeros(128, dtype=jnp.int32)
np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)])
def test_add_one_grid(self):
@functools.partial(
pl.pallas_call,
in_specs=[pl.BlockSpec((128,), lambda *i: i)],
out_specs=pl.BlockSpec((128,), lambda *i: i),
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32),
grid=2,
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1.0
x = jnp.arange(128 * 2).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
def test_add_one_grid_with_scratch(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32),
in_specs=[pl.BlockSpec((128,), lambda *i: i)],
out_specs=pl.BlockSpec((128,), lambda *i: i),
scratch_shapes=[plgpu.SMEM((128,), jnp.float32)],
grid=2,
)
def kernel(x_ref, o_ref, scratch_ref):
scratch_ref[...] = x_ref[...] + 1
o_ref[...] = scratch_ref[...]
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
@parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16])
def test_add_one_grid_pipelined(self, max_concurrent_steps):
@functools.partial(
pl.pallas_call,
in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))],
out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)),
out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
max_concurrent_steps=max_concurrent_steps,
),
grid=(2, 4),
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1.0
x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
def test_add_one_grid_pipelined_program_id(self):
@functools.partial(
pl.pallas_call,
out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)),
out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
max_concurrent_steps=2,
),
grid=(4, 4),
)
def kernel(o_ref):
o_ref[...] = jnp.broadcast_to(pl.program_id(1), o_ref.shape)
np.testing.assert_array_equal(
kernel(),
jnp.repeat(jnp.repeat(jnp.arange(4), 16)[None], 16, axis=0),
)
def test_add_one_grid_pipelined_sequential_invariant_output(self):
@functools.partial(
pl.pallas_call,
in_specs=[pl.BlockSpec((32, 16), lambda i, j: (i, j))],
out_specs=pl.BlockSpec((32, 16), lambda i, j: (i, 0)),
out_shape=jax.ShapeDtypeStruct([32 * 2, 64], jnp.float32),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
max_concurrent_steps=2,
),
grid=(2, 4),
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1.0
x = jnp.arange(32 * 2 * 64).reshape((32 * 2, 64)).astype(jnp.float32)
y = jnp.empty_like(x)
for i in range(2):
i_slice = slice(32 * i, 32 * (i + 1))
for j in range(4):
j_slice = slice(16 * j, 16 * (j + 1))
y = y.at[i_slice, :16].set(x[i_slice, j_slice] + 1)
# We only compare the elements in the first 16 columns, because the rest
# are never written to.
np.testing.assert_array_equal(kernel(x)[:, :16], y[:, :16])
@parameterized.product(indexer=[..., slice(128), slice(None, 128)])
def test_copy_smem_to_gmem(self, indexer):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
scratch_shapes=[plgpu.SMEM((256,), jnp.float32)],
)
def kernel(x_ref, o_ref_gmem, scratch_ref):
scratch_ref[...] = x_ref[...] + 1
plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref_gmem.at[indexer])
plgpu.wait_smem_to_gmem(0)
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)
@parameterized.product(indexer=[..., slice(128), slice(None, 128)])
def test_copy_gmem_to_smem(self, indexer):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
scratch_shapes=[
plgpu.SMEM((256,), jnp.float32),
plgpu.Barrier(num_arrivals=1),
],
)
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
plgpu.copy_gmem_to_smem(
x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier=barrier_ref
)
plgpu.barrier_wait(barrier_ref)
o_ref[...] = scratch_ref[...] + 1
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)
@parameterized.product(indexer=[0, 1, 2, 3])
def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
scratch_shapes=[
plgpu.SMEM((128,), jnp.float32),
plgpu.Barrier(num_arrivals=1, num_barriers=4),
],
)
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
plgpu.copy_gmem_to_smem(
x_ref_gmem, scratch_ref, barrier=barrier_ref.at[indexer]
)
plgpu.barrier_wait(barrier_ref.at[indexer])
o_ref[...] = scratch_ref[...] + 1
x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
@parameterized.named_parameters(("_g2s", False), ("_s2g", True))
def test_copy_with_transforms(self, to_smem):
def kernel(x_ref, o_ref, barrier_ref):
if to_smem:
plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier=barrier_ref)
plgpu.barrier_wait(barrier_ref)
else:
plgpu.copy_smem_to_gmem(x_ref, o_ref)
plgpu.wait_smem_to_gmem(0)
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
out_spec = plgpu.GPUBlockSpec(
(128, 128),
lambda: (0, 0),
transforms=(
plgpu.TilingTransform((64, 32)),
plgpu.SwizzleTransform(128),
),
memory_space=plgpu.SMEM,
)
if not to_smem:
in_spec, out_spec = out_spec, in_spec
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
in_specs=(in_spec,),
out_specs=out_spec,
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
)
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), x)
def test_scoped_copy_with_transforms(self):
ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))
def kernel(x_ref, o_ref, barrier_ref):
def body(tmp_ref):
plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier=barrier_ref)
plgpu.barrier_wait(barrier_ref)
o_ref[...] = tmp_ref[...] * 2
pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts))
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
out_spec = plgpu.GPUBlockSpec(
(128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM,
)
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
in_specs=(in_spec,),
out_specs=out_spec,
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
)
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), x * 2)
def test_copy_with_transforms_and_indexing(self):
def kernel(x_ref, o_ref, barrier_ref):
for i in range(2):
plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier=barrier_ref)
plgpu.barrier_wait(barrier_ref)
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
out_spec = plgpu.GPUBlockSpec(
(2, 128, 128),
lambda: (0, 0, 0),
transforms=(
plgpu.TilingTransform((64, 32)),
plgpu.TransposeTransform((0, 2, 1, 3, 4)),
plgpu.SwizzleTransform(128),
),
memory_space=plgpu.SMEM,
)
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct([2, 128, 128], jnp.float32),
in_specs=(in_spec,),
out_specs=out_spec,
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
)
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0))
def test_indexing_before_transpose(self):
def kernel(x_ref, o_ref, barrier_ref):
for i in range(2):
plgpu.copy_gmem_to_smem(
x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier=barrier_ref
)
plgpu.barrier_wait(barrier_ref)
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
out_spec = plgpu.GPUBlockSpec(
(2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM,
)
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32),
in_specs=(in_spec,),
out_specs=out_spec,
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
)
x = jnp.arange(2 * 64 * 128, dtype=jnp.float32).reshape(2, 64, 128)
xt = x.transpose((1, 0, 2))
np.testing.assert_array_equal(f(x), np.stack([xt, xt], axis=0))
def test_copy_gmem_to_smem_in_run_scoped(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
)
def kernel(x_ref_gmem, o_ref):
def body(barrier_ref):
def inner_body(scratch_ref):
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier=barrier_ref)
plgpu.barrier_wait(barrier_ref)
o_ref[...] = scratch_ref[...] + 1
pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32))
pl.run_scoped(body, plgpu.Barrier(num_arrivals=1))
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
def test_add_doubled_sum(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + jnp.sum(x_ref[...]) + jnp.sum(x_ref[...])
x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + x.sum()*2)
@parameterized.parameters(False, True)
def test_rsqrt(self, approx_math):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math),
)
def kernel(x_ref, o_ref):
o_ref[...] = jax.lax.rsqrt(x_ref[...])
x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_allclose(kernel(x), jax.lax.rsqrt(x))
@parameterized.product(input_factor=[0.001, 1, 10, 100, 100])
def test_layer_norm(self, input_factor):
eps = 1e-5
gamma = 1.0
beta = 1.0
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def layer_norm(x_ref, o_ref):
x_mean = jnp.mean(x_ref[...])
x_centered = x_ref[...] - x_mean
o_ref[...] = (
x_centered * jax.lax.rsqrt(jnp.mean(x_centered**2) + eps) * gamma
+ beta
)
def layer_norm_np(x):
x_mean = np.mean(x)
x_centered = x - x_mean
return (x_centered / np.sqrt(np.mean(x_centered**2) + eps) * gamma) + beta
# Ones are always fully precise
x = jnp.ones((256,)).astype(jnp.float32) * input_factor
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x))
# random (and anything else is not)
x = (
jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32)
* input_factor
)
# TODO(cperivol): find out why in this particular case we have a small-ish error.
rtol = 1e-07 if input_factor > 10 else 5e-5
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=rtol)
def test_print(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, o_ref):
del x_ref, o_ref
pl.debug_print("It works!")
x = jnp.arange(256).astype(jnp.float32)
with jtu.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertEqual(output(), "It works!\n")
def test_print_wgmma_tiled_layout(self):
shape = (128, 64)
size = math.prod(shape)
def kernel(x_ref, o_ref):
pl.debug_print("{}", x_ref[...])
spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)))
x = jnp.arange(size, dtype=jnp.float32).reshape(shape)
f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec)
with jtu.capture_stdout() as get_output:
jax.block_until_ready(f(x))
output = get_output()
results = re.findall(r"\[(\d+), (\d+)\]/\[128, 64\]: (\d+)", output)
self.assertLen(results, size)
for i, j, v in results:
i, j, v = map(int, (i, j, v))
self.assertEqual(v, i * shape[1] + j)
def test_print_scalar(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
)
def kernel(x_ref, o_ref):
del o_ref
pl.debug_print("x.sum() = {}", x_ref[...].sum())
x = jnp.arange(256)
with jtu.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertIn(f"x.sum() = {x.sum()}", output())
def test_print_scalar_array(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
)
def kernel(x_ref, o_ref):
del o_ref
pl.debug_print("x.sum() = {}", x_ref[...].sum() + 1)
x = jnp.arange(256)
with jtu.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertIn(f"x.sum() = {x.sum() + 1}", output())
def test_print_array(self):
in_shape = [2, 1, 64, 64]
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(in_shape, jnp.int32),
)
def kernel(x_ref, o_ref):
del o_ref
pl.debug_print("x: {}", x_ref[...])
x = jnp.arange(math.prod(in_shape)).reshape(in_shape)
with jtu.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output())
def test_run_scoped(self):
def kernel(x_ref, o_ref):
def body(tmp_ref):
self.assertEqual(tmp_ref.shape, (8, 128))
tmp_ref[...] = x_ref[...] + 1.0
return tmp_ref[...]
tmp = pl.run_scoped(body, plgpu.SMEM((8, 128), jnp.float32))
self.assertEqual(tmp.shape, (8, 128))
o_ref[...] = tmp
inp = np.ones((8, 128))
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)
o = f(inp)
np.testing.assert_array_equal(o, inp + 1.0)
def test_program_id(self):
@functools.partial(
pl.pallas_call,
in_specs=(),
out_specs=pl.BlockSpec((128,), lambda *i: i),
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
grid=2,
)
def kernel(o_ref):
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
np.testing.assert_array_equal(
kernel(),
jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32),
)
def test_program_id_in_block_spec(self):
@functools.partial(
pl.pallas_call,
out_specs=pl.BlockSpec((128,), lambda *_: pl.program_id(0)),
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
grid=2,
)
def kernel(o_ref):
del o_ref
# ``assertRaises`` have no way of asserting against the cause, so we
# have to use ``traceback.format_exception`` manually.
with self.assertRaises(Exception) as exc_info:
kernel()
self.assertIn(
"not supported in this context",
"".join(traceback.format_exception(exc_info.exception)),
)
def test_num_programs(self):
@functools.partial(
pl.pallas_call,
in_specs=(),
out_specs=pl.BlockSpec((128,), lambda *i: i),
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
grid=2,
)
def kernel(o_ref):
o_ref[...] = jnp.full(o_ref.shape, pl.num_programs(0))
np.testing.assert_array_equal(
kernel(),
jnp.full([256], 2, dtype=jnp.int32),
)
def test_swizzled_blockspec_shapes(self):
spec = plgpu.GPUBlockSpec(
(128, 64),
lambda *i: i,
transforms=(
plgpu.TilingTransform((64, 64)),
plgpu.SwizzleTransform(128),
),
)
@functools.partial(
pl.pallas_call,
in_specs=[spec],
out_specs=spec,
out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16),
grid=(2, 2),
)
def kernel(x_ref, o_ref):
assert x_ref.shape == (128, 64), x_ref.shape
o_ref[...] = x_ref[...]
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
np.testing.assert_array_equal(kernel(x), x)
def test_fori_loop_array(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, o_ref):
# Equivalent to x_ref[...] + 2 + 3.
o_ref[...] = jax.lax.fori_loop(2, 4, lambda i, x: x + i, x_ref[...])
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 2.0 + 3.0)
def test_fori_loop_scalar(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(o_ref):
# Equivalent to 2 + 3.
o_ref[...] = jax.lax.broadcast(
jax.lax.fori_loop(2, 4, lambda i, x: x + i, 0.0), o_ref.shape
)
np.testing.assert_array_equal(
kernel(), jnp.full([256], 5.0, dtype=jnp.float32)
)
def test_fori_loop_indexed_store(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
def body(idx, _):
o_ref[idx] = x_ref[idx] + y_ref[idx]
return ()
jax.lax.fori_loop(0, 4, body, ())
x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
y = x + 1
np.testing.assert_array_equal(kernel(x, y), x + y)
def test_cond(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
)
def kernel(x_ref, o_ref):
acc = x_ref[...].sum()
jax.lax.cond(
acc % 2 == 0,
lambda: pl.debug_print("acc * 2: {}", acc * 2),
lambda: pl.debug_print("acc: {}", acc),
)
o_ref[...] = jnp.broadcast_to(acc, o_ref.shape)
x = jnp.arange(256)
with jtu.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertIn("acc * 2:", output())
@parameterized.parameters(jnp.float16, jnp.float32)
def test_wgmma(self, dtype):
# TensorCores can only fuse transposes of 16-bit values, and RHS
# is expected to be column major by default.
rhs_transpose = jnp.dtype(dtype).itemsize != 2
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
def kernel(a_ref, b_ref, o_ref):
if rhs_transpose:
b_ref = plgpu.transpose_ref(b_ref, (1, 0))
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref)
return acc_ref[...]
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype)
b_shape = (128, 192)
if rhs_transpose:
b_shape = b_shape[::-1]
b = jax.random.uniform(key2, shape=b_shape, dtype=dtype)
rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),)
if rhs_transpose:
rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),)
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(64, 128),
lambda i, j: (i, j),
transforms=(
plgpu.TilingTransform((64, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
plgpu.GPUBlockSpec(
b_shape,
lambda *i: i,
transforms=(*rhs_transforms, plgpu.SwizzleTransform(128)),
),
],
out_specs=plgpu.GPUBlockSpec((64, 192), lambda *i: i),
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
grid=(1, 1),
)(a, b)
np.testing.assert_allclose(
res, a @ (b.T if rhs_transpose else b), rtol=1e-3
)
def test_wgmma_registers(self):
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
return acc_ref[...]
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16)
transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128))
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms),
plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms),
],
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
def test_wgmma_sliced_ref(self):
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0])
return acc_ref[...]
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(2, 64, 128), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16)
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(2, 64, 128), lambda: (0, 0, 0),
transforms=(
plgpu.TilingTransform((64, 64)),
plgpu.SwizzleTransform(128),
),
),
plgpu.GPUBlockSpec(
(2, 128, 192), lambda: (0, 0, 0),
transforms=(
plgpu.TilingTransform((64, 64)),
plgpu.SwizzleTransform(128),
),
),
],
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
)(a, b)
np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3)
def test_wgmma_sliced_acc(self):
swizzle = 128
elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref)
return acc_ref[:, :64], acc_ref[:, 64:]
o_ref[:, :64], o_ref[:, 64:] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32))
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16)
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(64, 128),
lambda i, j: (i, j),
transforms=(
plgpu.TilingTransform((64, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
plgpu.GPUBlockSpec(
(128, 128),
lambda *i: i,
transforms=(
plgpu.TilingTransform((elems_128b, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
],
out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i),
out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32),
grid=(1, 1),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
def test_input_output_aliases(self):
# Note that we're writing to the input pointer, which should alias b_ptr.
def kernel(a_ref, b_ref):
del b_ref
a_ref[...] = jnp.ones_like(a_ref)
a = np.zeros((64, 64), dtype=jnp.float32)
b = pl.pallas_call(
kernel,
in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)],
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM),
input_output_aliases={0: 0},
out_shape=a,
)(a)
np.testing.assert_array_equal(b, np.ones_like(a))
def test_realistic_matmul(self):
dtype = jnp.float16
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
grid_m, grid_k, grid_n = 132, 10, 4
tile_m = tile_n = 128
tile_k = elems_128b
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
def kernel(a_ref, b_ref, o_ref, acc_ref):
# Make sure tiling does not alter the shape of references
assert a_ref.shape == (tile_m, tile_k)
assert b_ref.shape == (tile_k, tile_n)
assert o_ref.shape == acc_ref.shape == (tile_m, tile_n)
plgpu.wgmma(acc_ref, a_ref, b_ref)
is_last_step = pl.program_id(2) == grid_k - 1
@pl.when(is_last_step)
def _epilogue():
o_ref[...] = acc_ref[...].astype(dtype)
plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=dtype)
b = jax.random.uniform(key2, shape=(k, n), dtype=dtype)
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(tile_m, tile_k),
lambda m, n, k: (m, k),
transforms=(
plgpu.TilingTransform((64, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
plgpu.GPUBlockSpec(
(tile_k, tile_n),
lambda m, n, k: (k, n),
transforms=(
plgpu.TilingTransform((elems_128b, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
],
out_specs=plgpu.GPUBlockSpec(
(tile_m, tile_n),
lambda m, n, k: (m, n),
transforms=(
plgpu.TilingTransform((64, elems_128b)),
plgpu.SwizzleTransform(128),
),
),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)],
grid=(grid_m, grid_n, grid_k),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "parallel", "sequential"],
max_concurrent_steps=2,
delay_release=1,
),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
def test_slicing(self):
left = upper = slice(None, 64)
right = lower = slice(64, None)
# We rotate the four quadrants of the input clockwise.
def rotate(src, dst):
dst[upper, left] = src[lower, left]
dst[upper, right] = src[upper, left]
dst[lower, right] = src[upper, right]
dst[lower, left] = src[lower, right]
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
spec = plgpu.GPUBlockSpec(
(128, 128),
lambda: (0, 0),
transforms=(
plgpu.TilingTransform((64, 64)),
plgpu.SwizzleTransform(128),
),
)
f = pl.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec)
expected = np.empty_like(x)
rotate(x, expected)
np.testing.assert_array_equal(f(x), expected)
def test_layout_cast(self, shape=(256, 64)):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
)
def kernel(o_ref):
o_ref[...] = plgpu.layout_cast(jnp.full(shape, 42.0), plgpu.Layout.WGMMA)
x = jnp.full(shape, 42.0)
np.testing.assert_array_equal(kernel(), x)
class PipelineTest(PallasTest):
def test_manual(self, max_concurrent_steps=2, num_steps=4):
def kernel(x_gmem, o_gmem):
return pl.run_scoped(
functools.partial(scoped_kernel, x_gmem, o_gmem),
plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32),
plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32),
plgpu.Barrier(1, num_barriers=max_concurrent_steps),
)
def scoped_kernel(x_gmem, o_gmem, x_smem, o_smem, barrier):
gmem_slice = pl.ds(pl.program_id(0) * 32, 32)
def body(step, _):
slot = step % max_concurrent_steps
# Wait for the current GMEM->SMEM copy to complete.
plgpu.barrier_wait(barrier.at[slot])
# Wait for the previous output SMEM->GMEM copy to complete.
plgpu.wait_smem_to_gmem(max_concurrent_steps - 1)
o_smem[...] = x_smem[...] + 1.0
plgpu.copy_smem_to_gmem(
o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)]
)
fetch_step = step + max_concurrent_steps
fetch_slot = slot # (x + y) % y == x % y
jax.lax.cond(
fetch_step < num_steps,
lambda: plgpu.copy_gmem_to_smem(
x_gmem.at[gmem_slice, pl.ds(fetch_step * 16, 16)],
x_smem.at[fetch_slot],
barrier=barrier.at[fetch_slot],
),
lambda: None,
)
return ()
# Initialize the pipeline.
for slot in range(min(max_concurrent_steps, num_steps)):
plgpu.copy_gmem_to_smem(
x_gmem.at[gmem_slice, pl.ds(slot * 16, 16)],
x_smem.at[slot],
barrier=barrier.at[slot],
)
jax.lax.fori_loop(0, num_steps, body, ())
# Finalize the pipeline.
plgpu.wait_smem_to_gmem(0)
x = jnp.arange(32 * 4 * 64).reshape(32 * 4, 64).astype(jnp.float32)
kernel_fn = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid=(4, 1),
)
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
class CoreMapTest(PallasTest):
def test_multiple_wg(self):
mesh = plgpu.GPUMesh(num_threads=2, axis_names=("y",))
@jax.jit
def f():
@pl.run_state
def inner(y_ref):
@pl.core_map(mesh)
def kernel():
wg_idx = jax.lax.axis_index("y")
y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,))
y_init = jnp.zeros((2, 128), np.int32)
return inner(y_init)
np.testing.assert_array_equal(
f(), np.repeat(np.arange(2), 128).reshape(2, 128)
)
def test_multiple_wg_with_grid(self):
mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg"))
@jax.jit
def f():
@pl.run_state
def inner(y_ref):
@pl.core_map(mesh)
def kernel():
xy_idx = jax.lax.axis_index(("x", "y"))
yx_idx = jax.lax.axis_index(("y", "x"))
wg_idx = jax.lax.axis_index("wg")
num_wgs = jax.lax.psum(1, "wg")
y_ref[xy_idx, wg_idx] = jnp.broadcast_to(
yx_idx * num_wgs + wg_idx, (128,)
)
y_init = jnp.zeros((4, 2, 128), np.int32)
return inner(y_init)
np.testing.assert_array_equal(
f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128)
)
if __name__ == "__main__":
absltest.main()