# Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test TPU-specific extensions to pallas_call.""" import contextlib import functools import gc import io import math import re import sys from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax from jax._src import checkify from jax._src import state from jax._src import test_util as jtu from jax._src.interpreters import partial_eval as pe from jax._src.lib import xla_extension from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax._src.state import utils as state_utils from jax._src.state import discharge as state_discharge from jax.experimental import mesh_utils from jax.experimental import mosaic from jax.experimental import pallas as pl from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas.ops.tpu import example_kernel from jax.extend import linear_util as lu import jax.numpy as jnp import numpy as np jax.config.parse_flags_with_absl() P = jax.sharding.PartitionSpec partial = functools.partial @contextlib.contextmanager def string_stdout(): """Redirects stdout to a string.""" initial_stdout = sys.stdout stringio = io.StringIO() sys.stdout = stringio yield stringio sys.stdout = initial_stdout class PallasBaseTest(jtu.JaxTestCase): INTERPRET: bool = False def setUp(self): if not jtu.test_device_matches(['tpu']) and not self.INTERPRET: self.skipTest('Test requires TPUs, or interpret mode') super().setUp() _trace_kernel_to_jaxpr.cache_clear() def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) class PallasCallScalarPrefetchTest(PallasBaseTest): def test_trivial_scalar_prefetch(self): def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32) x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): s = pl.load(s_ref, (i,)) return (s, 0) out = self.pallas_call( body, out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform), ], out_specs=pl.BlockSpec( (x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0) ), grid=8, ), )(s, x) np.testing.assert_allclose(out, x.reshape((8, 8, -1))[s].reshape(x.shape)) def test_trivial_scalar_prefetch_with_windowless_args(self): def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32) x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) out = self.pallas_call( body, out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, ), )(s, x) np.testing.assert_array_equal(out, x) @jtu.parameterized_filterable( kwargs=[ dict(scratch=scratch, vmap=vmap, dyn_grid=dyn_grid) for scratch in [True, False] for vmap in [False, True] for dyn_grid in [False, True] ] ) def test_scalar_prefetch_calling_convention( self, *, scratch: bool, vmap: bool, dyn_grid: bool): # Tests what we process correctly all the various inputs and outputs: # dynamic_grid_dims, index, inputs, outputs, scratch. if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: self.skipTest("TODO: dslice(start, 1) raises error about slice inputs being int32 and int64") to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) if vmap: x_shape = (4, 16, 128) else: x_shape = (16, 128) x = np.arange(math.prod(x_shape), dtype=np.float32).reshape(x_shape) def f(x, grid_size, to_store): s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, # 1 pytree grid=(grid_size,), in_specs=[pl.BlockSpec((8, 128), lambda i, s_ref: (pl.load(s_ref[0], (i,)), 0)), pl.BlockSpec((1, 128), lambda i, s_ref: (0, 0))], out_specs=pl.BlockSpec((32, 128), lambda i, s_ref: (pl.load(s_ref[0], i), 0)), scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if scratch else []), ), ) def kernel(s_refs, src, to_store, dst, *scratch_refs): s_ref, s2, s3 = s_refs assert s_ref.shape == (2,) assert s2.shape == (3,) assert s3 is None store_idx = s_ref[pl.program_id(0)] pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store[...]) # Pass a pytree of scalar return kernel((s, np.arange(3, dtype=np.int32), None), x, to_store) if dyn_grid: f = jax.jit(f) if vmap: res = jax.vmap(lambda x: f(x, 2, to_store))(x) else: res = f(x, 2, to_store) if vmap: for i in range(x.shape[0]): self.assertAllClose(res[i, 0:1], to_store) self.assertAllClose(res[i, 33:34], to_store) else: self.assertAllClose(res[0:1], to_store) self.assertAllClose(res[33:34], to_store) def test_with_unhashable_grid_spec(self): # Make sure that we don't crash when the GridSpec has non-hashable parts @functools.partial( self.pallas_call, out_shape=[[jax.ShapeDtypeStruct((8, 128), np.int32)]], grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, # 1 pytree grid=(1,), in_specs=[[pl.BlockSpec((8, 128), lambda i, s_ref: (0, 0))]], out_specs=[[pl.BlockSpec((8, 128), lambda i, s_ref: (0, 0))]], scratch_shapes=[[pltpu.SemaphoreType.REGULAR((3,))]], ), ) def kernel(s_ref, x_ref, o_ref, scratch_ref): assert isinstance(s_ref, list) assert isinstance(x_ref, list) assert isinstance(o_ref, list) assert isinstance(scratch_ref, list) o_ref[0][...] = x_ref[0][...] x_shape = (8, 128) s = np.array([0, 1], np.int32) x = np.arange(math.prod(x_shape), dtype=np.int32).reshape(x_shape) res = kernel([s, s], [x]) self.assertIsInstance(res, tuple) # Even though we asked for a list! self.assertAllClose(res[0][0], x) def test_vmap_scalar_prefetch(self): def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32) x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): s = pl.load(s_ref, (i,)) return (s, 0) def f(x): return self.pallas_call( body, out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform), ], out_specs=pl.BlockSpec( (x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0) ), grid=8), )(s, x) np.testing.assert_allclose( jax.vmap(f)(x), x.reshape((2, 8, 8, -1))[:, s].reshape(x.shape) ) def test_multiple_scalar_prefetch(self): def body(s1_ref, s2_ref, x_ref, o_ref): del s1_ref, s2_ref o_ref[...] = x_ref[...] s1 = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32) s2 = jnp.array([7, 6, 5, 4, 3, 2, 1, 0], jnp.int32) x = jnp.arange(64 * 128, dtype=jnp.int32).reshape((64, 128)) def _x_transform(i, s1_ref, _): return s1_ref[i], 0 def _o_transform(i, _, s2_ref): return s2_ref[i], 0 out = self.pallas_call( body, out_shape=jax.ShapeDtypeStruct((64, 128), jnp.int32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=2, in_specs=[ pl.BlockSpec((8, 128), _x_transform), ], out_specs=pl.BlockSpec((8, 128), _o_transform), grid=8, ), )(s1, s2, x) out_ref = x.reshape((8, 8, -1))[s1][::-1].reshape((64, 128)) np.testing.assert_allclose(out, out_ref) def test_scalar_interpreter(self): program = jnp.array([0, 0, 1, 0, 1, 1], jnp.int32) x = jnp.arange(8 * 8 * 128.0, dtype=jnp.float32).reshape(8 * 8, 128) def body(sprogram_ref, x_ref, o_ref, state_ref): x = x_ref[...] def add_branch_fn(j): state_ref[...] += jnp.float32(j) return () def mult_branch_fn(j): state_ref[...] *= jnp.float32(j) return () def single_inst(i, _): _ = jax.lax.switch( sprogram_ref[i], ( add_branch_fn, mult_branch_fn, ), i, ) # We can't use for loop state right now, because Pallas functionalizes it, # and Mosaic support for returning values form scf.if is incomplete. state_ref[...] = x lax.fori_loop(0, sprogram_ref.shape[0], single_inst, None, unroll=True) o_ref[...] = state_ref[...] # Ignore the scratch output. out, _ = self.pallas_call( body, out_shape=[ jax.ShapeDtypeStruct(x.shape, jnp.float32), jax.ShapeDtypeStruct((8, 128), jnp.float32), ], grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[pl.BlockSpec((8, 128), lambda i, *_: (i, 0))], out_specs=[ pl.BlockSpec((8, 128), lambda i, *_: (i, 0)), pl.BlockSpec((8, 128), lambda *_: (0, 0)), ], grid=8, ), )(program, x) expected = x for i, p in enumerate(program): if p == 0: expected += i elif p == 1: expected *= i np.testing.assert_allclose(out, expected) def test_scalar_interpreter_dynamic_loop(self): loop_end = jnp.array([5], jnp.int32) def body(loop_end_ref, out_ref): out_ref[...] = jnp.zeros_like(out_ref) def loop_body(i, carry): del i, carry out_ref[...] += 1 lax.fori_loop(0, loop_end_ref[0], loop_body, None) out = self.pallas_call( body, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, out_specs=pl.BlockSpec((8, 128), lambda *_: (0, 0)), grid=1, ), )(loop_end) expected_out = jnp.ones((8, 128), jnp.float32) * 5 np.testing.assert_allclose(out, expected_out) def test_vmap_scalar_prefetch_1sized(self): def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32) x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): s = pl.load(s_ref, (i,)) return (s, 0) s = s[None] x = x[None] out = jax.vmap( self.pallas_call( body, out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ pl.BlockSpec((x.shape[1] // 8, x.shape[2]), _x_transform), ], out_specs=pl.BlockSpec( (x.shape[1] // 8, x.shape[2]), lambda i, _: (i, 0) ), grid=8, ), ) )(s, x) np.testing.assert_allclose( out, x.reshape((1, 8, 8, -1))[:, s].reshape(x.shape) ) def test_nontrivial_vmap_scalar_prefetch(self): def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32) x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): s = pl.load(s_ref, (i,)) return (s, 0) s = jnp.tile(s[None], [2, 1]) @jax.jit @jax.vmap def kernel(s, x): return self.pallas_call( body, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform), ], out_specs=pl.BlockSpec( (x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0) ), grid=8, ), compiler_params=pltpu.TPUCompilerParams( allow_input_fusion=[False, True] ), )(s, x) first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:]) second = x[1, ...].reshape((1, 8, 8, -1))[:, s[1, ...]].reshape(x.shape[1:]) expected = jnp.stack([first, second]) np.testing.assert_allclose(kernel(s, x), expected) def test_input_output_aliasing_with_scalar_prefetch(self): x = jnp.ones((32, 1024, 1024)) expected = x + 1 def kernel(_, x_ref, y_ref): y_ref[...] = x_ref[...] + 1. @partial(jax.jit, donate_argnums=(0,)) def f(x): return self.pallas_call( kernel, out_shape=x, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ pl.BlockSpec((None, 1024, 1024), lambda i, _: (i, 0, 0)) ], out_specs=pl.BlockSpec( (None, 1024, 1024), lambda i, _: (i, 0, 0) ), grid=(x.shape[0],), ), input_output_aliases={1: 0}, )(jnp.array([1, 2, 3]), x) o = f(x) np.testing.assert_array_equal(o, expected) compiled = f.lower(jax.ShapeDtypeStruct(x.shape, x.dtype)).compile() mem_analysis = compiled.memory_analysis() expected_num_bytes = np.prod(x.shape) * x.dtype.itemsize self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes) class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): INTERPRET: bool = True class PallasCallDynamicGridTest(PallasBaseTest): def test_can_query_grid_statically_via_num_programs(self): def kernel(_): num_programs = pl.num_programs(0) self.assertIsInstance(num_programs, int) self.assertEqual(num_programs, 2) self.pallas_call(kernel, out_shape=None, grid=(2,))() def test_can_query_grid_statically_via_num_programs_in_block_spec(self): def kernel(*_): pass def x_index_map(_): num_programs = pl.num_programs(0) self.assertIsInstance(num_programs, int) self.assertEqual(num_programs, 2) return 0, 0 self.pallas_call( kernel, in_specs=[pl.BlockSpec((8, 128), x_index_map)], out_shape=None, grid=(2,), )(jnp.ones((8, 128))) def test_dynamic_grid_has_dynamic_size(self): def kernel(_): num_programs = pl.num_programs(0) self.assertIsInstance(num_programs, int, msg=type(num_programs)) self.assertEqual(num_programs, 2) num_programs = pl.num_programs(1) self.assertIsInstance(num_programs, jax.Array) @jax.jit def outer(x): self.pallas_call(kernel, out_shape=None, grid=(2, x))() outer(2) def test_dynamic_grid(self): shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) def kernel(y_ref): @pl.when(pl.program_id(0) == 0) def _init(): y_ref[...] = jnp.zeros_like(y_ref) y_ref[...] += 1 @jax.jit def dynamic_kernel(steps): return self.pallas_call( kernel, grid=(steps * 2,), out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), out_shape=result_ty, )() np.testing.assert_array_equal( dynamic_kernel(jnp.int32(4)), np.full(shape, 8.0, np.float32) ) def test_dynamic_grid_overflow(self): # If we pad statically the dynamic grid dims to max int32, then the product # of this grid size will overflow int64 and can cause failing checks in XLA. shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) def kernel(y_ref): @pl.when(sum(pl.program_id(i) for i in range(3)) == 0) def _init(): y_ref[...] = jnp.zeros_like(y_ref) y_ref[...] += 1 @jax.jit def dynamic_kernel(steps): return self.pallas_call( kernel, grid=(steps * 2, steps + 1, 3), out_specs=pl.BlockSpec(shape, lambda *_: (0, 0)), out_shape=result_ty, )() np.testing.assert_array_equal( dynamic_kernel(jnp.int32(4)), np.full(shape, 120.0, np.float32) ) # TODO(apaszke): Add tests for scalar_prefetch too def test_dynamic_grid_scalar_input(self): shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) def kernel(scalar_input_ref, output_ref): output_ref[...] = jnp.full_like(output_ref, scalar_input_ref[0, 0]) @jax.jit def dynamic_kernel(steps): return self.pallas_call( kernel, out_shape=result_ty, in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), grid=(steps * 2,), )(jnp.array([[42]], dtype=jnp.int32)) np.testing.assert_array_equal( dynamic_kernel(jnp.int32(4)), np.full(shape, 42.0, np.float32) ) def test_vmap_trivial_dynamic_grid(self): shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) def kernel(x_ref, y_ref): @pl.when(pl.program_id(0) == 0) def _init(): y_ref[...] = x_ref[...] y_ref[...] += 1 @jax.jit @jax.vmap def dynamic_kernel(steps, x): return self.pallas_call( kernel, grid=(steps * 2,), in_specs=[pl.BlockSpec(shape, lambda i: (0, 0))], out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), out_shape=result_ty, )(x) x = jnp.arange(8 * 128., dtype=jnp.float32).reshape((1, *shape)) np.testing.assert_array_equal( dynamic_kernel(jnp.array([4], jnp.int32), x), x + 8.0 ) def test_vmap_nontrivial_dynamic_grid(self): # Dynamic grid doesn't support vmapping over multiple distinct grid values # at the moment. shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) def kernel(y_ref): @pl.when(pl.program_id(0) == 0) def _init(): y_ref[...] = jnp.zeros_like(y_ref) y_ref[...] += 1 @jax.jit @jax.vmap def dynamic_kernel(steps): return self.pallas_call( kernel, grid=(steps * 2,), out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), out_shape=result_ty, )() out = dynamic_kernel(jnp.array([4, 8], jnp.int32)) first = jnp.full(shape, fill_value=8.0, dtype=jnp.float32) second = jnp.full(shape, fill_value=16.0, dtype=jnp.float32) expected_out = jnp.stack([first, second], axis=0) np.testing.assert_array_equal(out, expected_out) def test_vmap_dynamic_grid(self): shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) def kernel(x_ref, y_ref): @pl.when(pl.program_id(0) == 0) def _init(): y_ref[...] = x_ref[...] y_ref[...] += jnp.float32(1.) @jax.jit def dynamic_kernel(x, steps): return self.pallas_call( kernel, grid=(steps * 2,), out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), out_shape=result_ty, )(x) x = jnp.arange(4 * 8 * 128., dtype=jnp.float32).reshape((4, *shape)) np.testing.assert_array_equal( jax.jit(jax.vmap(dynamic_kernel, in_axes=(0, None)))(x, jnp.int32(4)), x + 8, ) def test_num_programs(self): def kernel(y_ref): y_ref[0, 0] = pl.num_programs(0) @jax.jit def dynamic_kernel(steps): return self.pallas_call( kernel, grid=(steps * 2,), out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32), )() self.assertEqual(dynamic_kernel(np.int32(4)), 8) @parameterized.parameters(range(1, 4)) def test_vmap_num_programs(self, num_vmaps): result_ty = jax.ShapeDtypeStruct((8, 128), jnp.int32) def kernel(y_ref): y_ref[...] = jnp.full_like(y_ref, pl.num_programs(0)) kernel_call = self.pallas_call( kernel, grid=(8,), out_specs=pl.BlockSpec(result_ty.shape, lambda i: (0, 0)), out_shape=result_ty, ) out_shape = (*(2 for _ in range(num_vmaps)), *result_ty.shape) f = kernel_call for _ in range(num_vmaps): f = lambda impl=f: jax.vmap(impl, axis_size=2)() out = jax.jit(f)() np.testing.assert_array_equal(out, np.full(out_shape, 8.0)) def test_num_programs_block_spec(self): def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] @jax.jit def dynamic_kernel(steps, x): return self.pallas_call( kernel, grid=(steps * 2,), in_specs=[ pl.BlockSpec( (8, 128), # Should always evaluate to (1, 0) lambda i: (1 + 8 - pl.num_programs(0), 0), ) ], out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32), )(x) x = np.arange(4 * 8 * 128., dtype=np.int32).reshape((4 * 8, 128)) np.testing.assert_array_equal(dynamic_kernel(np.int32(4), x), x[8:16]) class PallasCallDynamicGridInterpretTest(PallasCallDynamicGridTest): INTERPRET = True class PallasCallDMATest(PallasBaseTest): def setUp(self): super().setUp() if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs not supported on TPU generations <= 3') def test_can_have_unspecified_memory_spaces(self): def kernel(x_ref, y_ref): # Just test whether things compile del x_ref, y_ref x = jnp.ones((8, 128), dtype=jnp.float32) y = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) jax.block_until_ready(y) def test_run_scoped_tracks_effects(self): def kernel(x_ref, y_ref): def body(temp_ref): temp_ref[...] = jnp.ones_like(temp_ref) x_ref[...] = 4 * y_ref[...] + temp_ref[...] pl.run_scoped(body, pltpu.VMEM((8,), jnp.float32)) return [] jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(kernel), [ state.shaped_array_ref((8,), jnp.float32), state.shaped_array_ref((8,), jnp.float32), ], ) expected_effects = {state.ReadEffect(1), state.WriteEffect(0)} self.assertSetEqual(jaxpr.effects, expected_effects) def test_scoped_allocation(self): def kernel(y_ref): def body(x_ref): x_ref[...] = jnp.ones_like(x_ref) y_ref[...] = 4 * x_ref[...] pl.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32)) o = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )() np.testing.assert_allclose(o, 4 * np.ones_like(o)) def test_run_scoped_can_return_scalar_value(self): def kernel(y_ref): def body(x_ref): x_ref[0] = 0 x_ref[0] += 1 return x_ref[0] + 2 out = pl.run_scoped(body, pltpu.SMEM((1,), jnp.int32)) y_ref[0] = out o = self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), ), out_shape=jax.ShapeDtypeStruct((1,), jnp.int32), )() np.testing.assert_allclose(o, jnp.array([3], jnp.int32)) def test_run_scoped_can_return_scalar_values(self): def kernel(y_ref): def body(x_ref): x_ref[0] = 0 x_ref[0] += 1 return x_ref[0] + 2, x_ref[0] out = pl.run_scoped(body, pltpu.SMEM((1,), jnp.int32)) y_ref[0], y_ref[1] = out o = self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), ), out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), )() np.testing.assert_allclose(o, jnp.array([3, 1], jnp.int32)) def test_run_scoped_can_return_vector_values(self): def kernel(y_ref): def body(x_ref): x_ref[...] = jnp.ones_like(x_ref) return x_ref[...] + 1 out = pl.run_scoped(body, pltpu.VMEM((16, 128), jnp.int32)) y_ref[...] = out o = self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), ), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.int32), )() np.testing.assert_allclose(o, jnp.full((16, 128), 2, dtype=jnp.int32)) def test_run_scoped_can_return_padded_vector_values(self): def kernel(y_ref): def body(x_ref): x_ref[...] = jnp.ones_like(x_ref) return x_ref[...] + 1 out = pl.run_scoped(body, pltpu.VMEM((17, 128), jnp.int32)) y_ref[...] = out o = self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), ), out_shape=jax.ShapeDtypeStruct((17, 128), jnp.int32), )() np.testing.assert_allclose(o, jnp.full((17, 128), 2, dtype=jnp.int32)) def test_nested_scoped_allocation(self): def kernel(y_ref): def body(x_ref): x_ref[...] = jnp.zeros_like(x_ref) def inner_body(z_ref): z_ref[...] = jnp.ones_like(z_ref) x_ref[...] = z_ref[...] pl.run_scoped(inner_body, pltpu.VMEM((8, 128), jnp.float32)) y_ref[...] = 4 * x_ref[...] pl.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32)) o = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )() np.testing.assert_allclose(o, 4 * np.ones_like(o)) def test_run_scoped_partial_discharge(self): def f(a_ref, b_ref): def scope(): a_ref[...] = jnp.ones(4, jnp.float32) b_ref[...] = jnp.ones(4, jnp.float32) return [] pl.run_scoped(scope) return [] aref1 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) aref2 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) in_avals = [aref1, aref2] stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) discharged_jaxpr, _ = state_discharge.discharge_state( stateful_jaxpr, consts=(), should_discharge=[False, True]) self.assertLen(discharged_jaxpr.invars, 2) self.assertLen(discharged_jaxpr.outvars, 1) self.assertIsInstance(discharged_jaxpr.invars[0].aval, state.AbstractRef) self.assertIsInstance(discharged_jaxpr.invars[1].aval, jax.core.ShapedArray) self.assertEqual(discharged_jaxpr.effects, {state.WriteEffect(0)}) def test_can_allocate_semaphore(self): def kernel(y_ref): def body(sem1): pass pl.run_scoped(body, pltpu.SemaphoreType.DMA) jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) def test_can_allocate_multiple_semaphores(self): def kernel(y_ref): def body(sem1, sem2): pass pl.run_scoped(body, pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.REGULAR) jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) def test_can_allocate_semaphore_array(self): def kernel(y_ref): def body(dma_sems, sems): self.assertTupleEqual(dma_sems.shape, (4,)) self.assertTupleEqual(sems.shape, (3,)) if self.INTERPRET: self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.integer)) self.assertTrue(jnp.issubdtype(sems.dtype, jnp.integer)) else: self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) pl.run_scoped( body, pltpu.SemaphoreType.DMA((4,)), pltpu.SemaphoreType.REGULAR((3,)) ) jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) def test_can_allocate_scratch_semaphore_array(self): def kernel(y_ref, dma_sems, sems): self.assertTupleEqual(dma_sems.shape, (4,)) self.assertTupleEqual(sems.shape, (3,)) if self.INTERPRET: self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.integer)) self.assertTrue(jnp.issubdtype(sems.dtype, jnp.integer)) else: self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) jax.block_until_ready( self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, scratch_shapes=[ pltpu.SemaphoreType.DMA((4,)), pltpu.SemaphoreType.REGULAR((3,)), ], ), )() ) def test_can_wait_on_semaphore(self): def kernel(y_ref): def body(sem): pltpu.semaphore_signal(sem) pltpu.semaphore_wait(sem) pl.run_scoped(body, pltpu.SemaphoreType.REGULAR) def body2(sem): pltpu.semaphore_signal(sem, 2) pltpu.semaphore_wait(sem) pltpu.semaphore_wait(sem) pl.run_scoped(body2, pltpu.SemaphoreType.REGULAR) def body3(sem): pltpu.semaphore_signal(sem) pltpu.semaphore_signal(sem) pltpu.semaphore_signal(sem) pltpu.semaphore_wait(sem) pltpu.semaphore_wait(sem) pltpu.semaphore_wait(sem) pl.run_scoped(body3, pltpu.SemaphoreType.REGULAR) # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) def test_can_wait_on_semaphore_array(self): def kernel(y_ref): def body(sems): pltpu.semaphore_signal(sems.at[0]) pltpu.semaphore_wait(sems.at[0]) pltpu.semaphore_signal(sems.at[1], 2) pltpu.semaphore_wait(sems.at[1]) pltpu.semaphore_wait(sems.at[1]) pltpu.semaphore_signal(sems.at[2]) pltpu.semaphore_signal(sems.at[2]) pltpu.semaphore_signal(sems.at[2]) pltpu.semaphore_wait(sems.at[2]) pltpu.semaphore_wait(sems.at[2]) pltpu.semaphore_wait(sems.at[2]) pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((3,))) # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) def test_can_wait_on_semaphore_array_with_dynamic_index(self): def kernel(y_ref): i = pl.program_id(0) def body(sems): pltpu.semaphore_signal(sems.at[i, 0]) pltpu.semaphore_wait(sems.at[i, 0]) pltpu.semaphore_signal(sems.at[i, 1], 2) pltpu.semaphore_wait(sems.at[i, 1]) pltpu.semaphore_wait(sems.at[i, 1]) pltpu.semaphore_signal(sems.at[i, 2]) pltpu.semaphore_signal(sems.at[i, 2]) pltpu.semaphore_signal(sems.at[i, 2]) pltpu.semaphore_wait(sems.at[i, 2]) pltpu.semaphore_wait(sems.at[i, 2]) pltpu.semaphore_wait(sems.at[i, 2]) pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((4, 3))) jax.block_until_ready( self.pallas_call( kernel, in_specs=[], out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), grid=4, )() ) def test_can_read_semaphore(self): m, n = 2, 3 def kernel(y_ref): def body(sems): for r in range(m): for c in range(n): v = r * n + c pltpu.semaphore_signal(sems.at[r, c],v) y_ref[r, c] = pltpu.semaphore_read(sems.at[r, c]) pltpu.semaphore_wait(sems.at[r, c], v) pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((m, n))) y = jax.block_until_ready( self.pallas_call( kernel, out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32), )() ) np.testing.assert_array_equal( y, jnp.arange(m * n).astype(jnp.int32).reshape((m, n)) ) def test_can_read_dma_semaphore(self): def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): sem_val_ref[0, 0] = 123 pltpu.async_copy(x_hbm_ref, y_hbm_ref, dma_sem).wait() sem_val_ref[0, 0] = pltpu.semaphore_read(dma_sem) x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) y, sem_val = jax.block_until_ready( self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_specs=[ pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), ], scratch_shapes=[pltpu.SemaphoreType.DMA], ), out_shape=[ jax.ShapeDtypeStruct((8, 128), jnp.int32), jax.ShapeDtypeStruct((1, 1), jnp.int32), ], )(x) ) np.testing.assert_array_equal(y, x) np.testing.assert_array_equal(sem_val, 0) def test_hbm_hbm_dma(self): def kernel(x_hbm_ref, y_hbm_ref): def body(sem): pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x) def test_cannot_dma_with_nonscalar_semaphore_ref(self): def kernel(x_hbm_ref, y_hbm_ref): def body(sem): pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) with self.assertRaisesRegex(ValueError, 'Cannot signal'): x = jnp.arange(8 * 128.).reshape((8, 128)) self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) def test_dma_with_scalar_semaphore_ref(self): def kernel(x_hbm_ref, y_hbm_ref): def body(sem): pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem.at[0]).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x) def test_hbm_hbm_grid_dma(self): # When using the grid, we have to emit Mosaic window_params. Test that they # work correctly with ANY memory space operands. def kernel(x_hbm_ref, y_hbm_ref): i = pl.program_id(0) def body(sem): pltpu.async_copy( x_hbm_ref.at[pl.ds(i, 1)], y_hbm_ref.at[pl.ds(i, 1)], sem ).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), grid=(2,), )(x) np.testing.assert_allclose(y, x) def test_hbm_vmem_dma(self): def kernel(x_hbm_ref, y_ref): def body(x_ref, sem): pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], x_ref.at[:, pl.ds(128)], sem).wait() y_ref[...] = x_ref[...] pl.run_scoped( body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA ) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) def test_vmem_hbm_dma(self): def kernel(x_ref, y_hbm_ref): def body(y_ref, sem): y_ref[...] = x_ref[...] pltpu.async_copy(y_ref, y_hbm_ref, sem).wait() pl.run_scoped( body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA ) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) def test_vmem_hbm_vmem_dma(self): def kernel(x_hbm_ref, y_hbm_ref): def body(x_ref, y_ref, sem): pltpu.async_copy(x_hbm_ref, x_ref, sem).wait() y_ref[...] = x_ref[...] pltpu.async_copy(y_ref, y_hbm_ref, sem).wait() pl.run_scoped( body, pltpu.VMEM((8, 128), jnp.float32), pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA, ) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) def test_hbm_smem_dma(self): def kernel(x_hbm_ref, y_ref): def body(x_ref, sem): pltpu.async_copy(x_hbm_ref, x_ref, sem).wait() y_ref[...] = x_ref[0, 0] * jnp.ones_like(y_ref) pl.run_scoped( body, pltpu.SMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA ) x = 4 * jnp.ones((8, 128), jnp.float32) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) def test_smem_hbm_dma(self): def kernel(x_ref, y_hbm_ref): def body(y_ref, sem): y_ref[0, 0] = 0.0 y_ref[0, 1] = x_ref[4, 4] pltpu.async_copy(y_ref, y_hbm_ref, sem).wait() pl.run_scoped( body, pltpu.SMEM((1, 2), jnp.float32), pltpu.SemaphoreType.DMA ) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((1, 2), jnp.float32), )(x) expected = jnp.zeros_like(x[0:1, 0:2]).at[0, 1].set(x[4, 4]) np.testing.assert_allclose(y, expected) def test_vmem_vmem_dma(self): def kernel(x_ref, y_ref): def body(sem): pltpu.async_copy(x_ref, y_ref, sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) def test_hbm_vmem_dma_slicing(self): def kernel(x_hbm_ref, y_ref): def body(sem): dma1 = pltpu.async_copy( x_hbm_ref.at[pl.ds(0, 8)], y_ref.at[pl.ds(0, 8)], sem ) dma2 = pltpu.async_copy( x_hbm_ref.at[pl.ds(8, 8)], y_ref.at[pl.ds(8, 8)], sem ) dma1.wait() dma2.wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((16, 128)) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) def test_hbm_vmem_dma_indexing(self): def kernel(x_hbm_ref, y_ref): def body(sem): dma1 = pltpu.async_copy( x_hbm_ref.at[0], y_ref.at[pl.ds(0, 8)], sem ) dma2 = pltpu.async_copy( x_hbm_ref.at[1], y_ref.at[pl.ds(8, 8)], sem ) dma1.wait() dma2.wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x.reshape((16, 128))) def test_hbm_vmem_dma_multiple_indexing(self): if self.INTERPRET: self.skipTest('Multiple indexing not supported in interpret mode.') def kernel(x_hbm_ref, y_ref): def body(sem): for i in range(3): dma1 = pltpu.async_copy( x_hbm_ref.at[pl.ds(i, 1)].at[0, 0], y_ref.at[i].at[pl.ds(0, 8)], sem ) dma2 = pltpu.async_copy( x_hbm_ref.at[pl.ds(i, 1)].at[0, 1], y_ref.at[i].at[pl.ds(8, 8)], sem ) dma1.wait() dma2.wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(3 * 2 * 8 * 128.).reshape((3, 2, 8, 128)) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((3, 16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x.reshape((3, 16, 128))) def test_cannot_squeeze_lane_sublane(self): if self.INTERPRET: self.skipTest('Only works on Mosaic TPU.') def kernel(x_hbm_ref, y_ref): def body(sem): dma1 = pltpu.async_copy( x_hbm_ref.at[:, :, 0], y_ref.at[pl.ds(0, 8)], sem ) dma2 = pltpu.async_copy( x_hbm_ref.at[:, :, 1], y_ref.at[pl.ds(8, 8)], sem ) dma1.wait() dma2.wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) with self.assertRaises(Exception): _ = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) def test_hoisted_scratch_space(self): def kernel(x_ref, y_ref, scratch_ref): i = pl.program_id(0) @pl.when(i == 0) def _(): scratch_ref[...] = x_ref[...] scratch_ref[...] += jnp.ones_like(scratch_ref) @pl.when(i == 2) def _(): y_ref[...] = scratch_ref[...] x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ pl.BlockSpec((8, 128), lambda i: (0, 0)), ], scratch_shapes=[pltpu.VMEM((8, 128), jnp.float32)], out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), grid=(3,), ), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x + 3) def test_hoisted_smem_space(self): # TODO(sharadmv,apaszke): enable SMEM scratch spaces # TODO(sharadmv,apaszke): add support for ()-shaped SMEM refs self.skipTest('Currently doesn\'t work') def kernel(y_ref, scratch_ref): scratch_ref[0, 0] = pl.program_id(0) y_ref[...] = jnp.broadcast_to(scratch_ref[0, 0], y_ref.shape) y = pl.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[], scratch_shapes=[pltpu.SMEM((1, 1), jnp.int32)], out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)), grid=(2,), ), debug=True, out_shape=jax.ShapeDtypeStruct((2, 8, 128), jnp.int32), )() expected = jnp.broadcast_to(jnp.arange(2, dtype=jnp.int32)[..., None, None], (2, 8, 128)) np.testing.assert_array_equal(y, expected) def test_hoisted_semaphore(self): def kernel(x_bbm_ref, y_ref, sem, dma_sem): pltpu.semaphore_signal(sem) pltpu.semaphore_wait(sem) pltpu.async_copy(x_bbm_ref, y_ref, dma_sem).wait() x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], scratch_shapes=[pltpu.SemaphoreType.REGULAR, pltpu.SemaphoreType.DMA], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), ), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x) def test_large_array_indexing(self): n = 6 dtype = jnp.bfloat16 # This test sometimes OOMs on smaller chips. We garbage collect # to increase the chance there is 6GB memory available. gc.collect() x = jax.lax.broadcasted_iota(dtype, (n, 1024 * 1024, 512), 0) def kernel(index, x, y, sem): pltpu.async_copy(x.at[index[0]], y.at[:], sem).wait() run = self.pallas_call(kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ pl.BlockSpec( memory_space=pl.ANY)], out_specs=pl.BlockSpec( memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA], ), out_shape=jax.ShapeDtypeStruct(x.shape[1:], dtype), ) for i in range(x.shape[0]): y = run(jnp.array([i], dtype=jnp.int32), x) np.testing.assert_array_equal(y, i) del y def test_dynamic_dma_on_2nd_minor(self): def kernel(array, data, index, size, _, sem): pltpu.async_copy( data.at[pl.ds(0, size[0])], array.at[pl.ds(index[0], size[0])], sem ).wait() def run(array, data, index, size): return pl.pallas_call( kernel, out_shape=array, in_specs=[ pl.BlockSpec(memory_space=pltpu.ANY), pl.BlockSpec(memory_space=pltpu.VMEM), pl.BlockSpec(memory_space=pltpu.SMEM), pl.BlockSpec(memory_space=pltpu.SMEM), ], scratch_shapes=[ pltpu.SemaphoreType.DMA, ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), input_output_aliases={0: 0}, )(array, data, index, size) array = jnp.zeros((1024, 128), jnp.int32) data = jnp.ones((8, 128), jnp.int32) index = jnp.array([3], jnp.int32) size = jnp.array([5], jnp.int32) expected = array.at[index[0] : index[0] + size[0]].set( data[index[0] : index[0] + size[0]] ) result = run(array, data, index, size) np.testing.assert_array_equal(result, expected) class PallasCallDMAInterpretTest(PallasCallDMATest): INTERPRET = True def test_interpret_local_dma(self): # We run this test in interpret mode to test semaphore counting. # On a physical device the values update asynchronously so we cannot # deterministically check the values. def test_kernel(x_ref, o_ref, sem_out_ref, copy_sem, ): o_ref[...] = jnp.zeros_like(o_ref[...]) input_to_output_copy = pltpu.make_async_copy( src_ref=x_ref.at[0:8], dst_ref=o_ref.at[0:8], sem=copy_sem.at[0], ) input_to_output_copy.start() sem_out_ref[0, :] = jnp.ones_like( sem_out_ref[0, :]) * pltpu.semaphore_read(copy_sem.at[0]) input_to_output_copy.wait() sem_out_ref[1, :] = jnp.ones_like( sem_out_ref[0, :]) * pltpu.semaphore_read(copy_sem.at[0]) out_shape = (jax.ShapeDtypeStruct((16, 128), jnp.int32), jax.ShapeDtypeStruct((2, 1), jnp.int32)) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], scratch_shapes=( [pltpu.SemaphoreType.DMA(2,)] ) ) kernel = pl.pallas_call( test_kernel, out_shape=out_shape, grid_spec=grid_spec, interpret=True, ) x = jax.random.randint( jax.random.key(0), shape=(16, 128), minval=0, maxval=128) result, semaphores = kernel(x) np.testing.assert_array_equal(result[0:8], x[0:8]) np.testing.assert_array_equal(result[8:], jnp.zeros_like(result[8:])) # Make sure semaphores have the correct value before and after DMA wait. result_sem_pre_wait = semaphores[0, 0] np.testing.assert_array_equal(result_sem_pre_wait, result[0:8].size) result_sem_post_wait = semaphores[1, 0] np.testing.assert_array_equal(result_sem_post_wait, 0) def test_interpreter_semaphore_counting(self): # We run this test in interpret mode because the kernel exits with # non-zero values. In normal Pallas this would crash the kernel. def test_kernel(o_ref, sem_ref, ): o_ref[...] = jnp.zeros_like(o_ref) pltpu.semaphore_signal(sem_ref.at[0], 1) pltpu.semaphore_signal(sem_ref.at[1], 2) pltpu.semaphore_signal(sem_ref.at[2], 3) pltpu.semaphore_signal(sem_ref.at[3], 4) o_ref[0, 0] = pltpu.semaphore_read(sem_ref.at[0]) o_ref[1, 0] = pltpu.semaphore_read(sem_ref.at[1]) o_ref[2, 0] = pltpu.semaphore_read(sem_ref.at[2]) o_ref[3, 0] = pltpu.semaphore_read(sem_ref.at[3]) pltpu.semaphore_wait(sem_ref.at[0], 4) pltpu.semaphore_wait(sem_ref.at[1], 3) pltpu.semaphore_wait(sem_ref.at[2], 2) pltpu.semaphore_wait(sem_ref.at[3], 1) o_ref[4, 0] = pltpu.semaphore_read(sem_ref.at[0]) o_ref[5, 0] = pltpu.semaphore_read(sem_ref.at[1]) o_ref[6, 0] = pltpu.semaphore_read(sem_ref.at[2]) o_ref[7, 0] = pltpu.semaphore_read(sem_ref.at[3]) out_shape = jax.ShapeDtypeStruct((8, 1), jnp.int32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, scratch_shapes=( [pltpu.SemaphoreType.DMA(4,)] ) ) results = pl.pallas_call( test_kernel, out_shape=out_shape, grid_spec=grid_spec, interpret=True, )() expected = jnp.array([1, 2, 3, 4, -3, -1, 1, 3]).reshape(out_shape.shape) np.testing.assert_array_equal(results, expected) class PallasCallTest(PallasBaseTest): def test_cost_analysis(self): def kernel(x, y): y[:] = x[:] x = jnp.arange(1024.).reshape(8, 128) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), cost_estimate=pl.CostEstimate( flops=1234, transcendentals=21, bytes_accessed=12345 ), ) (analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis() self.assertEqual(analysis_result['flops'], 1234) self.assertEqual(analysis_result['transcendentals'], 21) self.assertEqual(analysis_result['bytes accessed'], 12345) def test_cost_analysis_vmap(self): def kernel(x, y): y[:] = x[:] batch_size = 3 x = jnp.arange(batch_size * 1024.).reshape(batch_size, 8, 128) f = pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), cost_estimate=pl.CostEstimate( flops=1234, transcendentals=21, bytes_accessed=12345 ), ) f = jax.vmap(f) (analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis() self.assertEqual(analysis_result['flops'], batch_size * 1234) self.assertEqual(analysis_result['transcendentals'], batch_size * 21) self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345) def test_vmem_limit(self): shape = (128, 128) def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) with self.assertRaises(xla_extension.XlaRuntimeError): self.pallas_call( kernel, out_shape=x, compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=256), )(x) self.pallas_call( kernel, out_shape=x, compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=int(2**18)), )(x) def test_allow_input_fusion(self): shape = (3, 128, 128) def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] def f(x, y): z = jax.numpy.add(x, y) return self.pallas_call( kernel, grid=(3,), in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))], out_specs=pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0)), out_shape=x, compiler_params=pltpu.TPUCompilerParams(allow_input_fusion=[True]), )(z) x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) y = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) out = f(x, y) expected = x + y np.testing.assert_array_equal(out, expected) compiled = jax.jit(f).lower(x, y).compile().as_text() assert re.search(r'fusion.*kind=kCustom.*fused_computation', compiled) def test_set_internal_scratch_size(self): shape = (128, 128) def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) requested_bytes = 128 * 4 with self.assertRaisesRegex( Exception, f'Requested internal scratch size {requested_bytes} needs to be at' ' least', ): self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), compiler_params=pltpu.TPUCompilerParams( internal_scratch_in_bytes=requested_bytes, ), )(x) @parameterized.product(dtype=[jnp.bfloat16, jnp.float32]) def test_pltpu_repeat(self, dtype): def test_kernel(x_ref, o_ref): x = x_ref[...] o_ref[...] = pltpu.repeat(x, 2, axis=1) @jax.jit def test(x: jax.Array) -> jax.Array: return pl.pallas_call( test_kernel, out_shape=jax.ShapeDtypeStruct([x.shape[0], x.shape[1] * 2], x.dtype), )(x) x = jnp.arange(2048, dtype=dtype).reshape((8, 256)) y = test(x) np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1)) def test_masked_store(self): if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("Test requires masked store support") shape = (16, 256) mask_shape = (10, 130) mask_start = (4, 5) dtype = jnp.float32 def body(scalar_ref, x_ref, o_ref): o_ref[...] = jnp.full(shape, -1, dtype=dtype) b0, b1 = scalar_ref[0], scalar_ref[1] e0, e1 = b0 + mask_shape[0], b1 + mask_shape[1] iota0 = lax.broadcasted_iota(jnp.int32, shape, 0) iota1 = lax.broadcasted_iota(jnp.int32, shape, 1) mask0 = jnp.logical_and(b0 <= iota0, iota0 < e0) mask1 = jnp.logical_and(b1 <= iota1, iota1 < e1) pl.store( o_ref, (slice(None), slice(None)), x_ref[...], mask=jnp.logical_and(mask0, mask1), ) s = jnp.array(mask_start, jnp.int32) x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) out = pl.pallas_call( body, out_shape=jax.ShapeDtypeStruct(shape, dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, ), )(s, x) slices = tuple(slice(b, b + l) for b, l in zip(mask_start, mask_shape)) expected = jnp.full(shape, -1, dtype=dtype) expected = expected.at[slices].set(x[slices]) np.testing.assert_array_equal(out, expected) class PallasUXTest(PallasBaseTest): def test_mlir_location(self): # Make sure that MLIR locations are correctly propagated to primitives. args = (jax.ShapeDtypeStruct((8, 128), jnp.float32),) f = example_kernel.double as_tpu_kernel = mosaic.as_tpu_kernel def capture_as_tpu_kernel(module, *args, **kwargs): asm = module.operation.get_asm(enable_debug_info=True) self.assertIn('example_kernel.py":25', asm) return as_tpu_kernel(module, *args, **kwargs) mosaic.as_tpu_kernel = capture_as_tpu_kernel try: jax.jit(f).lower(*args) finally: mosaic.as_tpu_kernel = as_tpu_kernel class PallasMegacoreTest(PallasBaseTest): def test_megacore_splitting(self): # We want to make sure a 3-sized dimension is split across megacore # correctly, and if we combine the (3, 3) dimensions together it is still # correct. def matmul_kernel(x_ref, y_ref, z_ref): @pl.when(pl.program_id(2) == 0) def _(): z_ref[...] = jnp.zeros_like(z_ref) z_ref[...] += x_ref[...] @ y_ref[...] k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.uniform(k1, (3, 3, 512, 512)) y = jax.random.uniform(k2, (3, 3, 512, 512)) z = jax.vmap( jax.vmap( pl.pallas_call( matmul_kernel, out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), grid=(4, 4, 4), in_specs=[ pl.BlockSpec((128, 128), lambda i, j, k: (i, k)), pl.BlockSpec((128, 128), lambda i, j, k: (k, j)), ], out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)), debug=True, ) ) )(x, y) np.testing.assert_allclose( z, jax.vmap(jax.vmap(jnp.dot))(x, y), rtol=1e-6 ) class PallasCallVmapTest(PallasBaseTest): def test_scratch_input_vmap(self): """Test that vmapp-ing a kernel with scratch inputs works correctly.""" # Scratch inputs are only available for PallasTPU. This is why this test # does not live with the other vmap tests in: # jax/tests/pallas/pallas_test.py def add_one_with_scratch(x_ref, o_ref, scratch_ref): scratch_ref[...] = jnp.ones_like(scratch_ref[...]) o_ref[...] = x_ref[...] + scratch_ref[...] tile_size = 128 tile_shape = (tile_size, tile_size) array_shape = (2 * tile_size, 2 * tile_size) vmapped_add_one_with_scratch = jax.vmap( pl.pallas_call( add_one_with_scratch, out_shape=jax.ShapeDtypeStruct(array_shape, jnp.int32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[pl.BlockSpec(tile_shape, lambda i, j: (i, j))], out_specs=pl.BlockSpec(tile_shape, lambda i, j: (i, j)), scratch_shapes=[pltpu.VMEM(tile_shape, dtype=jnp.int32)], grid=(2, 2), ), ) ) x = jnp.broadcast_to(jnp.arange(array_shape[0]), (10, *array_shape)) out = vmapped_add_one_with_scratch(x) out_ref = x + 1 np.testing.assert_array_equal(out, out_ref, strict=True) class PallasCallDynamicDMATest(PallasBaseTest): def setUp(self): super().setUp() if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs not supported on TPU generations <= 3') def test_simple_tile_aligned_dynamic_size_dma(self): def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): size = size_smem_ref[0] pltpu.async_copy( x_hbm_ref.at[pl.ds(0, size)], o_hbm_ref.at[pl.ds(0, size)], sem).wait() x = jnp.tile(jnp.arange(8, dtype=jnp.int32)[:, None, None], [1, 8, 128]) o = jnp.zeros((8, 8, 128), dtype=jnp.int32) size = jnp.array([4], dtype=jnp.int32) out = pl.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), pl.BlockSpec(memory_space=pltpu.ANY), pl.BlockSpec(memory_space=pltpu.ANY)], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA] ), out_shape=o, input_output_aliases={2: 0}, )(size, x, o) expected = o.at[:4].set(x.at[:4].get()) np.testing.assert_array_equal(out, expected) def test_simple_dynamic_size_dma(self): self.skipTest("doesn't work yet.") def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): size = size_smem_ref[0] pltpu.async_copy( x_hbm_ref.at[pl.ds(0, size)], o_hbm_ref.at[pl.ds(0, size)], sem).wait() x = jnp.arange(8, dtype=jnp.int32) o = jnp.zeros(8, dtype=jnp.int32) size = jnp.array([4], dtype=jnp.int32) out = pl.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), pl.BlockSpec(memory_space=pltpu.ANY), pl.BlockSpec(memory_space=pltpu.ANY)], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA] ), out_shape=o, input_output_aliases={2: 0}, )(size, x, o) expected = o.at[:4].set(x.at[:4].get()) np.testing.assert_array_equal(out, expected) class PallasCallRefTransformTest(PallasBaseTest): @parameterized.product(slice_first=[True, False]) def test_dma_bitcasted_ref(self, slice_first): if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs not supported on TPU generations <= 3') def kernel(x_hbm_ref, y_hbm_ref): def body(sem): ref = ( x_hbm_ref.at[:8, :, :128].bitcast(jnp.int16) if slice_first else x_hbm_ref.bitcast(jnp.int16).at[:8, :, :128] ) pltpu.async_copy(ref, y_hbm_ref.at[...], sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 1, 256)) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 2, 128), jnp.int16), )(x) expected = ( state_utils.bitcast(x[:8, :, :128], jnp.int16) if slice_first else state_utils.bitcast(x, jnp.int16)[:8, :, :128] ) np.testing.assert_array_equal(y, expected) @parameterized.product(slice_first=[True, False]) def test_load_bitcasted_ref(self, slice_first: bool): def kernel(x_ref, y_ref): ref = ( x_ref.at[:8, :128].bitcast(jnp.int16) if slice_first else x_ref.bitcast(jnp.int16).at[:16, :128] ) y_ref[...] = ref[...] x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 256)) y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((16, 128), jnp.int16), )(x) expected = ( state_utils.bitcast(x[:8, :128], jnp.int16) if slice_first else state_utils.bitcast(x, jnp.int16)[:16, :128] ) np.testing.assert_array_equal(y, expected) @parameterized.product(slice_first=[True, False]) def test_store_bitcasted_ref(self, slice_first): def kernel(x_ref, y_ref): ref = ( y_ref.at[:8, :128].bitcast(jnp.bfloat16) if slice_first else y_ref.bitcast(jnp.bfloat16).at[:16, :128] ) ref[...] = x_ref[...] x = jnp.arange(16 * 128, dtype=jnp.bfloat16).reshape((16, 128)) y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32), )(x) expected = state_utils.bitcast(x, jnp.int32) np.testing.assert_array_equal(y[:8, :128], expected) @parameterized.product(slice_first=[True, False]) def test_dma_reshaped_ref(self, slice_first): if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs not supported on TPU generations <= 3') def kernel(x_hbm_ref, y_hbm_ref): def body(sem): ref = ( x_hbm_ref.at[:8, :, :].reshape(8, 128) if slice_first else x_hbm_ref.reshape(16, 128).at[:8, :] ) pltpu.async_copy(ref, y_hbm_ref.reshape(8, 128).at[...], sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(16 * 128, dtype=jnp.int32).reshape(16, 1, 128) y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 1, 128), jnp.int32), )(x) expected = ( x[:8, :, :128].reshape((8, 128)) if slice_first else x.reshape(16, 128)[:8, :128] ).reshape(8, 1, 128) np.testing.assert_array_equal(y, expected) def test_load_reshaped_ref(self): if not jtu.is_device_tpu_at_least(4): self.skipTest('No expected (1, 128) tiling') def kernel(x_ref, y_ref): y_ref[...] = x_ref.reshape(5, 128)[...] x = jnp.arange(5 * 128, dtype=jnp.int32).reshape(5, 1, 128) y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((5, 128), jnp.int32), )(x) expected = x.reshape(5, 128) np.testing.assert_array_equal(y, expected) def test_store_reshaped_ref(self): if not jtu.is_device_tpu_at_least(4): self.skipTest('No expected (1, 128) tiling') def kernel(x_ref, y_ref): y_ref.reshape(5, 128)[...] = x_ref[...] x = jnp.arange(5 * 128, dtype=jnp.int32).reshape(5, 128) y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((5, 1, 128), jnp.int32), )(x) expected = x.reshape(5, 1, 128) np.testing.assert_array_equal(y, expected) def test_multiple_ref_transforms(self): def kernel(x_ref, y_ref): ref = ( x_ref.at[:16, :256] # i32(16, 256) .bitcast(jnp.int16) # i16(32, 256) .reshape((2, 16, 256)) # i16(2, 16, 256) .bitcast(jnp.float16) # bf16(2, 16, 256) .at[1:, :, :] # bf16(1, 16, 256) .reshape((16, 256)) # bf16(16, 256) .at[:, :128] # bf16(16, 128) .bitcast(jnp.int32) # i32(8, 128) ) y_ref[...] = ref[...] x = jnp.arange(32 * 256, dtype=jnp.int32).reshape((32, 256)) y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32), )(x) np.testing.assert_array_equal(y, x[8:16, :128]) class PallasCallPrintTest(PallasBaseTest): def test_debug_print(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), ) def kernel(x_ref, o_ref): pl.debug_print('It works!') x = jnp.arange(8 * 128, dtype=jnp.float32).reshape((8, 128)) compiled_kernel = ( jax.jit(kernel) .lower(x) .compile({'xla_tpu_enable_log_recorder': 'true'}) ) with jtu.capture_stderr() as get_output: jax.block_until_ready(compiled_kernel(x)) self.assertIn('It works!', get_output()) def test_debug_print_with_values(self): @functools.partial( self.pallas_call, in_specs=(pl.BlockSpec(memory_space=pltpu.SMEM),), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), ) def kernel(x_ref, o_ref): pl.debug_print('x[0] == {}', x_ref[0]) x = jnp.array([42, 24]).astype(jnp.int32) compiled_kernel = ( jax.jit(kernel) .lower(x) .compile({'xla_tpu_enable_log_recorder': 'true'}) ) with jtu.capture_stderr() as get_output: jax.block_until_ready(compiled_kernel(x)) self.assertIn('x[0] == 42', get_output()) @parameterized.named_parameters( (f"{'_'.join(map(str, shape))}_{dtype.__name__}", shape, dtype) for shape in ( (2, 8, 128), # test unaligned shapes (3,), (3, 4), (2, 3, 4), (2, 9, 129), ) for dtype in (jnp.int32, jnp.uint32, jnp.float32) ) def test_debug_print_vector(self, shape, dtype): # TODO(ayx): Remove after this date. if not jtu.if_cloud_tpu_at_least(2025, 1, 16): self.skipTest("Requires libtpu built after 2025-01-16") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, dtype), ) def kernel(x_ref, o_ref): pl.debug_print("{}", x_ref[...]) o_ref[...] = x_ref[...] n = np.prod(shape) x = jnp.arange(n, dtype=dtype).reshape(shape) compiled_kernel = ( jax.jit(kernel) .lower(x) .compile({"xla_tpu_enable_log_recorder": "true"}) ) with jtu.capture_stderr() as get_output: jax.block_until_ready(compiled_kernel(x)) output = get_output() numbers = [ int(num) for line in output.splitlines() if (match := re.search(r"\{(.*)", line)) # extract contents after `{` for num in re.findall(r"\d+", match.group(1)) ] # Check if the numbers in the output match the values generated by `arange`. self.assertLen(numbers, n) self.assertTrue(all(num == i for i, num in enumerate(numbers))) class PallasCallTraceTest(PallasBaseTest): def test_trace_start_stop_match(self): def kernel(o_ref): with jax.named_scope('scope1'): o_ref[...] = jnp.zeros_like(o_ref[...]) with string_stdout() as msg: _ = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), debug=True, )() # TODO(justinfu): Add an official lowering API to get the MLIR. debug_string = msg.getvalue() num_start = debug_string.count('tpu.trace_start') num_stop = debug_string.count('tpu.trace_stop') self.assertEqual(num_start, 1) self.assertEqual(num_stop, 1) def test_run_scoped(self): def kernel(o_ref): def scope1(): with jax.named_scope('scope1'): o_ref[...] = jnp.zeros_like(o_ref[...]) pl.run_scoped(scope1) def scope2(): with jax.named_scope('scope2'): o_ref[...] = o_ref[...] + 1 pl.run_scoped(scope2) with string_stdout() as msg: _ = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), debug=True, )() # TODO(justinfu): Add an official lowering API to get the MLIR. debug_string = msg.getvalue() num_start = debug_string.count('tpu.trace_start') num_stop = debug_string.count('tpu.trace_stop') self.assertEqual(num_start, 2) self.assertEqual(num_stop, 2) class PallasCallTPUBooleanTest(PallasBaseTest): """Tests for loading/storing from bool memrefs on TPUs. We specifically test bools because they have special handling. Bools are stored as integers inside of memrefs, and we typecast to/from bools automatically on load. """ INTERPRET: bool = False @parameterized.parameters((False,), (True,)) def test_scalar_bool_load_store(self, value): def kernel(x_ref, o_ref): o_ref[0, 0] = jnp.logical_not(x_ref[0, 0]) input = jnp.array([[value]]) output_shape = jax.ShapeDtypeStruct((1, 1), jnp.bool_) result = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=output_shape, )(input) np.testing.assert_array_equal(result, jnp.logical_not(input)) @parameterized.parameters((False,), (True,)) def test_scalar_bool_run_scoped(self, value): if self.INTERPRET: self.skipTest('run_scoped not supported in non-interpret mode.') def kernel(x_ref, o_ref): def inner_scope(scoped_ref): scoped_ref[0, 0] = jnp.logical_not(x_ref[0, 0]) o_ref[0, 0] = scoped_ref[0, 0] pl.run_scoped(inner_scope, pltpu.SMEM((1, 1), dtype=jnp.bool_)) input_arr = jnp.array([[value]]) output_shape = jax.ShapeDtypeStruct((1, 1), jnp.bool_) result = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=output_shape, )(input_arr) np.testing.assert_array_equal(result, jnp.logical_not(input_arr)) def test_vector_bool_load_store(self): def kernel(x_ref, o_ref): o_ref[...] = x_ref[...] input = jax.random.bernoulli(jax.random.key(0), p=0.5, shape=(8, 128)) output_shape = jax.ShapeDtypeStruct((8, 128), jnp.bool_) result = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=output_shape, )(input) np.testing.assert_array_equal(result, input) def test_vector_bool_masking_with_indexing(self): def kernel(mask_ref, true_ref, false_ref, o_ref): o_ref[0, ...] = jnp.where( mask_ref[0, ...], true_ref[0, ...], false_ref[0, ...]) key = jax.random.key(0) k1, k2, k3 = jax.random.split(key, 3) values_1 = jax.random.normal(k1, (1, 256, 256), jnp.float32) values_2 = jax.random.normal(k2, (1, 256, 256), jnp.float32) mask = jax.random.bernoulli(k3, p=0.5, shape=(1, 256, 256)) output_shape = jax.ShapeDtypeStruct((1, 256, 256), jnp.float32) result = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM), pl.BlockSpec(memory_space=pltpu.VMEM), pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=output_shape, )(mask, values_1, values_2) expected = jnp.where(mask, values_1, values_2) np.testing.assert_array_equal(result, expected) def test_bool_dma_not_implemented(self): if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs not supported on TPU generations <= 3') if self.INTERPRET: self.skipTest('Test only applies to non-interpret mode.') num_devices = jax.local_device_count() def kernel(x_ref, o_ref, send_sem, recv_sem): index = lax.axis_index('x') neighbor = lax.rem(index + 1, num_devices) copy = pltpu.make_async_remote_copy(x_ref, o_ref, send_sem, recv_sem, device_id=(0, neighbor)) copy.start() copy.wait() input_arr = jnp.ones((8, 128), dtype=jnp.bool_) output_shape = jax.ShapeDtypeStruct((8, 128), jnp.bool_) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), grid=(1,), scratch_shapes=[pltpu.SemaphoreType.DMA] * 2, ) test_fn = self.pallas_call( kernel, grid_spec=grid_spec, out_shape=output_shape, ) with self.assertRaisesRegex( Exception, 'DMAs with bool dtypes are not supported.'): devices = mesh_utils.create_device_mesh((num_devices,)) mesh = jax.sharding.Mesh(devices, ('x',)) sharding = jax.sharding.NamedSharding(mesh, P(None, 'x')) input_arr = jax.device_put(input_arr, sharding) jax.jit( shard_map.shard_map( test_fn, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), check_rep=False ) )(input_arr) class PallasCallTPUBooleanInterpretTest(PallasCallTPUBooleanTest): INTERPRET: bool = True class PallasCallTPUCheckifyTest(PallasBaseTest): @parameterized.parameters((2,), (5,), (6,), (7,)) def test_checkify_with_scalar_prefetch(self, threshold): def body(scalar_ref, x_ref, o_ref): scalar = scalar_ref[pl.program_id(0)] o_ref[...] = x_ref[...] checkify.check(scalar < threshold, 'failed on value {x}', x=scalar) s = jnp.array([4, 3, 2, 6, 3, 5, 2, 7], jnp.int32) x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): s = pl.load(s_ref, (i,)) return (s, 0) pallas_call = self.pallas_call( body, out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform), ], out_specs=pl.BlockSpec( (x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0) ), grid=8, ), ) checked_call = checkify.checkify(pallas_call) err, out = checked_call(s, x) expected_error_value = s[jnp.argmax(s >= threshold)] with self.assertRaisesRegex( checkify.JaxRuntimeError, f'failed on value {expected_error_value}'): err.throw() np.testing.assert_allclose(out, x.reshape((8, 8, -1))[s].reshape(x.shape)) def test_checkify_with_scratch(self): def body(x_ref, o_ref, scratch_ref): scratch_ref[...] = x_ref[...] o_ref[...] = scratch_ref[...] all_nequal = ~jnp.all(o_ref[...] == x_ref[...]) checkify.check(all_nequal, 'x_ref equals o_ref id=({x}, {y})', x=pl.program_id(0), y=pl.program_id(1)) x = jax.random.uniform(jax.random.key(0), (128, 512), dtype=jnp.float32) pallas_call = self.pallas_call( body, out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ pl.BlockSpec((32, 128), lambda i, j: (i, j)), ], out_specs=pl.BlockSpec((32, 128), lambda i, j: (i, j)), scratch_shapes=[pltpu.VMEM((32, 128), dtype=jnp.float32)], grid=(4, 4), ), ) checked_call = checkify.checkify(pallas_call) err, out = checked_call(x) with self.assertRaisesRegex( checkify.JaxRuntimeError, r'x_ref equals o_ref id=\(0, 0\)'): err.throw() np.testing.assert_allclose(out, x) @parameterized.parameters((4,), (9,)) def test_checkify_with_dynamic_grid(self, iteration): grid_size = 4 shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) def kernel(y_ref): @pl.when(pl.program_id(0) == 0) def _init(): y_ref[...] = jnp.zeros_like(y_ref) y_ref[...] += 1 @pl.when(pl.program_id(0) == iteration) def _(): checkify.check(False, f"error on iteration {iteration}") @jax.jit def dynamic_kernel(steps): pallas_call = self.pallas_call( kernel, grid=(steps * 2,), out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), out_shape=result_ty, ) return checkify.checkify(pallas_call)() err, result = dynamic_kernel(jnp.int32(grid_size)) if iteration < grid_size * 2: with self.assertRaisesRegex( checkify.JaxRuntimeError, f"error on iteration {iteration}"): err.throw() np.testing.assert_array_equal( result, np.full(shape, grid_size * 2.0, np.float32) ) class PallasCallTPUCheckifyInterpretTest(PallasCallTPUCheckifyTest): INTERPRET: bool = True class PrettyPrintingTest(PallasBaseTest): @parameterized.parameters( ( lambda i: (i, pl.ds(0, 8), pl.ds(0, 128)), 'dma_start c[d,:,:] -> e[...] f', ), ( lambda i: (0, pl.ds(i, 8), pl.ds(0, 128)), 'dma_start c[0,d:d+8,:] -> e[...] f', ), ( lambda i: (i, pl.ds(2, 4), pl.ds(0, 100)), 'dma_start c[d,2:6,:100] -> e[...] f', ), ( lambda i: (i, pl.ds(2, 6), pl.ds(4, 100)), 'dma_start c[d,2:,4:104] -> e[...] f', ), ) def test_dma_custom_pretty_print(self, indexer, expected): def body(x_hbm_ref, i): def inner(x_ref, sem): pltpu.async_copy(x_hbm_ref.at[indexer(i)], x_ref, sem).wait() pl.run_scoped( inner, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA ) return [] jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [state.shaped_array_ref((2, 8, 128), jnp.int32), jax.core.ShapedArray((), jnp.int32)] ) self.assertIn(expected, jaxpr.pretty_print(use_color=False)) def only_passes_in_interpret(unless_generation: int | None = None): def decorator(f): def wrapper(self): if self.INTERPRET or ( unless_generation is not None and jtu.is_device_tpu_at_least(unless_generation) ): f(self) else: with self.assertRaises(Exception): f(self) return wrapper return decorator class MiscellaneousTest(PallasBaseTest): """Tests for reported bugs. Only pass in interpret mode unless fixed.""" def test_float32_stack(self): x = np.arange(128, dtype=jnp.float32).reshape(1, 128) y = x + 128 def kernel(x_ref, y_ref, out_ref): out_ref[...] = jnp.stack([x_ref[...], y_ref[...]], axis=1) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((1, 2, 128), jnp.float32) )(x, y) np.testing.assert_array_equal(out, np.stack([x, y], axis=1)) @only_passes_in_interpret() def test_lane_to_chunk_reshape_bf16(self): """b/348038320""" x = np.arange(256 * 1024, dtype=jnp.bfloat16).reshape(1, 256, 1024) def kernel(x_ref, out_ref): out_ref[...] = jnp.reshape(x_ref[...], (1, 256, 8, 128)) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((1, 256, 8, 128), jnp.bfloat16) )(x) np.testing.assert_array_equal(out, np.reshape(x, (1, 256, 8, 128))) def test_lane_to_chunk_broadcast_fp32(self): x = np.arange(256 * 128, dtype=jnp.float32).reshape(1, 256, 128) def kernel(x_ref, out_ref): out_ref[...] = jnp.broadcast_to( jnp.expand_dims(x_ref[...], 2), (1, 256, 8, 128) ) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((1, 256, 8, 128), jnp.float32) )(x) np.testing.assert_array_equal( out, np.broadcast_to(np.expand_dims(x, 2), (1, 256, 8, 128)) ) @only_passes_in_interpret() def test_lane_dynamic_slice(self): """b/346849973""" x = np.arange(128, dtype=jnp.float32) def kernel(x_ref, out_ref): out_ref[...] = lax.dynamic_slice_in_dim(x_ref[...], 64, 1, 0) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((1,), jnp.float32) )(x) np.testing.assert_array_equal(out, x[64:65]) def test_lane_broadcast_bf16(self): x = np.arange(256, dtype=jnp.bfloat16).reshape(256, 1) def kernel(x_ref, out_ref): out_ref[...] = jnp.broadcast_to(x_ref[...], (256, 512)) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((256, 512), jnp.bfloat16) )(x) np.testing.assert_array_equal(out, np.broadcast_to(x, (256, 512))) def test_bfloat16_to_uint32_bitcast(self): x = np.arange(16 * 2 * 256, dtype=jnp.bfloat16).reshape(16, 2, 256) def kernel(x_ref, out_ref): out_ref[...] = pltpu.bitcast(x_ref[...], jnp.uint32) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((16, 1, 256), jnp.uint32) )(x) np.testing.assert_array_equal(out, state_utils.bitcast(x, jnp.uint32)) @only_passes_in_interpret() def test_roll_partial(self): """b/337384645""" x = np.arange(8192, dtype=jnp.float32).reshape(128, 64) def kernel(x_ref, out_ref): out_ref[...] = pltpu.roll(x_ref[...], 3, 1) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((128, 64), jnp.float32) )(x) np.testing.assert_array_equal(out, np.roll(x, 3, 1)) @only_passes_in_interpret() def test_retiling1(self): """b/352626602""" x = np.arange(1024, dtype=jnp.bfloat16).reshape(1024) def kernel(x_ref, out_ref): out_ref[:, :] = jnp.reshape(x_ref[:].astype(jnp.float32), (8, 128)) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(out, np.reshape(x, (8, 128))) def test_retiling2(self): x = np.arange(1 * 8 * 1024, dtype=jnp.bfloat16).reshape(1, 8, 1024) def kernel(x_ref, out_ref): out_ref[:, :, :] = jnp.reshape( x_ref[:, 7, :].astype(jnp.float32), (1, 8, 128) ) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((1, 8, 128), jnp.float32), )(x) np.testing.assert_array_equal(out, np.reshape(x[:, 7, :], (1, 8, 128))) def test_sublane_adding_shape_cast_f32(self): x = np.arange(8 * 128, dtype=jnp.float32).reshape(8, 128) def kernel(x_ref, out_ref): out_ref[:, 0, :] = x_ref[:, :] out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 1, 128), jnp.float32) )(x) np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128))) @only_passes_in_interpret() def test_sublane_adding_shape_cast_bf16(self): """b/352833257""" x = np.arange(8 * 128, dtype=jnp.bfloat16).reshape(8, 128) def kernel(x_ref, out_ref): out_ref[:, 0, :] = x_ref[:, :] out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 1, 128), jnp.bfloat16) )(x) np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128))) def test_mixed_strides(self): x = np.zeros((8, 128), dtype=jnp.float32) y = np.zeros((8, 2, 128), dtype=jnp.bfloat16) def kernel(x_ref, y_ref, out_ref): out_ref[:, :] = x_ref[:, :] + y_ref[:, 1, :].astype(jnp.float32) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x, y) np.testing.assert_array_equal(out, np.zeros((8, 128), dtype=jnp.float32)) def test_sum(self): x = np.zeros((8, 2, 8, 128), dtype=jnp.float32) def kernel(x_ref, out_ref): out_ref[:, :, :] = jnp.sum(x_ref[:, :, :, :], 2) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 2, 128), jnp.float32) )(x) np.testing.assert_array_equal(out, np.zeros((8, 2, 128), dtype=jnp.float32)) @only_passes_in_interpret() def test_transpose(self): """b/356475128""" x = np.zeros((8, 2, 8, 128), dtype=jnp.float32) def kernel(x_ref, out_ref): out_ref[:, :, :, :] = jnp.transpose(x_ref[:, :, :, :], (0, 2, 1, 3)) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 8, 2, 128), jnp.float32) )(x) np.testing.assert_array_equal( out, np.zeros((8, 8, 2, 128), dtype=jnp.float32) ) class MiscellaneousInterpretTest(MiscellaneousTest): INTERPRET: bool = True if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())