# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for TPU specific operations within pallas_call.""" import functools import math import sys import unittest from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax from jax._src import test_util as jtu from jax._src.pallas import utils as pallas_utils from jax.experimental import pallas as pl import jax.numpy as jnp import numpy as np if sys.platform != "win32": from jax.experimental.pallas import tpu as pltpu else: pltpu = None try: import hypothesis as hp except (ModuleNotFoundError, ImportError): raise unittest.SkipTest("tests depend on hypothesis library") import hypothesis.strategies as hps jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=100) _JAX_DTYPES = ( jnp.float32, jnp.bfloat16, jnp.int32, jnp.int16, jnp.int8, jnp.bool_, ) class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False def setUp(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Test only supported on TPU.") super().setUp() @classmethod def pallas_call(cls, *args, **kwargs): return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) class OpsTest(PallasBaseTest): @parameterized.product( from_dtype=_JAX_DTYPES, to_dtype=_JAX_DTYPES, is_ref_bitcast=[False, True] ) def test_bitcast(self, from_dtype, to_dtype, is_ref_bitcast): if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Run on TPUv4+ to have expected memory layout") if from_dtype == to_dtype: self.skipTest("No bitcast needed") if from_dtype == jnp.bool_ or to_dtype == jnp.bool_: self.skipTest("Bitcasting with bool is not supported") def kernel(x_ref, y_ref): if is_ref_bitcast: y_ref[...] = x_ref.bitcast(to_dtype)[...] else: y_ref[...] = pltpu.bitcast(x_ref[...], to_dtype) m, n = 1, 256 in_packing = 32 // pallas_utils.dtype_bitwidth(from_dtype) out_packing = 32 // pallas_utils.dtype_bitwidth(to_dtype) in_shape = (m * in_packing, n) out_shape = (m * out_packing, n) inp = np.arange(np.prod(in_shape), dtype=from_dtype).reshape(in_shape) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(out_shape, to_dtype), )(inp) if not self.INTERPRET: out_interpret = pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(out_shape, to_dtype), interpret=True, )(inp) self.assertAllClose(out, out_interpret) @parameterized.product(is_dynamic=(False, True)) @hp.given( axis=hps.integers(0, 3), shift=hps.integers(0, 3), stride=hps.one_of(hps.just(None), hps.integers(0, 2)), # Stride dimension on the minor most is not supported. stride_axis=hps.one_of(hps.just(None), hps.integers(0, 2)), ) @hp.example(3, 9, 1, 2) @hp.example(3, 9, 2, 2) @hp.example(0, 9, 0, 1) @hp.example(0, 9, 1, 1) def test_roll(self, is_dynamic, axis, shift, stride, stride_axis): if (stride is None) != (stride_axis is None): self.skipTest( "Roll op requires both stride and stride_axis to be either specified" " or not specified." ) if (not jtu.is_device_tpu(version=5)) and stride_axis == 2: self.skipTest( "Roll op with stride axis on 2nd minor requires at least TPU v5" ) shape = (4, 4, 32, 512) def kernel(s_ref, x_ref, y_ref): amt = s_ref[0] if is_dynamic else shift y_ref[...] = pltpu.roll( x_ref[...], amt, axis, stride=stride, stride_axis=stride_axis ) def roll(x, shift, axis, stride=None, stride_axis=None): assert (stride is None) == (stride_axis is None) if stride is None: return np.roll(x, shift, axis) outputs = [ np.roll(xs, shift + i * stride, axis) for i, xs in enumerate(np.split(x, x.shape[stride_axis], stride_axis)) ] return np.concatenate(outputs, stride_axis) inp = np.arange(np.prod(shape), dtype=jnp.int32).reshape(shape) ref = roll(inp, shift, axis, stride, stride_axis) dynamic_shift = jnp.array([abs(shift)], jnp.int32) for interpret in [False, True]: out = pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.int32), grid_spec=pltpu.PrefetchScalarGridSpec(num_scalar_prefetch=1), interpret=interpret, )(dynamic_shift, inp) np.testing.assert_array_equal(out, ref, err_msg=f"{interpret=}") def test_interleave_vectors(self): if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Expect TPUv4+") def kernel(x_ref, y_ref, out_ref): x = pltpu.bitcast(x_ref[...].astype(jnp.float32), jnp.int32) y = pltpu.bitcast(y_ref[...].astype(jnp.float32), jnp.int32) shift = jax.lax.broadcast(16, x.shape) out_ref[...] = pltpu.bitcast( y | jax.lax.shift_right_logical(x, shift), jnp.bfloat16 ) m, n = 16, 128 inp = np.arange(m * n * 2, dtype=jnp.bfloat16).reshape(m, n * 2) x, y = np.split(inp, 2, axis=1) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((m * 2, n), jnp.bfloat16), )(x, y) np.testing.assert_array_equal(out, inp.reshape(m * 2, n)) @parameterized.parameters([jnp.int32, jnp.int16, jnp.int8, jnp.int4]) def test_row_broadcast(self, dtype): if not jtu.if_cloud_tpu_at_least(2025, 1, 10): self.skipTest("Requires libtpu built after 2025-01-10") if not self.INTERPRET and jtu.get_tpu_version() < 5: self.skipTest("Requires TPUv5+") def kernel(x_ref, y_ref): y_ref[...] = jnp.broadcast_to(x_ref[pl.ds(3, 1)], y_ref.shape).astype(y_ref.dtype) m, n = 4, 1152 x = jax.random.randint( jax.random.key(12), (m, n), minval=-1000, maxval=1000, dtype=jnp.int32 ).astype(dtype) y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32) )(x) np.testing.assert_array_equal(y, jnp.broadcast_to(x[3:4], y.shape)) def test_tpu_unsigned_int(self): self.skipTest("TODO(apaszke): Unsigned upcasts were implemented incorrectly") def body(x_ref, o_ref): # Test cast from uint16 -> uint32 ux = lax.convert_element_type(x_ref[...], jnp.uint32) res = ux + 1 # Test cast from uint32 -> float32 o_ref[...] = res.astype(jnp.float32) out = jax.ShapeDtypeStruct((8, 128), jnp.float32) x = jnp.arange(8 * 128, dtype=jnp.uint16).reshape((8, 128)) result = self.pallas_call(body, out_shape=out)(x) np.testing.assert_array_equal(result, x.astype(jnp.float32) + 1.0) def test_tpu_signed_int_upcast(self): if not jtu.is_device_tpu_at_least(version=5): self.skipTest("TPUv5+ needed for integer matmuls") def body(x_ref, o_ref): # Test cast from int4 -> int8 ux = lax.convert_element_type(x_ref[...], jnp.int8) o_ref[...] = jax.lax.dot(ux, ux, preferred_element_type=jnp.int32) out = jax.ShapeDtypeStruct((128, 128), jnp.int32) x = jnp.arange(128 * 128, dtype=jnp.int4).reshape((128, 128)) result = self.pallas_call(body, out_shape=out)(x) np.testing.assert_array_equal( result, jax.lax.dot( x.astype(jnp.int8), x.astype(jnp.int8), preferred_element_type=jnp.int32, ), ) def test_select_with_scalar_condition(self): def kernel(cond, lhs, rhs, out): out[:] = jax.lax.select(cond[0] != 0, lhs[:], rhs[:]) def run(cond, lhs, rhs): return self.pallas_call( kernel, out_shape=lhs, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ pl.BlockSpec(memory_space=pltpu.SMEM), pl.BlockSpec(memory_space=pltpu.VMEM), pl.BlockSpec(memory_space=pltpu.VMEM), ], ), name="select_kernel", )(cond, lhs, rhs) cond = jnp.array([1], dtype=jnp.int32) lhs = jnp.zeros((8, 128), dtype=jnp.float32) rhs = jnp.ones((8, 128), dtype=jnp.float32) assert (run(cond, lhs, rhs) == lhs).all() def test_logical_and_relayouted_mask(self): def get_mask(x_ref): x = x_ref[...] == 1 iota = jax.lax.broadcasted_iota(jnp.int32, x_ref.shape, 1) iota = iota > 7 return jnp.logical_and(x, iota) def body(x_ref, y_ref): y_ref[...] = jnp.where(get_mask(x_ref), 0.0, -1.0) shape = (2, 512) out = jax.ShapeDtypeStruct(shape, jnp.float32) x = jnp.arange(8 * 128, dtype=jnp.int32).reshape(shape) result = self.pallas_call(body, out_shape=out)(x) expected = jnp.ones(x.shape, dtype=jnp.float32) expected = expected.at[...].set(jnp.where(get_mask(x), 0.0, -1.0)) np.testing.assert_array_equal(result, expected) @parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int16, jnp.int8]) def test_cast_vector_to_mask(self, dtype): if not jtu.if_cloud_tpu_at_least(2025, 1, 22): self.skipTest("Requires libtpu built after 2025-01-22") shape = (128, 128) bitwidth = pallas_utils.dtype_bitwidth(dtype) if jtu.get_tpu_version() < 5 and bitwidth < 32: self.skipTest( f"Not implemented: cast vector to mask with bitwidth == {bitwidth}" ) @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, dtype), ) def kernel(x_ref, mask_ref, o_ref): zeros = jnp.zeros_like(x_ref) o_ref[...] = jnp.where(mask_ref[...], x_ref[...], zeros) mask = jax.random.bernoulli(jax.random.key(1234), 0.5, shape).astype(dtype) x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + 1 out = kernel(x, mask) expected = jnp.where(mask, x, jnp.zeros_like(x)) self.assertArraysEqual(out, expected) @parameterized.product( dtype = [jnp.float32, jnp.bfloat16, jnp.int32], axis = [0, 1, 2], reduce_func = [jnp.sum, jnp.max, jnp.min] ) def test_reduction(self, dtype, axis, reduce_func): if dtype == jnp.int32: # TODO(apaszke): Remove after 12 weeks have passed. if not jtu.if_cloud_tpu_at_least(2024, 12, 19): self.skipTest("Requires libtpu built after 2024-12-19") if axis == 2: self.skipTest("Int32 reduction on minor is not supported.") # TODO(b/384127570): fix bfloat16 reduction. if dtype == jnp.bfloat16 and reduce_func != jnp.sum: self.skipTest("b/384127570") in_shape = (2, 16, 128) out_shape = list(in_shape) out_shape[axis] = 1 def kernel(x, out): out[:] = reduce_func(x[:], axis, keepdims=True) x = jnp.arange(np.prod(in_shape), dtype=dtype).reshape(in_shape) result = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), )(x) expected = reduce_func(x, axis, keepdims=True) np.testing.assert_array_equal(result, expected) @parameterized.product( msk_dtype=[jnp.float32, jnp.bfloat16, jnp.int8], dtype=[jnp.float32, jnp.bfloat16], ) def test_i1_relayout_with_bitwidth_change(self, msk_dtype, dtype): if not jtu.if_cloud_tpu_at_least(2025, 1, 25): self.skipTest("Requires libtpu built after 2025-01-25") shape = (129, 129) msk_bitwidth = pallas_utils.dtype_bitwidth(msk_dtype) bitwidth = pallas_utils.dtype_bitwidth(dtype) if jtu.get_tpu_version() < 5 and msk_bitwidth < 32: self.skipTest( "Not implemented: cast vector to mask with bitwidth ==" f" {msk_bitwidth}" ) if jtu.get_tpu_version() < 5 and bitwidth < 32: self.skipTest(f"Not implemented: comparison with bitwidth == {bitwidth}") @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, dtype), ) def kernel(x_ref, mask_ref, o_ref): zeros = jnp.zeros_like(x_ref) o_ref[...] = jnp.where(mask_ref[...], x_ref[...], zeros) mask = jax.random.bernoulli(jax.random.key(1234), 0.5, shape).astype( msk_dtype ) x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + 1 out = kernel(x, mask) expected = jnp.where(mask, x, jnp.zeros_like(x)) self.assertArraysEqual(out, expected) @parameterized.product( target=(jnp.int8,), # TODO(apaszke): Add int4. round=(False, True), ) def test_quantize(self, target, round): if not jtu.if_cloud_tpu_at_least(2025, 1, 15): self.skipTest("Requires libtpu built after 2025-01-15") if not jtu.is_device_tpu_at_least(version=6): self.skipTest("Requires TPUv6+") shape = (256, 256) # NOTE: 256 * 256 == 2 ** 16, so those are all bf16 values. x = lax.bitcast_convert_type( np.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape), jnp.bfloat16, ) round_fn = jnp.rint if round else lambda x: x def kernel(x_ref, o_ref): o_ref[...] = round_fn(x_ref[...]).astype(target) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(shape, target) )(x) ref = jax.jit(lambda x: round_fn(x).astype(target))(x) np.testing.assert_array_equal(out, ref) @parameterized.product(axis=[0, 1], mode=["promise_in_bounds", None]) def test_dynamic_gather_along_axis(self, axis, mode): if not jtu.if_cloud_tpu_at_least(2025, 2, 5): self.skipTest("Requires libtpu built after 2025-02-05") if (axis == 0 and not jtu.is_device_tpu_at_least(version=5)) or ( axis == 1 and not jtu.is_device_tpu_at_least(version=4) ): self.skipTest("Requires TPUv5+ for axis=0 and TPUv4+ for axis=1") dtype = jnp.int32 shape = (8, 128) def kernel(x, indices, out): out[...] = jnp.take_along_axis(x[...], indices[...], axis, mode=mode) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) idx = jax.random.randint( key=jax.random.key(1234), shape=shape, minval=0, maxval=shape[axis], dtype=jnp.int32, ) actual = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(shape, dtype) )(x, idx) expected = np.take_along_axis(x, idx, axis=axis) np.testing.assert_array_equal(actual, expected) @parameterized.product(dtype=[jnp.float32, jnp.bfloat16]) def test_float_div(self, dtype): if not jtu.if_cloud_tpu_at_least(2025, 2, 13): self.skipTest("Requires libtpu built after 2025-02-13") if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Requires TPUv4+") kwargs = {} if jtu.get_tpu_version() == 6: kwargs.update(dict(rtol=1e-2)) def kernel(x, y, out): out[:] = jax.lax.div(x[:], y[:]) run = pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), dtype), ) k1, k2 = jax.random.split(jax.random.key(1234), 2) x = jax.random.normal(k1, (8, 128), dtype=dtype) y = jax.random.normal(k2, (8, 128), dtype=dtype) np.testing.assert_allclose(run(x, y), jax.lax.div(x, y), **kwargs) @parameterized.product( dtype=[jnp.float32, jnp.bfloat16, jnp.int8], ) def test_concat_mask(self, dtype): if not jtu.if_cloud_tpu_at_least(2025, 2, 19): self.skipTest("Requires libtpu built after 2025-02-19") bitwidth = pallas_utils.dtype_bitwidth(dtype) if jtu.get_tpu_version() < 5 and bitwidth < 32: self.skipTest( f"Not implemented: cast vector to mask with bitwidth == {bitwidth}" ) shape = (128, 128) def kernel(x, out): mask = x[...] != 0 concated_mask = jnp.concatenate([mask, mask], axis=0) concated_x = jnp.concatenate([x[:], x[:]], axis=0) out[:] = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x)) x = jax.random.normal(jax.random.key(1234), shape, dtype=jnp.float32) if dtype == jnp.int8: x = (x * 100).astype(jnp.int8) else: x = x.astype(dtype) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((shape[0] * 2, shape[1]), dtype) )(x) concated_mask = jnp.concatenate([x != 0, x != 0], axis=0) concated_x = jnp.concatenate([x, x], axis=0) expected = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x)) np.testing.assert_array_equal(out, expected) class OpsInterpretTest(OpsTest): INTERPRET = True if __name__ == "__main__": absltest.main()