Add jaxpr test to ensure that no normalization happens for constant indices

This commit is contained in:
Lukas Geiger 2022-04-21 20:40:10 +01:00
parent 0c5b1326c7
commit 5e2dd9ccd4

View File

@ -871,6 +871,18 @@ class IndexingTest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
@parameterized.named_parameters(
{"testcase_name": f"_{idx_type_name}_{idx}", "idx": idx, "idx_type": idx_type}
for idx in (-3, 0, 5)
for idx_type_name, idx_type in (
("int", int), ("np.array", np.array), ("jnp.array", jnp.array)))
def testConstantIndexing(self, idx, idx_type):
x = jnp.arange(10)
idx = idx_type(idx)
jaxpr = jax.make_jaxpr(lambda: x[idx])()
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.gather_p)
def testIndexingEmptyDimension(self):
# Issue 2671: XLA error when indexing into dimension of size 0
x = jnp.ones((2, 0))