Move some backwards compatibility tests from jax_triton to jax/pallas.

While doing this I moved `matmul.py` to `jax/experimental/pallas/ops/tpu`

PiperOrigin-RevId: 660341331
This commit is contained in:
George Necula 2024-08-07 04:59:19 -07:00 committed by jax authors
parent 28ca734d9b
commit 3e5e947542
6 changed files with 574 additions and 7 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,85 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example matmul TPU kernel.
See discussion in https://jax.readthedocs.io/en/latest/pallas/tpu/matmul.html.
"""
import functools
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
def matmul_kernel(x_tile_ref, y_tile_ref, o_tile_ref, acc_ref):
@pl.when(pl.program_id(2) == 0)
def init():
acc_ref[...] = jnp.zeros_like(acc_ref)
acc_ref[...] = acc_ref[...] + jnp.dot(
x_tile_ref[...],
y_tile_ref[...],
preferred_element_type=acc_ref.dtype,
)
# It is possible to make this conditional but in general this bundle packs
# quite well for a simple matmul kernel
o_tile_ref[...] = acc_ref[...].astype(o_tile_ref.dtype)
@functools.partial(
jax.jit, static_argnames=["block_shape", "block_k", "debug", "out_dtype"]
)
def matmul(
x: jax.Array,
y: jax.Array,
*,
block_shape,
block_k: int = 256,
out_dtype: jnp.dtype | None = None,
debug: bool = False,
) -> jax.Array:
if out_dtype is None:
if x.dtype != y.dtype:
# TODO(tlongeri): Maybe we could use a deduction similar to jnp.dot
raise TypeError(
f"Cannot deduce output dtype for different input dtypes: {x.dtype},"
f" {y.dtype}"
)
out_dtype = x.dtype
acc_dtype = jnp.float32
if x.dtype in [jnp.int8, jnp.int4, jnp.uint8, jnp.uint4]:
acc_dtype = jnp.int32
l, r = block_shape
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), out_dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((l, block_k), lambda i, _, k: (i, k)),
pl.BlockSpec((block_k, r), lambda _, j, k: (k, j)),
],
out_specs=pl.BlockSpec((l, r), lambda i, j, k: (i, j)),
grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k),
scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)],
),
compiler_params=dict(
mosaic=dict(dimension_semantics=("parallel", "parallel", "arbitrary"))
),
debug=debug,
)(x, y)

View File

@ -203,6 +203,7 @@ jax_test(
"//jax:internal_export_back_compat_test_util",
"//jax:pallas",
"//jax:pallas_gpu", # build_cleaner: keep
"//jax:pallas_tpu_ops", # build_cleaner: keep
],
)

View File

@ -17,15 +17,21 @@ See the export_back_compat_test_util module docstring for how to setup and
update these tests.
"""
from absl.testing import absltest
import math
from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.internal_test_util import export_back_compat_test_util as bctu
from jax._src.internal_test_util.export_back_compat_test_data.pallas import cuda_add_one
from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_matmul
from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_semaphore_dma
from jax._src.internal_test_util.export_back_compat_test_data.pallas import triton_add_one
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas.ops.tpu import matmul
import jax.numpy as jnp
config.parse_flags_with_absl()
@ -36,14 +42,12 @@ class CompatTest(bctu.CompatTestBase):
def setUp(self):
if jax.config.x64_enabled:
self.skipTest("Only works in 32-bit")
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Only works on GPU")
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on GPUs with capability >= sm80")
super().setUp()
def test_cuda_add_one(self):
def test_triton_add_one(self):
def func(x):
def add_one(x_ref, o_ref):
o_ref[0] = x_ref[0] + 1
@ -52,10 +56,53 @@ class CompatTest(bctu.CompatTestBase):
in_specs=[pl.BlockSpec((1,), lambda i: i)],
out_specs=pl.BlockSpec((1,), lambda i: i),
grid=8)(x)
data = self.load_testdata(cuda_add_one.data_2024_05_02)
data = self.load_testdata(triton_add_one.data_2024_05_02)
self.run_one_test(func, data)
@jax.default_matmul_precision("bfloat16")
def test_mosaic_matmul(self):
dtype = jnp.float32
def func():
# Build the inputs here, to reduce the size of the golden inputs.
x_shape = (1024, 512)
bias = 1.0
scale = 1e-3
x = bias + scale * jnp.arange(
math.prod(x_shape), dtype=dtype).reshape(x_shape)
y = x[:512, :256]
res = matmul.matmul(x, y, block_shape=(256, 256))
# Keep only slices of the output, to reduce the size of the goldens.
return res[::16, ::16]
data = self.load_testdata(mosaic_matmul.data_2023_09_22)
self.run_one_test(func, data, rtol=2e-7)
def test_mosaic_semaphore_dma(self):
if not (jtu.test_device_matches(["tpu"]) and
jtu.is_device_tpu_at_least(4)):
# TODO: crashes during compilation on TPU v4
self.skipTest("Only works on TPU v5+")
# The signatures of TPU ops for semaphore and DMA have changed.
# This test ensures that the new signatures are backwards compatible.
def func():
def dma_kernel(x, y):
def body(dma_sem, sem):
pltpu.async_copy(x, y, dma_sem).wait()
pltpu.semaphore_signal(sem)
pltpu.semaphore_wait(sem)
pl.run_scoped(
body, pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.REGULAR
)
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
y = pl.pallas_call(dma_kernel, out_shape=x)(x)
return jnp.array_equal(x, y).astype(jnp.float32)
data = self.load_testdata(
mosaic_semaphore_dma.semaphore_and_dma_2024_04_22)
self.run_one_test(func, data)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())