mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
gather passing all operand vmap tests
This commit is contained in:
parent
b6cb3509cd
commit
b53eb241f7
48
jax/lax.py
48
jax/lax.py
@ -2085,40 +2085,32 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
|
||||
operand_bdim, start_indices_bdim = batch_dims
|
||||
|
||||
if operand_bdim is not None and start_indices_bdim is None:
|
||||
collapsed_slice_dims = set(dimension_numbers.collapsed_slice_dims)
|
||||
num_preceding_window_dims = sum( # 2
|
||||
1 for i in range(len(slice_sizes))
|
||||
if i < operand_bdim and i not in collapsed_slice_dims)
|
||||
offset_dims = list(dimension_numbers.offset_dims) # [2, 3, 7]
|
||||
if num_preceding_window_dims == 0:
|
||||
bdim_offset_dim = 0
|
||||
else:
|
||||
bdim_offset_dim = offset_dims[num_preceding_window_dims - 1] + 1
|
||||
new_offset_dims = (offset_dims[:num_preceding_window_dims]
|
||||
+ [bdim_offset_dim]
|
||||
+ list(onp.add(1, offset_dims[num_preceding_window_dims:])))
|
||||
new_offset_dims = tuple(new_offset_dims)
|
||||
|
||||
slice_sizes = list(slice_sizes)
|
||||
slice_sizes.insert(operand_bdim, operand.shape[operand_bdim])
|
||||
slice_sizes = tuple(slice_sizes)
|
||||
|
||||
collapsed_slice_dims = tuple(i + 1 if i >= operand_bdim else i
|
||||
for i in dimension_numbers.collapsed_slice_dims)
|
||||
|
||||
start_index_map = tuple(i + 1 if i > operand_bdim else i
|
||||
for i in dimension_numbers.start_index_map)
|
||||
|
||||
operand = batching.move_dim_to_front(operand, operand_bdim)
|
||||
slice_sizes = (operand.shape[0],) + slice_sizes
|
||||
offset_dims = (0,) + tuple(onp.add(1, dimension_numbers.offset_dims))
|
||||
collapsed_slice_dims = tuple(onp.add(1, dimension_numbers.collapsed_slice_dims))
|
||||
start_index_map = tuple(onp.add(1, dimension_numbers.start_index_map))
|
||||
dnums = GatherDimensionNumbers(
|
||||
offset_dims=new_offset_dims,
|
||||
offset_dims=offset_dims,
|
||||
collapsed_slice_dims=collapsed_slice_dims,
|
||||
start_index_map=start_index_map,
|
||||
index_vector_dim=dimension_numbers.index_vector_dim)
|
||||
|
||||
return gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes), bdim_offset_dim
|
||||
slice_sizes=slice_sizes), 0
|
||||
|
||||
elif operand_bdim is None and start_indices_bdim is not None:
|
||||
start_indices = batching.move_dim_to_front(start_indices, start_indices_bdim)
|
||||
offset_dims = tuple(onp.add(1, dimension_numbers.offset_dims))
|
||||
index_vector_dim = dimension_numbers.index_vector_dim + 1
|
||||
start_index_map = (0,) + tuple(dimension_numbers.start_index_map)
|
||||
dnums = GatherDimensionNumbers(
|
||||
offset_dims=offset_dims,
|
||||
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
|
||||
start_index_map=start_index_map, index_vector_dim=index_vector_dim)
|
||||
return gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes), 0
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj):
|
||||
raise NotImplementedError # TODO(mattjj, phawkins)
|
||||
|
||||
|
||||
gather_p = standard_primitive(
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user