mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix xla Shape vs shape tuple confusion and verify indexing test passes.
This commit is contained in:
parent
8de992d706
commit
f7bcb5a9c2
@ -2537,8 +2537,8 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers):
|
||||
proto.offset_dims.extend(dimension_numbers.offset_dims)
|
||||
proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims)
|
||||
proto.start_index_map.extend(dimension_numbers.start_index_map)
|
||||
assert len(indices_shape) > 0
|
||||
proto.index_vector_dim = len(indices_shape) - 1
|
||||
assert indices_shape.rank() > 0
|
||||
proto.index_vector_dim = indices_shape.rank() - 1
|
||||
return proto
|
||||
|
||||
def _gather_dtype_rule(operand, start_indices, **kwargs):
|
||||
@ -2688,8 +2688,8 @@ def _scatter_dimensions_proto(indices_shape, dimension_numbers):
|
||||
proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims)
|
||||
proto.scatter_dims_to_operand_dims.extend(
|
||||
dimension_numbers.scatter_dims_to_operand_dims)
|
||||
assert len(indices_shape) > 0
|
||||
proto.index_vector_dim = len(indices_shape) - 1
|
||||
assert indices_shape.rank() > 0
|
||||
proto.index_vector_dim = indices_shape.rank() - 1
|
||||
return proto
|
||||
|
||||
def _scatter_dtype_rule(operand, scatter_indices, updates, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user