Fix xla Shape vs shape tuple confusion and verify indexing test passes.

This commit is contained in:
Peter Hawkins 2019-03-01 11:05:04 -05:00
parent 8de992d706
commit f7bcb5a9c2

View File

@ -2537,8 +2537,8 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers):
proto.offset_dims.extend(dimension_numbers.offset_dims)
proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims)
proto.start_index_map.extend(dimension_numbers.start_index_map)
assert len(indices_shape) > 0
proto.index_vector_dim = len(indices_shape) - 1
assert indices_shape.rank() > 0
proto.index_vector_dim = indices_shape.rank() - 1
return proto
def _gather_dtype_rule(operand, start_indices, **kwargs):
@ -2688,8 +2688,8 @@ def _scatter_dimensions_proto(indices_shape, dimension_numbers):
proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims)
proto.scatter_dims_to_operand_dims.extend(
dimension_numbers.scatter_dims_to_operand_dims)
assert len(indices_shape) > 0
proto.index_vector_dim = len(indices_shape) - 1
assert indices_shape.rank() > 0
proto.index_vector_dim = indices_shape.rank() - 1
return proto
def _scatter_dtype_rule(operand, scatter_indices, updates, **kwargs):