Add support for mixing basic and advanced indexing in the same scatter operation.

This commit is contained in:
Peter Hawkins 2019-07-14 10:57:41 -04:00
parent b45ea2b416
commit 0850318a83
3 changed files with 120 additions and 49 deletions

View File

@ -563,10 +563,13 @@ def broadcast(operand, sizes):
def broadcast_in_dim(operand, shape, broadcast_dimensions):
if operand.ndim == len(shape) and not len(broadcast_dimensions):
return operand
else:
return broadcast_in_dim_p.bind(
operand, shape=tuple(shape),
broadcast_dimensions=tuple(broadcast_dimensions))
if any(x < 0 or x >= len(shape) for x in broadcast_dimensions):
msg = ("broadcast dimensions must be >= 0 and < ndim(shape), got {} for "
"shape {}")
raise ValueError(msg.format(broadcast_dimensions, shape))
return broadcast_in_dim_p.bind(
operand, shape=tuple(shape),
broadcast_dimensions=tuple(broadcast_dimensions))
def reshape(operand, new_sizes, dimensions=None):
"""Wraps XLA's `Reshape

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as onp
from ..abstract_arrays import ShapedArray, ConcreteArray
@ -39,6 +41,10 @@ def _is_advanced_int_indexer(idx):
isinstance(idx, tuple) and all(onp.ndim(elt) == 0 for elt in idx))
return out and np._is_advanced_int_indexer(idx)
def _triggers_unpack(x):
return (isinstance(x, np.ndarray) or isinstance(x, collections.Sequence)
or isinstance(x, slice) or x is Ellipsis or x is None)
def _scatter_update(x, idx, y, scatter_op):
"""Helper for indexed updates.
@ -67,50 +73,49 @@ def _scatter_update(x, idx, y, scatter_op):
y_shape = np.shape(y)
y = lax.convert_element_type(y, lax.dtype(x))
# Check if there's advanced indexing going on, and handle differently based on
# whether it is or isn't mixed with basic indexing.
if _is_advanced_int_indexer(idx):
if np._is_advanced_int_indexer_without_slices(idx):
if isinstance(idx, (tuple, list)):
if any(onp.shape(e) for e in idx):
# At least one sequence element in the index list means broadcasting.
idx = np.broadcast_arrays(*idx)
else:
# The index list is a flat list of integers.
idx = [lax.concatenate([lax.reshape(e, (1,)) for e in idx], 0)]
else:
# The indexer is just a single integer array.
idx = [idx]
stacked_idx = np.concatenate(
[np.mod(np.reshape(a, a.shape + (1,)), np._constant_like(a, x.shape[i]))
for i, a in enumerate(idx)], axis=-1)
y = np.broadcast_to(y, idx[0].shape + onp.shape(x)[len(idx):])
dnums = lax.ScatterDimensionNumbers(
update_window_dims=tuple(range(len(idx[0].shape), y.ndim)),
inserted_window_dims=tuple(range(len(idx))),
scatter_dims_to_operand_dims=tuple(range(len(idx))))
return scatter_op(x, stacked_idx, y, dnums)
elif np._is_advanced_int_indexer(idx):
# TODO(mattjj, phawkins): one of us is going to implement this case someday
msg = "Unimplemented case for indexed update. Open a feature request!"
raise NotImplementedError(msg)
else:
assert False # unreachable
# At this point there's no advanced indexing going on, so we process each
# element of the index one at a time to build up a scatter.
# "Basic slicing is initiated if the selection object is a non-array,
# non-tuple sequence containing slice objects, [Ellipses, or newaxis
# objects]". Detects this case and canonicalizes to a tuple.
if not isinstance(idx, tuple):
idx = (idx,)
if isinstance(idx, collections.Sequence) and not isinstance(idx, np.ndarray):
if any(_triggers_unpack(i) for i in idx):
idx = tuple(idx)
else:
idx = (idx,)
else:
idx = (idx,)
# Remove ellipses and add trailing slice(None)s.
idx = np._canonicalize_tuple_index(x, idx)
# Check for advanced indexing.
# Do the advanced indexing axes appear contiguously? If not, NumPy semantics
# move the advanced axes to the front.
advanced_axes_are_contiguous = False
advanced_indexes = None
# The positions of the advanced indexing axes in `idx`.
idx_advanced_axes = []
# The positions of the advanced indexes in x's shape.
# collapsed, after None axes have been removed. See below.
x_advanced_axes = None
if _is_advanced_int_indexer(idx):
idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
advanced_pairs = (
(np.asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
if (isinstance(e, collections.Sequence) or isinstance(e, np.ndarray)))
advanced_pairs = ((np.mod(e, np._constant_like(e, x_shape[j])), i, j)
for e, i, j in advanced_pairs)
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
advanced_axes_are_contiguous = onp.all(onp.diff(idx_advanced_axes) == 1)
_int = lambda aval: not aval.shape and onp.issubdtype(aval.dtype, onp.integer)
x_axis = 0
x_axis = 0 # Current axis in x.
y_axis = 0 # Current axis in y, before collapsing. See below.
collapsed_y_axis = 0 # Current axis in y, after collapsing.
@ -131,7 +136,35 @@ def _scatter_update(x, idx, y, scatter_op):
# Finally, we reverse reversed_y_dims to handle slices with negative strides.
reversed_y_dims = []
for i in idx:
for idx_pos, i in enumerate(idx):
# If the advanced indices are not contiguous they are moved to the front
# of the slice. Otherwise, they replace the chunk of advanced indices.
if (advanced_indexes is not None and
(advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or
not advanced_axes_are_contiguous and idx_pos == 0)):
advanced_indexes = np.broadcast_arrays(*advanced_indexes)
shape = advanced_indexes[0].shape
ndim = len(shape)
advanced_indexes = [
lax.convert_element_type(lax.reshape(a, shape + (1,)), np.int32)
for a in advanced_indexes]
scatter_indices = lax.broadcast_in_dim(
scatter_indices, onp.insert(scatter_indices.shape, -1, shape),
tuple(range(scatter_indices.ndim - 1)) + (scatter_indices.ndim + ndim - 1,))
scatter_indices = np.concatenate([scatter_indices] + advanced_indexes, -1)
scatter_dims_to_operand_dims.extend(x_advanced_axes)
inserted_window_dims.extend(x_advanced_axes)
slice_shape.extend(shape)
collapsed_slice_shape.extend(shape)
y_axis += ndim
collapsed_y_axis += ndim
if idx_pos in idx_advanced_axes:
x_axis += 1
continue
try:
abstract_i = core.get_aval(i)
except TypeError:
@ -200,7 +233,7 @@ def _scatter_update(x, idx, y, scatter_op):
dnums = lax.ScatterDimensionNumbers(
update_window_dims = tuple(update_window_dims),
inserted_window_dims = tuple(inserted_window_dims),
inserted_window_dims = tuple(sorted(inserted_window_dims)),
scatter_dims_to_operand_dims = tuple(scatter_dims_to_operand_dims)
)
return scatter_op(x, scatter_indices, y, dnums)

View File

@ -316,7 +316,7 @@ ADVANCED_INDEXING_TESTS_NO_REPEATS = [
]),
]
MIXED_ADVANCED_INDEXING_TESTS = [
MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [
("SlicesAndOneIntArrayIndex",
[IndexSpec(shape=(2, 3), indexer=(onp.array([0, 1]), slice(1, 2))),
IndexSpec(shape=(2, 3), indexer=(slice(0, 2),
@ -325,7 +325,7 @@ MIXED_ADVANCED_INDEXING_TESTS = [
onp.array([0, 2]),
slice(None))),
IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis,
onp.array([[0, 2], [1, 1]]),
onp.array([[0, 2], [1, 3]]),
slice(None))),
]),
("SlicesAndTwoIntArrayIndices",
@ -346,10 +346,7 @@ MIXED_ADVANCED_INDEXING_TESTS = [
onp.array([-1, 2]))),
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, -2]),
slice(None, None, 2),
onp.array([-1, 2, -1]))),
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([[0, 2], [2, 0]]),
Ellipsis,
onp.array([[1, 0], [1, 0]]))),
onp.array([-1, 2, 1]))),
]),
("NonesAndIntArrayIndices",
[IndexSpec(shape=(3, 4, 5), indexer=[onp.array([0, 2]),
@ -370,6 +367,22 @@ MIXED_ADVANCED_INDEXING_TESTS = [
]),
]
MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [
("SlicesAndOneIntArrayIndex",
[
IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis,
onp.array([[0, 2], [1, 1]]),
slice(None))),
]),
("SlicesAndTwoIntArrayIndices",
[IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, -2]),
slice(None, None, 2),
onp.array([-1, 2, -1]))),
IndexSpec(shape=(3, 4, 5), indexer=(onp.array([[0, 2], [2, 0]]),
Ellipsis,
onp.array([[1, 0], [1, 0]]))),
]),]
class IndexingTest(jtu.JaxTestCase):
"""Tests for Numpy indexing translation rules."""
@ -794,6 +807,28 @@ class IndexedUpdateTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer,
"update_shape": update_shape, "update_dtype": update_dtype,
"op": op
} for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
for shape, indexer in index_specs
for op in UpdateOps
for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
for update_shape in _broadcastable_shapes(_update_shape(shape, indexer))
for update_dtype in ([dtype] if op == UpdateOps.ADD else all_dtypes)
for rng in [jtu.rand_default()]))
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
rng, indexer, op):
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
name, jtu.format_shape_dtype_string(shape, dtype), indexer,