jax/pallas support ellipsis indexing

PiperOrigin-RevId: 634922391
This commit is contained in:
jax authors 2024-05-17 16:57:05 -07:00 committed by jax authors
parent 02c19e9600
commit 641d5c8be3
2 changed files with 68 additions and 4 deletions

View File

@ -17,7 +17,7 @@
from __future__ import annotations
import dataclasses
from typing import Any, Union
from typing import Any, Union, List
from jax._src import core
from jax._src import tree_util
@ -190,9 +190,17 @@ class NDIndexer:
if len(indices) == 1 and indices[0] is ...:
indices = (slice(None),) * len(shape)
if any(idx is ... for idx in indices):
# TODO(sharadmv,mattjj): support patterns that include ellipsis in them
# e.g. x[0, ..., 1].
raise NotImplementedError("Ellipsis in indexer not supported yet.")
new_indices : List[Any] = []
num_ellipsis = sum(1 for idx in indices if idx is ...)
if num_ellipsis > 1:
raise ValueError("Only one ellipsis is supported.")
for idx in indices:
if idx is ...:
expand = (slice(None),) * (len(shape) - len(indices) + 1)
new_indices.extend(expand)
else:
new_indices.append(idx)
indices = tuple(new_indices)
if len(indices) > len(shape):
raise ValueError("`indices` must not be longer than `shape`: "
f"{indices=}, {shape=}")

View File

@ -246,5 +246,61 @@ class IndexerTest(jtu.JaxTestCase):
interpret=True,
)(x, y)
def test_ellipsis_indexing_iterpret_only(self):
# Interpreter only test! YMMV actually compiling this.
def permute_columns_in_row_kernel(left, right, new_left, new_right):
shape = left.shape
k = shape[-1]
ndim = len(shape)
left_slices = [
left[..., :1],
right[..., :1],
left[..., 1:k-1]
]
right_slices = [
right[..., 1:k],
left[..., k-1:k]
]
new_left[...] = np.concatenate(left_slices, axis=ndim - 1)
new_right[...] = np.concatenate(right_slices, axis=ndim - 1)
left = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32)
right = jnp.array([[7, 8, 9], [10, 11, 12]], dtype=jnp.float32)
output_shape = left.shape
# hack to reuse the same fn for np cat
import jax.numpy as np # noqa: F811
left_out, right_out = pl.pallas_call(
permute_columns_in_row_kernel,
grid=(1,),
out_shape=[
jax.ShapeDtypeStruct(output_shape, jnp.float32),
jax.ShapeDtypeStruct(output_shape, jnp.float32)
],
in_specs=[
pl.BlockSpec(lambda i: (0, 0), left.shape),
pl.BlockSpec(lambda i: (0, 0), right.shape)
],
out_specs=[
pl.BlockSpec(lambda i: (0, 0), output_shape),
pl.BlockSpec(lambda i: (0, 0), output_shape)
],
interpret=True,
)(left, right)
import numpy as np # noqa: F811
left_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
right_np = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.float32)
left_out_np = left_np.copy()
right_out_np = right_np.copy()
permute_columns_in_row_kernel(left_np, right_np, left_out_np, right_out_np)
np.testing.assert_array_equal(left_out_np, left_out)
np.testing.assert_array_equal(right_out_np, right_out)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())