Use gather as the batched form of dynamic_slice.

Avoids building an unrolled slice/concatenate structure that is linear in the size of the batch dimension.
This commit is contained in:
Peter Hawkins 2019-03-13 11:10:50 -04:00
parent cd1e4601c7
commit 8fe8f23b27

View File

@ -2664,29 +2664,17 @@ def _dynamic_slice_transpose_rule(t, operand, start_indices, slice_sizes,
return [dynamic_update_slice(zeros, t, start_indices), ad_util.zero]
def _dynamic_slice_batching_rule(batched_args, batch_dims, slice_sizes,
**unused_kwargs):
operand, start_indices = batched_args
op_bdim, idx_bdim = batch_dims
operand_shape):
# A dynamic slice is a special case of gather; we can delegate to the gather
# batching rule.
# TODO(phawkins): consider removing dynamic_slice entirely and using gather
# always.
dims = tuple(range(len(operand_shape)))
dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(),
start_index_map=dims)
return _gather_batching_rule(batched_args, batch_dims, dnums, slice_sizes,
operand_shape)
if idx_bdim is None:
new_start_indices = concatenate(
[start_indices[:op_bdim], _zeros(start_indices, shape=(1,)),
start_indices[op_bdim:]], 0)
new_slice_sizes = list(slice_sizes)
new_slice_sizes.insert(op_bdim, operand.shape[op_bdim])
out = dynamic_slice(operand, new_start_indices, new_slice_sizes)
return out, op_bdim
else:
# TODO(mattjj): add support for Gather HLO, use it here
start_indices = batching.bdim_at_front(start_indices, idx_bdim)
if op_bdim is None:
out = concatenate([dynamic_slice(operand, idx, slice_sizes)
for idx in start_indices], 0)
else:
operand = batching.bdim_at_front(operand, op_bdim)
out = concatenate([dynamic_slice(op, idx, slice_sizes)
for op, idx in zip(operand, start_indices)], 0)
return out, 0
dynamic_slice_p = standard_primitive(
_dynamic_slice_shape_rule, _input_dtype, 'dynamic_slice',