improve error messages for lax.slice/index funs

c.f. #292
This commit is contained in:
Matthew Johnson 2019-02-02 21:41:06 -08:00
parent 1ab4a2ea54
commit 0afb6202c9

View File

@ -540,15 +540,17 @@ def slice_in_dim(operand, start_index, limit_index, stride=1, axis=0):
limit_indices = list(operand.shape)
strides = [1] * operand.ndim
start_indices[axis] = start_index
limit_indices[axis] = limit_index
strides[axis] = stride
axis = int(axis)
start_indices[axis] = int(start_index)
limit_indices[axis] = int(limit_index)
strides[axis] = int(stride)
return slice(operand, start_indices, limit_indices, strides)
def index_in_dim(operand, index, axis=0, keepdims=True):
"""Convenience wrapper around slice to perform int indexing."""
index, axis = int(index), int(axis)
axis_size = operand.shape[axis]
wrapped_index = index + axis_size if index < 0 else index
if not 0 <= wrapped_index < axis_size:
@ -566,8 +568,9 @@ def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0):
start_indices = [onp.array([0])] * operand.ndim
slice_sizes = list(operand.shape)
axis = int(axis)
start_indices[axis] = reshape(rem(start_index, operand.shape[axis]), [1])
slice_sizes[axis] = slice_size
slice_sizes[axis] = int(slice_size)
start_indices = concatenate(start_indices, 0)
return dynamic_slice(operand, start_indices, slice_sizes)
@ -583,12 +586,14 @@ def dynamic_index_in_dim(operand, index, axis=0, keepdims=True):
def dynamic_update_slice_in_dim(operand, update, start_index, axis):
axis = int(axis)
start_indices = [0] * _ndim(operand)
start_indices[axis] = start_index % operand.shape[axis]
return dynamic_update_slice(operand, update, start_indices)
def dynamic_update_index_in_dim(operand, update, index, axis):
axis = int(axis)
if _ndim(update) != _ndim(operand):
assert _ndim(update) + 1 == _ndim(operand)
ax = axis % _ndim(operand)