mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
use more _const and _constant_like helpers
This commit is contained in:
parent
ea9c311349
commit
8df660e9ea
@ -571,7 +571,7 @@ def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0):
|
||||
slice_sizes = list(operand.shape)
|
||||
|
||||
axis = int(axis)
|
||||
axis_size = onp.array(operand.shape[axis], start_index.dtype)
|
||||
axis_size = _const(start_index, operand.shape[axis])
|
||||
start_indices[axis] = reshape(rem(start_index, axis_size), [1])
|
||||
slice_sizes[axis] = int(slice_size)
|
||||
|
||||
|
@ -1775,7 +1775,7 @@ def take(a, indices, axis=None, out=None, mode=None):
|
||||
# TODO(phawkins): we have no way to report out of bounds errors yet.
|
||||
raise NotImplementedError("The 'raise' mode to np.take is not supported.")
|
||||
elif mode == "wrap":
|
||||
indices = mod(indices, onp.array(a.shape[axis], _dtype(indices)))
|
||||
indices = mod(indices, _constant_like(indices, a.shape[axis]))
|
||||
elif mode != "clip" and mode is not None:
|
||||
raise ValueError("Invalid mode '{}' for np.take".format(mode))
|
||||
|
||||
@ -1834,7 +1834,7 @@ def _rewriting_take(arr, idx, axis=0):
|
||||
if isinstance(abstract_idx, ConcreteArray) and _int(abstract_idx):
|
||||
return lax.index_in_dim(arr, idx, axis, False)
|
||||
elif isinstance(abstract_idx, ShapedArray) and _int(abstract_idx):
|
||||
idx = mod(idx, onp.array(arr.shape[axis], _dtype(idx)))
|
||||
idx = mod(idx, _constant_like(idx, arr.shape[axis]))
|
||||
return lax.dynamic_index_in_dim(arr, idx, axis, False)
|
||||
|
||||
# Handle slice index (only static, otherwise an error is raised)
|
||||
@ -1903,7 +1903,7 @@ def _rewriting_take(arr, idx, axis=0):
|
||||
# The indexer is just a single integer array.
|
||||
idx = [idx]
|
||||
|
||||
flat_idx = tuple([mod(ravel(x), onp.array(arr.shape[i], _dtype(x)))
|
||||
flat_idx = tuple([mod(ravel(x), _constant_like(x, arr.shape[i]))
|
||||
for i, x in enumerate(idx)])
|
||||
# TODO(mattjj): if we instead lower directly to lax.gather, we can probably
|
||||
# eliminate the reshape here.
|
||||
@ -1922,7 +1922,7 @@ def _rewriting_take(arr, idx, axis=0):
|
||||
idx_advanced, axes = zip(*advanced_pairs)
|
||||
idx_advanced = broadcast_arrays(*idx_advanced)
|
||||
|
||||
flat_idx = tuple(mod(ravel(x), onp.array(arr_sliced.shape[i], _dtype(x)))
|
||||
flat_idx = tuple(mod(ravel(x), _constant_like(x, arr_sliced.shape[i]))
|
||||
for i, x in zip(axes, idx_advanced))
|
||||
# TODO(mattjj): if we instead lower directly to lax.gather, we can probably
|
||||
# eliminate the reshape here.
|
||||
|
Loading…
x
Reference in New Issue
Block a user