mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00
Merge pull request #10216 from lgeiger:slice-none
PiperOrigin-RevId: 441877962
This commit is contained in:
commit
0443f5ed9a
@ -3749,11 +3749,15 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
|
||||
# Normalize the slice to use None when possible
|
||||
start, stop, step = i.start, i.stop, i.step
|
||||
try:
|
||||
if ((step is None or core.symbolic_equal_dim(step, 1)) and
|
||||
stop is not None and core.symbolic_equal_dim(stop, x_shape[x_axis])):
|
||||
# The following is a useful special case with shape polymorphism
|
||||
stop = None
|
||||
except TypeError:
|
||||
if step is None or core.symbolic_equal_dim(step, 1):
|
||||
step = None
|
||||
if step is None:
|
||||
if start is None or core.symbolic_equal_dim(start, 0):
|
||||
start = None
|
||||
if stop is None or (not isinstance(stop, core.Tracer) and
|
||||
core.greater_equal_dim(stop, x_shape[x_axis])):
|
||||
stop = None
|
||||
except (TypeError, core.InconclusiveDimensionOperation):
|
||||
pass
|
||||
|
||||
# Handle slice(None)
|
||||
|
@ -849,6 +849,11 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
||||
self.assertNotIn('gather', str(jaxpr))
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
|
||||
|
||||
def testIndexingEmptyDimension(self):
|
||||
# Issue 2671: XLA error when indexing into dimension of size 0
|
||||
x = jnp.ones((2, 0))
|
||||
|
Loading…
x
Reference in New Issue
Block a user