mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] simplify BCOO indexing implementation
This commit is contained in:
parent
94ba43bfba
commit
6ed9b14d8d
@ -789,30 +789,10 @@ def _reshape(self, *args, **kwargs):
|
||||
|
||||
def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None, fill_value=None):
|
||||
# mirrors lax_numpy._rewriting_take.
|
||||
|
||||
# Handle some special cases, falling back if error messages might differ.
|
||||
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
|
||||
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
|
||||
if 0 <= idx < arr.shape[0]:
|
||||
return sparsify(lambda arr: lax.index_in_dim(arr, idx, keepdims=False))(arr)
|
||||
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
|
||||
isinstance(idx, slice) and
|
||||
(type(idx.start) is int or idx.start is None) and
|
||||
(type(idx.stop) is int or idx.stop is None) and
|
||||
(type(idx.step) is int or idx.step is None)):
|
||||
n = arr.shape[0]
|
||||
start = idx.start if idx.start is not None else 0
|
||||
stop = idx.stop if idx.stop is not None else n
|
||||
step = idx.step if idx.step is not None else 1
|
||||
if (0 <= start < n and 0 <= stop <= n and 0 < step and
|
||||
(start, stop, step) != (0, n, 1)):
|
||||
return sparsify(lambda arr: lax.slice_in_dim(arr, start, stop, step))(arr)
|
||||
|
||||
treedef, static_idx, dynamic_idx = lax_numpy._split_index_for_jit(idx, arr.shape)
|
||||
result = sparsify(
|
||||
lambda arr, idx: lax_numpy._gather(arr, treedef, static_idx, idx, indices_are_sorted,
|
||||
unique_indices, mode, fill_value))(arr, dynamic_idx)
|
||||
# Only sparsify the array argument; sparse indices not yet supported
|
||||
result = sparsify(functools.partial(
|
||||
lax_numpy._rewriting_take, idx=idx, indices_are_sorted=indices_are_sorted,
|
||||
mode=mode, unique_indices=unique_indices, fill_value=fill_value))(arr)
|
||||
# Account for a corner case in the rewriting_take implementation.
|
||||
if not isinstance(result, BCOO) and np.size(result) == 0:
|
||||
result = BCOO.fromdense(result)
|
||||
|
Loading…
x
Reference in New Issue
Block a user