mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] add bcoo_gather & support for sparse indexing
This commit is contained in:
parent
a13541441b
commit
90dc008340
@ -199,6 +199,7 @@ from jax.experimental.sparse.bcoo import (
|
||||
bcoo_extract_p as bcoo_extract_p,
|
||||
bcoo_fromdense as bcoo_fromdense,
|
||||
bcoo_fromdense_p as bcoo_fromdense_p,
|
||||
bcoo_gather as bcoo_gather,
|
||||
bcoo_multiply_dense as bcoo_multiply_dense,
|
||||
bcoo_multiply_sparse as bcoo_multiply_sparse,
|
||||
bcoo_update_layout as bcoo_update_layout,
|
||||
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Any, NamedTuple, Optional, Sequence, Tuple
|
||||
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -30,7 +30,9 @@ from jax import tree_util
|
||||
from jax import vmap
|
||||
from jax.config import config
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.util import _broadcasting_vmap, _count_stored_elements, _safe_asarray, CuSparseEfficiencyWarning, SparseEfficiencyError, SparseEfficiencyWarning
|
||||
from jax.experimental.sparse.util import (
|
||||
_broadcasting_vmap, _count_stored_elements, _safe_asarray, CuSparseEfficiencyWarning,
|
||||
SparseEfficiencyError, SparseEfficiencyWarning)
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
@ -41,12 +43,13 @@ from jax._src import api_util
|
||||
from jax._src.lax.lax import (
|
||||
_const, ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
|
||||
DotDimensionNumbers)
|
||||
from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.numpy.setops import _unique
|
||||
from jax._src.typing import Array
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
|
||||
@ -1993,6 +1996,9 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq
|
||||
Returns:
|
||||
out: BCOO array containing the slice.
|
||||
"""
|
||||
# Use abstract eval to validate inputs.
|
||||
jax.eval_shape(partial(lax.dynamic_slice, slice_sizes=slice_sizes),
|
||||
jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices)
|
||||
if not isinstance(mat, BCOO):
|
||||
raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
|
||||
start_indices = tuple(jnp.asarray(i) for i in start_indices)
|
||||
@ -2012,16 +2018,17 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq
|
||||
data_sizes = []
|
||||
indices_start = []
|
||||
indices_sizes = []
|
||||
zero = _const(start_indices[0] if start_indices else np.int32, 0)
|
||||
for i, (start, size) in enumerate(zip(start_batch, size_batch)):
|
||||
data_is_broadcast = mat.data.shape[i] != mat.shape[i]
|
||||
indices_is_broadcast = mat.indices.shape[i] != mat.shape[i]
|
||||
data_start.append(0 if data_is_broadcast else start)
|
||||
data_start.append(zero if data_is_broadcast else start)
|
||||
data_sizes.append(1 if data_is_broadcast else size)
|
||||
indices_start.append(0 if indices_is_broadcast else start)
|
||||
indices_start.append(zero if indices_is_broadcast else start)
|
||||
indices_sizes.append(1 if indices_is_broadcast else size)
|
||||
data_start.append(0)
|
||||
data_start.append(zero)
|
||||
data_sizes.append(mat.nse)
|
||||
indices_start.extend([0, 0])
|
||||
indices_start.extend([zero, zero])
|
||||
indices_sizes.extend([mat.nse, mat.n_sparse])
|
||||
data_start.extend(start_dense)
|
||||
data_sizes.extend(size_dense)
|
||||
@ -2236,6 +2243,65 @@ def _bcoo_multiply_dense(data, indices, v, *, spinfo):
|
||||
_mul = _broadcasting_vmap(_mul)
|
||||
return _mul(data, indices, v)
|
||||
|
||||
def bcoo_gather(operand: BCOO, start_indices: Array,
|
||||
dimension_numbers: GatherDimensionNumbers,
|
||||
slice_sizes: Shape, *,
|
||||
unique_indices: bool = False,
|
||||
indices_are_sorted: bool = False,
|
||||
mode: Optional[Union[str, GatherScatterMode]] = None,
|
||||
fill_value = None) -> BCOO:
|
||||
"""BCOO version of lax.gather."""
|
||||
_validate_bcoo(operand.data, operand.indices, operand.shape)
|
||||
|
||||
# TODO(jakevdp) make use of unique_indices and indices_are_sorted?
|
||||
if mode is None:
|
||||
mode = GatherScatterMode.PROMISE_IN_BOUNDS
|
||||
parsed_mode = GatherScatterMode.from_any(mode)
|
||||
if parsed_mode != GatherScatterMode.PROMISE_IN_BOUNDS:
|
||||
raise NotImplementedError(f"bcoo_gather: mode={mode} not yet supported.")
|
||||
|
||||
kwds = dict(dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
|
||||
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
|
||||
mode=mode, fill_value=fill_value)
|
||||
|
||||
# Abstract eval lax.gather to validate arguments & determine output shape.
|
||||
out_aval = jax.eval_shape(partial(lax.gather, **kwds),
|
||||
jax.ShapeDtypeStruct(operand.shape, operand.dtype),
|
||||
jax.ShapeDtypeStruct(start_indices.shape, start_indices.dtype))
|
||||
offset_dims = dimension_numbers.offset_dims
|
||||
collapsed_slice_dims = dimension_numbers.collapsed_slice_dims
|
||||
start_index_map = dimension_numbers.start_index_map
|
||||
|
||||
# Expand start_indices & slice_sizes to full rank & use bcoo_dynamic_slice
|
||||
full_start_indices: List[ArrayLike] = [_const(start_indices, 0)] * operand.ndim
|
||||
in_axes: List[Optional[int]] = [None for i in range(operand.ndim)]
|
||||
full_slice_sizes = list(operand.shape)
|
||||
for i, j in enumerate(start_index_map):
|
||||
full_start_indices[j] = start_indices[..., i].ravel()
|
||||
full_slice_sizes[j] = slice_sizes[j]
|
||||
in_axes[j] = 0
|
||||
def slice_func(indices):
|
||||
slc = bcoo_dynamic_slice(operand, indices, slice_sizes=full_slice_sizes)
|
||||
return bcoo_squeeze(slc, dimensions=collapsed_slice_dims)
|
||||
result = vmap(slice_func, in_axes=(in_axes,))(full_start_indices)
|
||||
result = bcoo_reshape(result,
|
||||
new_sizes=(*start_indices.shape[:-1], *result.shape[1:]),
|
||||
dimensions=tuple(range(result.ndim)))
|
||||
|
||||
# Use offset_dims to permute result dimensions
|
||||
if result.shape:
|
||||
batch_dims = tuple(dim for dim in range(len(out_aval.shape))
|
||||
if dim not in offset_dims)
|
||||
permutation = np.zeros(result.ndim, dtype=int)
|
||||
permutation[np.array(batch_dims + offset_dims)] = np.arange(result.ndim)
|
||||
if set(permutation[:len(batch_dims)]) != set(range(len(batch_dims))):
|
||||
# TODO: jakevdp more granular approach here. Can we do this in a
|
||||
# way that preserves the original batch dimensions?
|
||||
result = bcoo_update_layout(result, n_batch=0)
|
||||
result = bcoo_transpose(result, permutation=tuple(permutation))
|
||||
|
||||
return result.reshape(out_aval.shape).astype(out_aval.dtype)
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class BCOO(JAXSparse):
|
||||
"""Experimental batched COO matrix implemented in JAX
|
||||
@ -2331,6 +2397,16 @@ class BCOO(JAXSparse):
|
||||
repr_ = f"{type(self.data).__name__}[{repr_}]"
|
||||
return repr_
|
||||
|
||||
# Stub methods: these are defined in transform.py
|
||||
def reshape(self, *args, **kwargs) -> BCOO:
|
||||
raise NotImplementedError("BCOO.reshape")
|
||||
|
||||
def astype(self, *args, **kwargs) -> BCOO:
|
||||
raise NotImplementedError("BCOO.astype")
|
||||
|
||||
def sum(self) -> BCOO:
|
||||
raise NotImplementedError("BCOO.sum")
|
||||
|
||||
@classmethod
|
||||
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0, n_batch=0):
|
||||
"""Create a BCOO array from a (dense) :class:`DeviceArray`."""
|
||||
|
@ -634,6 +634,18 @@ def _reshape_sparse(spenv, *spvalues, new_sizes, dimensions):
|
||||
|
||||
sparse_rules[lax.reshape_p] = _reshape_sparse
|
||||
|
||||
|
||||
def _gather_sparse_rule(spenv, *args, dimension_numbers, slice_sizes, unique_indices,
|
||||
indices_are_sorted, mode, fill_value):
|
||||
operand, start_indices = spvalues_to_arrays(spenv, args)
|
||||
result = sparse.bcoo_gather(operand, start_indices, dimension_numbers=dimension_numbers,
|
||||
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
mode=mode, fill_value=fill_value)
|
||||
return arrays_to_spvalues(spenv, (result,))
|
||||
|
||||
sparse_rules[lax.gather_p] = _gather_sparse_rule
|
||||
|
||||
def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
|
||||
# TODO(jakevdp): currently this approach discards all information about
|
||||
# shared data & indices when generating the sparsified jaxpr. The
|
||||
|
@ -1019,18 +1019,31 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(dense_result, sparse_result_jit.todense())
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense, idx=idx)
|
||||
for shape, idx in [
|
||||
[(5,), np.index_exp[:]],
|
||||
[(5,), np.index_exp[4]],
|
||||
[(5,), np.index_exp[::2]],
|
||||
[(5,), np.index_exp[1::2]],
|
||||
[(5,), 1],
|
||||
[(3, 4), np.index_exp[1]],
|
||||
[(3, 4), np.index_exp[1, 2]],
|
||||
[(3, 4), np.index_exp[np.array([1, 2])]],
|
||||
[(3, 4), np.index_exp[np.array([[1], [2]]), 0]],
|
||||
[(3, 4), np.index_exp[np.array([[1], [2]]), 1:]],
|
||||
[(3, 4), np.index_exp[np.array([True, False, True])]],
|
||||
[(3, 4), np.index_exp[:2, np.array([True, False, True, False])]],
|
||||
[(3, 4), np.index_exp[None, 0, np.array([[2]])]],
|
||||
[(3, 4, 5), np.index_exp[2]],
|
||||
[(3, 4, 5), np.index_exp[:, 2]]
|
||||
]
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) + 1 - n_batch)
|
||||
for n_dense in [0] # TODO(jakevdp): add tests with n_dense
|
||||
],
|
||||
dtype=jtu.dtypes.floating,
|
||||
idx=[1, slice(1, 3)],
|
||||
dtype=jtu.dtypes.numeric,
|
||||
)
|
||||
def test_bcoo_getitem(self, shape, dtype, n_batch, n_dense, idx):
|
||||
# Note: __getitem__ is currently only supported for simple slices and indexing
|
||||
rng = self.rng()
|
||||
sprng = rand_sparse(rng)
|
||||
sprng = rand_sparse(self.rng())
|
||||
M = sprng(shape, dtype)
|
||||
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
||||
self.assertArraysEqual(M[idx], Msp[idx].todense())
|
||||
|
Loading…
x
Reference in New Issue
Block a user