Update batching_test.py for index_vector_dim change.

This commit is contained in:
Peter Hawkins 2019-03-01 11:59:54 -05:00
parent f7bcb5a9c2
commit 97c6ff3347
2 changed files with 75 additions and 76 deletions

View File

@ -2633,7 +2633,7 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
count_shape = list(start_indices.shape)
count_shape[-1] = 1
counts = broadcasted_iota(start_indices.dtype, tuple(count_shape), 0)
start_indices = concatenate([counts, start_indices], len(counts_shape) - 1)
start_indices = concatenate([counts, start_indices], len(count_shape) - 1)
slice_sizes = (1,) + slice_sizes
collapsed_slice_dims = (0,) + tuple(onp.add(1, dimension_numbers.collapsed_slice_dims))

View File

@ -554,18 +554,20 @@ class BatchingTest(jtu.JaxTestCase):
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
for dtype in [onp.float32, onp.int32]
for axis, shape, idxs, dnums, slice_sizes in [
(0, (3, 5), onp.array([0, 2]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1,)),
(1, (10, 3), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
index_vector_dim=1), (2,)),
(1, (10, 3, 5), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1, 3)),
(2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
index_vector_dim=1), (1, 3)),
(0, (3, 5), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1,)),
(1, (10, 3), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
(2,)),
(1, (10, 3, 5), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1, 3)),
(2, (10, 5, 3), onp.array([[0, 2], [1, 0]]),
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,),
start_index_map=(0, 1)),
(1, 3)),
]
for rng_idx in [jtu.rand_int(max(shape))]
for rng in [jtu.rand_default()])
@ -586,19 +588,20 @@ class BatchingTest(jtu.JaxTestCase):
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
for dtype in [onp.float32, onp.float64]
for axis, shape, idxs, dnums, slice_sizes in [
(0, (3, 5), onp.array([0, 2]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1,)),
(1, (10, 3), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
index_vector_dim=1), (2,)),
(1, (10, 3, 5), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1, 3)),
(2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
index_vector_dim=1), (1, 3)),
]
(0, (3, 5), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1,)),
(1, (10, 3), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
(2,)),
(1, (10, 3, 5), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1, 3)),
(2, (10, 5, 3), onp.array([[0, 2], [1, 0]]),
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,),
start_index_map=(0, 1)),
(1, 3)), ]
for rng_idx in [jtu.rand_int(max(shape))]
for rng in [jtu.rand_default()])
def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
@ -619,21 +622,17 @@ class BatchingTest(jtu.JaxTestCase):
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
for dtype in [onp.float32, onp.int32]
for axis, shape, idxs, dnums, slice_sizes in [
(0, (5,), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1,)),
(1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T,
(0, (5,), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
(1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
index_vector_dim=1), (2,)),
(1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T,
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
(1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1, 3)),
(0, (10, 5), onp.array([[[0, 2], [1, 0]],
[[1, 2], [0, 3]]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
index_vector_dim=1), (1, 3)),
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
(0, (10, 5), onp.array([[[0, 1], [2, 0]],
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
]
for rng_idx in [jtu.rand_int(max(shape))]
for rng in [jtu.rand_default()])
@ -654,21 +653,17 @@ class BatchingTest(jtu.JaxTestCase):
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
for dtype in [onp.float32, onp.float64]
for axis, shape, idxs, dnums, slice_sizes in [
(0, (5,), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1,)),
(1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T,
(0, (5,), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
(1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
index_vector_dim=1), (2,)),
(1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T,
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
(1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1, 3)),
(0, (10, 5), onp.array([[[0, 2], [1, 0]],
[[1, 2], [0, 3]]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
index_vector_dim=1), (1, 3)),
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
(0, (10, 5), onp.array([[[0, 1], [2, 0]],
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
]
for rng_idx in [jtu.rand_int(max(shape))]
for rng in [jtu.rand_default()])
@ -691,21 +686,23 @@ class BatchingTest(jtu.JaxTestCase):
"rng": rng, "rng_idx": rng_idx}
for dtype in [onp.float32, onp.int32]
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
(0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1,)),
(1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
(0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]),
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
index_vector_dim=1), (2,)),
(0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T,
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1,)),
(1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1, 3)),
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
(2,)),
(0, 1, (2, 10, 5,), onp.array([[[0, 2, 1], [0, 3, 3]]]).T,
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1, 3)),
(2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]],
[[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
index_vector_dim=1), (1, 3)),
[[1, 0], [2, 0]]]),
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
(1, 3)),
]
for rng_idx in [jtu.rand_int(max(shape))]
for rng in [jtu.rand_default()])
@ -729,21 +726,23 @@ class BatchingTest(jtu.JaxTestCase):
"rng": rng, "rng_idx": rng_idx}
for dtype in [onp.float32, onp.int32]
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
(0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1,)),
(1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
(0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]),
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
index_vector_dim=1), (2,)),
(0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T,
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1,)),
(1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1, 3)),
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
(2,)),
(0, 1, (2, 10, 5,), onp.array([[[0, 2, 1], [0, 3, 3]]]).T,
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1, 3)),
(2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]],
[[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
index_vector_dim=1), (1, 3)),
[[1, 0], [2, 0]]]),
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
(1, 3)),
]
for rng_idx in [jtu.rand_int(max(shape))]
for rng in [jtu.rand_default()])