simplify gather shape rule

Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
Matthew Johnson 2019-06-05 17:04:33 -07:00
parent a61d5f6e78
commit 9e49500c54

View File

@ -2631,22 +2631,11 @@ def _gather_shape_rule(operand, start_indices, dimension_numbers, slice_sizes,
msg = ("slice_sizes must have rank equal to the gather operand; "
"operand.shape={}, slice_sizes={}".format(operand_shape, slice_sizes))
raise ValueError(msg)
expanded_start_indices_shape = list(start_indices.shape)
result_rank = len(dimension_numbers.offset_dims)
result_rank += len(expanded_start_indices_shape) - 1
output_shape = []
offset_dims_seen = 0
gather_dims_seen = 0
for i in xrange(result_rank):
if i in dimension_numbers.offset_dims:
while offset_dims_seen in dimension_numbers.collapsed_slice_dims:
offset_dims_seen += 1
output_shape.append(slice_sizes[offset_dims_seen])
offset_dims_seen += 1
else:
output_shape.append(expanded_start_indices_shape[gather_dims_seen])
gather_dims_seen += 1
return tuple(output_shape)
result_rank = len(dimension_numbers.offset_dims) + start_indices.ndim - 1
start_indices_shape = iter(start_indices.shape[:-1])
slice_sizes = iter(onp.delete(slice_sizes, dimension_numbers.collapsed_slice_dims))
return tuple(next(slice_sizes) if i in dimension_numbers.offset_dims
else next(start_indices_shape) for i in range(result_rank))
def _gather_translation_rule(c, operand, start_indices, dimension_numbers,
slice_sizes, operand_shape):