Use select instead of rem to handle index wraparound.

This commit is contained in:
Peter Hawkins 2019-08-15 15:22:55 -04:00
parent a36c08291a
commit 6d357fe884
2 changed files with 17 additions and 8 deletions

View File

@ -1282,10 +1282,8 @@ def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0):
slice_sizes = list(operand.shape)
axis = int(axis)
axis_size = _const(start_index, operand.shape[axis])
start_indices[axis] = rem(start_index, axis_size)
start_indices[axis] = start_index
slice_sizes[axis] = int(slice_size)
return dynamic_slice(operand, start_indices, slice_sizes)
@ -1301,7 +1299,7 @@ def dynamic_index_in_dim(operand, index, axis=0, keepdims=True):
def dynamic_update_slice_in_dim(operand, update, start_index, axis):
axis = int(axis)
start_indices = [0] * _ndim(operand)
start_indices[axis] = start_index % operand.shape[axis]
start_indices[axis] = start_index
return dynamic_update_slice(operand, update, start_indices)
@ -4216,12 +4214,16 @@ def _dynamic_slice_indices(operand, start_indices):
.format(start_indices.shape))
start_indices = [reshape(slice(start_indices, [i], [i+1]), ())
for i in range(operand.ndim)]
else:
start_indices = [onp.asarray(i) if isinstance(i, int) else i
for i in start_indices]
if len(start_indices) != operand.ndim:
msg = ("Length of slice indices must match number of operand dimensions ({} "
"vs {})")
raise ValueError(msg.format(len(start_indices, operand.shape)))
# map int over operand.shape to raise any dynamic-shape errors
return [rem(i, int(d)) for i, d in zip(start_indices, operand.shape)]
return [select(lt(i, _const(i, 0)), add(i, _const(i, int(d))), i)
for i, d in zip(start_indices, operand.shape)]

View File

@ -2169,6 +2169,13 @@ def take(a, indices, axis=None, out=None, mode=None):
slice_sizes=tuple(slice_sizes))
def _normalize_index(index, axis_size):
"""Normalizes an index value in the range [-N, N) to the range [0, N)."""
return lax.select(
lax.lt(index, _constant_like(index, 0)),
lax.add(index, _constant_like(index, axis_size)),
index)
@partial(jit, static_argnums=(2,))
def _take_along_axis(arr, indices, axis):
if axis is None:
@ -2199,7 +2206,7 @@ def _take_along_axis(arr, indices, axis):
j = 0
for i in range(rank):
if i == axis:
indices = indices % _constant_like(indices, axis_size)
indices = _normalize_index(indices, axis_size)
gather_indices.append(lax.reshape(indices, gather_index_shape))
slice_sizes.append(1)
start_index_map.append(i)
@ -2306,7 +2313,7 @@ def _index_to_gather(x_shape, idx):
advanced_pairs = (
(asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
if (isinstance(e, collections.Sequence) or isinstance(e, ndarray)))
advanced_pairs = ((mod(e, _constant_like(e, x_shape[j])), i, j)
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
for e, i, j in advanced_pairs)
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
advanced_axes_are_contiguous = onp.all(onp.diff(idx_advanced_axes) == 1)
@ -2376,7 +2383,7 @@ def _index_to_gather(x_shape, idx):
# Handle basic int indexes.
if (isinstance(abstract_i, ConcreteArray) or
isinstance(abstract_i, ShapedArray)) and _int(abstract_i):
i = mod(i, _constant_like(i, x_shape[x_axis]))
i = _normalize_index(i, x_shape[x_axis])
i = lax.convert_element_type(i, int32)
i = broadcast_to(i, tuple(gather_indices.shape[:-1]) + (1,))
gather_indices = concatenate((gather_indices, i), -1)