mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Expand constant indexing test to check slice
This commit is contained in:
parent
2e681ffe76
commit
167c6a9f0c
@ -873,9 +873,10 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{idx_type_name}_{idx}", "idx": idx, "idx_type": idx_type}
|
||||
for idx in (-3, 0, 5)
|
||||
for idx in (-3, 5)
|
||||
for idx_type_name, idx_type in (
|
||||
("int", int), ("np.array", np.array), ("jnp.array", jnp.array)))
|
||||
("int", int), ("np.array", np.array), ("jnp.array", jnp.array),
|
||||
("slice_up_to", slice), ("slice_from", lambda s: slice(s, None))))
|
||||
def testConstantIndexing(self, idx, idx_type):
|
||||
x = jnp.arange(10)
|
||||
idx = idx_type(idx)
|
||||
|
Loading…
x
Reference in New Issue
Block a user