Merge pull request #6897 from hawkinsp:indexunique

PiperOrigin-RevId: 378550369
This commit is contained in:
jax authors 2021-06-09 18:49:00 -07:00
commit 7540690157
4 changed files with 48 additions and 18 deletions

View File

@ -951,7 +951,10 @@ class GatherDimensionNumbers(NamedTuple):
def gather(operand: Array, start_indices: Array,
dimension_numbers: GatherDimensionNumbers,
slice_sizes: Shape) -> Array:
slice_sizes: Shape,
*,
unique_indices: bool = False,
indices_are_sorted: bool = False) -> Array:
"""Gather operator.
Wraps `XLA's Gather operator
@ -969,13 +972,20 @@ def gather(operand: Array, start_indices: Array,
how dimensions of `operand`, `start_indices` and the output relate.
slice_sizes: the size of each slice. Must be a sequence of non-negative
integers with length equal to `ndim(operand)`.
indices_are_sorted: whether `indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the indices in ``operand`` are
guaranteed to not overlap with each other. If true, may improve
performance on some backends.
Returns:
An array containing the gather output.
"""
return gather_p.bind(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=canonicalize_shape(slice_sizes))
slice_sizes=canonicalize_shape(slice_sizes),
unique_indices=bool(unique_indices),
indices_are_sorted=bool(indices_are_sorted))
class ScatterDimensionNumbers(NamedTuple):
@ -4231,7 +4241,7 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds)
return _gather_batching_rule(
[operand, index], [operand_bd, index_bdim], dimension_numbers=dnums,
slice_sizes=slice_sizes)
slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True)
dynamic_slice_p = standard_primitive(
@ -4368,7 +4378,7 @@ def _no_duplicate_dims(dims, op_name, name):
raise TypeError(f"{name} in {op_name} op must not repeat; got: {dims}.")
def _gather_shape_rule(operand, start_indices, *, dimension_numbers,
slice_sizes):
slice_sizes, unique_indices, indices_are_sorted):
"""Validates the well-formedness of the arguments to Gather.
The code implements the checks based on the detailed operation semantics of
@ -4481,19 +4491,23 @@ def _gather_shape_rule(operand, start_indices, *, dimension_numbers,
else next(start_indices_shape) for i in range(output_shape_rank))
def _gather_translation_rule(c, operand, start_indices, *, dimension_numbers,
slice_sizes):
slice_sizes, unique_indices, indices_are_sorted):
indices_shape = c.get_shape(start_indices)
# We don't consume unique_indices directly in gather(), only in its transpose
# (scatter).
return xops.Gather(
operand, start_indices,
_gather_dimensions_proto(indices_shape, dimension_numbers), slice_sizes,
indices_are_sorted=False)
indices_are_sorted=indices_are_sorted)
def _gather_jvp_rule(g, operand, start_indices, *, dimension_numbers,
slice_sizes):
return gather(g, start_indices, dimension_numbers, slice_sizes)
slice_sizes, unique_indices, indices_are_sorted):
return gather(g, start_indices, dimension_numbers, slice_sizes,
unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted)
def _gather_transpose_rule(t, operand, start_indices, *, dimension_numbers,
slice_sizes):
slice_sizes, unique_indices, indices_are_sorted):
assert ad.is_undefined_primal(operand)
operand_shape = operand.aval.shape
if type(t) is ad_util.Zero:
@ -4505,12 +4519,12 @@ def _gather_transpose_rule(t, operand, start_indices, *, dimension_numbers,
inserted_window_dims=dimension_numbers.collapsed_slice_dims,
scatter_dims_to_operand_dims=dimension_numbers.start_index_map)
out = scatter_add(zeros, start_indices, t, scatter_dnums,
indices_are_sorted=False,
unique_indices=False)
unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted)
return [out, None]
def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
slice_sizes):
slice_sizes, unique_indices, indices_are_sorted):
operand, start_indices = batched_args
operand_bdim, start_indices_bdim = batch_dims
@ -4525,7 +4539,8 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
collapsed_slice_dims=collapsed_slice_dims,
start_index_map=start_index_map)
return gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes), 0
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted), 0
elif operand_bdim is None and start_indices_bdim is not None:
start_indices = batching.moveaxis(start_indices, start_indices_bdim, 0)
@ -4534,8 +4549,11 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
offset_dims=offset_dims,
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
start_index_map=dimension_numbers.start_index_map)
# If batching indexed accesses into the same array, the batched gather may
# no longer have sorted or unique indices.
return gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes), 0
slice_sizes=slice_sizes, unique_indices=False,
indices_are_sorted=False), 0
else:
# move batch dimensions to the front to simplify logic
@ -4562,7 +4580,8 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
collapsed_slice_dims=collapsed_slice_dims,
start_index_map=start_index_map)
return gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes), 0
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted), 0
gather_p = standard_primitive(
_gather_shape_rule, _gather_dtype_rule, 'gather',

View File

@ -4976,7 +4976,9 @@ def _gather(arr, treedef, static_idx, dynamic_idx):
# We avoid generating a gather when indexer.gather_indices.size is empty.
if not core.is_empty_shape(indexer.gather_indices.shape):
y = lax.gather(y, indexer.gather_indices, indexer.dnums,
indexer.gather_slice_shape)
indexer.gather_slice_shape,
unique_indices=indexer.unique_indices,
indices_are_sorted=indexer.indices_are_sorted)
# Reverses axes with negative strides.
if indexer.reversed_y_dims:
@ -4998,6 +5000,12 @@ _Indexer = collections.namedtuple("_Indexer", [
# A GatherDimensionNumbers object describing the gather to perform.
"dnums",
# Are the gather_indices known to be non-overlapping and/or sorted?
# (In practice, these translate to "there no advanced indices", because
# only advanced indices could lead to index repetition.)
"unique_indices",
"indices_are_sorted",
# Slice dimensions that have negative strides, and so must be reversed after
# the gather.
"reversed_y_dims",
@ -5237,7 +5245,9 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
gather_slice_shape=gather_slice_shape,
reversed_y_dims=reversed_y_dims,
dnums=dnums,
gather_indices=gather_indices_array)
gather_indices=gather_indices_array,
unique_indices=advanced_indexes is None,
indices_are_sorted=advanced_indexes is None)
def _should_unpack_list_index(x):
"""Helper for _eliminate_deprecated_list_indexing."""

View File

@ -102,7 +102,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
)
out = scatter_op(x, indexer.gather_indices, y, dnums,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
unique_indices=indexer.unique_indices or unique_indices)
return lax.convert_element_type(out, dtype)

View File

@ -2169,6 +2169,7 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers):
@partial(bool_to_int8, argnums=0)
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
indices_are_sorted, unique_indices,
_in_avals, _out_aval):
"""Tensorflow implementation of gather."""
del _in_avals