[sparse] simplify BCOO indexing implementation

This commit is contained in:
Jake VanderPlas 2022-11-02 10:13:35 -07:00
parent 94ba43bfba
commit 6ed9b14d8d

View File

@ -789,30 +789,10 @@ def _reshape(self, *args, **kwargs):
def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
# mirrors lax_numpy._rewriting_take.
# Handle some special cases, falling back if error messages might differ.
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
return sparsify(lambda arr: lax.index_in_dim(arr, idx, keepdims=False))(arr)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
(type(idx.stop) is int or idx.stop is None) and
(type(idx.step) is int or idx.step is None)):
n = arr.shape[0]
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else n
step = idx.step if idx.step is not None else 1
if (0 <= start < n and 0 <= stop <= n and 0 < step and
(start, stop, step) != (0, n, 1)):
return sparsify(lambda arr: lax.slice_in_dim(arr, start, stop, step))(arr)
treedef, static_idx, dynamic_idx = lax_numpy._split_index_for_jit(idx, arr.shape)
result = sparsify(
lambda arr, idx: lax_numpy._gather(arr, treedef, static_idx, idx, indices_are_sorted,
unique_indices, mode, fill_value))(arr, dynamic_idx)
# Only sparsify the array argument; sparse indices not yet supported
result = sparsify(functools.partial(
lax_numpy._rewriting_take, idx=idx, indices_are_sorted=indices_are_sorted,
mode=mode, unique_indices=unique_indices, fill_value=fill_value))(arr)
# Account for a corner case in the rewriting_take implementation.
if not isinstance(result, BCOO) and np.size(result) == 0:
result = BCOO.fromdense(result)