mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
jax/pallas support ellipsis indexing
PiperOrigin-RevId: 634922391
This commit is contained in:
parent
02c19e9600
commit
641d5c8be3
@ -17,7 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, List
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import tree_util
|
||||
@ -190,9 +190,17 @@ class NDIndexer:
|
||||
if len(indices) == 1 and indices[0] is ...:
|
||||
indices = (slice(None),) * len(shape)
|
||||
if any(idx is ... for idx in indices):
|
||||
# TODO(sharadmv,mattjj): support patterns that include ellipsis in them
|
||||
# e.g. x[0, ..., 1].
|
||||
raise NotImplementedError("Ellipsis in indexer not supported yet.")
|
||||
new_indices : List[Any] = []
|
||||
num_ellipsis = sum(1 for idx in indices if idx is ...)
|
||||
if num_ellipsis > 1:
|
||||
raise ValueError("Only one ellipsis is supported.")
|
||||
for idx in indices:
|
||||
if idx is ...:
|
||||
expand = (slice(None),) * (len(shape) - len(indices) + 1)
|
||||
new_indices.extend(expand)
|
||||
else:
|
||||
new_indices.append(idx)
|
||||
indices = tuple(new_indices)
|
||||
if len(indices) > len(shape):
|
||||
raise ValueError("`indices` must not be longer than `shape`: "
|
||||
f"{indices=}, {shape=}")
|
||||
|
@ -246,5 +246,61 @@ class IndexerTest(jtu.JaxTestCase):
|
||||
interpret=True,
|
||||
)(x, y)
|
||||
|
||||
def test_ellipsis_indexing_iterpret_only(self):
|
||||
# Interpreter only test! YMMV actually compiling this.
|
||||
def permute_columns_in_row_kernel(left, right, new_left, new_right):
|
||||
shape = left.shape
|
||||
k = shape[-1]
|
||||
ndim = len(shape)
|
||||
left_slices = [
|
||||
left[..., :1],
|
||||
right[..., :1],
|
||||
left[..., 1:k-1]
|
||||
]
|
||||
right_slices = [
|
||||
right[..., 1:k],
|
||||
left[..., k-1:k]
|
||||
]
|
||||
new_left[...] = np.concatenate(left_slices, axis=ndim - 1)
|
||||
new_right[...] = np.concatenate(right_slices, axis=ndim - 1)
|
||||
|
||||
left = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32)
|
||||
right = jnp.array([[7, 8, 9], [10, 11, 12]], dtype=jnp.float32)
|
||||
|
||||
output_shape = left.shape
|
||||
|
||||
# hack to reuse the same fn for np cat
|
||||
import jax.numpy as np # noqa: F811
|
||||
left_out, right_out = pl.pallas_call(
|
||||
permute_columns_in_row_kernel,
|
||||
grid=(1,),
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct(output_shape, jnp.float32),
|
||||
jax.ShapeDtypeStruct(output_shape, jnp.float32)
|
||||
],
|
||||
in_specs=[
|
||||
pl.BlockSpec(lambda i: (0, 0), left.shape),
|
||||
pl.BlockSpec(lambda i: (0, 0), right.shape)
|
||||
],
|
||||
out_specs=[
|
||||
pl.BlockSpec(lambda i: (0, 0), output_shape),
|
||||
pl.BlockSpec(lambda i: (0, 0), output_shape)
|
||||
],
|
||||
interpret=True,
|
||||
)(left, right)
|
||||
|
||||
|
||||
import numpy as np # noqa: F811
|
||||
left_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
|
||||
right_np = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.float32)
|
||||
left_out_np = left_np.copy()
|
||||
right_out_np = right_np.copy()
|
||||
|
||||
|
||||
permute_columns_in_row_kernel(left_np, right_np, left_out_np, right_out_np)
|
||||
np.testing.assert_array_equal(left_out_np, left_out)
|
||||
np.testing.assert_array_equal(right_out_np, right_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user