use more _const and _constant_like helpers

This commit is contained in:
Matthew Johnson 2019-02-13 08:25:11 -08:00
parent ea9c311349
commit 8df660e9ea
2 changed files with 5 additions and 5 deletions

View File

@ -571,7 +571,7 @@ def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0):
slice_sizes = list(operand.shape)
axis = int(axis)
axis_size = onp.array(operand.shape[axis], start_index.dtype)
axis_size = _const(start_index, operand.shape[axis])
start_indices[axis] = reshape(rem(start_index, axis_size), [1])
slice_sizes[axis] = int(slice_size)

View File

@ -1775,7 +1775,7 @@ def take(a, indices, axis=None, out=None, mode=None):
# TODO(phawkins): we have no way to report out of bounds errors yet.
raise NotImplementedError("The 'raise' mode to np.take is not supported.")
elif mode == "wrap":
indices = mod(indices, onp.array(a.shape[axis], _dtype(indices)))
indices = mod(indices, _constant_like(indices, a.shape[axis]))
elif mode != "clip" and mode is not None:
raise ValueError("Invalid mode '{}' for np.take".format(mode))
@ -1834,7 +1834,7 @@ def _rewriting_take(arr, idx, axis=0):
if isinstance(abstract_idx, ConcreteArray) and _int(abstract_idx):
return lax.index_in_dim(arr, idx, axis, False)
elif isinstance(abstract_idx, ShapedArray) and _int(abstract_idx):
idx = mod(idx, onp.array(arr.shape[axis], _dtype(idx)))
idx = mod(idx, _constant_like(idx, arr.shape[axis]))
return lax.dynamic_index_in_dim(arr, idx, axis, False)
# Handle slice index (only static, otherwise an error is raised)
@ -1903,7 +1903,7 @@ def _rewriting_take(arr, idx, axis=0):
# The indexer is just a single integer array.
idx = [idx]
flat_idx = tuple([mod(ravel(x), onp.array(arr.shape[i], _dtype(x)))
flat_idx = tuple([mod(ravel(x), _constant_like(x, arr.shape[i]))
for i, x in enumerate(idx)])
# TODO(mattjj): if we instead lower directly to lax.gather, we can probably
# eliminate the reshape here.
@ -1922,7 +1922,7 @@ def _rewriting_take(arr, idx, axis=0):
idx_advanced, axes = zip(*advanced_pairs)
idx_advanced = broadcast_arrays(*idx_advanced)
flat_idx = tuple(mod(ravel(x), onp.array(arr_sliced.shape[i], _dtype(x)))
flat_idx = tuple(mod(ravel(x), _constant_like(x, arr_sliced.shape[i]))
for i, x in zip(axes, idx_advanced))
# TODO(mattjj): if we instead lower directly to lax.gather, we can probably
# eliminate the reshape here.