mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Support None and negative indices in slice_in_dim
This commit is contained in:
parent
b15a27a7fc
commit
48cb6af6b4
@ -1309,6 +1309,17 @@ def slice_in_dim(operand, start_index, limit_index, stride=1, axis=0):
|
||||
limit_indices = list(operand.shape)
|
||||
strides = [1] * operand.ndim
|
||||
|
||||
# translate `None`
|
||||
len_axis = operand.shape[axis]
|
||||
start_index = start_index if start_index is not None else 0
|
||||
limit_index = limit_index if limit_index is not None else len_axis
|
||||
|
||||
# translate negative indices
|
||||
if start_index < 0:
|
||||
start_index = start_index + len_axis
|
||||
if limit_index < 0:
|
||||
limit_index = limit_index + len_axis
|
||||
|
||||
axis = int(axis)
|
||||
start_indices[axis] = int(start_index)
|
||||
limit_indices[axis] = int(limit_index)
|
||||
|
Loading…
x
Reference in New Issue
Block a user