[sparse] add support for bcoo equivalent of lax.slice

This commit is contained in:
Jake VanderPlas 2022-08-25 09:19:44 -07:00
parent 5527966b27
commit 269e75273a
4 changed files with 136 additions and 1 deletions

View File

@ -203,6 +203,7 @@ from jax.experimental.sparse.bcoo import (
bcoo_update_layout as bcoo_update_layout,
bcoo_reduce_sum as bcoo_reduce_sum,
bcoo_reshape as bcoo_reshape,
bcoo_slice as bcoo_slice,
bcoo_sort_indices as bcoo_sort_indices,
bcoo_sort_indices_p as bcoo_sort_indices_p,
bcoo_spdot_general_p as bcoo_spdot_general_p,

View File

@ -17,7 +17,7 @@
import functools
from functools import partial
import operator
from typing import Any, NamedTuple, Sequence, Tuple
from typing import Any, NamedTuple, Optional, Sequence, Tuple
import warnings
import numpy as np
@ -1785,6 +1785,69 @@ def bcoo_reshape(mat, *, new_sizes, dimensions):
return BCOO((data, new_indices), shape=new_sizes)
def bcoo_slice(mat, *, start_indices: Sequence[int], limit_indices: Sequence[int],
strides: Optional[Sequence[int]]=None):
"""Sparse implementation of {func}`jax.lax.slice`.
Args:
operand: BCOO array to be reshaped.
start_indices: sequence of integers of length `mat.ndim` specifying the starting
indices of each slice.
limit_indices: sequence of integers of length `mat.ndim` specifying the ending
indices of each slice
strides: sequence of integers of length `mat.ndim` specifying the stride for
each slice
Returns:
out: BCOO array containing the slice.
"""
if not isinstance(mat, BCOO):
raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
start_indices = [operator.index(i) for i in start_indices]
limit_indices = [operator.index(i) for i in limit_indices]
if strides is not None:
strides = [operator.index(i) for i in strides]
else:
strides = [1] * mat.ndim
if len(start_indices) != len(limit_indices) != len(strides) != mat.ndim:
raise ValueError(f"bcoo_slice: indices must have size mat.ndim={mat.ndim}")
if strides != [1] * mat.ndim:
raise NotImplementedError(f"non-unit strides; got {strides}")
if not all(0 <= start <= end <= size
for start, end, size in safe_zip(start_indices, limit_indices, mat.shape)):
raise ValueError(f"bcoo_slice: invalid indices. Got start_indices={start_indices}, "
f"limit_indices={limit_indices} and shape={mat.shape}")
start_batch, start_sparse, start_dense = split_list(start_indices, [mat.n_batch, mat.n_sparse])
end_batch, end_sparse, end_dense = split_list(limit_indices, [mat.n_batch, mat.n_sparse])
data_slices = []
index_slices = []
for i, (start, end) in enumerate(zip(start_batch, end_batch)):
data_slices.append(slice(None) if mat.data.shape[i] != mat.shape[i] else slice(start, end))
index_slices.append(slice(None) if mat.indices.shape[i] != mat.shape[i] else slice(start, end))
data_slices.append(slice(None))
index_slices.extend([slice(None), slice(None)])
for i, (start, end) in enumerate(zip(start_dense, end_dense)):
data_slices.append(slice(start, end))
new_data = mat.data[tuple(data_slices)]
new_indices = mat.indices[tuple(index_slices)]
new_shape = [end - start for start, end in safe_zip(start_indices, limit_indices)]
if mat.n_sparse:
starts = jnp.expand_dims(jnp.array(start_sparse, dtype=new_indices.dtype), range(mat.n_batch + 1))
ends = jnp.expand_dims(jnp.array(end_sparse, dtype=new_indices.dtype), range(mat.n_batch + 1))
sparse_shape = jnp.array(mat.shape[mat.n_batch: mat.n_batch + mat.n_sparse], dtype=new_indices.dtype)
keep = jnp.all((new_indices >= starts) & (new_indices < ends), -1, keepdims=True)
new_indices = jnp.where(keep, new_indices - starts, sparse_shape)
keep_data = lax.expand_dims(keep[..., 0], range(mat.n_batch + 1, mat.n_batch + 1 + mat.n_dense))
new_data = jnp.where(keep_data, new_data, 0)
return BCOO((new_data, new_indices), shape=new_shape)
def _tuple_replace(tup, ind, val):
return tuple(val if i == ind else t for i, t in enumerate(tup))

View File

@ -759,6 +759,13 @@ def _todense_sparse_rule(spenv, spvalue, *, tree):
sparse_rules[sparse.todense_p] = _todense_sparse_rule
def _slice_sparse_rule(spenv, *operands, **params):
args = spvalues_to_arrays(spenv, operands)
out = sparse.bcoo_slice(*args, **params)
return arrays_to_spvalues(spenv, [out])
sparse_rules[lax.slice_p] = _slice_sparse_rule
#------------------------------------------------------------------------------
# BCOO methods derived from sparsify
@ -775,6 +782,25 @@ def _reshape(self, *args, **kwargs):
def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
# mirrors lax_numpy._rewriting_take.
# Handle some special cases, falling back if error messages might differ.
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
return sparsify(lambda arr: lax.index_in_dim(arr, idx, keepdims=False))(arr)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
(type(idx.stop) is int or idx.stop is None) and
(type(idx.step) is int or idx.step is None)):
n = arr.shape[0]
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else n
step = idx.step if idx.step is not None else 1
if (0 <= start < n and 0 <= stop <= n and 0 < step and
(start, stop, step) != (0, n, 1)):
return sparsify(lambda arr: lax.slice_in_dim(arr, start, stop, step))(arr)
treedef, static_idx, dynamic_idx = lax_numpy._split_index_for_jit(idx, arr.shape)
result = sparsify(
lambda arr, idx: lax_numpy._gather(arr, treedef, static_idx, idx, indices_are_sorted,

View File

@ -36,6 +36,7 @@ from jax import lax
from jax._src.lib import xla_extension_version
from jax._src.lib import gpu_sparse
from jax._src.lib import xla_bridge
from jax._src.util import unzip2
from jax import jit
from jax import tree_util
from jax import vmap
@ -926,6 +927,50 @@ class BCOOTest(jtu.JaxTestCase):
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
self.assertArraysEqual(trans(M), trans(Msp).todense())
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)))
def test_bcoo_slice(self, shape, dtype, n_batch, n_dense):
rng = self.rng()
sprng = rand_sparse(rng)
M = sprng(shape, dtype)
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
rng = self.rng()
slices = rng.randint(0, M.shape, (2, M.ndim)).T
slices.sort(1)
start_indices, limit_indices = unzip2(slices)
strides = None # strides currently not implemented
kwds = dict(start_indices=start_indices, limit_indices=limit_indices, strides=strides)
dense_result = lax.slice(M, **kwds)
sparse_result = sparse.bcoo_slice(Msp, **kwds)
self.assertArraysEqual(dense_result, sparse_result.todense())
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}_idx={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, idx),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense,
"idx": idx}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)
for idx in [1, slice(1, 3)]))
def test_bcoo_getitem(self, shape, dtype, n_batch, n_dense, idx):
# Note: __getitem__ is currently only supported for simple slices and indexing
rng = self.rng()
sprng = rand_sparse(rng)
M = sprng(shape, dtype)
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
self.assertArraysEqual(M[idx], Msp[idx].todense())
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),