rocm_jax/tests/pallas/tpu_pallas_test.py
2025-01-13 13:22:21 -08:00

2693 lines
87 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test TPU-specific extensions to pallas_call."""
import contextlib
import functools
import gc
import io
import math
import re
import sys
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax._src import checkify
from jax._src import state
from jax._src import test_util as jtu
from jax._src.interpreters import partial_eval as pe
from jax._src.lib import xla_extension
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
from jax._src.state import utils as state_utils
from jax._src.state import discharge as state_discharge
from jax.experimental import mesh_utils
from jax.experimental import mosaic
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas.ops.tpu import example_kernel
from jax.extend import linear_util as lu
import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
P = jax.sharding.PartitionSpec
partial = functools.partial
@contextlib.contextmanager
def string_stdout():
"""Redirects stdout to a string."""
initial_stdout = sys.stdout
stringio = io.StringIO()
sys.stdout = stringio
yield stringio
sys.stdout = initial_stdout
class PallasBaseTest(jtu.JaxTestCase):
INTERPRET: bool = False
def setUp(self):
if not jtu.test_device_matches(['tpu']) and not self.INTERPRET:
self.skipTest('Test requires TPUs, or interpret mode')
super().setUp()
_trace_kernel_to_jaxpr.cache_clear()
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
class PallasCallScalarPrefetchTest(PallasBaseTest):
def test_trivial_scalar_prefetch(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]
s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
def _x_transform(i, s_ref):
s = pl.load(s_ref, (i,))
return (s, 0)
out = self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform),
],
out_specs=pl.BlockSpec(
(x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0)
),
grid=8,
),
)(s, x)
np.testing.assert_allclose(out, x.reshape((8, 8, -1))[s].reshape(x.shape))
def test_trivial_scalar_prefetch_with_windowless_args(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]
s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
out = self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
),
)(s, x)
np.testing.assert_array_equal(out, x)
@jtu.parameterized_filterable(
kwargs=[
dict(scratch=scratch, vmap=vmap, dyn_grid=dyn_grid)
for scratch in [True, False]
for vmap in [False, True]
for dyn_grid in [False, True]
]
)
def test_scalar_prefetch_calling_convention(
self, *,
scratch: bool, vmap: bool, dyn_grid: bool):
# Tests what we process correctly all the various inputs and outputs:
# dynamic_grid_dims, index, inputs, outputs, scratch.
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
self.skipTest("TODO: dslice(start, 1) raises error about slice inputs being int32 and int64")
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
if vmap:
x_shape = (4, 16, 128)
else:
x_shape = (16, 128)
x = np.arange(math.prod(x_shape), dtype=np.float32).reshape(x_shape)
def f(x, grid_size, to_store):
s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1, # 1 pytree
grid=(grid_size,),
in_specs=[pl.BlockSpec((8, 128),
lambda i, s_ref: (pl.load(s_ref[0], (i,)), 0)),
pl.BlockSpec((1, 128), lambda i, s_ref: (0, 0))],
out_specs=pl.BlockSpec((32, 128),
lambda i, s_ref: (pl.load(s_ref[0], i), 0)),
scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if scratch
else []),
),
)
def kernel(s_refs, src, to_store, dst, *scratch_refs):
s_ref, s2, s3 = s_refs
assert s_ref.shape == (2,)
assert s2.shape == (3,)
assert s3 is None
store_idx = s_ref[pl.program_id(0)]
pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store[...])
# Pass a pytree of scalar
return kernel((s, np.arange(3, dtype=np.int32), None), x, to_store)
if dyn_grid:
f = jax.jit(f)
if vmap:
res = jax.vmap(lambda x: f(x, 2, to_store))(x)
else:
res = f(x, 2, to_store)
if vmap:
for i in range(x.shape[0]):
self.assertAllClose(res[i, 0:1], to_store)
self.assertAllClose(res[i, 33:34], to_store)
else:
self.assertAllClose(res[0:1], to_store)
self.assertAllClose(res[33:34], to_store)
def test_with_unhashable_grid_spec(self):
# Make sure that we don't crash when the GridSpec has non-hashable parts
@functools.partial(
self.pallas_call,
out_shape=[[jax.ShapeDtypeStruct((8, 128), np.int32)]],
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1, # 1 pytree
grid=(1,),
in_specs=[[pl.BlockSpec((8, 128),
lambda i, s_ref: (0, 0))]],
out_specs=[[pl.BlockSpec((8, 128),
lambda i, s_ref: (0, 0))]],
scratch_shapes=[[pltpu.SemaphoreType.REGULAR((3,))]],
),
)
def kernel(s_ref, x_ref, o_ref, scratch_ref):
assert isinstance(s_ref, list)
assert isinstance(x_ref, list)
assert isinstance(o_ref, list)
assert isinstance(scratch_ref, list)
o_ref[0][...] = x_ref[0][...]
x_shape = (8, 128)
s = np.array([0, 1], np.int32)
x = np.arange(math.prod(x_shape), dtype=np.int32).reshape(x_shape)
res = kernel([s, s], [x])
self.assertIsInstance(res, tuple) # Even though we asked for a list!
self.assertAllClose(res[0][0], x)
def test_vmap_scalar_prefetch(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]
s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128))
def _x_transform(i, s_ref):
s = pl.load(s_ref, (i,))
return (s, 0)
def f(x):
return self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform),
],
out_specs=pl.BlockSpec(
(x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0)
),
grid=8),
)(s, x)
np.testing.assert_allclose(
jax.vmap(f)(x), x.reshape((2, 8, 8, -1))[:, s].reshape(x.shape)
)
def test_multiple_scalar_prefetch(self):
def body(s1_ref, s2_ref, x_ref, o_ref):
del s1_ref, s2_ref
o_ref[...] = x_ref[...]
s1 = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
s2 = jnp.array([7, 6, 5, 4, 3, 2, 1, 0], jnp.int32)
x = jnp.arange(64 * 128, dtype=jnp.int32).reshape((64, 128))
def _x_transform(i, s1_ref, _):
return s1_ref[i], 0
def _o_transform(i, _, s2_ref):
return s2_ref[i], 0
out = self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct((64, 128), jnp.int32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[
pl.BlockSpec((8, 128), _x_transform),
],
out_specs=pl.BlockSpec((8, 128), _o_transform),
grid=8,
),
)(s1, s2, x)
out_ref = x.reshape((8, 8, -1))[s1][::-1].reshape((64, 128))
np.testing.assert_allclose(out, out_ref)
def test_scalar_interpreter(self):
program = jnp.array([0, 0, 1, 0, 1, 1], jnp.int32)
x = jnp.arange(8 * 8 * 128.0, dtype=jnp.float32).reshape(8 * 8, 128)
def body(sprogram_ref, x_ref, o_ref, state_ref):
x = x_ref[...]
def add_branch_fn(j):
state_ref[...] += jnp.float32(j)
return ()
def mult_branch_fn(j):
state_ref[...] *= jnp.float32(j)
return ()
def single_inst(i, _):
_ = jax.lax.switch(
sprogram_ref[i],
(
add_branch_fn,
mult_branch_fn,
),
i,
)
# We can't use for loop state right now, because Pallas functionalizes it,
# and Mosaic support for returning values form scf.if is incomplete.
state_ref[...] = x
lax.fori_loop(0, sprogram_ref.shape[0], single_inst, None, unroll=True)
o_ref[...] = state_ref[...]
# Ignore the scratch output.
out, _ = self.pallas_call(
body,
out_shape=[
jax.ShapeDtypeStruct(x.shape, jnp.float32),
jax.ShapeDtypeStruct((8, 128), jnp.float32),
],
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[pl.BlockSpec((8, 128), lambda i, *_: (i, 0))],
out_specs=[
pl.BlockSpec((8, 128), lambda i, *_: (i, 0)),
pl.BlockSpec((8, 128), lambda *_: (0, 0)),
],
grid=8,
),
)(program, x)
expected = x
for i, p in enumerate(program):
if p == 0:
expected += i
elif p == 1:
expected *= i
np.testing.assert_allclose(out, expected)
def test_scalar_interpreter_dynamic_loop(self):
loop_end = jnp.array([5], jnp.int32)
def body(loop_end_ref, out_ref):
out_ref[...] = jnp.zeros_like(out_ref)
def loop_body(i, carry):
del i, carry
out_ref[...] += 1
lax.fori_loop(0, loop_end_ref[0], loop_body, None)
out = self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
out_specs=pl.BlockSpec((8, 128), lambda *_: (0, 0)),
grid=1,
),
)(loop_end)
expected_out = jnp.ones((8, 128), jnp.float32) * 5
np.testing.assert_allclose(out, expected_out)
def test_vmap_scalar_prefetch_1sized(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]
s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
def _x_transform(i, s_ref):
s = pl.load(s_ref, (i,))
return (s, 0)
s = s[None]
x = x[None]
out = jax.vmap(
self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec((x.shape[1] // 8, x.shape[2]), _x_transform),
],
out_specs=pl.BlockSpec(
(x.shape[1] // 8, x.shape[2]), lambda i, _: (i, 0)
),
grid=8,
),
)
)(s, x)
np.testing.assert_allclose(
out, x.reshape((1, 8, 8, -1))[:, s].reshape(x.shape)
)
def test_nontrivial_vmap_scalar_prefetch(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]
s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128))
def _x_transform(i, s_ref):
s = pl.load(s_ref, (i,))
return (s, 0)
s = jnp.tile(s[None], [2, 1])
@jax.jit
@jax.vmap
def kernel(s, x):
return self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform),
],
out_specs=pl.BlockSpec(
(x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0)
),
grid=8,
),
compiler_params=pltpu.TPUCompilerParams(
allow_input_fusion=[False, True]
),
)(s, x)
first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:])
second = x[1, ...].reshape((1, 8, 8, -1))[:, s[1, ...]].reshape(x.shape[1:])
expected = jnp.stack([first, second])
np.testing.assert_allclose(kernel(s, x), expected)
def test_input_output_aliasing_with_scalar_prefetch(self):
x = jnp.ones((32, 1024, 1024))
expected = x + 1
def kernel(_, x_ref, y_ref):
y_ref[...] = x_ref[...] + 1.
@partial(jax.jit, donate_argnums=(0,))
def f(x):
return self.pallas_call(
kernel,
out_shape=x,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec((None, 1024, 1024), lambda i, _: (i, 0, 0))
],
out_specs=pl.BlockSpec(
(None, 1024, 1024), lambda i, _: (i, 0, 0)
),
grid=(x.shape[0],),
),
input_output_aliases={1: 0},
)(jnp.array([1, 2, 3]), x)
o = f(x)
np.testing.assert_array_equal(o, expected)
compiled = f.lower(jax.ShapeDtypeStruct(x.shape, x.dtype)).compile()
mem_analysis = compiled.memory_analysis()
expected_num_bytes = np.prod(x.shape) * x.dtype.itemsize
self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes)
class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest):
INTERPRET: bool = True
class PallasCallDynamicGridTest(PallasBaseTest):
def test_can_query_grid_statically_via_num_programs(self):
def kernel(_):
num_programs = pl.num_programs(0)
self.assertIsInstance(num_programs, int)
self.assertEqual(num_programs, 2)
self.pallas_call(kernel, out_shape=None, grid=(2,))()
def test_can_query_grid_statically_via_num_programs_in_block_spec(self):
def kernel(*_):
pass
def x_index_map(_):
num_programs = pl.num_programs(0)
self.assertIsInstance(num_programs, int)
self.assertEqual(num_programs, 2)
return 0, 0
self.pallas_call(
kernel,
in_specs=[pl.BlockSpec((8, 128), x_index_map)],
out_shape=None,
grid=(2,),
)(jnp.ones((8, 128)))
def test_dynamic_grid_has_dynamic_size(self):
def kernel(_):
num_programs = pl.num_programs(0)
self.assertIsInstance(num_programs, int, msg=type(num_programs))
self.assertEqual(num_programs, 2)
num_programs = pl.num_programs(1)
self.assertIsInstance(num_programs, jax.Array)
@jax.jit
def outer(x):
self.pallas_call(kernel, out_shape=None, grid=(2, x))()
outer(2)
def test_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = jnp.zeros_like(y_ref)
y_ref[...] += 1
@jax.jit
def dynamic_kernel(steps):
return self.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(shape, lambda i: (0, 0)),
out_shape=result_ty,
)()
np.testing.assert_array_equal(
dynamic_kernel(jnp.int32(4)), np.full(shape, 8.0, np.float32)
)
def test_dynamic_grid_overflow(self):
# If we pad statically the dynamic grid dims to max int32, then the product
# of this grid size will overflow int64 and can cause failing checks in XLA.
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(y_ref):
@pl.when(sum(pl.program_id(i) for i in range(3)) == 0)
def _init():
y_ref[...] = jnp.zeros_like(y_ref)
y_ref[...] += 1
@jax.jit
def dynamic_kernel(steps):
return self.pallas_call(
kernel,
grid=(steps * 2, steps + 1, 3),
out_specs=pl.BlockSpec(shape, lambda *_: (0, 0)),
out_shape=result_ty,
)()
np.testing.assert_array_equal(
dynamic_kernel(jnp.int32(4)), np.full(shape, 120.0, np.float32)
)
# TODO(apaszke): Add tests for scalar_prefetch too
def test_dynamic_grid_scalar_input(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(scalar_input_ref, output_ref):
output_ref[...] = jnp.full_like(output_ref, scalar_input_ref[0, 0])
@jax.jit
def dynamic_kernel(steps):
return self.pallas_call(
kernel,
out_shape=result_ty,
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)],
out_specs=pl.BlockSpec(shape, lambda i: (0, 0)),
grid=(steps * 2,),
)(jnp.array([[42]], dtype=jnp.int32))
np.testing.assert_array_equal(
dynamic_kernel(jnp.int32(4)), np.full(shape, 42.0, np.float32)
)
def test_vmap_trivial_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(x_ref, y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = x_ref[...]
y_ref[...] += 1
@jax.jit
@jax.vmap
def dynamic_kernel(steps, x):
return self.pallas_call(
kernel,
grid=(steps * 2,),
in_specs=[pl.BlockSpec(shape, lambda i: (0, 0))],
out_specs=pl.BlockSpec(shape, lambda i: (0, 0)),
out_shape=result_ty,
)(x)
x = jnp.arange(8 * 128., dtype=jnp.float32).reshape((1, *shape))
np.testing.assert_array_equal(
dynamic_kernel(jnp.array([4], jnp.int32), x), x + 8.0
)
def test_vmap_nontrivial_dynamic_grid(self):
# Dynamic grid doesn't support vmapping over multiple distinct grid values
# at the moment.
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = jnp.zeros_like(y_ref)
y_ref[...] += 1
@jax.jit
@jax.vmap
def dynamic_kernel(steps):
return self.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(shape, lambda i: (0, 0)),
out_shape=result_ty,
)()
out = dynamic_kernel(jnp.array([4, 8], jnp.int32))
first = jnp.full(shape, fill_value=8.0, dtype=jnp.float32)
second = jnp.full(shape, fill_value=16.0, dtype=jnp.float32)
expected_out = jnp.stack([first, second], axis=0)
np.testing.assert_array_equal(out, expected_out)
def test_vmap_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(x_ref, y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = x_ref[...]
y_ref[...] += jnp.float32(1.)
@jax.jit
def dynamic_kernel(x, steps):
return self.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(shape, lambda i: (0, 0)),
out_shape=result_ty,
)(x)
x = jnp.arange(4 * 8 * 128., dtype=jnp.float32).reshape((4, *shape))
np.testing.assert_array_equal(
jax.jit(jax.vmap(dynamic_kernel, in_axes=(0, None)))(x, jnp.int32(4)),
x + 8,
)
def test_num_programs(self):
def kernel(y_ref):
y_ref[0, 0] = pl.num_programs(0)
@jax.jit
def dynamic_kernel(steps):
return self.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(memory_space=pltpu.SMEM),
out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32),
)()
self.assertEqual(dynamic_kernel(np.int32(4)), 8)
@parameterized.parameters(range(1, 4))
def test_vmap_num_programs(self, num_vmaps):
result_ty = jax.ShapeDtypeStruct((8, 128), jnp.int32)
def kernel(y_ref):
y_ref[...] = jnp.full_like(y_ref, pl.num_programs(0))
kernel_call = self.pallas_call(
kernel,
grid=(8,),
out_specs=pl.BlockSpec(result_ty.shape, lambda i: (0, 0)),
out_shape=result_ty,
)
out_shape = (*(2 for _ in range(num_vmaps)), *result_ty.shape)
f = kernel_call
for _ in range(num_vmaps):
f = lambda impl=f: jax.vmap(impl, axis_size=2)()
out = jax.jit(f)()
np.testing.assert_array_equal(out, np.full(out_shape, 8.0))
def test_num_programs_block_spec(self):
def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]
@jax.jit
def dynamic_kernel(steps, x):
return self.pallas_call(
kernel,
grid=(steps * 2,),
in_specs=[
pl.BlockSpec(
(8, 128),
# Should always evaluate to (1, 0)
lambda i: (1 + 8 - pl.num_programs(0), 0),
)
],
out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32),
)(x)
x = np.arange(4 * 8 * 128., dtype=np.int32).reshape((4 * 8, 128))
np.testing.assert_array_equal(dynamic_kernel(np.int32(4), x), x[8:16])
class PallasCallDynamicGridInterpretTest(PallasCallDynamicGridTest):
INTERPRET = True
class PallasCallDMATest(PallasBaseTest):
def setUp(self):
super().setUp()
if not jtu.is_device_tpu_at_least(4):
self.skipTest('DMAs not supported on TPU generations <= 3')
def test_can_have_unspecified_memory_spaces(self):
def kernel(x_ref, y_ref):
# Just test whether things compile
del x_ref, y_ref
x = jnp.ones((8, 128), dtype=jnp.float32)
y = self.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=pl.ANY)],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
jax.block_until_ready(y)
def test_run_scoped_tracks_effects(self):
def kernel(x_ref, y_ref):
def body(temp_ref):
temp_ref[...] = jnp.ones_like(temp_ref)
x_ref[...] = 4 * y_ref[...] + temp_ref[...]
pl.run_scoped(body, pltpu.VMEM((8,), jnp.float32))
return []
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(kernel),
[
state.shaped_array_ref((8,), jnp.float32),
state.shaped_array_ref((8,), jnp.float32),
],
)
expected_effects = {state.ReadEffect(1), state.WriteEffect(0)}
self.assertSetEqual(jaxpr.effects, expected_effects)
def test_scoped_allocation(self):
def kernel(y_ref):
def body(x_ref):
x_ref[...] = jnp.ones_like(x_ref)
y_ref[...] = 4 * x_ref[...]
pl.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32))
o = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)()
np.testing.assert_allclose(o, 4 * np.ones_like(o))
def test_run_scoped_can_return_scalar_value(self):
def kernel(y_ref):
def body(x_ref):
x_ref[0] = 0
x_ref[0] += 1
return x_ref[0] + 2
out = pl.run_scoped(body, pltpu.SMEM((1,), jnp.int32))
y_ref[0] = out
o = self.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
),
out_shape=jax.ShapeDtypeStruct((1,), jnp.int32),
)()
np.testing.assert_allclose(o, jnp.array([3], jnp.int32))
def test_run_scoped_can_return_scalar_values(self):
def kernel(y_ref):
def body(x_ref):
x_ref[0] = 0
x_ref[0] += 1
return x_ref[0] + 2, x_ref[0]
out = pl.run_scoped(body, pltpu.SMEM((1,), jnp.int32))
y_ref[0], y_ref[1] = out
o = self.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
),
out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
)()
np.testing.assert_allclose(o, jnp.array([3, 1], jnp.int32))
def test_run_scoped_can_return_vector_values(self):
def kernel(y_ref):
def body(x_ref):
x_ref[...] = jnp.ones_like(x_ref)
return x_ref[...] + 1
out = pl.run_scoped(body, pltpu.VMEM((16, 128), jnp.int32))
y_ref[...] = out
o = self.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
),
out_shape=jax.ShapeDtypeStruct((16, 128), jnp.int32),
)()
np.testing.assert_allclose(o, jnp.full((16, 128), 2, dtype=jnp.int32))
def test_run_scoped_can_return_padded_vector_values(self):
def kernel(y_ref):
def body(x_ref):
x_ref[...] = jnp.ones_like(x_ref)
return x_ref[...] + 1
out = pl.run_scoped(body, pltpu.VMEM((17, 128), jnp.int32))
y_ref[...] = out
o = self.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
),
out_shape=jax.ShapeDtypeStruct((17, 128), jnp.int32),
)()
np.testing.assert_allclose(o, jnp.full((17, 128), 2, dtype=jnp.int32))
def test_nested_scoped_allocation(self):
def kernel(y_ref):
def body(x_ref):
x_ref[...] = jnp.zeros_like(x_ref)
def inner_body(z_ref):
z_ref[...] = jnp.ones_like(z_ref)
x_ref[...] = z_ref[...]
pl.run_scoped(inner_body, pltpu.VMEM((8, 128), jnp.float32))
y_ref[...] = 4 * x_ref[...]
pl.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32))
o = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)()
np.testing.assert_allclose(o, 4 * np.ones_like(o))
def test_run_scoped_partial_discharge(self):
def f(a_ref, b_ref):
def scope():
a_ref[...] = jnp.ones(4, jnp.float32)
b_ref[...] = jnp.ones(4, jnp.float32)
return []
pl.run_scoped(scope)
return []
aref1 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32')))
aref2 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32')))
in_avals = [aref1, aref2]
stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = state_discharge.discharge_state(
stateful_jaxpr, consts=(), should_discharge=[False, True])
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIsInstance(discharged_jaxpr.invars[0].aval, state.AbstractRef)
self.assertIsInstance(discharged_jaxpr.invars[1].aval, jax.core.ShapedArray)
self.assertEqual(discharged_jaxpr.effects, {state.WriteEffect(0)})
def test_can_allocate_semaphore(self):
def kernel(y_ref):
def body(sem1):
pass
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
jax.block_until_ready(self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)())
def test_can_allocate_multiple_semaphores(self):
def kernel(y_ref):
def body(sem1, sem2):
pass
pl.run_scoped(body, pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.REGULAR)
jax.block_until_ready(self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)())
def test_can_allocate_semaphore_array(self):
def kernel(y_ref):
def body(dma_sems, sems):
self.assertTupleEqual(dma_sems.shape, (4,))
self.assertTupleEqual(sems.shape, (3,))
if self.INTERPRET:
self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.integer))
self.assertTrue(jnp.issubdtype(sems.dtype, jnp.integer))
else:
self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore))
self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore))
pl.run_scoped(
body, pltpu.SemaphoreType.DMA((4,)), pltpu.SemaphoreType.REGULAR((3,))
)
jax.block_until_ready(self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)())
def test_can_allocate_scratch_semaphore_array(self):
def kernel(y_ref, dma_sems, sems):
self.assertTupleEqual(dma_sems.shape, (4,))
self.assertTupleEqual(sems.shape, (3,))
if self.INTERPRET:
self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.integer))
self.assertTrue(jnp.issubdtype(sems.dtype, jnp.integer))
else:
self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore))
self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore))
jax.block_until_ready(
self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
scratch_shapes=[
pltpu.SemaphoreType.DMA((4,)),
pltpu.SemaphoreType.REGULAR((3,)),
],
),
)()
)
def test_can_wait_on_semaphore(self):
def kernel(y_ref):
def body(sem):
pltpu.semaphore_signal(sem)
pltpu.semaphore_wait(sem)
pl.run_scoped(body, pltpu.SemaphoreType.REGULAR)
def body2(sem):
pltpu.semaphore_signal(sem, 2)
pltpu.semaphore_wait(sem)
pltpu.semaphore_wait(sem)
pl.run_scoped(body2, pltpu.SemaphoreType.REGULAR)
def body3(sem):
pltpu.semaphore_signal(sem)
pltpu.semaphore_signal(sem)
pltpu.semaphore_signal(sem)
pltpu.semaphore_wait(sem)
pltpu.semaphore_wait(sem)
pltpu.semaphore_wait(sem)
pl.run_scoped(body3, pltpu.SemaphoreType.REGULAR)
# TODO(b/345534352): Add interpret support for semaphore signal/wait.
jax.block_until_ready(self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)())
def test_can_wait_on_semaphore_array(self):
def kernel(y_ref):
def body(sems):
pltpu.semaphore_signal(sems.at[0])
pltpu.semaphore_wait(sems.at[0])
pltpu.semaphore_signal(sems.at[1], 2)
pltpu.semaphore_wait(sems.at[1])
pltpu.semaphore_wait(sems.at[1])
pltpu.semaphore_signal(sems.at[2])
pltpu.semaphore_signal(sems.at[2])
pltpu.semaphore_signal(sems.at[2])
pltpu.semaphore_wait(sems.at[2])
pltpu.semaphore_wait(sems.at[2])
pltpu.semaphore_wait(sems.at[2])
pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((3,)))
# TODO(b/345534352): Add interpret support for semaphore signal/wait.
jax.block_until_ready(self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)())
def test_can_wait_on_semaphore_array_with_dynamic_index(self):
def kernel(y_ref):
i = pl.program_id(0)
def body(sems):
pltpu.semaphore_signal(sems.at[i, 0])
pltpu.semaphore_wait(sems.at[i, 0])
pltpu.semaphore_signal(sems.at[i, 1], 2)
pltpu.semaphore_wait(sems.at[i, 1])
pltpu.semaphore_wait(sems.at[i, 1])
pltpu.semaphore_signal(sems.at[i, 2])
pltpu.semaphore_signal(sems.at[i, 2])
pltpu.semaphore_signal(sems.at[i, 2])
pltpu.semaphore_wait(sems.at[i, 2])
pltpu.semaphore_wait(sems.at[i, 2])
pltpu.semaphore_wait(sems.at[i, 2])
pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((4, 3)))
jax.block_until_ready(
self.pallas_call(
kernel,
in_specs=[],
out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
grid=4,
)()
)
def test_can_read_semaphore(self):
m, n = 2, 3
def kernel(y_ref):
def body(sems):
for r in range(m):
for c in range(n):
v = r * n + c
pltpu.semaphore_signal(sems.at[r, c],v)
y_ref[r, c] = pltpu.semaphore_read(sems.at[r, c])
pltpu.semaphore_wait(sems.at[r, c], v)
pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((m, n)))
y = jax.block_until_ready(
self.pallas_call(
kernel,
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32),
)()
)
np.testing.assert_array_equal(
y, jnp.arange(m * n).astype(jnp.int32).reshape((m, n))
)
def test_can_read_dma_semaphore(self):
def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem):
sem_val_ref[0, 0] = 123
pltpu.async_copy(x_hbm_ref, y_hbm_ref, dma_sem).wait()
sem_val_ref[0, 0] = pltpu.semaphore_read(dma_sem)
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
y, sem_val = jax.block_until_ready(
self.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec(memory_space=pl.ANY)],
out_specs=[
pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
],
scratch_shapes=[pltpu.SemaphoreType.DMA],
),
out_shape=[
jax.ShapeDtypeStruct((8, 128), jnp.int32),
jax.ShapeDtypeStruct((1, 1), jnp.int32),
],
)(x)
)
np.testing.assert_array_equal(y, x)
np.testing.assert_array_equal(sem_val, 0)
def test_hbm_hbm_dma(self):
def kernel(x_hbm_ref, y_hbm_ref):
def body(sem):
pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)],
sem).wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
x = jnp.arange(8 * 128.).reshape((8, 128))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_array_equal(y, x)
def test_cannot_dma_with_nonscalar_semaphore_ref(self):
def kernel(x_hbm_ref, y_hbm_ref):
def body(sem):
pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)],
sem).wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA((1,)))
with self.assertRaisesRegex(ValueError, 'Cannot signal'):
x = jnp.arange(8 * 128.).reshape((8, 128))
self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
def test_dma_with_scalar_semaphore_ref(self):
def kernel(x_hbm_ref, y_hbm_ref):
def body(sem):
pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)],
sem.at[0]).wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA((1,)))
x = jnp.arange(8 * 128.).reshape((8, 128))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_array_equal(y, x)
def test_hbm_hbm_grid_dma(self):
# When using the grid, we have to emit Mosaic window_params. Test that they
# work correctly with ANY memory space operands.
def kernel(x_hbm_ref, y_hbm_ref):
i = pl.program_id(0)
def body(sem):
pltpu.async_copy(
x_hbm_ref.at[pl.ds(i, 1)], y_hbm_ref.at[pl.ds(i, 1)], sem
).wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((2, 8, 128), jnp.float32),
grid=(2,),
)(x)
np.testing.assert_allclose(y, x)
def test_hbm_vmem_dma(self):
def kernel(x_hbm_ref, y_ref):
def body(x_ref, sem):
pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], x_ref.at[:, pl.ds(128)],
sem).wait()
y_ref[...] = x_ref[...]
pl.run_scoped(
body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA
)
x = jnp.arange(8 * 128.).reshape((8, 128))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)
def test_vmem_hbm_dma(self):
def kernel(x_ref, y_hbm_ref):
def body(y_ref, sem):
y_ref[...] = x_ref[...]
pltpu.async_copy(y_ref, y_hbm_ref, sem).wait()
pl.run_scoped(
body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA
)
x = jnp.arange(8 * 128.).reshape((8, 128))
y = self.pallas_call(
kernel,
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)
def test_vmem_hbm_vmem_dma(self):
def kernel(x_hbm_ref, y_hbm_ref):
def body(x_ref, y_ref, sem):
pltpu.async_copy(x_hbm_ref, x_ref, sem).wait()
y_ref[...] = x_ref[...]
pltpu.async_copy(y_ref, y_hbm_ref, sem).wait()
pl.run_scoped(
body,
pltpu.VMEM((8, 128), jnp.float32),
pltpu.VMEM((8, 128), jnp.float32),
pltpu.SemaphoreType.DMA,
)
x = jnp.arange(8 * 128.).reshape((8, 128))
y = self.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=pl.ANY)],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)
def test_hbm_smem_dma(self):
def kernel(x_hbm_ref, y_ref):
def body(x_ref, sem):
pltpu.async_copy(x_hbm_ref, x_ref, sem).wait()
y_ref[...] = x_ref[0, 0] * jnp.ones_like(y_ref)
pl.run_scoped(
body, pltpu.SMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA
)
x = 4 * jnp.ones((8, 128), jnp.float32)
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)
def test_smem_hbm_dma(self):
def kernel(x_ref, y_hbm_ref):
def body(y_ref, sem):
y_ref[0, 0] = 0.0
y_ref[0, 1] = x_ref[4, 4]
pltpu.async_copy(y_ref, y_hbm_ref, sem).wait()
pl.run_scoped(
body, pltpu.SMEM((1, 2), jnp.float32), pltpu.SemaphoreType.DMA
)
x = jnp.arange(8 * 128.).reshape((8, 128))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((1, 2), jnp.float32),
)(x)
expected = jnp.zeros_like(x[0:1, 0:2]).at[0, 1].set(x[4, 4])
np.testing.assert_allclose(y, expected)
def test_vmem_vmem_dma(self):
def kernel(x_ref, y_ref):
def body(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
x = jnp.arange(8 * 128.).reshape((8, 128))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)
def test_hbm_vmem_dma_slicing(self):
def kernel(x_hbm_ref, y_ref):
def body(sem):
dma1 = pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, 8)], y_ref.at[pl.ds(0, 8)], sem
)
dma2 = pltpu.async_copy(
x_hbm_ref.at[pl.ds(8, 8)], y_ref.at[pl.ds(8, 8)], sem
)
dma1.wait()
dma2.wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
x = jnp.arange(2 * 8 * 128.).reshape((16, 128))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)
def test_hbm_vmem_dma_indexing(self):
def kernel(x_hbm_ref, y_ref):
def body(sem):
dma1 = pltpu.async_copy(
x_hbm_ref.at[0], y_ref.at[pl.ds(0, 8)], sem
)
dma2 = pltpu.async_copy(
x_hbm_ref.at[1], y_ref.at[pl.ds(8, 8)], sem
)
dma1.wait()
dma2.wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x.reshape((16, 128)))
def test_hbm_vmem_dma_multiple_indexing(self):
if self.INTERPRET:
self.skipTest('Multiple indexing not supported in interpret mode.')
def kernel(x_hbm_ref, y_ref):
def body(sem):
for i in range(3):
dma1 = pltpu.async_copy(
x_hbm_ref.at[pl.ds(i, 1)].at[0, 0], y_ref.at[i].at[pl.ds(0, 8)],
sem
)
dma2 = pltpu.async_copy(
x_hbm_ref.at[pl.ds(i, 1)].at[0, 1], y_ref.at[i].at[pl.ds(8, 8)],
sem
)
dma1.wait()
dma2.wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
x = jnp.arange(3 * 2 * 8 * 128.).reshape((3, 2, 8, 128))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
out_shape=jax.ShapeDtypeStruct((3, 16, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x.reshape((3, 16, 128)))
def test_cannot_squeeze_lane_sublane(self):
if self.INTERPRET:
self.skipTest('Only works on Mosaic TPU.')
def kernel(x_hbm_ref, y_ref):
def body(sem):
dma1 = pltpu.async_copy(
x_hbm_ref.at[:, :, 0], y_ref.at[pl.ds(0, 8)], sem
)
dma2 = pltpu.async_copy(
x_hbm_ref.at[:, :, 1], y_ref.at[pl.ds(8, 8)], sem
)
dma1.wait()
dma2.wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128))
with self.assertRaises(Exception):
_ = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32),
)(x)
def test_hoisted_scratch_space(self):
def kernel(x_ref, y_ref, scratch_ref):
i = pl.program_id(0)
@pl.when(i == 0)
def _():
scratch_ref[...] = x_ref[...]
scratch_ref[...] += jnp.ones_like(scratch_ref)
@pl.when(i == 2)
def _():
y_ref[...] = scratch_ref[...]
x = jnp.arange(8 * 128.).reshape((8, 128))
y = self.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((8, 128), lambda i: (0, 0)),
],
scratch_shapes=[pltpu.VMEM((8, 128), jnp.float32)],
out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)),
grid=(3,),
),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_array_equal(y, x + 3)
def test_hoisted_smem_space(self):
# TODO(sharadmv,apaszke): enable SMEM scratch spaces
# TODO(sharadmv,apaszke): add support for ()-shaped SMEM refs
self.skipTest('Currently doesn\'t work')
def kernel(y_ref, scratch_ref):
scratch_ref[0, 0] = pl.program_id(0)
y_ref[...] = jnp.broadcast_to(scratch_ref[0, 0], y_ref.shape)
y = pl.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[],
scratch_shapes=[pltpu.SMEM((1, 1), jnp.int32)],
out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
grid=(2,),
),
debug=True,
out_shape=jax.ShapeDtypeStruct((2, 8, 128), jnp.int32),
)()
expected = jnp.broadcast_to(jnp.arange(2, dtype=jnp.int32)[..., None, None],
(2, 8, 128))
np.testing.assert_array_equal(y, expected)
def test_hoisted_semaphore(self):
def kernel(x_bbm_ref, y_ref, sem, dma_sem):
pltpu.semaphore_signal(sem)
pltpu.semaphore_wait(sem)
pltpu.async_copy(x_bbm_ref, y_ref, dma_sem).wait()
x = jnp.arange(8 * 128.).reshape((8, 128))
y = self.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
scratch_shapes=[pltpu.SemaphoreType.REGULAR,
pltpu.SemaphoreType.DMA],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_array_equal(y, x)
def test_large_array_indexing(self):
n = 6
dtype = jnp.bfloat16
# This test sometimes OOMs on smaller chips. We garbage collect
# to increase the chance there is 6GB memory available.
gc.collect()
x = jax.lax.broadcasted_iota(dtype, (n, 1024 * 1024, 512), 0)
def kernel(index, x, y, sem):
pltpu.async_copy(x.at[index[0]], y.at[:], sem).wait()
run = self.pallas_call(kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec(
memory_space=pl.ANY)],
out_specs=pl.BlockSpec(
memory_space=pl.ANY),
scratch_shapes=[pltpu.SemaphoreType.DMA],
),
out_shape=jax.ShapeDtypeStruct(x.shape[1:], dtype),
)
for i in range(x.shape[0]):
y = run(jnp.array([i], dtype=jnp.int32), x)
np.testing.assert_array_equal(y, i)
del y
def test_dynamic_dma_on_2nd_minor(self):
def kernel(array, data, index, size, _, sem):
pltpu.async_copy(
data.at[pl.ds(0, size[0])], array.at[pl.ds(index[0], size[0])], sem
).wait()
def run(array, data, index, size):
return pl.pallas_call(
kernel,
out_shape=array,
in_specs=[
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.VMEM),
pl.BlockSpec(memory_space=pltpu.SMEM),
pl.BlockSpec(memory_space=pltpu.SMEM),
],
scratch_shapes=[
pltpu.SemaphoreType.DMA,
],
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
input_output_aliases={0: 0},
)(array, data, index, size)
array = jnp.zeros((1024, 128), jnp.int32)
data = jnp.ones((8, 128), jnp.int32)
index = jnp.array([3], jnp.int32)
size = jnp.array([5], jnp.int32)
expected = array.at[index[0] : index[0] + size[0]].set(
data[index[0] : index[0] + size[0]]
)
result = run(array, data, index, size)
np.testing.assert_array_equal(result, expected)
class PallasCallDMAInterpretTest(PallasCallDMATest):
INTERPRET = True
def test_interpret_local_dma(self):
# We run this test in interpret mode to test semaphore counting.
# On a physical device the values update asynchronously so we cannot
# deterministically check the values.
def test_kernel(x_ref,
o_ref,
sem_out_ref,
copy_sem,
):
o_ref[...] = jnp.zeros_like(o_ref[...])
input_to_output_copy = pltpu.make_async_copy(
src_ref=x_ref.at[0:8],
dst_ref=o_ref.at[0:8],
sem=copy_sem.at[0],
)
input_to_output_copy.start()
sem_out_ref[0, :] = jnp.ones_like(
sem_out_ref[0, :]) * pltpu.semaphore_read(copy_sem.at[0])
input_to_output_copy.wait()
sem_out_ref[1, :] = jnp.ones_like(
sem_out_ref[0, :]) * pltpu.semaphore_read(copy_sem.at[0])
out_shape = (jax.ShapeDtypeStruct((16, 128), jnp.int32),
jax.ShapeDtypeStruct((2, 1), jnp.int32))
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
scratch_shapes=(
[pltpu.SemaphoreType.DMA(2,)]
)
)
kernel = pl.pallas_call(
test_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=True,
)
x = jax.random.randint(
jax.random.key(0), shape=(16, 128), minval=0, maxval=128)
result, semaphores = kernel(x)
np.testing.assert_array_equal(result[0:8], x[0:8])
np.testing.assert_array_equal(result[8:], jnp.zeros_like(result[8:]))
# Make sure semaphores have the correct value before and after DMA wait.
result_sem_pre_wait = semaphores[0, 0]
np.testing.assert_array_equal(result_sem_pre_wait, result[0:8].size)
result_sem_post_wait = semaphores[1, 0]
np.testing.assert_array_equal(result_sem_post_wait, 0)
def test_interpreter_semaphore_counting(self):
# We run this test in interpret mode because the kernel exits with
# non-zero values. In normal Pallas this would crash the kernel.
def test_kernel(o_ref,
sem_ref,
):
o_ref[...] = jnp.zeros_like(o_ref)
pltpu.semaphore_signal(sem_ref.at[0], 1)
pltpu.semaphore_signal(sem_ref.at[1], 2)
pltpu.semaphore_signal(sem_ref.at[2], 3)
pltpu.semaphore_signal(sem_ref.at[3], 4)
o_ref[0, 0] = pltpu.semaphore_read(sem_ref.at[0])
o_ref[1, 0] = pltpu.semaphore_read(sem_ref.at[1])
o_ref[2, 0] = pltpu.semaphore_read(sem_ref.at[2])
o_ref[3, 0] = pltpu.semaphore_read(sem_ref.at[3])
pltpu.semaphore_wait(sem_ref.at[0], 4)
pltpu.semaphore_wait(sem_ref.at[1], 3)
pltpu.semaphore_wait(sem_ref.at[2], 2)
pltpu.semaphore_wait(sem_ref.at[3], 1)
o_ref[4, 0] = pltpu.semaphore_read(sem_ref.at[0])
o_ref[5, 0] = pltpu.semaphore_read(sem_ref.at[1])
o_ref[6, 0] = pltpu.semaphore_read(sem_ref.at[2])
o_ref[7, 0] = pltpu.semaphore_read(sem_ref.at[3])
out_shape = jax.ShapeDtypeStruct((8, 1), jnp.int32)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
scratch_shapes=(
[pltpu.SemaphoreType.DMA(4,)]
)
)
results = pl.pallas_call(
test_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
interpret=True,
)()
expected = jnp.array([1, 2, 3, 4, -3, -1, 1, 3]).reshape(out_shape.shape)
np.testing.assert_array_equal(results, expected)
class PallasCallTest(PallasBaseTest):
def test_cost_analysis(self):
def kernel(x, y):
y[:] = x[:]
x = jnp.arange(1024.).reshape(8, 128)
f = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
cost_estimate=pl.CostEstimate(
flops=1234, transcendentals=21, bytes_accessed=12345
),
)
(analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis()
self.assertEqual(analysis_result['flops'], 1234)
self.assertEqual(analysis_result['transcendentals'], 21)
self.assertEqual(analysis_result['bytes accessed'], 12345)
def test_cost_analysis_vmap(self):
def kernel(x, y):
y[:] = x[:]
batch_size = 3
x = jnp.arange(batch_size * 1024.).reshape(batch_size, 8, 128)
f = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
cost_estimate=pl.CostEstimate(
flops=1234, transcendentals=21, bytes_accessed=12345
),
)
f = jax.vmap(f)
(analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis()
self.assertEqual(analysis_result['flops'], batch_size * 1234)
self.assertEqual(analysis_result['transcendentals'], batch_size * 21)
self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345)
def test_vmem_limit(self):
shape = (128, 128)
def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]
x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape)
with self.assertRaises(xla_extension.XlaRuntimeError):
self.pallas_call(
kernel,
out_shape=x,
compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=256),
)(x)
self.pallas_call(
kernel,
out_shape=x,
compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=int(2**18)),
)(x)
def test_allow_input_fusion(self):
shape = (3, 128, 128)
def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]
def f(x, y):
z = jax.numpy.add(x, y)
return self.pallas_call(
kernel,
grid=(3,),
in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0)),
out_shape=x,
compiler_params=pltpu.TPUCompilerParams(allow_input_fusion=[True]),
)(z)
x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape)
y = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape)
out = f(x, y)
expected = x + y
np.testing.assert_array_equal(out, expected)
compiled = jax.jit(f).lower(x, y).compile().as_text()
assert re.search(r'fusion.*kind=kCustom.*fused_computation', compiled)
def test_set_internal_scratch_size(self):
shape = (128, 128)
def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
requested_bytes = 128 * 4
with self.assertRaisesRegex(
Exception,
f'Requested internal scratch size {requested_bytes} needs to be at'
' least',
):
self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
compiler_params=pltpu.TPUCompilerParams(
internal_scratch_in_bytes=requested_bytes,
),
)(x)
@parameterized.product(dtype=[jnp.bfloat16, jnp.float32])
def test_pltpu_repeat(self, dtype):
def test_kernel(x_ref, o_ref):
x = x_ref[...]
o_ref[...] = pltpu.repeat(x, 2, axis=1)
@jax.jit
def test(x: jax.Array) -> jax.Array:
return pl.pallas_call(
test_kernel,
out_shape=jax.ShapeDtypeStruct([x.shape[0], x.shape[1] * 2], x.dtype),
)(x)
x = jnp.arange(2048, dtype=dtype).reshape((8, 256))
y = test(x)
np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1))
def test_masked_store(self):
if jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("Test requires masked store support")
shape = (16, 256)
mask_shape = (10, 130)
mask_start = (4, 5)
dtype = jnp.float32
def body(scalar_ref, x_ref, o_ref):
o_ref[...] = jnp.full(shape, -1, dtype=dtype)
b0, b1 = scalar_ref[0], scalar_ref[1]
e0, e1 = b0 + mask_shape[0], b1 + mask_shape[1]
iota0 = lax.broadcasted_iota(jnp.int32, shape, 0)
iota1 = lax.broadcasted_iota(jnp.int32, shape, 1)
mask0 = jnp.logical_and(b0 <= iota0, iota0 < e0)
mask1 = jnp.logical_and(b1 <= iota1, iota1 < e1)
pl.store(
o_ref,
(slice(None), slice(None)),
x_ref[...],
mask=jnp.logical_and(mask0, mask1),
)
s = jnp.array(mask_start, jnp.int32)
x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)
out = pl.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(shape, dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
),
)(s, x)
slices = tuple(slice(b, b + l) for b, l in zip(mask_start, mask_shape))
expected = jnp.full(shape, -1, dtype=dtype)
expected = expected.at[slices].set(x[slices])
np.testing.assert_array_equal(out, expected)
class PallasUXTest(PallasBaseTest):
def test_mlir_location(self):
# Make sure that MLIR locations are correctly propagated to primitives.
args = (jax.ShapeDtypeStruct((8, 128), jnp.float32),)
f = example_kernel.double
as_tpu_kernel = mosaic.as_tpu_kernel
def capture_as_tpu_kernel(module, *args, **kwargs):
asm = module.operation.get_asm(enable_debug_info=True)
self.assertIn('example_kernel.py":25', asm)
return as_tpu_kernel(module, *args, **kwargs)
mosaic.as_tpu_kernel = capture_as_tpu_kernel
try:
jax.jit(f).lower(*args)
finally:
mosaic.as_tpu_kernel = as_tpu_kernel
class PallasMegacoreTest(PallasBaseTest):
def test_megacore_splitting(self):
# We want to make sure a 3-sized dimension is split across megacore
# correctly, and if we combine the (3, 3) dimensions together it is still
# correct.
def matmul_kernel(x_ref, y_ref, z_ref):
@pl.when(pl.program_id(2) == 0)
def _():
z_ref[...] = jnp.zeros_like(z_ref)
z_ref[...] += x_ref[...] @ y_ref[...]
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.uniform(k1, (3, 3, 512, 512))
y = jax.random.uniform(k2, (3, 3, 512, 512))
z = jax.vmap(
jax.vmap(
pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32),
grid=(4, 4, 4),
in_specs=[
pl.BlockSpec((128, 128), lambda i, j, k: (i, k)),
pl.BlockSpec((128, 128), lambda i, j, k: (k, j)),
],
out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)),
debug=True,
)
)
)(x, y)
np.testing.assert_allclose(
z, jax.vmap(jax.vmap(jnp.dot))(x, y), rtol=1e-6
)
class PallasCallVmapTest(PallasBaseTest):
def test_scratch_input_vmap(self):
"""Test that vmapp-ing a kernel with scratch inputs works correctly."""
# Scratch inputs are only available for PallasTPU. This is why this test
# does not live with the other vmap tests in:
# jax/tests/pallas/pallas_test.py
def add_one_with_scratch(x_ref, o_ref, scratch_ref):
scratch_ref[...] = jnp.ones_like(scratch_ref[...])
o_ref[...] = x_ref[...] + scratch_ref[...]
tile_size = 128
tile_shape = (tile_size, tile_size)
array_shape = (2 * tile_size, 2 * tile_size)
vmapped_add_one_with_scratch = jax.vmap(
pl.pallas_call(
add_one_with_scratch,
out_shape=jax.ShapeDtypeStruct(array_shape, jnp.int32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec(tile_shape, lambda i, j: (i, j))],
out_specs=pl.BlockSpec(tile_shape, lambda i, j: (i, j)),
scratch_shapes=[pltpu.VMEM(tile_shape, dtype=jnp.int32)],
grid=(2, 2),
),
)
)
x = jnp.broadcast_to(jnp.arange(array_shape[0]), (10, *array_shape))
out = vmapped_add_one_with_scratch(x)
out_ref = x + 1
np.testing.assert_array_equal(out, out_ref, strict=True)
class PallasCallDynamicDMATest(PallasBaseTest):
def setUp(self):
super().setUp()
if not jtu.is_device_tpu_at_least(4):
self.skipTest('DMAs not supported on TPU generations <= 3')
def test_simple_tile_aligned_dynamic_size_dma(self):
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
x = jnp.tile(jnp.arange(8, dtype=jnp.int32)[:, None, None], [1, 8, 128])
o = jnp.zeros((8, 8, 128), dtype=jnp.int32)
size = jnp.array([4], dtype=jnp.int32)
out = pl.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM),
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.ANY)],
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
scratch_shapes=[pltpu.SemaphoreType.DMA]
),
out_shape=o,
input_output_aliases={2: 0},
)(size, x, o)
expected = o.at[:4].set(x.at[:4].get())
np.testing.assert_array_equal(out, expected)
def test_simple_dynamic_size_dma(self):
self.skipTest("doesn't work yet.")
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
x = jnp.arange(8, dtype=jnp.int32)
o = jnp.zeros(8, dtype=jnp.int32)
size = jnp.array([4], dtype=jnp.int32)
out = pl.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM),
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.ANY)],
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
scratch_shapes=[pltpu.SemaphoreType.DMA]
),
out_shape=o,
input_output_aliases={2: 0},
)(size, x, o)
expected = o.at[:4].set(x.at[:4].get())
np.testing.assert_array_equal(out, expected)
class PallasCallRefTransformTest(PallasBaseTest):
@parameterized.product(slice_first=[True, False])
def test_dma_bitcasted_ref(self, slice_first):
if not jtu.is_device_tpu_at_least(4):
self.skipTest('DMAs not supported on TPU generations <= 3')
def kernel(x_hbm_ref, y_hbm_ref):
def body(sem):
ref = (
x_hbm_ref.at[:8, :, :128].bitcast(jnp.int16)
if slice_first
else x_hbm_ref.bitcast(jnp.int16).at[:8, :, :128]
)
pltpu.async_copy(ref, y_hbm_ref.at[...], sem).wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 1, 256))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 2, 128), jnp.int16),
)(x)
expected = (
state_utils.bitcast(x[:8, :, :128], jnp.int16)
if slice_first
else state_utils.bitcast(x, jnp.int16)[:8, :, :128]
)
np.testing.assert_array_equal(y, expected)
@parameterized.product(slice_first=[True, False])
def test_load_bitcasted_ref(self, slice_first: bool):
def kernel(x_ref, y_ref):
ref = (
x_ref.at[:8, :128].bitcast(jnp.int16)
if slice_first
else x_ref.bitcast(jnp.int16).at[:16, :128]
)
y_ref[...] = ref[...]
x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 256))
y = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((16, 128), jnp.int16),
)(x)
expected = (
state_utils.bitcast(x[:8, :128], jnp.int16)
if slice_first
else state_utils.bitcast(x, jnp.int16)[:16, :128]
)
np.testing.assert_array_equal(y, expected)
@parameterized.product(slice_first=[True, False])
def test_store_bitcasted_ref(self, slice_first):
def kernel(x_ref, y_ref):
ref = (
y_ref.at[:8, :128].bitcast(jnp.bfloat16)
if slice_first
else y_ref.bitcast(jnp.bfloat16).at[:16, :128]
)
ref[...] = x_ref[...]
x = jnp.arange(16 * 128, dtype=jnp.bfloat16).reshape((16, 128))
y = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32),
)(x)
expected = state_utils.bitcast(x, jnp.int32)
np.testing.assert_array_equal(y[:8, :128], expected)
@parameterized.product(slice_first=[True, False])
def test_dma_reshaped_ref(self, slice_first):
if not jtu.is_device_tpu_at_least(4):
self.skipTest('DMAs not supported on TPU generations <= 3')
def kernel(x_hbm_ref, y_hbm_ref):
def body(sem):
ref = (
x_hbm_ref.at[:8, :, :].reshape(8, 128)
if slice_first
else x_hbm_ref.reshape(16, 128).at[:8, :]
)
pltpu.async_copy(ref, y_hbm_ref.reshape(8, 128).at[...], sem).wait()
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
x = jnp.arange(16 * 128, dtype=jnp.int32).reshape(16, 1, 128)
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 1, 128), jnp.int32),
)(x)
expected = (
x[:8, :, :128].reshape((8, 128))
if slice_first
else x.reshape(16, 128)[:8, :128]
).reshape(8, 1, 128)
np.testing.assert_array_equal(y, expected)
def test_load_reshaped_ref(self):
if not jtu.is_device_tpu_at_least(4):
self.skipTest('No expected (1, 128) tiling')
def kernel(x_ref, y_ref):
y_ref[...] = x_ref.reshape(5, 128)[...]
x = jnp.arange(5 * 128, dtype=jnp.int32).reshape(5, 1, 128)
y = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((5, 128), jnp.int32),
)(x)
expected = x.reshape(5, 128)
np.testing.assert_array_equal(y, expected)
def test_store_reshaped_ref(self):
if not jtu.is_device_tpu_at_least(4):
self.skipTest('No expected (1, 128) tiling')
def kernel(x_ref, y_ref):
y_ref.reshape(5, 128)[...] = x_ref[...]
x = jnp.arange(5 * 128, dtype=jnp.int32).reshape(5, 128)
y = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((5, 1, 128), jnp.int32),
)(x)
expected = x.reshape(5, 1, 128)
np.testing.assert_array_equal(y, expected)
def test_multiple_ref_transforms(self):
def kernel(x_ref, y_ref):
ref = (
x_ref.at[:16, :256] # i32(16, 256)
.bitcast(jnp.int16) # i16(32, 256)
.reshape((2, 16, 256)) # i16(2, 16, 256)
.bitcast(jnp.float16) # bf16(2, 16, 256)
.at[1:, :, :] # bf16(1, 16, 256)
.reshape((16, 256)) # bf16(16, 256)
.at[:, :128] # bf16(16, 128)
.bitcast(jnp.int32) # i32(8, 128)
)
y_ref[...] = ref[...]
x = jnp.arange(32 * 256, dtype=jnp.int32).reshape((32, 256))
y = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32),
)(x)
np.testing.assert_array_equal(y, x[8:16, :128])
class PallasCallPrintTest(PallasBaseTest):
def test_debug_print(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)
def kernel(x_ref, o_ref):
pl.debug_print('It works!')
x = jnp.arange(8 * 128, dtype=jnp.float32).reshape((8, 128))
compiled_kernel = (
jax.jit(kernel)
.lower(x)
.compile({'xla_tpu_enable_log_recorder': 'true'})
)
with jtu.capture_stderr() as get_output:
jax.block_until_ready(compiled_kernel(x))
self.assertIn('It works!', get_output())
def test_debug_print_with_values(self):
@functools.partial(
self.pallas_call,
in_specs=(pl.BlockSpec(memory_space=pltpu.SMEM),),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)
def kernel(x_ref, o_ref):
pl.debug_print('x[0] == {}', x_ref[0])
x = jnp.array([42, 24]).astype(jnp.int32)
compiled_kernel = (
jax.jit(kernel)
.lower(x)
.compile({'xla_tpu_enable_log_recorder': 'true'})
)
with jtu.capture_stderr() as get_output:
jax.block_until_ready(compiled_kernel(x))
self.assertIn('x[0] == 42', get_output())
@parameterized.named_parameters(
(f"{'_'.join(map(str, shape))}_{dtype.__name__}", shape, dtype)
for shape in (
(2, 8, 128),
# test unaligned shapes
(3,),
(3, 4),
(2, 3, 4),
(2, 9, 129),
)
for dtype in (jnp.int32, jnp.uint32, jnp.float32)
)
def test_debug_print_vector(self, shape, dtype):
# TODO(ayx): Remove after this date.
if not jtu.if_cloud_tpu_at_least(2025, 1, 16):
self.skipTest("Requires libtpu built after 2025-01-16")
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(shape, dtype),
)
def kernel(x_ref, o_ref):
pl.debug_print("{}", x_ref[...])
o_ref[...] = x_ref[...]
n = np.prod(shape)
x = jnp.arange(n, dtype=dtype).reshape(shape)
compiled_kernel = (
jax.jit(kernel)
.lower(x)
.compile({"xla_tpu_enable_log_recorder": "true"})
)
with jtu.capture_stderr() as get_output:
jax.block_until_ready(compiled_kernel(x))
output = get_output()
numbers = [
int(num)
for line in output.splitlines()
if (match := re.search(r"\{(.*)", line)) # extract contents after `{`
for num in re.findall(r"\d+", match.group(1))
]
# Check if the numbers in the output match the values generated by `arange`.
self.assertLen(numbers, n)
self.assertTrue(all(num == i for i, num in enumerate(numbers)))
class PallasCallTraceTest(PallasBaseTest):
def test_trace_start_stop_match(self):
def kernel(o_ref):
with jax.named_scope('scope1'):
o_ref[...] = jnp.zeros_like(o_ref[...])
with string_stdout() as msg:
_ = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
debug=True,
)()
# TODO(justinfu): Add an official lowering API to get the MLIR.
debug_string = msg.getvalue()
num_start = debug_string.count('tpu.trace_start')
num_stop = debug_string.count('tpu.trace_stop')
self.assertEqual(num_start, 1)
self.assertEqual(num_stop, 1)
def test_run_scoped(self):
def kernel(o_ref):
def scope1():
with jax.named_scope('scope1'):
o_ref[...] = jnp.zeros_like(o_ref[...])
pl.run_scoped(scope1)
def scope2():
with jax.named_scope('scope2'):
o_ref[...] = o_ref[...] + 1
pl.run_scoped(scope2)
with string_stdout() as msg:
_ = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
debug=True,
)()
# TODO(justinfu): Add an official lowering API to get the MLIR.
debug_string = msg.getvalue()
num_start = debug_string.count('tpu.trace_start')
num_stop = debug_string.count('tpu.trace_stop')
self.assertEqual(num_start, 2)
self.assertEqual(num_stop, 2)
class PallasCallTPUBooleanTest(PallasBaseTest):
"""Tests for loading/storing from bool memrefs on TPUs.
We specifically test bools because they have special handling.
Bools are stored as integers inside of memrefs, and we typecast to/from
bools automatically on load.
"""
INTERPRET: bool = False
@parameterized.parameters((False,), (True,))
def test_scalar_bool_load_store(self, value):
def kernel(x_ref, o_ref):
o_ref[0, 0] = jnp.logical_not(x_ref[0, 0])
input = jnp.array([[value]])
output_shape = jax.ShapeDtypeStruct((1, 1), jnp.bool_)
result = self.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)],
out_specs=pl.BlockSpec(memory_space=pltpu.SMEM),
out_shape=output_shape,
)(input)
np.testing.assert_array_equal(result, jnp.logical_not(input))
@parameterized.parameters((False,), (True,))
def test_scalar_bool_run_scoped(self, value):
if self.INTERPRET:
self.skipTest('run_scoped not supported in non-interpret mode.')
def kernel(x_ref, o_ref):
def inner_scope(scoped_ref):
scoped_ref[0, 0] = jnp.logical_not(x_ref[0, 0])
o_ref[0, 0] = scoped_ref[0, 0]
pl.run_scoped(inner_scope, pltpu.SMEM((1, 1), dtype=jnp.bool_))
input_arr = jnp.array([[value]])
output_shape = jax.ShapeDtypeStruct((1, 1), jnp.bool_)
result = self.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)],
out_specs=pl.BlockSpec(memory_space=pltpu.SMEM),
out_shape=output_shape,
)(input_arr)
np.testing.assert_array_equal(result, jnp.logical_not(input_arr))
def test_vector_bool_load_store(self):
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...]
input = jax.random.bernoulli(jax.random.key(0), p=0.5, shape=(8, 128))
output_shape = jax.ShapeDtypeStruct((8, 128), jnp.bool_)
result = self.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)],
out_specs=pl.BlockSpec(memory_space=pltpu.VMEM),
out_shape=output_shape,
)(input)
np.testing.assert_array_equal(result, input)
def test_vector_bool_masking_with_indexing(self):
def kernel(mask_ref, true_ref, false_ref, o_ref):
o_ref[0, ...] = jnp.where(
mask_ref[0, ...], true_ref[0, ...], false_ref[0, ...])
key = jax.random.key(0)
k1, k2, k3 = jax.random.split(key, 3)
values_1 = jax.random.normal(k1, (1, 256, 256), jnp.float32)
values_2 = jax.random.normal(k2, (1, 256, 256), jnp.float32)
mask = jax.random.bernoulli(k3, p=0.5, shape=(1, 256, 256))
output_shape = jax.ShapeDtypeStruct((1, 256, 256), jnp.float32)
result = self.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM),
pl.BlockSpec(memory_space=pltpu.VMEM),
pl.BlockSpec(memory_space=pltpu.VMEM),
],
out_specs=pl.BlockSpec(memory_space=pltpu.VMEM),
out_shape=output_shape,
)(mask, values_1, values_2)
expected = jnp.where(mask, values_1, values_2)
np.testing.assert_array_equal(result, expected)
def test_bool_dma_not_implemented(self):
if not jtu.is_device_tpu_at_least(4):
self.skipTest('DMAs not supported on TPU generations <= 3')
if self.INTERPRET:
self.skipTest('Test only applies to non-interpret mode.')
num_devices = jax.local_device_count()
def kernel(x_ref, o_ref, send_sem, recv_sem):
index = lax.axis_index('x')
neighbor = lax.rem(index + 1, num_devices)
copy = pltpu.make_async_remote_copy(x_ref,
o_ref,
send_sem,
recv_sem,
device_id=(0, neighbor))
copy.start()
copy.wait()
input_arr = jnp.ones((8, 128), dtype=jnp.bool_)
output_shape = jax.ShapeDtypeStruct((8, 128), jnp.bool_)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
grid=(1,),
scratch_shapes=[pltpu.SemaphoreType.DMA] * 2,
)
test_fn = self.pallas_call(
kernel,
grid_spec=grid_spec,
out_shape=output_shape,
)
with self.assertRaisesRegex(
Exception, 'DMAs with bool dtypes are not supported.'):
devices = mesh_utils.create_device_mesh((num_devices,))
mesh = jax.sharding.Mesh(devices, ('x',))
sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
input_arr = jax.device_put(input_arr, sharding)
jax.jit(
shard_map.shard_map(
test_fn,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P(None, 'x'),
check_rep=False
)
)(input_arr)
class PallasCallTPUBooleanInterpretTest(PallasCallTPUBooleanTest):
INTERPRET: bool = True
class PallasCallTPUCheckifyTest(PallasBaseTest):
@parameterized.parameters((2,), (5,), (6,), (7,))
def test_checkify_with_scalar_prefetch(self, threshold):
def body(scalar_ref, x_ref, o_ref):
scalar = scalar_ref[pl.program_id(0)]
o_ref[...] = x_ref[...]
checkify.check(scalar < threshold, 'failed on value {x}', x=scalar)
s = jnp.array([4, 3, 2, 6, 3, 5, 2, 7], jnp.int32)
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
def _x_transform(i, s_ref):
s = pl.load(s_ref, (i,))
return (s, 0)
pallas_call = self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec((x.shape[0] // 8, x.shape[1]), _x_transform),
],
out_specs=pl.BlockSpec(
(x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0)
),
grid=8,
),
)
checked_call = checkify.checkify(pallas_call)
err, out = checked_call(s, x)
expected_error_value = s[jnp.argmax(s >= threshold)]
with self.assertRaisesRegex(
checkify.JaxRuntimeError, f'failed on value {expected_error_value}'):
err.throw()
np.testing.assert_allclose(out, x.reshape((8, 8, -1))[s].reshape(x.shape))
def test_checkify_with_scratch(self):
def body(x_ref, o_ref, scratch_ref):
scratch_ref[...] = x_ref[...]
o_ref[...] = scratch_ref[...]
all_nequal = ~jnp.all(o_ref[...] == x_ref[...])
checkify.check(all_nequal, 'x_ref equals o_ref id=({x}, {y})',
x=pl.program_id(0), y=pl.program_id(1))
x = jax.random.uniform(jax.random.key(0), (128, 512), dtype=jnp.float32)
pallas_call = self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((32, 128), lambda i, j: (i, j)),
],
out_specs=pl.BlockSpec((32, 128), lambda i, j: (i, j)),
scratch_shapes=[pltpu.VMEM((32, 128), dtype=jnp.float32)],
grid=(4, 4),
),
)
checked_call = checkify.checkify(pallas_call)
err, out = checked_call(x)
with self.assertRaisesRegex(
checkify.JaxRuntimeError, r'x_ref equals o_ref id=\(0, 0\)'):
err.throw()
np.testing.assert_allclose(out, x)
@parameterized.parameters((4,), (9,))
def test_checkify_with_dynamic_grid(self, iteration):
grid_size = 4
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = jnp.zeros_like(y_ref)
y_ref[...] += 1
@pl.when(pl.program_id(0) == iteration)
def _():
checkify.check(False, f"error on iteration {iteration}")
@jax.jit
def dynamic_kernel(steps):
pallas_call = self.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(shape, lambda i: (0, 0)),
out_shape=result_ty,
)
return checkify.checkify(pallas_call)()
err, result = dynamic_kernel(jnp.int32(grid_size))
if iteration < grid_size * 2:
with self.assertRaisesRegex(
checkify.JaxRuntimeError, f"error on iteration {iteration}"):
err.throw()
np.testing.assert_array_equal(
result, np.full(shape, grid_size * 2.0, np.float32)
)
class PallasCallTPUCheckifyInterpretTest(PallasCallTPUCheckifyTest):
INTERPRET: bool = True
class PrettyPrintingTest(PallasBaseTest):
@parameterized.parameters(
(
lambda i: (i, pl.ds(0, 8), pl.ds(0, 128)),
'dma_start c[d,:,:] -> e[...] f',
),
(
lambda i: (0, pl.ds(i, 8), pl.ds(0, 128)),
'dma_start c[0,d:d+8,:] -> e[...] f',
),
(
lambda i: (i, pl.ds(2, 4), pl.ds(0, 100)),
'dma_start c[d,2:6,:100] -> e[...] f',
),
(
lambda i: (i, pl.ds(2, 6), pl.ds(4, 100)),
'dma_start c[d,2:,4:104] -> e[...] f',
),
)
def test_dma_custom_pretty_print(self, indexer, expected):
def body(x_hbm_ref, i):
def inner(x_ref, sem):
pltpu.async_copy(x_hbm_ref.at[indexer(i)], x_ref, sem).wait()
pl.run_scoped(
inner, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA
)
return []
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.shaped_array_ref((2, 8, 128), jnp.int32),
jax.core.ShapedArray((), jnp.int32)]
)
self.assertIn(expected, jaxpr.pretty_print(use_color=False))
def only_passes_in_interpret(unless_generation: int | None = None):
def decorator(f):
def wrapper(self):
if self.INTERPRET or (
unless_generation is not None
and jtu.is_device_tpu_at_least(unless_generation)
):
f(self)
else:
with self.assertRaises(Exception):
f(self)
return wrapper
return decorator
class MiscellaneousTest(PallasBaseTest):
"""Tests for reported bugs. Only pass in interpret mode unless fixed."""
def test_float32_stack(self):
x = np.arange(128, dtype=jnp.float32).reshape(1, 128)
y = x + 128
def kernel(x_ref, y_ref, out_ref):
out_ref[...] = jnp.stack([x_ref[...], y_ref[...]], axis=1)
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((1, 2, 128), jnp.float32)
)(x, y)
np.testing.assert_array_equal(out, np.stack([x, y], axis=1))
@only_passes_in_interpret()
def test_lane_to_chunk_reshape_bf16(self):
"""b/348038320"""
x = np.arange(256 * 1024, dtype=jnp.bfloat16).reshape(1, 256, 1024)
def kernel(x_ref, out_ref):
out_ref[...] = jnp.reshape(x_ref[...], (1, 256, 8, 128))
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((1, 256, 8, 128), jnp.bfloat16)
)(x)
np.testing.assert_array_equal(out, np.reshape(x, (1, 256, 8, 128)))
def test_lane_to_chunk_broadcast_fp32(self):
x = np.arange(256 * 128, dtype=jnp.float32).reshape(1, 256, 128)
def kernel(x_ref, out_ref):
out_ref[...] = jnp.broadcast_to(
jnp.expand_dims(x_ref[...], 2), (1, 256, 8, 128)
)
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((1, 256, 8, 128), jnp.float32)
)(x)
np.testing.assert_array_equal(
out, np.broadcast_to(np.expand_dims(x, 2), (1, 256, 8, 128))
)
@only_passes_in_interpret()
def test_lane_dynamic_slice(self):
"""b/346849973"""
x = np.arange(128, dtype=jnp.float32)
def kernel(x_ref, out_ref):
out_ref[...] = lax.dynamic_slice_in_dim(x_ref[...], 64, 1, 0)
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((1,), jnp.float32)
)(x)
np.testing.assert_array_equal(out, x[64:65])
def test_lane_broadcast_bf16(self):
x = np.arange(256, dtype=jnp.bfloat16).reshape(256, 1)
def kernel(x_ref, out_ref):
out_ref[...] = jnp.broadcast_to(x_ref[...], (256, 512))
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((256, 512), jnp.bfloat16)
)(x)
np.testing.assert_array_equal(out, np.broadcast_to(x, (256, 512)))
def test_bfloat16_to_uint32_bitcast(self):
x = np.arange(16 * 2 * 256, dtype=jnp.bfloat16).reshape(16, 2, 256)
def kernel(x_ref, out_ref):
out_ref[...] = pltpu.bitcast(x_ref[...], jnp.uint32)
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((16, 1, 256), jnp.uint32)
)(x)
np.testing.assert_array_equal(out, state_utils.bitcast(x, jnp.uint32))
@only_passes_in_interpret()
def test_roll_partial(self):
"""b/337384645"""
x = np.arange(8192, dtype=jnp.float32).reshape(128, 64)
def kernel(x_ref, out_ref):
out_ref[...] = pltpu.roll(x_ref[...], 3, 1)
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((128, 64), jnp.float32)
)(x)
np.testing.assert_array_equal(out, np.roll(x, 3, 1))
@only_passes_in_interpret()
def test_retiling1(self):
"""b/352626602"""
x = np.arange(1024, dtype=jnp.bfloat16).reshape(1024)
def kernel(x_ref, out_ref):
out_ref[:, :] = jnp.reshape(x_ref[:].astype(jnp.float32), (8, 128))
out = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_array_equal(out, np.reshape(x, (8, 128)))
def test_retiling2(self):
x = np.arange(1 * 8 * 1024, dtype=jnp.bfloat16).reshape(1, 8, 1024)
def kernel(x_ref, out_ref):
out_ref[:, :, :] = jnp.reshape(
x_ref[:, 7, :].astype(jnp.float32), (1, 8, 128)
)
out = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((1, 8, 128), jnp.float32),
)(x)
np.testing.assert_array_equal(out, np.reshape(x[:, 7, :], (1, 8, 128)))
def test_sublane_adding_shape_cast_f32(self):
x = np.arange(8 * 128, dtype=jnp.float32).reshape(8, 128)
def kernel(x_ref, out_ref):
out_ref[:, 0, :] = x_ref[:, :]
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((8, 1, 128), jnp.float32)
)(x)
np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128)))
@only_passes_in_interpret()
def test_sublane_adding_shape_cast_bf16(self):
"""b/352833257"""
x = np.arange(8 * 128, dtype=jnp.bfloat16).reshape(8, 128)
def kernel(x_ref, out_ref):
out_ref[:, 0, :] = x_ref[:, :]
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((8, 1, 128), jnp.bfloat16)
)(x)
np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128)))
def test_mixed_strides(self):
x = np.zeros((8, 128), dtype=jnp.float32)
y = np.zeros((8, 2, 128), dtype=jnp.bfloat16)
def kernel(x_ref, y_ref, out_ref):
out_ref[:, :] = x_ref[:, :] + y_ref[:, 1, :].astype(jnp.float32)
out = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x, y)
np.testing.assert_array_equal(out, np.zeros((8, 128), dtype=jnp.float32))
def test_sum(self):
x = np.zeros((8, 2, 8, 128), dtype=jnp.float32)
def kernel(x_ref, out_ref):
out_ref[:, :, :] = jnp.sum(x_ref[:, :, :, :], 2)
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((8, 2, 128), jnp.float32)
)(x)
np.testing.assert_array_equal(out, np.zeros((8, 2, 128), dtype=jnp.float32))
@only_passes_in_interpret()
def test_transpose(self):
"""b/356475128"""
x = np.zeros((8, 2, 8, 128), dtype=jnp.float32)
def kernel(x_ref, out_ref):
out_ref[:, :, :, :] = jnp.transpose(x_ref[:, :, :, :], (0, 2, 1, 3))
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((8, 8, 2, 128), jnp.float32)
)(x)
np.testing.assert_array_equal(
out, np.zeros((8, 8, 2, 128), dtype=jnp.float32)
)
class MiscellaneousInterpretTest(MiscellaneousTest):
INTERPRET: bool = True
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())