Rewrite simple slicing to the static slicing primitive whenever possible

This makes it a lot easier to handle within Pallas and Mosaic.

PiperOrigin-RevId: 563128943
This commit is contained in:
Adam Paszke 2023-09-06 09:32:25 -07:00 committed by jax authors
parent b18dc111f9
commit bb8d5a0121
3 changed files with 43 additions and 13 deletions

View File

@ -26,6 +26,7 @@ import jax
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
@ -1151,6 +1152,21 @@ slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice')
ad.deflinear2(slice_p, _slice_transpose_rule)
batching.primitive_batchers[slice_p] = _slice_batching_rule
# Override the standard impl to defer to dynamic_slice whenever possible.
# This lets us reuse the same program for many applications of slicing for as
# long as they have the same output shape. Note that this only applies in
# op-by-op mode and inside larger computations we still lower to static slices.
@slice_p.def_impl
def _slice_impl(x, start_indices, limit_indices, strides):
if strides is not None:
return dispatch.apply_primitive(
slice_p, x, start_indices=start_indices,
limit_indices=limit_indices, strides=strides)
slice_sizes = tuple(np.array(limit_indices) - np.array(start_indices))
return dispatch.apply_primitive(dynamic_slice_p, x, *start_indices,
slice_sizes=slice_sizes)
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
strides = strides or [1] * len(start_indices)
aval_out, = ctx.avals_out

View File

@ -4216,11 +4216,19 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) ->
assert np.shape(ind) == () # checked above
start_indices.append(ind)
slice_sizes.append(1)
# Try to use static slicing when possible.
if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices):
int_start_indices = [int(i) for i in start_indices] # type: ignore
int_limit_indices = [i + s for i, s in zip(int_start_indices, slice_sizes)]
arr = lax.slice(
arr, start_indices=int_start_indices, limit_indices=int_limit_indices)
else:
# We must be careful with dtypes because dynamic_slice requires all
# start indices to have matching types.
if len(start_indices) > 1:
start_indices = util.promote_dtypes(*start_indices)
arr = lax.dynamic_slice(arr, start_indices=start_indices, slice_sizes=slice_sizes)
arr = lax.dynamic_slice(
arr, start_indices=start_indices, slice_sizes=slice_sizes)
if int_indices:
arr = lax.squeeze(arr, tuple(int_indices))
return arr

View File

@ -885,22 +885,22 @@ class IndexingTest(jtu.JaxTestCase):
def testSimpleIndexingUsesSlice(self):
jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 7)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.dynamic_slice_p)
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p)
jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
# Simple reverses lower to lax.rev_p
@ -908,6 +908,12 @@ class IndexingTest(jtu.JaxTestCase):
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)
# Non-static indices produce a dynamic slice
jaxpr = jax.make_jaxpr(lambda x, i: x[i])(jnp.ones((4,)), 2)
self.assertEqual(len(jaxpr.jaxpr.eqns), 6)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
def testTrivialGatherIsntGenerated(self):
# https://github.com/google/jax/issues/1621
jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))