Merge pull request #10216 from lgeiger:slice-none

PiperOrigin-RevId: 441877962
This commit is contained in:
jax authors 2022-04-14 16:04:48 -07:00
commit 0443f5ed9a
2 changed files with 14 additions and 5 deletions

View File

@ -3749,11 +3749,15 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
# Normalize the slice to use None when possible
start, stop, step = i.start, i.stop, i.step
try:
if ((step is None or core.symbolic_equal_dim(step, 1)) and
stop is not None and core.symbolic_equal_dim(stop, x_shape[x_axis])):
# The following is a useful special case with shape polymorphism
stop = None
except TypeError:
if step is None or core.symbolic_equal_dim(step, 1):
step = None
if step is None:
if start is None or core.symbolic_equal_dim(start, 0):
start = None
if stop is None or (not isinstance(stop, core.Tracer) and
core.greater_equal_dim(stop, x_shape[x_axis])):
stop = None
except (TypeError, core.InconclusiveDimensionOperation):
pass
# Handle slice(None)

View File

@ -849,6 +849,11 @@ class IndexingTest(jtu.JaxTestCase):
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertNotIn('gather', str(jaxpr))
jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
def testIndexingEmptyDimension(self):
# Issue 2671: XLA error when indexing into dimension of size 0
x = jnp.ones((2, 0))