mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Merge pull request #12091 from jakevdp:getitem-slice
PiperOrigin-RevId: 470066448
This commit is contained in:
commit
1a365346e8
@ -77,8 +77,8 @@ from jax._src.numpy.util import ( # noqa: F401
|
||||
_register_stackable, _stackable, _where, _wraps)
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.ops import scatter
|
||||
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
|
||||
canonicalize_axis as _canonicalize_axis)
|
||||
from jax._src.util import (unzip2, unzip3, prod as _prod, subvals, safe_zip,
|
||||
ceil_of_ratio, canonicalize_axis as _canonicalize_axis)
|
||||
from jax.experimental.array import Array
|
||||
|
||||
newaxis = None
|
||||
@ -3587,30 +3587,50 @@ def take_along_axis(arr, indices, axis: Optional[int],
|
||||
|
||||
### Indexing
|
||||
|
||||
def _is_integer_index(idx: Any) -> bool:
|
||||
return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_))
|
||||
|
||||
def _is_inbound_integer_index(idx: Any, size: int) -> bool:
|
||||
return _is_integer_index(idx) and -size <= idx < size
|
||||
|
||||
def _is_nonreversing_static_slice(idx: Any) -> bool:
|
||||
return (isinstance(idx, slice) and
|
||||
_all(i is None or _is_integer_index(i)
|
||||
for i in [idx.start, idx.stop, idx.step]) and
|
||||
(idx.step is None or idx.step > 0))
|
||||
|
||||
def _attempt_rewriting_take_via_slice(arr, idx):
|
||||
# attempt to compute _rewriting_take via lax.slice(); return None if not possible.
|
||||
idx = idx if isinstance(idx, tuple) else (idx,)
|
||||
if not _all(isinstance(i, int) for i in arr.shape):
|
||||
return None
|
||||
if len(idx) > arr.ndim:
|
||||
return None
|
||||
if not _all(_is_inbound_integer_index(i, size) or _is_nonreversing_static_slice(i)
|
||||
for i, size in zip(idx, arr.shape)):
|
||||
return None
|
||||
sqeeze_dimensions = [i for i, ind in enumerate(idx) if not isinstance(ind, slice)]
|
||||
idx += (arr.ndim - len(idx)) * (slice(None),)
|
||||
slices = []
|
||||
for ind, size in safe_zip(idx, arr.shape):
|
||||
if isinstance(ind, slice):
|
||||
slices.append(ind.indices(size))
|
||||
else:
|
||||
ind = ind + size if ind < 0 else ind
|
||||
slices.append((ind, ind + 1, 1))
|
||||
return lax.squeeze(lax.slice(arr, *unzip3(slices)), sqeeze_dimensions)
|
||||
|
||||
def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None, fill_value=None):
|
||||
# Computes arr[idx].
|
||||
# All supported cases of indexing can be implemented as an XLA gather,
|
||||
# followed by an optional reverse and broadcast_in_dim.
|
||||
|
||||
# Handle some special cases, falling back if error messages might differ.
|
||||
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
|
||||
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
|
||||
if 0 <= idx < arr.shape[0]:
|
||||
return lax.index_in_dim(arr, idx, keepdims=False)
|
||||
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
|
||||
isinstance(idx, slice) and
|
||||
(type(idx.start) is int or idx.start is None) and
|
||||
(type(idx.stop) is int or idx.stop is None) and
|
||||
(type(idx.step) is int or idx.step is None)):
|
||||
n = arr.shape[0]
|
||||
start = idx.start if idx.start is not None else 0
|
||||
stop = idx.stop if idx.stop is not None else n
|
||||
step = idx.step if idx.step is not None else 1
|
||||
if (0 <= start < n and 0 <= stop <= n and 0 < step and
|
||||
(start, stop, step) != (0, n, 1)):
|
||||
return lax.slice_in_dim(arr, start, stop, step)
|
||||
|
||||
# For simplicity of generated primitives, we call lax.slice in the simplest
|
||||
# cases: i.e. non-dynamic arrays indexed with only integers and slices.
|
||||
result = _attempt_rewriting_take_via_slice(arr, idx)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# TODO(mattjj,dougalm): expand dynamic shape indexing support
|
||||
if (jax.config.jax_dynamic_shapes and type(idx) is slice and idx.step is None
|
||||
|
@ -860,6 +860,30 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(expected, primals)
|
||||
self.assertAllClose(np.zeros_like(x), tangents)
|
||||
|
||||
def testSimpleIndexingUsesSlice(self):
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4)))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[:, 1::2])(jnp.ones((3, 4)))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5)))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p)
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5)))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p)
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5)))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p)
|
||||
|
||||
def testTrivialGatherIsntGenerated(self):
|
||||
# https://github.com/google/jax/issues/1621
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))
|
||||
@ -867,9 +891,12 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self.assertNotIn('gather', str(jaxpr))
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user