mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
simplify gather shape rule
Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
parent
a61d5f6e78
commit
9e49500c54
@ -2631,22 +2631,11 @@ def _gather_shape_rule(operand, start_indices, dimension_numbers, slice_sizes,
|
||||
msg = ("slice_sizes must have rank equal to the gather operand; "
|
||||
"operand.shape={}, slice_sizes={}".format(operand_shape, slice_sizes))
|
||||
raise ValueError(msg)
|
||||
expanded_start_indices_shape = list(start_indices.shape)
|
||||
result_rank = len(dimension_numbers.offset_dims)
|
||||
result_rank += len(expanded_start_indices_shape) - 1
|
||||
output_shape = []
|
||||
offset_dims_seen = 0
|
||||
gather_dims_seen = 0
|
||||
for i in xrange(result_rank):
|
||||
if i in dimension_numbers.offset_dims:
|
||||
while offset_dims_seen in dimension_numbers.collapsed_slice_dims:
|
||||
offset_dims_seen += 1
|
||||
output_shape.append(slice_sizes[offset_dims_seen])
|
||||
offset_dims_seen += 1
|
||||
else:
|
||||
output_shape.append(expanded_start_indices_shape[gather_dims_seen])
|
||||
gather_dims_seen += 1
|
||||
return tuple(output_shape)
|
||||
result_rank = len(dimension_numbers.offset_dims) + start_indices.ndim - 1
|
||||
start_indices_shape = iter(start_indices.shape[:-1])
|
||||
slice_sizes = iter(onp.delete(slice_sizes, dimension_numbers.collapsed_slice_dims))
|
||||
return tuple(next(slice_sizes) if i in dimension_numbers.offset_dims
|
||||
else next(start_indices_shape) for i in range(result_rank))
|
||||
|
||||
def _gather_translation_rule(c, operand, start_indices, dimension_numbers,
|
||||
slice_sizes, operand_shape):
|
||||
|
Loading…
x
Reference in New Issue
Block a user