From d5ee729a7f5d583c8e212c6a9f1b98221b13cbdc Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 25 Aug 2022 09:13:42 -0700 Subject: [PATCH] generate lax.slice instead of lax.gather for more indexing cases --- jax/_src/numpy/lax_numpy.py | 60 +++++++++++++++++++++----------- tests/lax_numpy_indexing_test.py | 31 +++++++++++++++-- 2 files changed, 69 insertions(+), 22 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f065945d0..a1c1fe881 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -77,8 +77,8 @@ from jax._src.numpy.util import ( # noqa: F401 _register_stackable, _stackable, _where, _wraps) from jax._src.numpy.vectorize import vectorize from jax._src.ops import scatter -from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio, - canonicalize_axis as _canonicalize_axis) +from jax._src.util import (unzip2, unzip3, prod as _prod, subvals, safe_zip, + ceil_of_ratio, canonicalize_axis as _canonicalize_axis) from jax.experimental.array import Array newaxis = None @@ -3587,30 +3587,50 @@ def take_along_axis(arr, indices, axis: Optional[int], ### Indexing +def _is_integer_index(idx: Any) -> bool: + return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_)) + +def _is_inbound_integer_index(idx: Any, size: int) -> bool: + return _is_integer_index(idx) and -size <= idx < size + +def _is_nonreversing_static_slice(idx: Any) -> bool: + return (isinstance(idx, slice) and + _all(i is None or _is_integer_index(i) + for i in [idx.start, idx.stop, idx.step]) and + (idx.step is None or idx.step > 0)) + +def _attempt_rewriting_take_via_slice(arr, idx): + # attempt to compute _rewriting_take via lax.slice(); return None if not possible. + idx = idx if isinstance(idx, tuple) else (idx,) + if not _all(isinstance(i, int) for i in arr.shape): + return None + if len(idx) > arr.ndim: + return None + if not _all(_is_inbound_integer_index(i, size) or _is_nonreversing_static_slice(i) + for i, size in zip(idx, arr.shape)): + return None + sqeeze_dimensions = [i for i, ind in enumerate(idx) if not isinstance(ind, slice)] + idx += (arr.ndim - len(idx)) * (slice(None),) + slices = [] + for ind, size in safe_zip(idx, arr.shape): + if isinstance(ind, slice): + slices.append(ind.indices(size)) + else: + ind = ind + size if ind < 0 else ind + slices.append((ind, ind + 1, 1)) + return lax.squeeze(lax.slice(arr, *unzip3(slices)), sqeeze_dimensions) + def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, mode=None, fill_value=None): # Computes arr[idx]. # All supported cases of indexing can be implemented as an XLA gather, # followed by an optional reverse and broadcast_in_dim. - # Handle some special cases, falling back if error messages might differ. - if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and - not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)): - if 0 <= idx < arr.shape[0]: - return lax.index_in_dim(arr, idx, keepdims=False) - if (arr.ndim > 0 and isinstance(arr.shape[0], int) and - isinstance(idx, slice) and - (type(idx.start) is int or idx.start is None) and - (type(idx.stop) is int or idx.stop is None) and - (type(idx.step) is int or idx.step is None)): - n = arr.shape[0] - start = idx.start if idx.start is not None else 0 - stop = idx.stop if idx.stop is not None else n - step = idx.step if idx.step is not None else 1 - if (0 <= start < n and 0 <= stop <= n and 0 < step and - (start, stop, step) != (0, n, 1)): - return lax.slice_in_dim(arr, start, stop, step) - + # For simplicity of generated primitives, we call lax.slice in the simplest + # cases: i.e. non-dynamic arrays indexed with only integers and slices. + result = _attempt_rewriting_take_via_slice(arr, idx) + if result is not None: + return result # TODO(mattjj,dougalm): expand dynamic shape indexing support if (jax.config.jax_dynamic_shapes and type(idx) is slice and idx.step is None diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index e5c72e216..297f3b855 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -860,6 +860,30 @@ class IndexingTest(jtu.JaxTestCase): self.assertAllClose(expected, primals) self.assertAllClose(np.zeros_like(x), tangents) + def testSimpleIndexingUsesSlice(self): + jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + + jaxpr = jax.make_jaxpr(lambda x: x[:, 1::2])(jnp.ones((3, 4))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + + jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 2) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p) + + jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 2) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p) + + jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 2) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p) + def testTrivialGatherIsntGenerated(self): # https://github.com/google/jax/issues/1621 jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4)) @@ -867,9 +891,12 @@ class IndexingTest(jtu.JaxTestCase): self.assertNotIn('gather', str(jaxpr)) jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4)) - self.assertEqual(len(jaxpr.jaxpr.eqns), 0) + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4)) - self.assertEqual(len(jaxpr.jaxpr.eqns), 0) + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 1)