mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add numpy indexing batching tests
This commit is contained in:
parent
b53eb241f7
commit
6dfe2d6e36
21
jax/lax.py
21
jax/lax.py
@ -2102,15 +2102,30 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
|
||||
start_indices = batching.move_dim_to_front(start_indices, start_indices_bdim)
|
||||
offset_dims = tuple(onp.add(1, dimension_numbers.offset_dims))
|
||||
index_vector_dim = dimension_numbers.index_vector_dim + 1
|
||||
start_index_map = (0,) + tuple(dimension_numbers.start_index_map)
|
||||
dnums = GatherDimensionNumbers(
|
||||
offset_dims=offset_dims,
|
||||
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
|
||||
start_index_map=start_index_map, index_vector_dim=index_vector_dim)
|
||||
start_index_map=dimension_numbers.start_index_map,
|
||||
index_vector_dim=index_vector_dim)
|
||||
return gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes), 0
|
||||
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj, phawkins)
|
||||
# TODO(mattjj,phawkins): this code is wrong
|
||||
operand = batching.move_dim_to_front(operand, operand_bdim)
|
||||
start_indices = batching.move_dim_to_front(start_indices, start_indices_bdim)
|
||||
|
||||
slice_sizes = (1,) + slice_sizes
|
||||
collapsed_slice_dims = (0,) + tuple(onp.add(1, dimension_numbers.collapsed_slice_dims))
|
||||
offset_dims = tuple(onp.add(1, dimension_numbers.offset_dims))
|
||||
start_index_map = tuple(onp.add(1, dimension_numbers.start_index_map))
|
||||
dnums = GatherDimensionNumbers(
|
||||
offset_dims=offset_dims,
|
||||
collapsed_slice_dims=collapsed_slice_dims,
|
||||
start_index_map=start_index_map,
|
||||
index_vector_dim=dimension_numbers.index_vector_dim + 1)
|
||||
return gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes), 0
|
||||
|
||||
|
||||
gather_p = standard_primitive(
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user