[sparse] add bcoo_gather & support for sparse indexing

This commit is contained in:
Jake VanderPlas 2022-11-11 04:25:14 -08:00
parent a13541441b
commit 90dc008340
4 changed files with 117 additions and 15 deletions

View File

@ -199,6 +199,7 @@ from jax.experimental.sparse.bcoo import (
bcoo_extract_p as bcoo_extract_p,
bcoo_fromdense as bcoo_fromdense,
bcoo_fromdense_p as bcoo_fromdense_p,
bcoo_gather as bcoo_gather,
bcoo_multiply_dense as bcoo_multiply_dense,
bcoo_multiply_sparse as bcoo_multiply_sparse,
bcoo_update_layout as bcoo_update_layout,

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import functools
from functools import partial
import operator
from typing import Any, NamedTuple, Optional, Sequence, Tuple
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
import warnings
import numpy as np
@ -30,7 +30,9 @@ from jax import tree_util
from jax import vmap
from jax.config import config
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _broadcasting_vmap, _count_stored_elements, _safe_asarray, CuSparseEfficiencyWarning, SparseEfficiencyError, SparseEfficiencyWarning
from jax.experimental.sparse.util import (
_broadcasting_vmap, _count_stored_elements, _safe_asarray, CuSparseEfficiencyWarning,
SparseEfficiencyError, SparseEfficiencyWarning)
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
@ -41,12 +43,13 @@ from jax._src import api_util
from jax._src.lax.lax import (
_const, ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)
from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge
from jax._src.lib import gpu_sparse
from jax._src.lib.mlir.dialects import mhlo
from jax._src.numpy.setops import _unique
from jax._src.typing import Array
from jax._src.typing import Array, ArrayLike
from jax._src.util import canonicalize_axis
@ -1993,6 +1996,9 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq
Returns:
out: BCOO array containing the slice.
"""
# Use abstract eval to validate inputs.
jax.eval_shape(partial(lax.dynamic_slice, slice_sizes=slice_sizes),
jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices)
if not isinstance(mat, BCOO):
raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
start_indices = tuple(jnp.asarray(i) for i in start_indices)
@ -2012,16 +2018,17 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq
data_sizes = []
indices_start = []
indices_sizes = []
zero = _const(start_indices[0] if start_indices else np.int32, 0)
for i, (start, size) in enumerate(zip(start_batch, size_batch)):
data_is_broadcast = mat.data.shape[i] != mat.shape[i]
indices_is_broadcast = mat.indices.shape[i] != mat.shape[i]
data_start.append(0 if data_is_broadcast else start)
data_start.append(zero if data_is_broadcast else start)
data_sizes.append(1 if data_is_broadcast else size)
indices_start.append(0 if indices_is_broadcast else start)
indices_start.append(zero if indices_is_broadcast else start)
indices_sizes.append(1 if indices_is_broadcast else size)
data_start.append(0)
data_start.append(zero)
data_sizes.append(mat.nse)
indices_start.extend([0, 0])
indices_start.extend([zero, zero])
indices_sizes.extend([mat.nse, mat.n_sparse])
data_start.extend(start_dense)
data_sizes.extend(size_dense)
@ -2236,6 +2243,65 @@ def _bcoo_multiply_dense(data, indices, v, *, spinfo):
_mul = _broadcasting_vmap(_mul)
return _mul(data, indices, v)
def bcoo_gather(operand: BCOO, start_indices: Array,
dimension_numbers: GatherDimensionNumbers,
slice_sizes: Shape, *,
unique_indices: bool = False,
indices_are_sorted: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None,
fill_value = None) -> BCOO:
"""BCOO version of lax.gather."""
_validate_bcoo(operand.data, operand.indices, operand.shape)
# TODO(jakevdp) make use of unique_indices and indices_are_sorted?
if mode is None:
mode = GatherScatterMode.PROMISE_IN_BOUNDS
parsed_mode = GatherScatterMode.from_any(mode)
if parsed_mode != GatherScatterMode.PROMISE_IN_BOUNDS:
raise NotImplementedError(f"bcoo_gather: mode={mode} not yet supported.")
kwds = dict(dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode, fill_value=fill_value)
# Abstract eval lax.gather to validate arguments & determine output shape.
out_aval = jax.eval_shape(partial(lax.gather, **kwds),
jax.ShapeDtypeStruct(operand.shape, operand.dtype),
jax.ShapeDtypeStruct(start_indices.shape, start_indices.dtype))
offset_dims = dimension_numbers.offset_dims
collapsed_slice_dims = dimension_numbers.collapsed_slice_dims
start_index_map = dimension_numbers.start_index_map
# Expand start_indices & slice_sizes to full rank & use bcoo_dynamic_slice
full_start_indices: List[ArrayLike] = [_const(start_indices, 0)] * operand.ndim
in_axes: List[Optional[int]] = [None for i in range(operand.ndim)]
full_slice_sizes = list(operand.shape)
for i, j in enumerate(start_index_map):
full_start_indices[j] = start_indices[..., i].ravel()
full_slice_sizes[j] = slice_sizes[j]
in_axes[j] = 0
def slice_func(indices):
slc = bcoo_dynamic_slice(operand, indices, slice_sizes=full_slice_sizes)
return bcoo_squeeze(slc, dimensions=collapsed_slice_dims)
result = vmap(slice_func, in_axes=(in_axes,))(full_start_indices)
result = bcoo_reshape(result,
new_sizes=(*start_indices.shape[:-1], *result.shape[1:]),
dimensions=tuple(range(result.ndim)))
# Use offset_dims to permute result dimensions
if result.shape:
batch_dims = tuple(dim for dim in range(len(out_aval.shape))
if dim not in offset_dims)
permutation = np.zeros(result.ndim, dtype=int)
permutation[np.array(batch_dims + offset_dims)] = np.arange(result.ndim)
if set(permutation[:len(batch_dims)]) != set(range(len(batch_dims))):
# TODO: jakevdp more granular approach here. Can we do this in a
# way that preserves the original batch dimensions?
result = bcoo_update_layout(result, n_batch=0)
result = bcoo_transpose(result, permutation=tuple(permutation))
return result.reshape(out_aval.shape).astype(out_aval.dtype)
@tree_util.register_pytree_node_class
class BCOO(JAXSparse):
"""Experimental batched COO matrix implemented in JAX
@ -2331,6 +2397,16 @@ class BCOO(JAXSparse):
repr_ = f"{type(self.data).__name__}[{repr_}]"
return repr_
# Stub methods: these are defined in transform.py
def reshape(self, *args, **kwargs) -> BCOO:
raise NotImplementedError("BCOO.reshape")
def astype(self, *args, **kwargs) -> BCOO:
raise NotImplementedError("BCOO.astype")
def sum(self) -> BCOO:
raise NotImplementedError("BCOO.sum")
@classmethod
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0, n_batch=0):
"""Create a BCOO array from a (dense) :class:`DeviceArray`."""

