mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Rewrite simple slicing to the static slicing primitive whenever possible
This makes it a lot easier to handle within Pallas and Mosaic. PiperOrigin-RevId: 563128943
This commit is contained in:
parent
b18dc111f9
commit
bb8d5a0121
@ -26,6 +26,7 @@ import jax
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
@ -1151,6 +1152,21 @@ slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice')
|
||||
ad.deflinear2(slice_p, _slice_transpose_rule)
|
||||
batching.primitive_batchers[slice_p] = _slice_batching_rule
|
||||
|
||||
# Override the standard impl to defer to dynamic_slice whenever possible.
|
||||
# This lets us reuse the same program for many applications of slicing for as
|
||||
# long as they have the same output shape. Note that this only applies in
|
||||
# op-by-op mode and inside larger computations we still lower to static slices.
|
||||
@slice_p.def_impl
|
||||
def _slice_impl(x, start_indices, limit_indices, strides):
|
||||
if strides is not None:
|
||||
return dispatch.apply_primitive(
|
||||
slice_p, x, start_indices=start_indices,
|
||||
limit_indices=limit_indices, strides=strides)
|
||||
slice_sizes = tuple(np.array(limit_indices) - np.array(start_indices))
|
||||
return dispatch.apply_primitive(dynamic_slice_p, x, *start_indices,
|
||||
slice_sizes=slice_sizes)
|
||||
|
||||
|
||||
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
|
||||
strides = strides or [1] * len(start_indices)
|
||||
aval_out, = ctx.avals_out
|
||||
|
@ -4216,11 +4216,19 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) ->
|
||||
assert np.shape(ind) == () # checked above
|
||||
start_indices.append(ind)
|
||||
slice_sizes.append(1)
|
||||
# Try to use static slicing when possible.
|
||||
if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices):
|
||||
int_start_indices = [int(i) for i in start_indices] # type: ignore
|
||||
int_limit_indices = [i + s for i, s in zip(int_start_indices, slice_sizes)]
|
||||
arr = lax.slice(
|
||||
arr, start_indices=int_start_indices, limit_indices=int_limit_indices)
|
||||
else:
|
||||
# We must be careful with dtypes because dynamic_slice requires all
|
||||
# start indices to have matching types.
|
||||
if len(start_indices) > 1:
|
||||
start_indices = util.promote_dtypes(*start_indices)
|
||||
arr = lax.dynamic_slice(arr, start_indices=start_indices, slice_sizes=slice_sizes)
|
||||
arr = lax.dynamic_slice(
|
||||
arr, start_indices=start_indices, slice_sizes=slice_sizes)
|
||||
if int_indices:
|
||||
arr = lax.squeeze(arr, tuple(int_indices))
|
||||
return arr
|
||||
|
@ -885,22 +885,22 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
|
||||
def testSimpleIndexingUsesSlice(self):
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4)))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 7)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.dynamic_slice_p)
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p)
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5)))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5)))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5)))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
|
||||
|
||||
# Simple reverses lower to lax.rev_p
|
||||
@ -908,6 +908,12 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)
|
||||
|
||||
# Non-static indices produce a dynamic slice
|
||||
jaxpr = jax.make_jaxpr(lambda x, i: x[i])(jnp.ones((4,)), 2)
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 6)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
|
||||
|
||||
def testTrivialGatherIsntGenerated(self):
|
||||
# https://github.com/google/jax/issues/1621
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))
|
||||
|
Loading…
x
Reference in New Issue
Block a user