mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Move some backwards compatibility tests from jax_triton to jax/pallas.
While doing this I moved `matmul.py` to `jax/experimental/pallas/ops/tpu` PiperOrigin-RevId: 660341331
This commit is contained in:
parent
28ca734d9b
commit
3e5e947542
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
85
jax/experimental/pallas/ops/tpu/matmul.py
Normal file
85
jax/experimental/pallas/ops/tpu/matmul.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Example matmul TPU kernel.
|
||||
|
||||
See discussion in https://jax.readthedocs.io/en/latest/pallas/tpu/matmul.html.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def matmul_kernel(x_tile_ref, y_tile_ref, o_tile_ref, acc_ref):
|
||||
@pl.when(pl.program_id(2) == 0)
|
||||
def init():
|
||||
acc_ref[...] = jnp.zeros_like(acc_ref)
|
||||
|
||||
acc_ref[...] = acc_ref[...] + jnp.dot(
|
||||
x_tile_ref[...],
|
||||
y_tile_ref[...],
|
||||
preferred_element_type=acc_ref.dtype,
|
||||
)
|
||||
# It is possible to make this conditional but in general this bundle packs
|
||||
# quite well for a simple matmul kernel
|
||||
o_tile_ref[...] = acc_ref[...].astype(o_tile_ref.dtype)
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit, static_argnames=["block_shape", "block_k", "debug", "out_dtype"]
|
||||
)
|
||||
def matmul(
|
||||
x: jax.Array,
|
||||
y: jax.Array,
|
||||
*,
|
||||
block_shape,
|
||||
block_k: int = 256,
|
||||
out_dtype: jnp.dtype | None = None,
|
||||
debug: bool = False,
|
||||
) -> jax.Array:
|
||||
if out_dtype is None:
|
||||
if x.dtype != y.dtype:
|
||||
# TODO(tlongeri): Maybe we could use a deduction similar to jnp.dot
|
||||
raise TypeError(
|
||||
f"Cannot deduce output dtype for different input dtypes: {x.dtype},"
|
||||
f" {y.dtype}"
|
||||
)
|
||||
out_dtype = x.dtype
|
||||
acc_dtype = jnp.float32
|
||||
if x.dtype in [jnp.int8, jnp.int4, jnp.uint8, jnp.uint4]:
|
||||
acc_dtype = jnp.int32
|
||||
|
||||
l, r = block_shape
|
||||
return pl.pallas_call(
|
||||
matmul_kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), out_dtype),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[
|
||||
pl.BlockSpec((l, block_k), lambda i, _, k: (i, k)),
|
||||
pl.BlockSpec((block_k, r), lambda _, j, k: (k, j)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((l, r), lambda i, j, k: (i, j)),
|
||||
grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k),
|
||||
scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)],
|
||||
),
|
||||
compiler_params=dict(
|
||||
mosaic=dict(dimension_semantics=("parallel", "parallel", "arbitrary"))
|
||||
),
|
||||
debug=debug,
|
||||
)(x, y)
|
@ -203,6 +203,7 @@ jax_test(
|
||||
"//jax:internal_export_back_compat_test_util",
|
||||
"//jax:pallas",
|
||||
"//jax:pallas_gpu", # build_cleaner: keep
|
||||
"//jax:pallas_tpu_ops", # build_cleaner: keep
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,15 +17,21 @@ See the export_back_compat_test_util module docstring for how to setup and
|
||||
update these tests.
|
||||
"""
|
||||
|
||||
from absl.testing import absltest
|
||||
import math
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.internal_test_util import export_back_compat_test_util as bctu
|
||||
from jax._src.internal_test_util.export_back_compat_test_data.pallas import cuda_add_one
|
||||
from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_matmul
|
||||
from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_semaphore_dma
|
||||
from jax._src.internal_test_util.export_back_compat_test_data.pallas import triton_add_one
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
from jax.experimental.pallas.ops.tpu import matmul
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -36,14 +42,12 @@ class CompatTest(bctu.CompatTestBase):
|
||||
def setUp(self):
|
||||
if jax.config.x64_enabled:
|
||||
self.skipTest("Only works in 32-bit")
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Only works on GPU")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
not jtu.is_cuda_compute_capability_at_least("8.0")):
|
||||
self.skipTest("Only works on GPUs with capability >= sm80")
|
||||
super().setUp()
|
||||
|
||||
def test_cuda_add_one(self):
|
||||
def test_triton_add_one(self):
|
||||
def func(x):
|
||||
def add_one(x_ref, o_ref):
|
||||
o_ref[0] = x_ref[0] + 1
|
||||
@ -52,10 +56,53 @@ class CompatTest(bctu.CompatTestBase):
|
||||
in_specs=[pl.BlockSpec((1,), lambda i: i)],
|
||||
out_specs=pl.BlockSpec((1,), lambda i: i),
|
||||
grid=8)(x)
|
||||
data = self.load_testdata(cuda_add_one.data_2024_05_02)
|
||||
data = self.load_testdata(triton_add_one.data_2024_05_02)
|
||||
|
||||
self.run_one_test(func, data)
|
||||
|
||||
@jax.default_matmul_precision("bfloat16")
|
||||
def test_mosaic_matmul(self):
|
||||
dtype = jnp.float32
|
||||
def func():
|
||||
# Build the inputs here, to reduce the size of the golden inputs.
|
||||
x_shape = (1024, 512)
|
||||
bias = 1.0
|
||||
scale = 1e-3
|
||||
x = bias + scale * jnp.arange(
|
||||
math.prod(x_shape), dtype=dtype).reshape(x_shape)
|
||||
y = x[:512, :256]
|
||||
res = matmul.matmul(x, y, block_shape=(256, 256))
|
||||
# Keep only slices of the output, to reduce the size of the goldens.
|
||||
return res[::16, ::16]
|
||||
|
||||
data = self.load_testdata(mosaic_matmul.data_2023_09_22)
|
||||
self.run_one_test(func, data, rtol=2e-7)
|
||||
|
||||
def test_mosaic_semaphore_dma(self):
|
||||
if not (jtu.test_device_matches(["tpu"]) and
|
||||
jtu.is_device_tpu_at_least(4)):
|
||||
# TODO: crashes during compilation on TPU v4
|
||||
self.skipTest("Only works on TPU v5+")
|
||||
|
||||
# The signatures of TPU ops for semaphore and DMA have changed.
|
||||
# This test ensures that the new signatures are backwards compatible.
|
||||
def func():
|
||||
def dma_kernel(x, y):
|
||||
def body(dma_sem, sem):
|
||||
pltpu.async_copy(x, y, dma_sem).wait()
|
||||
pltpu.semaphore_signal(sem)
|
||||
pltpu.semaphore_wait(sem)
|
||||
pl.run_scoped(
|
||||
body, pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.REGULAR
|
||||
)
|
||||
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
|
||||
y = pl.pallas_call(dma_kernel, out_shape=x)(x)
|
||||
return jnp.array_equal(x, y).astype(jnp.float32)
|
||||
|
||||
data = self.load_testdata(
|
||||
mosaic_semaphore_dma.semaphore_and_dma_2024_04_22)
|
||||
self.run_one_test(func, data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user