diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 6c06c6e08..97f7d0070 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -26,6 +26,7 @@ import jax from jax._src import ad_util from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src import source_info_util from jax._src import util @@ -1151,6 +1152,21 @@ slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice') ad.deflinear2(slice_p, _slice_transpose_rule) batching.primitive_batchers[slice_p] = _slice_batching_rule +# Override the standard impl to defer to dynamic_slice whenever possible. +# This lets us reuse the same program for many applications of slicing for as +# long as they have the same output shape. Note that this only applies in +# op-by-op mode and inside larger computations we still lower to static slices. +@slice_p.def_impl +def _slice_impl(x, start_indices, limit_indices, strides): + if strides is not None: + return dispatch.apply_primitive( + slice_p, x, start_indices=start_indices, + limit_indices=limit_indices, strides=strides) + slice_sizes = tuple(np.array(limit_indices) - np.array(start_indices)) + return dispatch.apply_primitive(dynamic_slice_p, x, *start_indices, + slice_sizes=slice_sizes) + + def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): strides = strides or [1] * len(start_indices) aval_out, = ctx.avals_out diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 69a0231e7..26a4b777f 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4216,11 +4216,19 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> assert np.shape(ind) == () # checked above start_indices.append(ind) slice_sizes.append(1) - # We must be careful with dtypes because dynamic_slice requires all - # start indices to have matching types. - if len(start_indices) > 1: - start_indices = util.promote_dtypes(*start_indices) - arr = lax.dynamic_slice(arr, start_indices=start_indices, slice_sizes=slice_sizes) + # Try to use static slicing when possible. + if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices): + int_start_indices = [int(i) for i in start_indices] # type: ignore + int_limit_indices = [i + s for i, s in zip(int_start_indices, slice_sizes)] + arr = lax.slice( + arr, start_indices=int_start_indices, limit_indices=int_limit_indices) + else: + # We must be careful with dtypes because dynamic_slice requires all + # start indices to have matching types. + if len(start_indices) > 1: + start_indices = util.promote_dtypes(*start_indices) + arr = lax.dynamic_slice( + arr, start_indices=start_indices, slice_sizes=slice_sizes) if int_indices: arr = lax.squeeze(arr, tuple(int_indices)) return arr diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 0716195c6..b74793cc6 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -885,22 +885,22 @@ class IndexingTest(jtu.JaxTestCase): def testSimpleIndexingUsesSlice(self): jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4))) - self.assertEqual(len(jaxpr.jaxpr.eqns), 7) - self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.dynamic_slice_p) + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p) jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5))) - self.assertEqual(len(jaxpr.jaxpr.eqns), 11) - self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p) + self.assertEqual(len(jaxpr.jaxpr.eqns), 2) + self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5))) - self.assertEqual(len(jaxpr.jaxpr.eqns), 11) - self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p) + self.assertEqual(len(jaxpr.jaxpr.eqns), 2) + self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5))) - self.assertEqual(len(jaxpr.jaxpr.eqns), 11) - self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p) + self.assertEqual(len(jaxpr.jaxpr.eqns), 2) + self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) # Simple reverses lower to lax.rev_p @@ -908,6 +908,12 @@ class IndexingTest(jtu.JaxTestCase): self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p) + # Non-static indices produce a dynamic slice + jaxpr = jax.make_jaxpr(lambda x, i: x[i])(jnp.ones((4,)), 2) + self.assertEqual(len(jaxpr.jaxpr.eqns), 6) + self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + def testTrivialGatherIsntGenerated(self): # https://github.com/google/jax/issues/1621 jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))