mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Update batching_test.py for index_vector_dim change.
This commit is contained in:
parent
f7bcb5a9c2
commit
97c6ff3347
@ -2633,7 +2633,7 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
|
||||
count_shape = list(start_indices.shape)
|
||||
count_shape[-1] = 1
|
||||
counts = broadcasted_iota(start_indices.dtype, tuple(count_shape), 0)
|
||||
start_indices = concatenate([counts, start_indices], len(counts_shape) - 1)
|
||||
start_indices = concatenate([counts, start_indices], len(count_shape) - 1)
|
||||
|
||||
slice_sizes = (1,) + slice_sizes
|
||||
collapsed_slice_dims = (0,) + tuple(onp.add(1, dimension_numbers.collapsed_slice_dims))
|
||||
|
@ -554,18 +554,20 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
|
||||
for dtype in [onp.float32, onp.int32]
|
||||
for axis, shape, idxs, dnums, slice_sizes in [
|
||||
(0, (3, 5), onp.array([0, 2]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1,)),
|
||||
(1, (10, 3), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
|
||||
index_vector_dim=1), (2,)),
|
||||
(1, (10, 3, 5), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
(2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
(0, (3, 5), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1,)),
|
||||
(1, (10, 3), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
||||
(2,)),
|
||||
(1, (10, 3, 5), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1, 3)),
|
||||
(2, (10, 5, 3), onp.array([[0, 2], [1, 0]]),
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,),
|
||||
start_index_map=(0, 1)),
|
||||
(1, 3)),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(shape))]
|
||||
for rng in [jtu.rand_default()])
|
||||
@ -586,19 +588,20 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
|
||||
for dtype in [onp.float32, onp.float64]
|
||||
for axis, shape, idxs, dnums, slice_sizes in [
|
||||
(0, (3, 5), onp.array([0, 2]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1,)),
|
||||
(1, (10, 3), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
|
||||
index_vector_dim=1), (2,)),
|
||||
(1, (10, 3, 5), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
(2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
]
|
||||
(0, (3, 5), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1,)),
|
||||
(1, (10, 3), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
||||
(2,)),
|
||||
(1, (10, 3, 5), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1, 3)),
|
||||
(2, (10, 5, 3), onp.array([[0, 2], [1, 0]]),
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,),
|
||||
start_index_map=(0, 1)),
|
||||
(1, 3)), ]
|
||||
for rng_idx in [jtu.rand_int(max(shape))]
|
||||
for rng in [jtu.rand_default()])
|
||||
def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
|
||||
@ -619,21 +622,17 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
|
||||
for dtype in [onp.float32, onp.int32]
|
||||
for axis, shape, idxs, dnums, slice_sizes in [
|
||||
(0, (5,), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1,)),
|
||||
(1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T,
|
||||
(0, (5,), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
|
||||
(1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
|
||||
index_vector_dim=1), (2,)),
|
||||
(1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T,
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
|
||||
(1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
(0, (10, 5), onp.array([[[0, 2], [1, 0]],
|
||||
[[1, 2], [0, 3]]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
|
||||
(0, (10, 5), onp.array([[[0, 1], [2, 0]],
|
||||
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(shape))]
|
||||
for rng in [jtu.rand_default()])
|
||||
@ -654,21 +653,17 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
"slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
|
||||
for dtype in [onp.float32, onp.float64]
|
||||
for axis, shape, idxs, dnums, slice_sizes in [
|
||||
(0, (5,), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1,)),
|
||||
(1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T,
|
||||
(0, (5,), onp.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)),
|
||||
(1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
|
||||
index_vector_dim=1), (2,)),
|
||||
(1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T,
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)),
|
||||
(1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T[..., None],
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
(0, (10, 5), onp.array([[[0, 2], [1, 0]],
|
||||
[[1, 2], [0, 3]]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)),
|
||||
(0, (10, 5), onp.array([[[0, 1], [2, 0]],
|
||||
[[1, 0], [2, 3]]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(shape))]
|
||||
for rng in [jtu.rand_default()])
|
||||
@ -691,21 +686,23 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
"rng": rng, "rng_idx": rng_idx}
|
||||
for dtype in [onp.float32, onp.int32]
|
||||
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
|
||||
(0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1,)),
|
||||
(1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
|
||||
(0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]),
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
|
||||
index_vector_dim=1), (2,)),
|
||||
(0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T,
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1,)),
|
||||
(1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
||||
(2,)),
|
||||
(0, 1, (2, 10, 5,), onp.array([[[0, 2, 1], [0, 3, 3]]]).T,
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1, 3)),
|
||||
(2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]],
|
||||
[[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
[[1, 0], [2, 0]]]),
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
||||
(1, 3)),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(shape))]
|
||||
for rng in [jtu.rand_default()])
|
||||
@ -729,21 +726,23 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
"rng": rng, "rng_idx": rng_idx}
|
||||
for dtype in [onp.float32, onp.int32]
|
||||
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
|
||||
(0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1,)),
|
||||
(1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
|
||||
(0, 0, (2, 5), onp.array([[[0], [2]], [[1], [3]]]),
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
|
||||
index_vector_dim=1), (2,)),
|
||||
(0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T,
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1,)),
|
||||
(1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T[..., None],
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
|
||||
(2,)),
|
||||
(0, 1, (2, 10, 5,), onp.array([[[0, 2, 1], [0, 3, 3]]]).T,
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
|
||||
(1, 3)),
|
||||
(2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]],
|
||||
[[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
[[1, 0], [2, 0]]]),
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
||||
(1, 3)),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(shape))]
|
||||
for rng in [jtu.rand_default()])
|
||||
|
Loading…
x
Reference in New Issue
Block a user