Merge pull request #12091 from jakevdp:getitem-slice

PiperOrigin-RevId: 470066448
This commit is contained in:
jax authors 2022-08-25 13:16:44 -07:00
commit 1a365346e8
2 changed files with 69 additions and 22 deletions

View File

@ -77,8 +77,8 @@ from jax._src.numpy.util import ( # noqa: F401
_register_stackable, _stackable, _where, _wraps)
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
canonicalize_axis as _canonicalize_axis)
from jax._src.util import (unzip2, unzip3, prod as _prod, subvals, safe_zip,
ceil_of_ratio, canonicalize_axis as _canonicalize_axis)
from jax.experimental.array import Array
newaxis = None
@ -3587,30 +3587,50 @@ def take_along_axis(arr, indices, axis: Optional[int],
### Indexing
def _is_integer_index(idx: Any) -> bool:
return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_))
def _is_inbound_integer_index(idx: Any, size: int) -> bool:
return _is_integer_index(idx) and -size <= idx < size
def _is_nonreversing_static_slice(idx: Any) -> bool:
return (isinstance(idx, slice) and
_all(i is None or _is_integer_index(i)
for i in [idx.start, idx.stop, idx.step]) and
(idx.step is None or idx.step > 0))
def _attempt_rewriting_take_via_slice(arr, idx):
# attempt to compute _rewriting_take via lax.slice(); return None if not possible.
idx = idx if isinstance(idx, tuple) else (idx,)
if not _all(isinstance(i, int) for i in arr.shape):
return None
if len(idx) > arr.ndim:
return None
if not _all(_is_inbound_integer_index(i, size) or _is_nonreversing_static_slice(i)
for i, size in zip(idx, arr.shape)):
return None
sqeeze_dimensions = [i for i, ind in enumerate(idx) if not isinstance(ind, slice)]
idx += (arr.ndim - len(idx)) * (slice(None),)
slices = []
for ind, size in safe_zip(idx, arr.shape):
if isinstance(ind, slice):
slices.append(ind.indices(size))
else:
ind = ind + size if ind < 0 else ind
slices.append((ind, ind + 1, 1))
return lax.squeeze(lax.slice(arr, *unzip3(slices)), sqeeze_dimensions)
def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.
# Handle some special cases, falling back if error messages might differ.
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
return lax.index_in_dim(arr, idx, keepdims=False)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
(type(idx.stop) is int or idx.stop is None) and
(type(idx.step) is int or idx.step is None)):
n = arr.shape[0]
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else n
step = idx.step if idx.step is not None else 1
if (0 <= start < n and 0 <= stop <= n and 0 < step and
(start, stop, step) != (0, n, 1)):
return lax.slice_in_dim(arr, start, stop, step)
# For simplicity of generated primitives, we call lax.slice in the simplest
# cases: i.e. non-dynamic arrays indexed with only integers and slices.
result = _attempt_rewriting_take_via_slice(arr, idx)
if result is not None:
return result
# TODO(mattjj,dougalm): expand dynamic shape indexing support
if (jax.config.jax_dynamic_shapes and type(idx) is slice and idx.step is None

View File

@ -860,6 +860,30 @@ class IndexingTest(jtu.JaxTestCase):
self.assertAllClose(expected, primals)
self.assertAllClose(np.zeros_like(x), tangents)
def testSimpleIndexingUsesSlice(self):
jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
jaxpr = jax.make_jaxpr(lambda x: x[:, 1::2])(jnp.ones((3, 4)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p)
jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p)
jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p)
def testTrivialGatherIsntGenerated(self):
# https://github.com/google/jax/issues/1621
jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))
@ -867,9 +891,12 @@ class IndexingTest(jtu.JaxTestCase):
self.assertNotIn('gather', str(jaxpr))
jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)