Update lax_test for index_vector_dim change.

This commit is contained in:
Peter Hawkins 2019-03-01 12:19:00 -05:00
parent 97c6ff3347
commit ce5857c91e

View File

@ -1387,18 +1387,18 @@ class LaxTest(jtu.JaxTestCase):
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
for dtype in all_dtypes
for shape, idxs, dnums, slice_sizes in [
((5,), onp.array([0, 2]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1,)),
((10,), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
index_vector_dim=1), (2,)),
((10, 5,), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1, 3)),
((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1,)),
((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
(2,)),
((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1, 3)),
((10, 5), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
index_vector_dim=1), (1, 3)),
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
(1, 3)),
]
for rng_idx in [jtu.rand_int(max(shape))]
for rng in [jtu.rand_default()]))
@ -1417,15 +1417,15 @@ class LaxTest(jtu.JaxTestCase):
"rng_idx": rng_idx}
for dtype in float_dtypes
for arg_shape, idxs, update_shape, dnums in [
((5,), onp.array([0, 2]), (2,), lax.ScatterDimensionNumbers(
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
((10,), onp.array([0, 0, 0]), (3, 2), lax.ScatterDimensionNumbers(
scatter_dims_to_operand_dims=(0,))),
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
((10, 5,), onp.array([0, 2, 1]), (3, 3), lax.ScatterDimensionNumbers(
scatter_dims_to_operand_dims=(0,))),
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
scatter_dims_to_operand_dims=(0,))),
]
for rng_idx in [jtu.rand_int(max(arg_shape))]
for rng in [jtu.rand_default()]))
@ -2131,15 +2131,15 @@ class LaxAutodiffTest(jtu.JaxTestCase):
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
for dtype in float_dtypes
for shape, idxs, dnums, slice_sizes in [
((5,), onp.array([0, 2]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1,)),
((10,), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
index_vector_dim=1), (2,)),
((10, 5,), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1, 3)),
((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1,)),
((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
(2,)),
((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1, 3)),
]
for rng_idx in [jtu.rand_int(max(shape))]
for rng in [jtu.rand_default()]))
@ -2159,15 +2159,15 @@ class LaxAutodiffTest(jtu.JaxTestCase):
"rng_idx": rng_idx}
for dtype in float_dtypes
for arg_shape, idxs, update_shape, dnums in [
((5,), onp.array([0, 2]), (2,), lax.ScatterDimensionNumbers(
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
((10,), onp.array([0, 0, 0]), (3, 2), lax.ScatterDimensionNumbers(
scatter_dims_to_operand_dims=(0,))),
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
((10, 5,), onp.array([0, 2, 1]), (3, 3), lax.ScatterDimensionNumbers(
scatter_dims_to_operand_dims=(0,))),
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,), index_vector_dim=1)),
scatter_dims_to_operand_dims=(0,))),
]
for rng_idx in [jtu.rand_int(max(arg_shape))]
for rng in [jtu.rand_default()]))