mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add support for mixing basic and advanced indexing in the same scatter operation.
This commit is contained in:
parent
b45ea2b416
commit
0850318a83
@ -563,10 +563,13 @@ def broadcast(operand, sizes):
|
||||
def broadcast_in_dim(operand, shape, broadcast_dimensions):
|
||||
if operand.ndim == len(shape) and not len(broadcast_dimensions):
|
||||
return operand
|
||||
else:
|
||||
return broadcast_in_dim_p.bind(
|
||||
operand, shape=tuple(shape),
|
||||
broadcast_dimensions=tuple(broadcast_dimensions))
|
||||
if any(x < 0 or x >= len(shape) for x in broadcast_dimensions):
|
||||
msg = ("broadcast dimensions must be >= 0 and < ndim(shape), got {} for "
|
||||
"shape {}")
|
||||
raise ValueError(msg.format(broadcast_dimensions, shape))
|
||||
return broadcast_in_dim_p.bind(
|
||||
operand, shape=tuple(shape),
|
||||
broadcast_dimensions=tuple(broadcast_dimensions))
|
||||
|
||||
def reshape(operand, new_sizes, dimensions=None):
|
||||
"""Wraps XLA's `Reshape
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from ..abstract_arrays import ShapedArray, ConcreteArray
|
||||
@ -39,6 +41,10 @@ def _is_advanced_int_indexer(idx):
|
||||
isinstance(idx, tuple) and all(onp.ndim(elt) == 0 for elt in idx))
|
||||
return out and np._is_advanced_int_indexer(idx)
|
||||
|
||||
def _triggers_unpack(x):
|
||||
return (isinstance(x, np.ndarray) or isinstance(x, collections.Sequence)
|
||||
or isinstance(x, slice) or x is Ellipsis or x is None)
|
||||
|
||||
def _scatter_update(x, idx, y, scatter_op):
|
||||
"""Helper for indexed updates.
|
||||
|
||||
@ -67,50 +73,49 @@ def _scatter_update(x, idx, y, scatter_op):
|
||||
y_shape = np.shape(y)
|
||||
y = lax.convert_element_type(y, lax.dtype(x))
|
||||
|
||||
# Check if there's advanced indexing going on, and handle differently based on
|
||||
# whether it is or isn't mixed with basic indexing.
|
||||
if _is_advanced_int_indexer(idx):
|
||||
if np._is_advanced_int_indexer_without_slices(idx):
|
||||
if isinstance(idx, (tuple, list)):
|
||||
if any(onp.shape(e) for e in idx):
|
||||
# At least one sequence element in the index list means broadcasting.
|
||||
idx = np.broadcast_arrays(*idx)
|
||||
else:
|
||||
# The index list is a flat list of integers.
|
||||
idx = [lax.concatenate([lax.reshape(e, (1,)) for e in idx], 0)]
|
||||
else:
|
||||
# The indexer is just a single integer array.
|
||||
idx = [idx]
|
||||
|
||||
stacked_idx = np.concatenate(
|
||||
[np.mod(np.reshape(a, a.shape + (1,)), np._constant_like(a, x.shape[i]))
|
||||
for i, a in enumerate(idx)], axis=-1)
|
||||
|
||||
y = np.broadcast_to(y, idx[0].shape + onp.shape(x)[len(idx):])
|
||||
|
||||
dnums = lax.ScatterDimensionNumbers(
|
||||
update_window_dims=tuple(range(len(idx[0].shape), y.ndim)),
|
||||
inserted_window_dims=tuple(range(len(idx))),
|
||||
scatter_dims_to_operand_dims=tuple(range(len(idx))))
|
||||
return scatter_op(x, stacked_idx, y, dnums)
|
||||
elif np._is_advanced_int_indexer(idx):
|
||||
# TODO(mattjj, phawkins): one of us is going to implement this case someday
|
||||
msg = "Unimplemented case for indexed update. Open a feature request!"
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
assert False # unreachable
|
||||
|
||||
# At this point there's no advanced indexing going on, so we process each
|
||||
# element of the index one at a time to build up a scatter.
|
||||
# "Basic slicing is initiated if the selection object is a non-array,
|
||||
# non-tuple sequence containing slice objects, [Ellipses, or newaxis
|
||||
# objects]". Detects this case and canonicalizes to a tuple.
|
||||
if not isinstance(idx, tuple):
|
||||
idx = (idx,)
|
||||
if isinstance(idx, collections.Sequence) and not isinstance(idx, np.ndarray):
|
||||
if any(_triggers_unpack(i) for i in idx):
|
||||
idx = tuple(idx)
|
||||
else:
|
||||
idx = (idx,)
|
||||
else:
|
||||
idx = (idx,)
|
||||
|
||||
# Remove ellipses and add trailing slice(None)s.
|
||||
idx = np._canonicalize_tuple_index(x, idx)
|
||||
|
||||
# Check for advanced indexing.
|
||||
|
||||
# Do the advanced indexing axes appear contiguously? If not, NumPy semantics
|
||||
# move the advanced axes to the front.
|
||||
advanced_axes_are_contiguous = False
|
||||
|
||||
advanced_indexes = None
|
||||
|
||||
# The positions of the advanced indexing axes in `idx`.
|
||||
idx_advanced_axes = []
|
||||
|
||||
# The positions of the advanced indexes in x's shape.
|
||||
# collapsed, after None axes have been removed. See below.
|
||||
x_advanced_axes = None
|
||||
|
||||
if _is_advanced_int_indexer(idx):
|
||||
idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
|
||||
advanced_pairs = (
|
||||
(np.asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
|
||||
if (isinstance(e, collections.Sequence) or isinstance(e, np.ndarray)))
|
||||
advanced_pairs = ((np.mod(e, np._constant_like(e, x_shape[j])), i, j)
|
||||
for e, i, j in advanced_pairs)
|
||||
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
|
||||
advanced_axes_are_contiguous = onp.all(onp.diff(idx_advanced_axes) == 1)
|
||||
|
||||
_int = lambda aval: not aval.shape and onp.issubdtype(aval.dtype, onp.integer)
|
||||
|
||||
x_axis = 0
|
||||
x_axis = 0 # Current axis in x.
|
||||
y_axis = 0 # Current axis in y, before collapsing. See below.
|
||||
collapsed_y_axis = 0 # Current axis in y, after collapsing.
|
||||
|
||||
@ -131,7 +136,35 @@ def _scatter_update(x, idx, y, scatter_op):
|
||||
# Finally, we reverse reversed_y_dims to handle slices with negative strides.
|
||||
reversed_y_dims = []
|
||||
|
||||
for i in idx:
|
||||
|
||||
for idx_pos, i in enumerate(idx):
|
||||
# If the advanced indices are not contiguous they are moved to the front
|
||||
# of the slice. Otherwise, they replace the chunk of advanced indices.
|
||||
if (advanced_indexes is not None and
|
||||
(advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or
|
||||
not advanced_axes_are_contiguous and idx_pos == 0)):
|
||||
advanced_indexes = np.broadcast_arrays(*advanced_indexes)
|
||||
shape = advanced_indexes[0].shape
|
||||
ndim = len(shape)
|
||||
advanced_indexes = [
|
||||
lax.convert_element_type(lax.reshape(a, shape + (1,)), np.int32)
|
||||
for a in advanced_indexes]
|
||||
|
||||
scatter_indices = lax.broadcast_in_dim(
|
||||
scatter_indices, onp.insert(scatter_indices.shape, -1, shape),
|
||||
tuple(range(scatter_indices.ndim - 1)) + (scatter_indices.ndim + ndim - 1,))
|
||||
scatter_indices = np.concatenate([scatter_indices] + advanced_indexes, -1)
|
||||
scatter_dims_to_operand_dims.extend(x_advanced_axes)
|
||||
inserted_window_dims.extend(x_advanced_axes)
|
||||
slice_shape.extend(shape)
|
||||
collapsed_slice_shape.extend(shape)
|
||||
y_axis += ndim
|
||||
collapsed_y_axis += ndim
|
||||
|
||||
if idx_pos in idx_advanced_axes:
|
||||
x_axis += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
abstract_i = core.get_aval(i)
|
||||
except TypeError:
|
||||
@ -200,7 +233,7 @@ def _scatter_update(x, idx, y, scatter_op):
|
||||
|
||||
dnums = lax.ScatterDimensionNumbers(
|
||||
update_window_dims = tuple(update_window_dims),
|
||||
inserted_window_dims = tuple(inserted_window_dims),
|
||||
inserted_window_dims = tuple(sorted(inserted_window_dims)),
|
||||
scatter_dims_to_operand_dims = tuple(scatter_dims_to_operand_dims)
|
||||
)
|
||||
return scatter_op(x, scatter_indices, y, dnums)
|
||||
|
@ -316,7 +316,7 @@ ADVANCED_INDEXING_TESTS_NO_REPEATS = [
|
||||
]),
|
||||
]
|
||||
|
||||
MIXED_ADVANCED_INDEXING_TESTS = [
|
||||
MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [
|
||||
("SlicesAndOneIntArrayIndex",
|
||||
[IndexSpec(shape=(2, 3), indexer=(onp.array([0, 1]), slice(1, 2))),
|
||||
IndexSpec(shape=(2, 3), indexer=(slice(0, 2),
|
||||
@ -325,7 +325,7 @@ MIXED_ADVANCED_INDEXING_TESTS = [
|
||||
onp.array([0, 2]),
|
||||
slice(None))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis,
|
||||
onp.array([[0, 2], [1, 1]]),
|
||||
onp.array([[0, 2], [1, 3]]),
|
||||
slice(None))),
|
||||
]),
|
||||
("SlicesAndTwoIntArrayIndices",
|
||||
@ -346,10 +346,7 @@ MIXED_ADVANCED_INDEXING_TESTS = [
|
||||
onp.array([-1, 2]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, -2]),
|
||||
slice(None, None, 2),
|
||||
onp.array([-1, 2, -1]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([[0, 2], [2, 0]]),
|
||||
Ellipsis,
|
||||
onp.array([[1, 0], [1, 0]]))),
|
||||
onp.array([-1, 2, 1]))),
|
||||
]),
|
||||
("NonesAndIntArrayIndices",
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=[onp.array([0, 2]),
|
||||
@ -370,6 +367,22 @@ MIXED_ADVANCED_INDEXING_TESTS = [
|
||||
]),
|
||||
]
|
||||
|
||||
MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [
|
||||
("SlicesAndOneIntArrayIndex",
|
||||
[
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis,
|
||||
onp.array([[0, 2], [1, 1]]),
|
||||
slice(None))),
|
||||
]),
|
||||
("SlicesAndTwoIntArrayIndices",
|
||||
[IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, -2]),
|
||||
slice(None, None, 2),
|
||||
onp.array([-1, 2, -1]))),
|
||||
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([[0, 2], [2, 0]]),
|
||||
Ellipsis,
|
||||
onp.array([[1, 0], [1, 0]]))),
|
||||
]),]
|
||||
|
||||
class IndexingTest(jtu.JaxTestCase):
|
||||
"""Tests for Numpy indexing translation rules."""
|
||||
|
||||
@ -794,6 +807,28 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
|
||||
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
|
||||
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
|
||||
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer,
|
||||
"update_shape": update_shape, "update_dtype": update_dtype,
|
||||
"op": op
|
||||
} for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
|
||||
for shape, indexer in index_specs
|
||||
for op in UpdateOps
|
||||
for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
|
||||
for update_shape in _broadcastable_shapes(_update_shape(shape, indexer))
|
||||
for update_dtype in ([dtype] if op == UpdateOps.ADD else all_dtypes)
|
||||
for rng in [jtu.rand_default()]))
|
||||
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||
rng, indexer, op):
|
||||
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
|
||||
onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
|
||||
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
|
||||
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list({
|
||||
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
|
||||
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
|
||||
|
Loading…
x
Reference in New Issue
Block a user