mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
parent
1ab4a2ea54
commit
0afb6202c9
13
jax/lax.py
13
jax/lax.py
@ -540,15 +540,17 @@ def slice_in_dim(operand, start_index, limit_index, stride=1, axis=0):
|
||||
limit_indices = list(operand.shape)
|
||||
strides = [1] * operand.ndim
|
||||
|
||||
start_indices[axis] = start_index
|
||||
limit_indices[axis] = limit_index
|
||||
strides[axis] = stride
|
||||
axis = int(axis)
|
||||
start_indices[axis] = int(start_index)
|
||||
limit_indices[axis] = int(limit_index)
|
||||
strides[axis] = int(stride)
|
||||
|
||||
return slice(operand, start_indices, limit_indices, strides)
|
||||
|
||||
|
||||
def index_in_dim(operand, index, axis=0, keepdims=True):
|
||||
"""Convenience wrapper around slice to perform int indexing."""
|
||||
index, axis = int(index), int(axis)
|
||||
axis_size = operand.shape[axis]
|
||||
wrapped_index = index + axis_size if index < 0 else index
|
||||
if not 0 <= wrapped_index < axis_size:
|
||||
@ -566,8 +568,9 @@ def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0):
|
||||
start_indices = [onp.array([0])] * operand.ndim
|
||||
slice_sizes = list(operand.shape)
|
||||
|
||||
axis = int(axis)
|
||||
start_indices[axis] = reshape(rem(start_index, operand.shape[axis]), [1])
|
||||
slice_sizes[axis] = slice_size
|
||||
slice_sizes[axis] = int(slice_size)
|
||||
|
||||
start_indices = concatenate(start_indices, 0)
|
||||
return dynamic_slice(operand, start_indices, slice_sizes)
|
||||
@ -583,12 +586,14 @@ def dynamic_index_in_dim(operand, index, axis=0, keepdims=True):
|
||||
|
||||
|
||||
def dynamic_update_slice_in_dim(operand, update, start_index, axis):
|
||||
axis = int(axis)
|
||||
start_indices = [0] * _ndim(operand)
|
||||
start_indices[axis] = start_index % operand.shape[axis]
|
||||
return dynamic_update_slice(operand, update, start_indices)
|
||||
|
||||
|
||||
def dynamic_update_index_in_dim(operand, update, index, axis):
|
||||
axis = int(axis)
|
||||
if _ndim(update) != _ndim(operand):
|
||||
assert _ndim(update) + 1 == _ndim(operand)
|
||||
ax = axis % _ndim(operand)
|
||||
|
Loading…
x
Reference in New Issue
Block a user