mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add jaxpr test to ensure that no normalization happens for constant indices
This commit is contained in:
parent
0c5b1326c7
commit
5e2dd9ccd4
@ -871,6 +871,18 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{idx_type_name}_{idx}", "idx": idx, "idx_type": idx_type}
|
||||
for idx in (-3, 0, 5)
|
||||
for idx_type_name, idx_type in (
|
||||
("int", int), ("np.array", np.array), ("jnp.array", jnp.array)))
|
||||
def testConstantIndexing(self, idx, idx_type):
|
||||
x = jnp.arange(10)
|
||||
idx = idx_type(idx)
|
||||
jaxpr = jax.make_jaxpr(lambda: x[idx])()
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
|
||||
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.gather_p)
|
||||
|
||||
def testIndexingEmptyDimension(self):
|
||||
# Issue 2671: XLA error when indexing into dimension of size 0
|
||||
x = jnp.ones((2, 0))
|
||||
|
Loading…
x
Reference in New Issue
Block a user