View File

@ -634,6 +634,18 @@ def _reshape_sparse(spenv, *spvalues, new_sizes, dimensions):
sparse_rules[lax.reshape_p] = _reshape_sparse
def _gather_sparse_rule(spenv, *args, dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
operand, start_indices = spvalues_to_arrays(spenv, args)
result = sparse.bcoo_gather(operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted,
mode=mode, fill_value=fill_value)
return arrays_to_spvalues(spenv, (result,))
sparse_rules[lax.gather_p] = _gather_sparse_rule
def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
# TODO(jakevdp): currently this approach discards all information about
# shared data & indices when generating the sparsified jaxpr. The

View File

@ -1019,18 +1019,31 @@ class BCOOTest(jtu.JaxTestCase):
self.assertArraysEqual(dense_result, sparse_result_jit.todense())
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense, idx=idx)
for shape, idx in [
[(5,), np.index_exp[:]],
[(5,), np.index_exp[4]],
[(5,), np.index_exp[::2]],
[(5,), np.index_exp[1::2]],
[(5,), 1],
[(3, 4), np.index_exp[1]],
[(3, 4), np.index_exp[1, 2]],
[(3, 4), np.index_exp[np.array([1, 2])]],
[(3, 4), np.index_exp[np.array([[1], [2]]), 0]],
[(3, 4), np.index_exp[np.array([[1], [2]]), 1:]],
[(3, 4), np.index_exp[np.array([True, False, True])]],
[(3, 4), np.index_exp[:2, np.array([True, False, True, False])]],
[(3, 4), np.index_exp[None, 0, np.array([[2]])]],
[(3, 4, 5), np.index_exp[2]],
[(3, 4, 5), np.index_exp[:, 2]]
]
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)
for n_dense in [0] # TODO(jakevdp): add tests with n_dense
],
dtype=jtu.dtypes.floating,
idx=[1, slice(1, 3)],
dtype=jtu.dtypes.numeric,
)
def test_bcoo_getitem(self, shape, dtype, n_batch, n_dense, idx):
# Note: __getitem__ is currently only supported for simple slices and indexing
rng = self.rng()
sprng = rand_sparse(rng)
sprng = rand_sparse(self.rng())
M = sprng(shape, dtype)
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
self.assertArraysEqual(M[idx], Msp[idx].todense())