mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
480 lines
16 KiB
Python
480 lines
16 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for TPU specific operations within pallas_call."""
|
|
|
|
import functools
|
|
import math
|
|
import sys
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import test_util as jtu
|
|
from jax._src.pallas import utils as pallas_utils
|
|
from jax.experimental import pallas as pl
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
if sys.platform != "win32":
|
|
from jax.experimental.pallas import tpu as pltpu
|
|
else:
|
|
pltpu = None
|
|
|
|
try:
|
|
import hypothesis as hp
|
|
except (ModuleNotFoundError, ImportError):
|
|
raise unittest.SkipTest("tests depend on hypothesis library")
|
|
|
|
import hypothesis.strategies as hps
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
jtu.setup_hypothesis(max_examples=100)
|
|
|
|
_JAX_DTYPES = (
|
|
jnp.float32,
|
|
jnp.bfloat16,
|
|
jnp.int32,
|
|
jnp.int16,
|
|
jnp.int8,
|
|
jnp.bool_,
|
|
)
|
|
|
|
|
|
class PallasBaseTest(jtu.JaxTestCase):
|
|
INTERPRET = False
|
|
|
|
def setUp(self):
|
|
if not jtu.test_device_matches(["tpu"]):
|
|
self.skipTest("Test only supported on TPU.")
|
|
|
|
super().setUp()
|
|
|
|
@classmethod
|
|
def pallas_call(cls, *args, **kwargs):
|
|
return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs)
|
|
|
|
|
|
class OpsTest(PallasBaseTest):
|
|
|
|
@parameterized.product(
|
|
from_dtype=_JAX_DTYPES, to_dtype=_JAX_DTYPES, is_ref_bitcast=[False, True]
|
|
)
|
|
def test_bitcast(self, from_dtype, to_dtype, is_ref_bitcast):
|
|
if not jtu.is_device_tpu_at_least(version=4):
|
|
self.skipTest("Run on TPUv4+ to have expected memory layout")
|
|
if from_dtype == to_dtype:
|
|
self.skipTest("No bitcast needed")
|
|
if from_dtype == jnp.bool_ or to_dtype == jnp.bool_:
|
|
self.skipTest("Bitcasting with bool is not supported")
|
|
|
|
def kernel(x_ref, y_ref):
|
|
if is_ref_bitcast:
|
|
y_ref[...] = x_ref.bitcast(to_dtype)[...]
|
|
else:
|
|
y_ref[...] = pltpu.bitcast(x_ref[...], to_dtype)
|
|
|
|
m, n = 1, 256
|
|
in_packing = 32 // pallas_utils.dtype_bitwidth(from_dtype)
|
|
out_packing = 32 // pallas_utils.dtype_bitwidth(to_dtype)
|
|
in_shape = (m * in_packing, n)
|
|
out_shape = (m * out_packing, n)
|
|
inp = np.arange(np.prod(in_shape), dtype=from_dtype).reshape(in_shape)
|
|
out = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct(out_shape, to_dtype),
|
|
)(inp)
|
|
if not self.INTERPRET:
|
|
out_interpret = pl.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct(out_shape, to_dtype),
|
|
interpret=True,
|
|
)(inp)
|
|
self.assertAllClose(out, out_interpret)
|
|
|
|
@parameterized.product(is_dynamic=(False, True))
|
|
@hp.given(
|
|
axis=hps.integers(0, 3),
|
|
shift=hps.integers(0, 3),
|
|
stride=hps.one_of(hps.just(None), hps.integers(0, 2)),
|
|
# Stride dimension on the minor most is not supported.
|
|
stride_axis=hps.one_of(hps.just(None), hps.integers(0, 2)),
|
|
)
|
|
@hp.example(3, 9, 1, 2)
|
|
@hp.example(3, 9, 2, 2)
|
|
@hp.example(0, 9, 0, 1)
|
|
@hp.example(0, 9, 1, 1)
|
|
def test_roll(self, is_dynamic, axis, shift, stride, stride_axis):
|
|
if (stride is None) != (stride_axis is None):
|
|
self.skipTest(
|
|
"Roll op requires both stride and stride_axis to be either specified"
|
|
" or not specified."
|
|
)
|
|
if (not jtu.is_device_tpu(version=5)) and stride_axis == 2:
|
|
self.skipTest(
|
|
"Roll op with stride axis on 2nd minor requires at least TPU v5"
|
|
)
|
|
shape = (4, 4, 32, 512)
|
|
|
|
def kernel(s_ref, x_ref, y_ref):
|
|
amt = s_ref[0] if is_dynamic else shift
|
|
y_ref[...] = pltpu.roll(
|
|
x_ref[...], amt, axis, stride=stride, stride_axis=stride_axis
|
|
)
|
|
|
|
def roll(x, shift, axis, stride=None, stride_axis=None):
|
|
assert (stride is None) == (stride_axis is None)
|
|
if stride is None:
|
|
return np.roll(x, shift, axis)
|
|
outputs = [
|
|
np.roll(xs, shift + i * stride, axis)
|
|
for i, xs in enumerate(np.split(x, x.shape[stride_axis], stride_axis))
|
|
]
|
|
return np.concatenate(outputs, stride_axis)
|
|
|
|
inp = np.arange(np.prod(shape), dtype=jnp.int32).reshape(shape)
|
|
ref = roll(inp, shift, axis, stride, stride_axis)
|
|
dynamic_shift = jnp.array([abs(shift)], jnp.int32)
|
|
for interpret in [False, True]:
|
|
out = pl.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct(shape, jnp.int32),
|
|
grid_spec=pltpu.PrefetchScalarGridSpec(num_scalar_prefetch=1),
|
|
interpret=interpret,
|
|
)(dynamic_shift, inp)
|
|
np.testing.assert_array_equal(out, ref, err_msg=f"{interpret=}")
|
|
|
|
def test_interleave_vectors(self):
|
|
if not jtu.is_device_tpu_at_least(version=4):
|
|
self.skipTest("Expect TPUv4+")
|
|
|
|
def kernel(x_ref, y_ref, out_ref):
|
|
x = pltpu.bitcast(x_ref[...].astype(jnp.float32), jnp.int32)
|
|
y = pltpu.bitcast(y_ref[...].astype(jnp.float32), jnp.int32)
|
|
shift = jax.lax.broadcast(16, x.shape)
|
|
out_ref[...] = pltpu.bitcast(
|
|
y | jax.lax.shift_right_logical(x, shift), jnp.bfloat16
|
|
)
|
|
|
|
m, n = 16, 128
|
|
inp = np.arange(m * n * 2, dtype=jnp.bfloat16).reshape(m, n * 2)
|
|
x, y = np.split(inp, 2, axis=1)
|
|
out = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct((m * 2, n), jnp.bfloat16),
|
|
)(x, y)
|
|
np.testing.assert_array_equal(out, inp.reshape(m * 2, n))
|
|
|
|
@parameterized.parameters([jnp.int32, jnp.int16, jnp.int8, jnp.int4])
|
|
def test_row_broadcast(self, dtype):
|
|
if not jtu.if_cloud_tpu_at_least(2025, 1, 10):
|
|
self.skipTest("Requires libtpu built after 2025-01-10")
|
|
if not self.INTERPRET and jtu.get_tpu_version() < 5:
|
|
self.skipTest("Requires TPUv5+")
|
|
def kernel(x_ref, y_ref):
|
|
y_ref[...] = jnp.broadcast_to(x_ref[pl.ds(3, 1)], y_ref.shape).astype(y_ref.dtype)
|
|
m, n = 4, 1152
|
|
x = jax.random.randint(
|
|
jax.random.key(12), (m, n), minval=-1000, maxval=1000, dtype=jnp.int32
|
|
).astype(dtype)
|
|
y = self.pallas_call(
|
|
kernel, out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32)
|
|
)(x)
|
|
np.testing.assert_array_equal(y, jnp.broadcast_to(x[3:4], y.shape))
|
|
|
|
def test_tpu_unsigned_int(self):
|
|
self.skipTest("TODO(apaszke): Unsigned upcasts were implemented incorrectly")
|
|
def body(x_ref, o_ref):
|
|
# Test cast from uint16 -> uint32
|
|
ux = lax.convert_element_type(x_ref[...], jnp.uint32)
|
|
res = ux + 1
|
|
# Test cast from uint32 -> float32
|
|
o_ref[...] = res.astype(jnp.float32)
|
|
out = jax.ShapeDtypeStruct((8, 128), jnp.float32)
|
|
x = jnp.arange(8 * 128, dtype=jnp.uint16).reshape((8, 128))
|
|
result = self.pallas_call(body, out_shape=out)(x)
|
|
np.testing.assert_array_equal(result, x.astype(jnp.float32) + 1.0)
|
|
|
|
def test_tpu_signed_int_upcast(self):
|
|
if not jtu.is_device_tpu_at_least(version=5):
|
|
self.skipTest("TPUv5+ needed for integer matmuls")
|
|
|
|
def body(x_ref, o_ref):
|
|
# Test cast from int4 -> int8
|
|
ux = lax.convert_element_type(x_ref[...], jnp.int8)
|
|
o_ref[...] = jax.lax.dot(ux, ux, preferred_element_type=jnp.int32)
|
|
|
|
out = jax.ShapeDtypeStruct((128, 128), jnp.int32)
|
|
x = jnp.arange(128 * 128, dtype=jnp.int4).reshape((128, 128))
|
|
result = self.pallas_call(body, out_shape=out)(x)
|
|
np.testing.assert_array_equal(
|
|
result,
|
|
jax.lax.dot(
|
|
x.astype(jnp.int8),
|
|
x.astype(jnp.int8),
|
|
preferred_element_type=jnp.int32,
|
|
),
|
|
)
|
|
|
|
def test_select_with_scalar_condition(self):
|
|
def kernel(cond, lhs, rhs, out):
|
|
out[:] = jax.lax.select(cond[0] != 0, lhs[:], rhs[:])
|
|
|
|
def run(cond, lhs, rhs):
|
|
return self.pallas_call(
|
|
kernel,
|
|
out_shape=lhs,
|
|
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
num_scalar_prefetch=0,
|
|
in_specs=[
|
|
pl.BlockSpec(memory_space=pltpu.SMEM),
|
|
pl.BlockSpec(memory_space=pltpu.VMEM),
|
|
pl.BlockSpec(memory_space=pltpu.VMEM),
|
|
],
|
|
),
|
|
name="select_kernel",
|
|
)(cond, lhs, rhs)
|
|
|
|
cond = jnp.array([1], dtype=jnp.int32)
|
|
lhs = jnp.zeros((8, 128), dtype=jnp.float32)
|
|
rhs = jnp.ones((8, 128), dtype=jnp.float32)
|
|
|
|
assert (run(cond, lhs, rhs) == lhs).all()
|
|
|
|
def test_logical_and_relayouted_mask(self):
|
|
def get_mask(x_ref):
|
|
x = x_ref[...] == 1
|
|
iota = jax.lax.broadcasted_iota(jnp.int32, x_ref.shape, 1)
|
|
iota = iota > 7
|
|
return jnp.logical_and(x, iota)
|
|
|
|
def body(x_ref, y_ref):
|
|
y_ref[...] = jnp.where(get_mask(x_ref), 0.0, -1.0)
|
|
|
|
shape = (2, 512)
|
|
out = jax.ShapeDtypeStruct(shape, jnp.float32)
|
|
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape(shape)
|
|
result = self.pallas_call(body, out_shape=out)(x)
|
|
expected = jnp.ones(x.shape, dtype=jnp.float32)
|
|
expected = expected.at[...].set(jnp.where(get_mask(x), 0.0, -1.0))
|
|
np.testing.assert_array_equal(result, expected)
|
|
|
|
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int16, jnp.int8])
|
|
def test_cast_vector_to_mask(self, dtype):
|
|
if not jtu.if_cloud_tpu_at_least(2025, 1, 22):
|
|
self.skipTest("Requires libtpu built after 2025-01-22")
|
|
shape = (128, 128)
|
|
bitwidth = pallas_utils.dtype_bitwidth(dtype)
|
|
if jtu.get_tpu_version() < 5 and bitwidth < 32:
|
|
self.skipTest(
|
|
f"Not implemented: cast vector to mask with bitwidth == {bitwidth}"
|
|
)
|
|
|
|
@functools.partial(
|
|
self.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct(shape, dtype),
|
|
)
|
|
def kernel(x_ref, mask_ref, o_ref):
|
|
zeros = jnp.zeros_like(x_ref)
|
|
o_ref[...] = jnp.where(mask_ref[...], x_ref[...], zeros)
|
|
|
|
mask = jax.random.bernoulli(jax.random.key(1234), 0.5, shape).astype(dtype)
|
|
x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + 1
|
|
|
|
out = kernel(x, mask)
|
|
expected = jnp.where(mask, x, jnp.zeros_like(x))
|
|
self.assertArraysEqual(out, expected)
|
|
|
|
@parameterized.product(
|
|
dtype = [jnp.float32, jnp.bfloat16, jnp.int32],
|
|
axis = [0, 1, 2],
|
|
reduce_func = [jnp.sum, jnp.max, jnp.min]
|
|
)
|
|
def test_reduction(self, dtype, axis, reduce_func):
|
|
if dtype == jnp.int32:
|
|
# TODO(apaszke): Remove after 12 weeks have passed.
|
|
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
|
self.skipTest("Requires libtpu built after 2024-12-19")
|
|
if axis == 2:
|
|
self.skipTest("Int32 reduction on minor is not supported.")
|
|
# TODO(b/384127570): fix bfloat16 reduction.
|
|
if dtype == jnp.bfloat16 and reduce_func != jnp.sum:
|
|
self.skipTest("b/384127570")
|
|
in_shape = (2, 16, 128)
|
|
out_shape = list(in_shape)
|
|
out_shape[axis] = 1
|
|
|
|
def kernel(x, out):
|
|
out[:] = reduce_func(x[:], axis, keepdims=True)
|
|
|
|
x = jnp.arange(np.prod(in_shape), dtype=dtype).reshape(in_shape)
|
|
result = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
|
|
)(x)
|
|
expected = reduce_func(x, axis, keepdims=True)
|
|
np.testing.assert_array_equal(result, expected)
|
|
|
|
@parameterized.product(
|
|
msk_dtype=[jnp.float32, jnp.bfloat16, jnp.int8],
|
|
dtype=[jnp.float32, jnp.bfloat16],
|
|
)
|
|
def test_i1_relayout_with_bitwidth_change(self, msk_dtype, dtype):
|
|
if not jtu.if_cloud_tpu_at_least(2025, 1, 25):
|
|
self.skipTest("Requires libtpu built after 2025-01-25")
|
|
shape = (129, 129)
|
|
msk_bitwidth = pallas_utils.dtype_bitwidth(msk_dtype)
|
|
bitwidth = pallas_utils.dtype_bitwidth(dtype)
|
|
if jtu.get_tpu_version() < 5 and msk_bitwidth < 32:
|
|
self.skipTest(
|
|
"Not implemented: cast vector to mask with bitwidth =="
|
|
f" {msk_bitwidth}"
|
|
)
|
|
if jtu.get_tpu_version() < 5 and bitwidth < 32:
|
|
self.skipTest(f"Not implemented: comparison with bitwidth == {bitwidth}")
|
|
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct(shape, dtype),
|
|
)
|
|
def kernel(x_ref, mask_ref, o_ref):
|
|
zeros = jnp.zeros_like(x_ref)
|
|
o_ref[...] = jnp.where(mask_ref[...], x_ref[...], zeros)
|
|
|
|
mask = jax.random.bernoulli(jax.random.key(1234), 0.5, shape).astype(
|
|
msk_dtype
|
|
)
|
|
x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + 1
|
|
|
|
out = kernel(x, mask)
|
|
expected = jnp.where(mask, x, jnp.zeros_like(x))
|
|
self.assertArraysEqual(out, expected)
|
|
|
|
@parameterized.product(
|
|
target=(jnp.int8,), # TODO(apaszke): Add int4.
|
|
round=(False, True),
|
|
)
|
|
def test_quantize(self, target, round):
|
|
if not jtu.if_cloud_tpu_at_least(2025, 1, 15):
|
|
self.skipTest("Requires libtpu built after 2025-01-15")
|
|
if not jtu.is_device_tpu_at_least(version=6):
|
|
self.skipTest("Requires TPUv6+")
|
|
shape = (256, 256)
|
|
# NOTE: 256 * 256 == 2 ** 16, so those are all bf16 values.
|
|
x = lax.bitcast_convert_type(
|
|
np.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape),
|
|
jnp.bfloat16,
|
|
)
|
|
|
|
round_fn = jnp.rint if round else lambda x: x
|
|
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = round_fn(x_ref[...]).astype(target)
|
|
out = self.pallas_call(
|
|
kernel, out_shape=jax.ShapeDtypeStruct(shape, target)
|
|
)(x)
|
|
|
|
ref = jax.jit(lambda x: round_fn(x).astype(target))(x)
|
|
np.testing.assert_array_equal(out, ref)
|
|
|
|
@parameterized.product(axis=[0, 1], mode=["promise_in_bounds", None])
|
|
def test_dynamic_gather_along_axis(self, axis, mode):
|
|
if not jtu.if_cloud_tpu_at_least(2025, 2, 5):
|
|
self.skipTest("Requires libtpu built after 2025-02-05")
|
|
if (axis == 0 and not jtu.is_device_tpu_at_least(version=5)) or (
|
|
axis == 1 and not jtu.is_device_tpu_at_least(version=4)
|
|
):
|
|
self.skipTest("Requires TPUv5+ for axis=0 and TPUv4+ for axis=1")
|
|
dtype = jnp.int32
|
|
shape = (8, 128)
|
|
|
|
def kernel(x, indices, out):
|
|
out[...] = jnp.take_along_axis(x[...], indices[...], axis, mode=mode)
|
|
|
|
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
|
idx = jax.random.randint(
|
|
key=jax.random.key(1234),
|
|
shape=shape,
|
|
minval=0,
|
|
maxval=shape[axis],
|
|
dtype=jnp.int32,
|
|
)
|
|
actual = self.pallas_call(
|
|
kernel, out_shape=jax.ShapeDtypeStruct(shape, dtype)
|
|
)(x, idx)
|
|
expected = np.take_along_axis(x, idx, axis=axis)
|
|
np.testing.assert_array_equal(actual, expected)
|
|
|
|
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16])
|
|
def test_float_div(self, dtype):
|
|
if not jtu.if_cloud_tpu_at_least(2025, 2, 13):
|
|
self.skipTest("Requires libtpu built after 2025-02-13")
|
|
if not jtu.is_device_tpu_at_least(version=4):
|
|
self.skipTest("Requires TPUv4+")
|
|
kwargs = {}
|
|
if jtu.get_tpu_version() == 6:
|
|
kwargs.update(dict(rtol=1e-2))
|
|
def kernel(x, y, out):
|
|
out[:] = jax.lax.div(x[:], y[:])
|
|
|
|
run = pl.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
|
|
)
|
|
k1, k2 = jax.random.split(jax.random.key(1234), 2)
|
|
x = jax.random.normal(k1, (8, 128), dtype=dtype)
|
|
y = jax.random.normal(k2, (8, 128), dtype=dtype)
|
|
np.testing.assert_allclose(run(x, y), jax.lax.div(x, y), **kwargs)
|
|
|
|
@parameterized.product(
|
|
dtype=[jnp.float32, jnp.bfloat16, jnp.int8],
|
|
)
|
|
def test_concat_mask(self, dtype):
|
|
if not jtu.if_cloud_tpu_at_least(2025, 2, 19):
|
|
self.skipTest("Requires libtpu built after 2025-02-19")
|
|
bitwidth = pallas_utils.dtype_bitwidth(dtype)
|
|
if jtu.get_tpu_version() < 5 and bitwidth < 32:
|
|
self.skipTest(
|
|
f"Not implemented: cast vector to mask with bitwidth == {bitwidth}"
|
|
)
|
|
shape = (128, 128)
|
|
|
|
def kernel(x, out):
|
|
mask = x[...] != 0
|
|
concated_mask = jnp.concatenate([mask, mask], axis=0)
|
|
concated_x = jnp.concatenate([x[:], x[:]], axis=0)
|
|
out[:] = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x))
|
|
|
|
x = jax.random.normal(jax.random.key(1234), shape, dtype=jnp.float32)
|
|
if dtype == jnp.int8:
|
|
x = (x * 100).astype(jnp.int8)
|
|
else:
|
|
x = x.astype(dtype)
|
|
out = self.pallas_call(
|
|
kernel, out_shape=jax.ShapeDtypeStruct((shape[0] * 2, shape[1]), dtype)
|
|
)(x)
|
|
concated_mask = jnp.concatenate([x != 0, x != 0], axis=0)
|
|
concated_x = jnp.concatenate([x, x], axis=0)
|
|
expected = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x))
|
|
np.testing.assert_array_equal(out, expected)
|
|
|
|
|
|
class OpsInterpretTest(OpsTest):
|
|
INTERPRET = True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main()
|