fix gather batching rule bug

This commit is contained in:
Matthew Johnson 2019-02-11 11:30:44 -08:00
parent d5ee720aea
commit 90d92a5a5c

View File

@ -2126,7 +2126,9 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
# (7, 3, 4, 6) where we concatenated an iota that counts along our batch
# dimension to the front of the ndindex.
index_vector_dim = dimension_numbers.index_vector_dim + 1
counts = broadcasted_iota(start_indices.dtype, start_indices.shape, 0)
count_shape = list(start_indices.shape)
count_shape[index_vector_dim] = 1
counts = broadcasted_iota(start_indices.dtype, tuple(count_shape), 0)
start_indices = concatenate([counts, start_indices], index_vector_dim)
slice_sizes = (1,) + slice_sizes