mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Update lax_test for index_vector_dim change.
This commit is contained in:
parent
97c6ff3347
commit
ce5857c91e
@ -1387,18 +1387,18 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
|
||||
for dtype in all_dtypes
|
||||
for shape, idxs, dnums, slice_sizes in [
|
||||
((5,), onp.array([0, 2]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1,)),
|
||||
((10,), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
|
||||
index_vector_dim=1), (2,)),
|
||||
((10, 5,), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1,)),
|
||||
((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
||||
(2,)),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1, 3)),
|
||||
((10, 5), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
||||
(1, 3)),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(shape))]
|
||||
for rng in [jtu.rand_default()]))
|
||||
@ -1417,15 +1417,15 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"rng_idx": rng_idx}
|
||||
for dtype in float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([0, 2]), (2,), lax.ScatterDimensionNumbers(
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
|
||||
((10,), onp.array([0, 0, 0]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
|
||||
((10, 5,), onp.array([0, 2, 1]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(arg_shape))]
|
||||
for rng in [jtu.rand_default()]))
|
||||
@ -2131,15 +2131,15 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
|
||||
for dtype in float_dtypes
|
||||
for shape, idxs, dnums, slice_sizes in [
|
||||
((5,), onp.array([0, 2]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1,)),
|
||||
((10,), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
|
||||
index_vector_dim=1), (2,)),
|
||||
((10, 5,), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1,)),
|
||||
((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
||||
(2,)),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1, 3)),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(shape))]
|
||||
for rng in [jtu.rand_default()]))
|
||||
@ -2159,15 +2159,15 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"rng_idx": rng_idx}
|
||||
for dtype in float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([0, 2]), (2,), lax.ScatterDimensionNumbers(
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
|
||||
((10,), onp.array([0, 0, 0]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
|
||||
((10, 5,), onp.array([0, 2, 1]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(arg_shape))]
|
||||
for rng in [jtu.rand_default()]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user