mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6897 from hawkinsp:indexunique
PiperOrigin-RevId: 378550369
This commit is contained in:
commit
7540690157
@ -951,7 +951,10 @@ class GatherDimensionNumbers(NamedTuple):
|
||||
|
||||
def gather(operand: Array, start_indices: Array,
|
||||
dimension_numbers: GatherDimensionNumbers,
|
||||
slice_sizes: Shape) -> Array:
|
||||
slice_sizes: Shape,
|
||||
*,
|
||||
unique_indices: bool = False,
|
||||
indices_are_sorted: bool = False) -> Array:
|
||||
"""Gather operator.
|
||||
|
||||
Wraps `XLA's Gather operator
|
||||
@ -969,13 +972,20 @@ def gather(operand: Array, start_indices: Array,
|
||||
how dimensions of `operand`, `start_indices` and the output relate.
|
||||
slice_sizes: the size of each slice. Must be a sequence of non-negative
|
||||
integers with length equal to `ndim(operand)`.
|
||||
indices_are_sorted: whether `indices` is known to be sorted. If
|
||||
true, may improve performance on some backends.
|
||||
unique_indices: whether the indices in ``operand`` are
|
||||
guaranteed to not overlap with each other. If true, may improve
|
||||
performance on some backends.
|
||||
|
||||
Returns:
|
||||
An array containing the gather output.
|
||||
"""
|
||||
return gather_p.bind(
|
||||
operand, start_indices, dimension_numbers=dimension_numbers,
|
||||
slice_sizes=canonicalize_shape(slice_sizes))
|
||||
slice_sizes=canonicalize_shape(slice_sizes),
|
||||
unique_indices=bool(unique_indices),
|
||||
indices_are_sorted=bool(indices_are_sorted))
|
||||
|
||||
|
||||
class ScatterDimensionNumbers(NamedTuple):
|
||||
@ -4231,7 +4241,7 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
|
||||
index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds)
|
||||
return _gather_batching_rule(
|
||||
[operand, index], [operand_bd, index_bdim], dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes)
|
||||
slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True)
|
||||
|
||||
|
||||
dynamic_slice_p = standard_primitive(
|
||||
@ -4368,7 +4378,7 @@ def _no_duplicate_dims(dims, op_name, name):
|
||||
raise TypeError(f"{name} in {op_name} op must not repeat; got: {dims}.")
|
||||
|
||||
def _gather_shape_rule(operand, start_indices, *, dimension_numbers,
|
||||
slice_sizes):
|
||||
slice_sizes, unique_indices, indices_are_sorted):
|
||||
"""Validates the well-formedness of the arguments to Gather.
|
||||
|
||||
The code implements the checks based on the detailed operation semantics of
|
||||
@ -4481,19 +4491,23 @@ def _gather_shape_rule(operand, start_indices, *, dimension_numbers,
|
||||
else next(start_indices_shape) for i in range(output_shape_rank))
|
||||
|
||||
def _gather_translation_rule(c, operand, start_indices, *, dimension_numbers,
|
||||
slice_sizes):
|
||||
slice_sizes, unique_indices, indices_are_sorted):
|
||||
indices_shape = c.get_shape(start_indices)
|
||||
# We don't consume unique_indices directly in gather(), only in its transpose
|
||||
# (scatter).
|
||||
return xops.Gather(
|
||||
operand, start_indices,
|
||||
_gather_dimensions_proto(indices_shape, dimension_numbers), slice_sizes,
|
||||
indices_are_sorted=False)
|
||||
indices_are_sorted=indices_are_sorted)
|
||||
|
||||
def _gather_jvp_rule(g, operand, start_indices, *, dimension_numbers,
|
||||
slice_sizes):
|
||||
return gather(g, start_indices, dimension_numbers, slice_sizes)
|
||||
slice_sizes, unique_indices, indices_are_sorted):
|
||||
return gather(g, start_indices, dimension_numbers, slice_sizes,
|
||||
unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted)
|
||||
|
||||
def _gather_transpose_rule(t, operand, start_indices, *, dimension_numbers,
|
||||
slice_sizes):
|
||||
slice_sizes, unique_indices, indices_are_sorted):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
operand_shape = operand.aval.shape
|
||||
if type(t) is ad_util.Zero:
|
||||
@ -4505,12 +4519,12 @@ def _gather_transpose_rule(t, operand, start_indices, *, dimension_numbers,
|
||||
inserted_window_dims=dimension_numbers.collapsed_slice_dims,
|
||||
scatter_dims_to_operand_dims=dimension_numbers.start_index_map)
|
||||
out = scatter_add(zeros, start_indices, t, scatter_dnums,
|
||||
indices_are_sorted=False,
|
||||
unique_indices=False)
|
||||
unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted)
|
||||
return [out, None]
|
||||
|
||||
def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
slice_sizes):
|
||||
slice_sizes, unique_indices, indices_are_sorted):
|
||||
operand, start_indices = batched_args
|
||||
operand_bdim, start_indices_bdim = batch_dims
|
||||
|
||||
@ -4525,7 +4539,8 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
collapsed_slice_dims=collapsed_slice_dims,
|
||||
start_index_map=start_index_map)
|
||||
return gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes), 0
|
||||
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted), 0
|
||||
|
||||
elif operand_bdim is None and start_indices_bdim is not None:
|
||||
start_indices = batching.moveaxis(start_indices, start_indices_bdim, 0)
|
||||
@ -4534,8 +4549,11 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
offset_dims=offset_dims,
|
||||
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
|
||||
start_index_map=dimension_numbers.start_index_map)
|
||||
# If batching indexed accesses into the same array, the batched gather may
|
||||
# no longer have sorted or unique indices.
|
||||
return gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes), 0
|
||||
slice_sizes=slice_sizes, unique_indices=False,
|
||||
indices_are_sorted=False), 0
|
||||
|
||||
else:
|
||||
# move batch dimensions to the front to simplify logic
|
||||
@ -4562,7 +4580,8 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
collapsed_slice_dims=collapsed_slice_dims,
|
||||
start_index_map=start_index_map)
|
||||
return gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes), 0
|
||||
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted), 0
|
||||
|
||||
gather_p = standard_primitive(
|
||||
_gather_shape_rule, _gather_dtype_rule, 'gather',
|
||||
|
@ -4976,7 +4976,9 @@ def _gather(arr, treedef, static_idx, dynamic_idx):
|
||||
# We avoid generating a gather when indexer.gather_indices.size is empty.
|
||||
if not core.is_empty_shape(indexer.gather_indices.shape):
|
||||
y = lax.gather(y, indexer.gather_indices, indexer.dnums,
|
||||
indexer.gather_slice_shape)
|
||||
indexer.gather_slice_shape,
|
||||
unique_indices=indexer.unique_indices,
|
||||
indices_are_sorted=indexer.indices_are_sorted)
|
||||
|
||||
# Reverses axes with negative strides.
|
||||
if indexer.reversed_y_dims:
|
||||
@ -4998,6 +5000,12 @@ _Indexer = collections.namedtuple("_Indexer", [
|
||||
# A GatherDimensionNumbers object describing the gather to perform.
|
||||
"dnums",
|
||||
|
||||
# Are the gather_indices known to be non-overlapping and/or sorted?
|
||||
# (In practice, these translate to "there no advanced indices", because
|
||||
# only advanced indices could lead to index repetition.)
|
||||
"unique_indices",
|
||||
"indices_are_sorted",
|
||||
|
||||
# Slice dimensions that have negative strides, and so must be reversed after
|
||||
# the gather.
|
||||
"reversed_y_dims",
|
||||
@ -5237,7 +5245,9 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
|
||||
gather_slice_shape=gather_slice_shape,
|
||||
reversed_y_dims=reversed_y_dims,
|
||||
dnums=dnums,
|
||||
gather_indices=gather_indices_array)
|
||||
gather_indices=gather_indices_array,
|
||||
unique_indices=advanced_indexes is None,
|
||||
indices_are_sorted=advanced_indexes is None)
|
||||
|
||||
def _should_unpack_list_index(x):
|
||||
"""Helper for _eliminate_deprecated_list_indexing."""
|
||||
|
@ -102,7 +102,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
||||
)
|
||||
out = scatter_op(x, indexer.gather_indices, y, dnums,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices)
|
||||
unique_indices=indexer.unique_indices or unique_indices)
|
||||
return lax.convert_element_type(out, dtype)
|
||||
|
||||
|
||||
|
@ -2169,6 +2169,7 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers):
|
||||
|
||||
@partial(bool_to_int8, argnums=0)
|
||||
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
|
||||
indices_are_sorted, unique_indices,
|
||||
_in_avals, _out_aval):
|
||||
"""Tensorflow implementation of gather."""
|
||||
del _in_avals
|
||||
|
Loading…
x
Reference in New Issue
Block a user