mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix gather batching rule bug
This commit is contained in:
parent
d5ee720aea
commit
90d92a5a5c
@ -2126,7 +2126,9 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
|
||||
# (7, 3, 4, 6) where we concatenated an iota that counts along our batch
|
||||
# dimension to the front of the ndindex.
|
||||
index_vector_dim = dimension_numbers.index_vector_dim + 1
|
||||
counts = broadcasted_iota(start_indices.dtype, start_indices.shape, 0)
|
||||
count_shape = list(start_indices.shape)
|
||||
count_shape[index_vector_dim] = 1
|
||||
counts = broadcasted_iota(start_indices.dtype, tuple(count_shape), 0)
|
||||
start_indices = concatenate([counts, start_indices], index_vector_dim)
|
||||
|
||||
slice_sizes = (1,) + slice_sizes
|
||||
|
Loading…
x
Reference in New Issue
Block a user