Expand constant indexing test to check slice

This commit is contained in:
Lukas Geiger 2022-04-23 00:05:03 +01:00
parent 2e681ffe76
commit 167c6a9f0c

View File

@ -873,9 +873,10 @@ class IndexingTest(jtu.JaxTestCase):
@parameterized.named_parameters(
{"testcase_name": f"_{idx_type_name}_{idx}", "idx": idx, "idx_type": idx_type}
for idx in (-3, 0, 5)
for idx in (-3, 5)
for idx_type_name, idx_type in (
("int", int), ("np.array", np.array), ("jnp.array", jnp.array)))
("int", int), ("np.array", np.array), ("jnp.array", jnp.array),
("slice_up_to", slice), ("slice_from", lambda s: slice(s, None))))
def testConstantIndexing(self, idx, idx_type):
x = jnp.arange(10)
idx = idx_type(idx)