mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 08:46:06 +00:00
733 lines
25 KiB
Python
733 lines
25 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for Pallas indexing logic and abstractions."""
|
|
|
|
from __future__ import annotations
|
|
import sys
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import jax
|
|
from jax import random
|
|
from jax._src import test_util as jtu
|
|
from jax._src import util
|
|
from jax._src.state import indexing
|
|
import numpy as np
|
|
import jax.numpy as jnp
|
|
from jax.experimental import pallas as pl
|
|
|
|
if sys.platform != "win32":
|
|
from jax.experimental.pallas import tpu as pltpu
|
|
else:
|
|
pltpu = None
|
|
|
|
try:
|
|
import hypothesis as hp
|
|
except (ModuleNotFoundError, ImportError):
|
|
raise unittest.SkipTest("tests depend on hypothesis library")
|
|
|
|
import hypothesis.extra.numpy as hnp
|
|
import hypothesis.strategies as hps
|
|
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
jtu.setup_hypothesis(max_examples=100)
|
|
|
|
|
|
Slice = indexing.Slice
|
|
NDIndexer = indexing.NDIndexer
|
|
ds = indexing.ds
|
|
|
|
|
|
_INDEXING_TEST_CASES = [
|
|
((4, 8, 128), (...,), (4, 8, 128)),
|
|
((4, 8, 128), (0,), (8, 128)),
|
|
((4, 8, 128), (pl.ds(1, 2),), (2, 8, 128)),
|
|
((4, 8, 128), (pl.ds(2, 2),), (2, 8, 128)),
|
|
((4, 8, 128), (pl.ds(2, 2), 0), (8, 128)),
|
|
((4, 8, 128), (pl.ds(2, 2), 1), (8, 128)),
|
|
((4, 8, 128), (slice(2, 4), 1), (8, 128)),
|
|
((4, 8, 128), (slice(2, 4), slice(0, 1), 0), (8, 128)),
|
|
((4, 8, 128), ((0, pl.ds(0, 8), pl.ds(0, 128)), ...), (8, 128)),
|
|
((4, 8, 128), (..., (0, pl.ds(0, 8), pl.ds(0, 128)), ...), (8, 128)),
|
|
]
|
|
|
|
|
|
def _maybe_ds_to_slice(x: int | slice | indexing.Slice) -> int | slice:
|
|
if isinstance(x, indexing.Slice):
|
|
return slice(x.start, x.start + x.size)
|
|
return x
|
|
|
|
|
|
def int_indexer_strategy(dim) -> hps.SearchStrategy[int]:
|
|
return hps.integers(min_value=np.iinfo(np.int32).min, max_value=dim - 1)
|
|
|
|
|
|
@hps.composite
|
|
def slice_indexer_strategy(draw, dim) -> Slice | slice:
|
|
start = draw(int_indexer_strategy(dim))
|
|
max_size = dim - start
|
|
size = draw(hps.integers(min_value=0, max_value=max_size))
|
|
return draw(
|
|
hps.one_of(
|
|
hps.just(Slice(start, size)), hps.just(slice(start, start + size))
|
|
)
|
|
)
|
|
|
|
|
|
@hps.composite
|
|
def array_indexer_strategy(draw, shape) -> jax.Array:
|
|
unbcast = [draw(hps.booleans()) for _ in shape]
|
|
shape = tuple(1 if unb else s for unb, s in zip(unbcast, shape))
|
|
return draw(hnp.arrays(dtype=np.dtype("int32"), shape=shape))
|
|
|
|
|
|
@hps.composite
|
|
def indexer_strategy(draw, dim, int_indexer_shape
|
|
) -> int | Slice | jax.Array:
|
|
return draw(hps.one_of(
|
|
int_indexer_strategy(dim),
|
|
slice_indexer_strategy(dim),
|
|
array_indexer_strategy(int_indexer_shape),
|
|
))
|
|
|
|
|
|
@hps.composite
|
|
def nd_indexer_strategy(draw, shape) -> NDIndexer:
|
|
num_indices = draw(hps.integers(min_value=0, max_value=len(shape)))
|
|
int_indexer_shape = draw(hnp.array_shapes())
|
|
indices = tuple(draw(indexer_strategy(dim, int_indexer_shape))
|
|
for dim in shape[:num_indices])
|
|
return NDIndexer.from_indices_shape(indices, shape)
|
|
|
|
|
|
class PallasBaseTest(jtu.JaxTestCase):
|
|
INTERPRET = False
|
|
|
|
def setUp(self):
|
|
if not self.INTERPRET:
|
|
if not jtu.test_device_matches(["tpu"]):
|
|
self.skipTest("Only interpret mode supported on non-TPU")
|
|
|
|
super().setUp()
|
|
|
|
@classmethod
|
|
def pallas_call(cls, *args, **kwargs):
|
|
return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs)
|
|
|
|
|
|
class IndexerTest(jtu.JaxTestCase):
|
|
"""These are unit tests for the indexer logic, not using pallas_call."""
|
|
|
|
def test_simple_ndindexer(self):
|
|
indices = (0, 0)
|
|
shape = (5, 5)
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), ())
|
|
|
|
def test_invalid_ndindexer(self):
|
|
indices = (0, 0, 0)
|
|
shape = (5, 5)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "`indices` must not be longer than `shape`"
|
|
):
|
|
_ = NDIndexer.from_indices_shape(indices, shape)
|
|
|
|
@parameterized.parameters(
|
|
((4, 0), (3, 5)),
|
|
((slice(3, 2), 0), (3, 5)),
|
|
((Slice(2, 2), 0), (3, 5)),
|
|
)
|
|
def test_invalid_ndindexer_oob(self, indices, shape):
|
|
with self.assertRaisesRegex(ValueError, "Out of bound"):
|
|
_ = NDIndexer.from_indices_shape(indices, shape)
|
|
|
|
def test_ndindexer_with_padding(self):
|
|
indices = ()
|
|
shape = (5, 5)
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), shape)
|
|
|
|
def test_ndindexer_with_ellipsis(self):
|
|
indices = (..., 4)
|
|
shape = (5, 5)
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), (5,))
|
|
|
|
def test_ndindexer_with_slices(self):
|
|
indices = (slice(2, 3), slice(4, 7))
|
|
shape = (5, 6)
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), (1, 2))
|
|
|
|
def test_ndindexer_with_arrays(self):
|
|
indices = (np.arange(10), np.arange(10))
|
|
shape = (5, 5)
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), (10,))
|
|
|
|
indices = (np.ones((10, 20)), np.ones((10, 20)))
|
|
shape = (5, 5)
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), (10, 20))
|
|
|
|
def test_ndindexer_with_arrays_and_broadcasting(self):
|
|
indices = (np.arange(10)[None], np.arange(20)[:, None])
|
|
shape = (5, 5)
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), (20, 10))
|
|
|
|
indices = (np.arange(10)[:, None], np.arange(20)[None, :])
|
|
shape = (5, 5)
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), (10, 20))
|
|
|
|
def test_ndindexer_with_arrays_and_invalid_broadcasting(self):
|
|
indices = (np.arange(10)[None], np.arange(20)[None, :])
|
|
shape = (5, 5)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Cannot broadcast shapes for indexing"
|
|
):
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
|
|
def test_indexer_with_all_types(self):
|
|
indices = (0, slice(10), np.arange(5))
|
|
shape = (2, 3, 4)
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 3))
|
|
|
|
indices = (0, slice(2, 10), np.arange(5))
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 1))
|
|
|
|
indices = (0, 1, np.arange(5))
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), (5,))
|
|
|
|
indices = (ds(0, 2), np.arange(5)[:, None], np.arange(4)[None])
|
|
indexer = NDIndexer.from_indices_shape(indices, shape)
|
|
self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 2))
|
|
|
|
@hp.given(hps.data())
|
|
def test_ndindexer(self, data):
|
|
shape = data.draw(hnp.array_shapes())
|
|
indexer = data.draw(nd_indexer_strategy(shape))
|
|
is_int_indexer = [not isinstance(idx, Slice) for idx in indexer.indices]
|
|
rest_indexers, int_indexers = util.partition_list(
|
|
is_int_indexer, indexer.indices
|
|
)
|
|
if int_indexers:
|
|
expected_int_indexer_shape = int_indexers[0].shape
|
|
else:
|
|
expected_int_indexer_shape = ()
|
|
self.assertTupleEqual(
|
|
indexer.int_indexer_shape, expected_int_indexer_shape
|
|
)
|
|
for idx in rest_indexers:
|
|
self.assertIsInstance(idx, (np.ndarray, Slice))
|
|
if isinstance(idx, np.ndarray):
|
|
self.assertTupleEqual(idx.shape, ())
|
|
self.assertEqual(idx.dtype, np.dtype("int32"))
|
|
rest_shape = tuple(
|
|
r.size for r in rest_indexers if not isinstance(r, np.ndarray)
|
|
)
|
|
self.assertTupleEqual((*indexer.int_indexer_shape, *rest_shape),
|
|
indexer.get_indexer_shape())
|
|
|
|
|
|
class IndexerOpsTest(PallasBaseTest):
|
|
|
|
def test_multi_indexing_interpreter_only(self):
|
|
if not self.INTERPRET:
|
|
self.skipTest("Only supported in interpret mode")
|
|
# Interpret only test! YMMV actually compiling this.
|
|
def permute(left, right, left_out_ref, right_out_ref):
|
|
left_out = jnp.zeros_like(left)
|
|
left_out = left_out.at[:, 0].set(left[:, 0])
|
|
left_out = left_out.at[:, 1].set(right[:, 0])
|
|
left_out = left_out.at[:, 2:].set(left[:, 1:-1])
|
|
|
|
right_out = jnp.zeros_like(right)
|
|
right_out = right_out.at[:, :-1].set(right[:, 1:])
|
|
right_out = right_out.at[:, -1].set(left[:, -1])
|
|
|
|
left_out_ref[...] = left_out
|
|
right_out_ref[...] = right_out
|
|
|
|
def invoke_permutes(x_ref, y_ref, x_out_ref, y_out_ref):
|
|
shape = x_ref.shape
|
|
_, n = shape[-2], shape[-1]
|
|
x_ref = x_ref.at[: n // 2, : n // 2]
|
|
y_ref = y_ref.at[: n // 2, : n // 2]
|
|
x_out_ref = x_out_ref.at[: n // 2, : n // 2]
|
|
y_out_ref = y_out_ref.at[: n // 2, : n // 2]
|
|
permute(x_ref, y_ref, x_out_ref, y_out_ref)
|
|
|
|
n = 8
|
|
x = jnp.ones([n, n])
|
|
y = jnp.ones([n, n])
|
|
jitted_permute = jax.jit(invoke_permutes)
|
|
grid = (1,)
|
|
pl.pallas_call(
|
|
jitted_permute,
|
|
grid=grid,
|
|
out_shape=[
|
|
jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
jax.ShapeDtypeStruct(x.shape, y.dtype),
|
|
],
|
|
in_specs=[
|
|
pl.BlockSpec(x.shape, lambda i: (0, 0)),
|
|
pl.BlockSpec(y.shape, lambda i: (0, 0)),
|
|
],
|
|
out_specs=[
|
|
pl.BlockSpec(x.shape, lambda i: (0, 0)),
|
|
pl.BlockSpec(y.shape, lambda i: (0, 0)),
|
|
],
|
|
interpret=True,
|
|
)(x, y)
|
|
|
|
def test_multi_indexing_destination_ref(self):
|
|
if not self.INTERPRET:
|
|
self.skipTest("Only supported in interpret mode")
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = jnp.zeros_like(o_ref)
|
|
new_o_ref = o_ref.at[pl.ds(0, 8)].at[0].at[pl.ds(0, 4), pl.ds(0, 4)]
|
|
new_o_ref[...] = x_ref[...]
|
|
|
|
x = jax.random.normal(jax.random.key(0), shape=(4, 4))
|
|
result = pl.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct((16, 16, 16), x.dtype),
|
|
interpret=True,
|
|
)(x)
|
|
expected = jnp.zeros((16, 16, 16)).at[0, 0:4, 0:4].set(x)
|
|
np.testing.assert_array_equal(result, expected)
|
|
|
|
def test_ellipsis_indexing_iterpret_only(self):
|
|
if not self.INTERPRET:
|
|
self.skipTest("Only supported in interpret mode")
|
|
# Interpret only test! YMMV actually compiling this.
|
|
def permute_columns_in_row_kernel(left, right, new_left, new_right):
|
|
shape = left.shape
|
|
k = shape[-1]
|
|
ndim = len(shape)
|
|
left_slices = [
|
|
left[..., :1],
|
|
right[..., :1],
|
|
left[..., 1:k-1]
|
|
]
|
|
right_slices = [
|
|
right[..., 1:k],
|
|
left[..., k-1:k]
|
|
]
|
|
new_left[...] = np.concatenate(left_slices, axis=ndim - 1)
|
|
new_right[...] = np.concatenate(right_slices, axis=ndim - 1)
|
|
|
|
left = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32)
|
|
right = jnp.array([[7, 8, 9], [10, 11, 12]], dtype=jnp.float32)
|
|
|
|
output_shape = left.shape
|
|
|
|
# hack to reuse the same fn for np cat
|
|
import jax.numpy as np # noqa: F811
|
|
left_out, right_out = pl.pallas_call(
|
|
permute_columns_in_row_kernel,
|
|
grid=(1,),
|
|
out_shape=[
|
|
jax.ShapeDtypeStruct(output_shape, jnp.float32),
|
|
jax.ShapeDtypeStruct(output_shape, jnp.float32)
|
|
],
|
|
in_specs=[
|
|
pl.BlockSpec(left.shape, lambda i: (0, 0)),
|
|
pl.BlockSpec(right.shape, lambda i: (0, 0))
|
|
],
|
|
out_specs=[
|
|
pl.BlockSpec(output_shape, lambda i: (0, 0)),
|
|
pl.BlockSpec(output_shape, lambda i: (0, 0))
|
|
],
|
|
interpret=True,
|
|
)(left, right)
|
|
|
|
import numpy as np # noqa: F811
|
|
left_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
|
|
right_np = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.float32)
|
|
left_out_np = left_np.copy()
|
|
right_out_np = right_np.copy()
|
|
|
|
permute_columns_in_row_kernel(left_np, right_np, left_out_np, right_out_np)
|
|
np.testing.assert_array_equal(left_out_np, left_out)
|
|
np.testing.assert_array_equal(right_out_np, right_out)
|
|
|
|
@hp.given(hps.data())
|
|
def test_vmap_nd_indexing(self, data):
|
|
self.skipTest("TODO(necula): enable this test; was in jax_triton.")
|
|
vmap_shape = data.draw(hnp.array_shapes(min_dims=1, max_dims=3, min_side=2),
|
|
label="vmap_shape")
|
|
el_shape = data.draw(hnp.array_shapes(min_dims=2), label="el_shape")
|
|
# TODO(sharadmv,apaszke): enable rank 0 and rank 1 Refs
|
|
# hp.assume(len(el_shape) >= 2)
|
|
nd_indexer = data.draw(nd_indexer_strategy(el_shape), label="nd_indexer")
|
|
expected_shape = jax.eval_shape(lambda x: x[nd_indexer],
|
|
jax.ShapeDtypeStruct(el_shape, jnp.float32))
|
|
|
|
ref = lambda x: x[nd_indexer]
|
|
def kernel(x_ref, y_ref):
|
|
x = pl.load(x_ref, nd_indexer)
|
|
pl.store(y_ref, (slice(None),) * len(y_ref.shape), x)
|
|
func = pl.pallas_call(kernel, out_shape=expected_shape)
|
|
|
|
shape = el_shape
|
|
for vmap_dim in vmap_shape[::-1]:
|
|
index = data.draw(hps.integers(min_value=0,
|
|
max_value=max(0, len(shape) - 2)),
|
|
label="index")
|
|
# hp.assume(index <= max(0, len(shape) - 2))
|
|
# TODO(sharadmv,apaszke): enable vmapping over batch axes in 2 minormost
|
|
# dimensions
|
|
shape = (*shape[:index], vmap_dim, *shape[index:])
|
|
ref = jax.vmap(ref, in_axes=index, out_axes=0)
|
|
func = jax.vmap(func, in_axes=index, out_axes=0)
|
|
key = random.PRNGKey(0)
|
|
x = random.normal(key, shape, dtype=jnp.float32)
|
|
expected = ref(x)
|
|
y = func(x)
|
|
np.testing.assert_array_equal(y, expected)
|
|
|
|
@parameterized.product(
|
|
indexer_type=["state", "pallas"],
|
|
case=_INDEXING_TEST_CASES,
|
|
)
|
|
def test_can_load_with_ref_at(self, indexer_type, case):
|
|
if self.INTERPRET:
|
|
self.skipTest("TODO: fails in interpret mode.")
|
|
in_shape, indexers, out_shape = case
|
|
dtype = jnp.float32
|
|
def body(x_ref, y_ref):
|
|
for indexer in indexers[:-1]:
|
|
x_ref = x_ref.at[indexer]
|
|
if indexer_type == "state":
|
|
x = x_ref[indexers[-1]]
|
|
y_ref[...] = x
|
|
elif indexer_type == "pallas":
|
|
x = pl.load(x_ref, indexers[-1])
|
|
pl.store(y_ref, ..., x)
|
|
|
|
x = random.normal(random.key(0), in_shape, dtype=dtype)
|
|
y = x
|
|
for indexer in indexers:
|
|
if not isinstance(indexer, tuple):
|
|
indexer = (indexer,)
|
|
indexer = tuple(map(_maybe_ds_to_slice, indexer))
|
|
y = y[indexer]
|
|
assert y.shape == out_shape
|
|
out = self.pallas_call(body, out_shape=y)(x)
|
|
self.assertAllClose(out, y)
|
|
|
|
@parameterized.product(
|
|
indexer_type=["state", "pallas"],
|
|
case=_INDEXING_TEST_CASES,
|
|
)
|
|
def test_can_store_with_ref_at(self, indexer_type, case):
|
|
if self.INTERPRET:
|
|
self.skipTest("TODO: fails in interpret mode.")
|
|
in_shape, indexers, val_shape = case
|
|
dtype = jnp.float32
|
|
def body(x_ref, y_ref):
|
|
y_ref[...] = jnp.zeros_like(y_ref)
|
|
for indexer in indexers[:-1]:
|
|
y_ref = y_ref.at[indexer]
|
|
if indexer_type == "state":
|
|
x = x_ref[...]
|
|
y_ref[indexers[-1]] = x
|
|
elif indexer_type == "pallas":
|
|
x = pl.load(x_ref, ...)
|
|
pl.store(y_ref, indexers[-1], x)
|
|
|
|
val = random.normal(random.key(0), val_shape, dtype=dtype)
|
|
# Use NumPy arrays to do nested indexing and mutation. This is really
|
|
# annoying to do in vanilla JAX.
|
|
x = np.zeros(in_shape, dtype=dtype)
|
|
y = x
|
|
for indexer in indexers:
|
|
if not isinstance(indexer, tuple):
|
|
indexer = (indexer,)
|
|
indexer = tuple(map(_maybe_ds_to_slice, indexer))
|
|
y = y[indexer]
|
|
assert y.shape == val_shape
|
|
y[...] = val
|
|
out = self.pallas_call(body, out_shape=x)(val)
|
|
self.assertAllClose(out, x)
|
|
|
|
@parameterized.product(
|
|
indexer_type=["state", "pallas"],
|
|
slice_type=["slice", "ds"],
|
|
)
|
|
@hp.given(
|
|
ref_shape=hps.sampled_from(((8, 8, 32), (7, 7, 33))),
|
|
indices=hps.tuples(
|
|
hps.integers(0, 6), hps.integers(0, 6), hps.integers(0, 31)
|
|
),
|
|
strides=hps.tuples(
|
|
hps.integers(1, 10), hps.integers(1, 10), hps.integers(1, 10)
|
|
),
|
|
)
|
|
def test_strided_load_and_store(
|
|
self, indexer_type, slice_type, ref_shape, indices, strides
|
|
):
|
|
if self.INTERPRET:
|
|
self.skipTest("TODO: fails in interpret mode.")
|
|
ref_shape = (*ref_shape, 128)
|
|
indices = (*indices, 0)
|
|
strides = (*strides, 1)
|
|
vec_shape = [
|
|
(l - i + s - 1) // s for l, i, s in zip(ref_shape, indices, strides)
|
|
]
|
|
dtype = jnp.float32
|
|
|
|
def body(x_ref, y_ref1, y_ref2):
|
|
if slice_type == "slice":
|
|
slices = tuple(
|
|
slice(i, rs, s) for i, rs, s in zip(indices, ref_shape, strides)
|
|
)
|
|
else:
|
|
slices = tuple(
|
|
pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides)
|
|
)
|
|
if indexer_type == "state":
|
|
y_ref1[...] = x_ref[slices]
|
|
y_ref2[slices] = y_ref1[...]
|
|
elif indexer_type == "pallas":
|
|
pl.store(y_ref1, ..., pl.load(x_ref, slices))
|
|
pl.store(y_ref2, slices, pl.load(y_ref1, ...))
|
|
|
|
x = random.normal(random.key(0), ref_shape, dtype=dtype)
|
|
y1, y2 = self.pallas_call(
|
|
body,
|
|
out_shape=[
|
|
jax.ShapeDtypeStruct(vec_shape, dtype),
|
|
jax.ShapeDtypeStruct(ref_shape, dtype),
|
|
],
|
|
)(x)
|
|
slices = tuple(
|
|
slice(i, l, s) for l, i, s in zip(ref_shape, indices, strides)
|
|
)
|
|
expected = x[slices]
|
|
self.assertAllClose(y1, expected, err_msg="Strided Load Error")
|
|
self.assertAllClose(
|
|
y2[slices], expected, err_msg="Strided Store Error"
|
|
)
|
|
|
|
def test_load_with_dynamic_2nd_minor_index(self):
|
|
if pltpu is None:
|
|
self.skipTest("No TPU module available.")
|
|
# We can take any dynamic index on the 2nd minor dimension as long as
|
|
# the minormost dimsize is vreg lane count.
|
|
m, n = 32, 128
|
|
k = 10
|
|
start = 2
|
|
|
|
def kernel(x_ref, indices, y_ref):
|
|
y_ref[...] = pl.load(x_ref, pl.ds(indices[0], k))
|
|
|
|
x = jnp.arange(m * n, dtype=jnp.int32).reshape((m, n))
|
|
indices = jnp.array([start])
|
|
|
|
res = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct((k, n), jnp.int32),
|
|
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
num_scalar_prefetch=0,
|
|
in_specs=[
|
|
pl.BlockSpec(memory_space=pltpu.VMEM),
|
|
pl.BlockSpec(memory_space=pltpu.SMEM),
|
|
],
|
|
),
|
|
)(x, indices)
|
|
self.assertAllClose(res, x[start : start + k, :], atol=0., rtol=0.)
|
|
|
|
def test_store_with_dynamic_2nd_minor_index(self):
|
|
if pltpu is None:
|
|
self.skipTest("No TPU module available.")
|
|
# We can take any dynamic index on the 2nd minor dimension as long as
|
|
# the minormost dimsize is vreg lane count.
|
|
m, n = 10, 128
|
|
k = 32
|
|
start = 2
|
|
|
|
def kernel(x_ref, indices, y_ref):
|
|
pl.store(y_ref, pl.ds(indices[0], m), x_ref[...])
|
|
|
|
x = jnp.arange(m * n, dtype=jnp.int32).reshape((m, n))
|
|
indices = jnp.array([start])
|
|
|
|
res = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct((k, n), jnp.int32),
|
|
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
num_scalar_prefetch=0,
|
|
in_specs=[
|
|
pl.BlockSpec(memory_space=pltpu.VMEM),
|
|
pl.BlockSpec(memory_space=pltpu.SMEM),
|
|
],
|
|
),
|
|
)(x, indices)
|
|
self.assertAllClose(res[start : start + m, :], x, atol=0., rtol=0.)
|
|
|
|
def test_load_one_row_with_dynamic_2nd_minor_index(self):
|
|
if pltpu is None:
|
|
self.skipTest("No TPU module available.")
|
|
# This test triggers strided load. We can take any dynamic index on the
|
|
# 2nd minor dimension as long as we load one row on the 2nd minor dim.
|
|
b, m, n = 4, 16, 256
|
|
start = 3
|
|
|
|
def kernel(x_ref, indices, y_ref):
|
|
y_ref[...] = x_ref[:, pl.ds(indices[0], 1), :]
|
|
|
|
x = jnp.arange(b * m * n, dtype=jnp.int32).reshape((b, m, n))
|
|
indices = jnp.array([start])
|
|
|
|
res = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct((b, 1, n), jnp.int32),
|
|
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
num_scalar_prefetch=0,
|
|
in_specs=[
|
|
pl.BlockSpec(memory_space=pltpu.VMEM),
|
|
pl.BlockSpec(memory_space=pltpu.SMEM),
|
|
],
|
|
),
|
|
)(x, indices)
|
|
self.assertAllClose(res, x[:, start : start + 1, :], atol=0., rtol=0.)
|
|
|
|
def test_store_one_row_with_dynamic_2nd_minor_index(self):
|
|
if pltpu is None:
|
|
self.skipTest("No TPU module available.")
|
|
# This test triggers strided store. We can take any dynamic index on the
|
|
# 2nd minor dimension as long as we store one row on the 2nd minor dim.
|
|
b, m, n = 4, 16, 256
|
|
start = 3
|
|
|
|
def kernel(x_ref, indices, y_ref):
|
|
y_ref[:, pl.ds(indices[0], 1), :] = x_ref[...]
|
|
|
|
x = jnp.arange(b * 1 * n, dtype=jnp.int32).reshape((b, 1, n))
|
|
indices = jnp.array([start])
|
|
|
|
res = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct((b, m, n), jnp.int32),
|
|
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
num_scalar_prefetch=0,
|
|
in_specs=[
|
|
pl.BlockSpec(memory_space=pltpu.VMEM),
|
|
pl.BlockSpec(memory_space=pltpu.SMEM),
|
|
],
|
|
),
|
|
)(x, indices)
|
|
self.assertAllClose(res[:, start : start + 1, :], x, atol=0., rtol=0.)
|
|
|
|
|
|
class IndexerOpsInterpretTest(IndexerOpsTest):
|
|
INTERPRET = True
|
|
|
|
|
|
# TODO(ayx): Fix all test cases here
|
|
_ADVANCED_INDEXER_TEST_CASES = [
|
|
# integer
|
|
((3, 2), lambda arr, a, b, c, d: arr[2]),
|
|
# slice
|
|
((12, 12), lambda arr, a, b, c, d: arr[::4, ::4]),
|
|
((16, 16), lambda arr, a, b, c, d: arr[1:14:2, 2:13:4]),
|
|
((8, 2), lambda arr, a, b, c, d: arr[1::3, :]),
|
|
# array
|
|
((4, 3), lambda arr, a, b, c, d: arr[a]),
|
|
((4, 3, 2), lambda arr, a, b, c, d: arr[c, c]),
|
|
# integer + 1-D array
|
|
((4, 3), lambda arr, a, b, c, d: arr[2, a]),
|
|
((4, 3), lambda arr, a, b, c, d: arr[a, 2]),
|
|
# slice + 1-D array
|
|
((4, 3), lambda arr, a, b, c, d: arr[a, :]),
|
|
# ((4, 3), lambda arr, a, b, c, d: arr[:, a]),
|
|
((6, 8, 3), lambda arr, a, b, c, d: arr[c, ::3]),
|
|
# ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]),
|
|
# ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]),
|
|
# ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]),
|
|
((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]),
|
|
((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]),
|
|
# ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, a, ::2]),
|
|
((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]),
|
|
((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, a]),
|
|
# slice + array w/ broadcasting
|
|
((8, 8, 3, 6), lambda arr, a, b, c, d: \
|
|
arr[b[:, None], ::4, a[None], a[:, None]]),
|
|
# integer + slice + 1-D array
|
|
((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, ::2, a]),
|
|
((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, a, ::2]),
|
|
# boolean
|
|
# ((6, 2), lambda arr, a, b, c, d: arr[d]),
|
|
# ((8, 6), lambda arr, a, b, c, d: arr[::4, d]),
|
|
]
|
|
|
|
|
|
class AdvancedIndexerOpsTest(PallasBaseTest):
|
|
|
|
def setUp(self):
|
|
if jtu.test_device_matches(["tpu"]):
|
|
self.skipTest("Advanced indexers are not supported on TPU")
|
|
|
|
# 4 arrays that are used in test cases of advanced indexing
|
|
self.a = jnp.array([1, 1, 1, 1, 1], dtype=jnp.int32)
|
|
self.b = jnp.array([1, 2, 2, 2, 2], dtype=jnp.int32)
|
|
self.c = jnp.array([1, 0, 2, 2, -1, 1], dtype=jnp.int32)
|
|
self.d = jnp.array([1, 0, 0, 0, 0, 1], dtype=jnp.bool_)
|
|
|
|
super().setUp()
|
|
|
|
@parameterized.parameters(_ADVANCED_INDEXER_TEST_CASES)
|
|
def test_advanced_indexer(self, in_shape: tuple[int, ...], indexing_func):
|
|
a, b, c, d = self.a, self.b, self.c, self.d
|
|
|
|
x = jnp.arange(np.prod(in_shape), dtype=jnp.float32).reshape(in_shape)
|
|
y = indexing_func(x, a, b, c, d)
|
|
|
|
# `a_ref`, `b_ref`, `c_ref` and `d_ref` are for testing purposes.
|
|
# We have them here because we need to have a unified function signature
|
|
# for all test cases, even if the arrays are actually not used in any
|
|
# computation.
|
|
def kernel(x_ref, a_ref, b_ref, c_ref, d_ref, o_ref):
|
|
a = a_ref[...]
|
|
b = b_ref[...]
|
|
c = c_ref[...]
|
|
d = d_ref[...]
|
|
o = indexing_func(x_ref, a, b, c, d)
|
|
o_ref[...] = o
|
|
|
|
y_ = self.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct(y.shape, jnp.float32),
|
|
)(x, a, b, c, d)
|
|
|
|
np.testing.assert_array_equal(y_, y)
|
|
|
|
|
|
class AdvancedIndexerOpsInterpretTest(AdvancedIndexerOpsTest):
|
|
INTERPRET = True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|