mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use select
instead of rem
to handle index wraparound.
This commit is contained in:
parent
a36c08291a
commit
6d357fe884
@ -1282,10 +1282,8 @@ def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0):
|
||||
slice_sizes = list(operand.shape)
|
||||
|
||||
axis = int(axis)
|
||||
axis_size = _const(start_index, operand.shape[axis])
|
||||
start_indices[axis] = rem(start_index, axis_size)
|
||||
start_indices[axis] = start_index
|
||||
slice_sizes[axis] = int(slice_size)
|
||||
|
||||
return dynamic_slice(operand, start_indices, slice_sizes)
|
||||
|
||||
|
||||
@ -1301,7 +1299,7 @@ def dynamic_index_in_dim(operand, index, axis=0, keepdims=True):
|
||||
def dynamic_update_slice_in_dim(operand, update, start_index, axis):
|
||||
axis = int(axis)
|
||||
start_indices = [0] * _ndim(operand)
|
||||
start_indices[axis] = start_index % operand.shape[axis]
|
||||
start_indices[axis] = start_index
|
||||
return dynamic_update_slice(operand, update, start_indices)
|
||||
|
||||
|
||||
@ -4216,12 +4214,16 @@ def _dynamic_slice_indices(operand, start_indices):
|
||||
.format(start_indices.shape))
|
||||
start_indices = [reshape(slice(start_indices, [i], [i+1]), ())
|
||||
for i in range(operand.ndim)]
|
||||
else:
|
||||
start_indices = [onp.asarray(i) if isinstance(i, int) else i
|
||||
for i in start_indices]
|
||||
if len(start_indices) != operand.ndim:
|
||||
msg = ("Length of slice indices must match number of operand dimensions ({} "
|
||||
"vs {})")
|
||||
raise ValueError(msg.format(len(start_indices, operand.shape)))
|
||||
# map int over operand.shape to raise any dynamic-shape errors
|
||||
return [rem(i, int(d)) for i, d in zip(start_indices, operand.shape)]
|
||||
return [select(lt(i, _const(i, 0)), add(i, _const(i, int(d))), i)
|
||||
for i, d in zip(start_indices, operand.shape)]
|
||||
|
||||
|
||||
|
||||
|
@ -2169,6 +2169,13 @@ def take(a, indices, axis=None, out=None, mode=None):
|
||||
slice_sizes=tuple(slice_sizes))
|
||||
|
||||
|
||||
def _normalize_index(index, axis_size):
|
||||
"""Normalizes an index value in the range [-N, N) to the range [0, N)."""
|
||||
return lax.select(
|
||||
lax.lt(index, _constant_like(index, 0)),
|
||||
lax.add(index, _constant_like(index, axis_size)),
|
||||
index)
|
||||
|
||||
@partial(jit, static_argnums=(2,))
|
||||
def _take_along_axis(arr, indices, axis):
|
||||
if axis is None:
|
||||
@ -2199,7 +2206,7 @@ def _take_along_axis(arr, indices, axis):
|
||||
j = 0
|
||||
for i in range(rank):
|
||||
if i == axis:
|
||||
indices = indices % _constant_like(indices, axis_size)
|
||||
indices = _normalize_index(indices, axis_size)
|
||||
gather_indices.append(lax.reshape(indices, gather_index_shape))
|
||||
slice_sizes.append(1)
|
||||
start_index_map.append(i)
|
||||
@ -2306,7 +2313,7 @@ def _index_to_gather(x_shape, idx):
|
||||
advanced_pairs = (
|
||||
(asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
|
||||
if (isinstance(e, collections.Sequence) or isinstance(e, ndarray)))
|
||||
advanced_pairs = ((mod(e, _constant_like(e, x_shape[j])), i, j)
|
||||
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
|
||||
for e, i, j in advanced_pairs)
|
||||
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
|
||||
advanced_axes_are_contiguous = onp.all(onp.diff(idx_advanced_axes) == 1)
|
||||
@ -2376,7 +2383,7 @@ def _index_to_gather(x_shape, idx):
|
||||
# Handle basic int indexes.
|
||||
if (isinstance(abstract_i, ConcreteArray) or
|
||||
isinstance(abstract_i, ShapedArray)) and _int(abstract_i):
|
||||
i = mod(i, _constant_like(i, x_shape[x_axis]))
|
||||
i = _normalize_index(i, x_shape[x_axis])
|
||||
i = lax.convert_element_type(i, int32)
|
||||
i = broadcast_to(i, tuple(gather_indices.shape[:-1]) + (1,))
|
||||
gather_indices = concatenate((gather_indices, i), -1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user