mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use gather as the batched form of dynamic_slice.
Avoids building an unrolled slice/concatenate structure that is linear in the size of the batch dimension.
This commit is contained in:
parent
cd1e4601c7
commit
8fe8f23b27
32
jax/lax.py
32
jax/lax.py
@ -2664,29 +2664,17 @@ def _dynamic_slice_transpose_rule(t, operand, start_indices, slice_sizes,
|
||||
return [dynamic_update_slice(zeros, t, start_indices), ad_util.zero]
|
||||
|
||||
def _dynamic_slice_batching_rule(batched_args, batch_dims, slice_sizes,
|
||||
**unused_kwargs):
|
||||
operand, start_indices = batched_args
|
||||
op_bdim, idx_bdim = batch_dims
|
||||
operand_shape):
|
||||
# A dynamic slice is a special case of gather; we can delegate to the gather
|
||||
# batching rule.
|
||||
# TODO(phawkins): consider removing dynamic_slice entirely and using gather
|
||||
# always.
|
||||
dims = tuple(range(len(operand_shape)))
|
||||
dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(),
|
||||
start_index_map=dims)
|
||||
return _gather_batching_rule(batched_args, batch_dims, dnums, slice_sizes,
|
||||
operand_shape)
|
||||
|
||||
if idx_bdim is None:
|
||||
new_start_indices = concatenate(
|
||||
[start_indices[:op_bdim], _zeros(start_indices, shape=(1,)),
|
||||
start_indices[op_bdim:]], 0)
|
||||
new_slice_sizes = list(slice_sizes)
|
||||
new_slice_sizes.insert(op_bdim, operand.shape[op_bdim])
|
||||
out = dynamic_slice(operand, new_start_indices, new_slice_sizes)
|
||||
return out, op_bdim
|
||||
else:
|
||||
# TODO(mattjj): add support for Gather HLO, use it here
|
||||
start_indices = batching.bdim_at_front(start_indices, idx_bdim)
|
||||
if op_bdim is None:
|
||||
out = concatenate([dynamic_slice(operand, idx, slice_sizes)
|
||||
for idx in start_indices], 0)
|
||||
else:
|
||||
operand = batching.bdim_at_front(operand, op_bdim)
|
||||
out = concatenate([dynamic_slice(op, idx, slice_sizes)
|
||||
for op, idx in zip(operand, start_indices)], 0)
|
||||
return out, 0
|
||||
|
||||
dynamic_slice_p = standard_primitive(
|
||||
_dynamic_slice_shape_rule, _input_dtype, 'dynamic_slice',
|
||||
|
Loading…
x
Reference in New Issue
Block a user