mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] add support for bcoo equivalent of lax.slice
This commit is contained in:
parent
5527966b27
commit
269e75273a
@ -203,6 +203,7 @@ from jax.experimental.sparse.bcoo import (
|
||||
bcoo_update_layout as bcoo_update_layout,
|
||||
bcoo_reduce_sum as bcoo_reduce_sum,
|
||||
bcoo_reshape as bcoo_reshape,
|
||||
bcoo_slice as bcoo_slice,
|
||||
bcoo_sort_indices as bcoo_sort_indices,
|
||||
bcoo_sort_indices_p as bcoo_sort_indices_p,
|
||||
bcoo_spdot_general_p as bcoo_spdot_general_p,
|
||||
|
@ -17,7 +17,7 @@
|
||||
import functools
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Any, NamedTuple, Sequence, Tuple
|
||||
from typing import Any, NamedTuple, Optional, Sequence, Tuple
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -1785,6 +1785,69 @@ def bcoo_reshape(mat, *, new_sizes, dimensions):
|
||||
return BCOO((data, new_indices), shape=new_sizes)
|
||||
|
||||
|
||||
def bcoo_slice(mat, *, start_indices: Sequence[int], limit_indices: Sequence[int],
|
||||
strides: Optional[Sequence[int]]=None):
|
||||
"""Sparse implementation of {func}`jax.lax.slice`.
|
||||
|
||||
Args:
|
||||
operand: BCOO array to be reshaped.
|
||||
start_indices: sequence of integers of length `mat.ndim` specifying the starting
|
||||
indices of each slice.
|
||||
limit_indices: sequence of integers of length `mat.ndim` specifying the ending
|
||||
indices of each slice
|
||||
strides: sequence of integers of length `mat.ndim` specifying the stride for
|
||||
each slice
|
||||
|
||||
Returns:
|
||||
out: BCOO array containing the slice.
|
||||
"""
|
||||
if not isinstance(mat, BCOO):
|
||||
raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
|
||||
start_indices = [operator.index(i) for i in start_indices]
|
||||
limit_indices = [operator.index(i) for i in limit_indices]
|
||||
if strides is not None:
|
||||
strides = [operator.index(i) for i in strides]
|
||||
else:
|
||||
strides = [1] * mat.ndim
|
||||
if len(start_indices) != len(limit_indices) != len(strides) != mat.ndim:
|
||||
raise ValueError(f"bcoo_slice: indices must have size mat.ndim={mat.ndim}")
|
||||
if strides != [1] * mat.ndim:
|
||||
raise NotImplementedError(f"non-unit strides; got {strides}")
|
||||
|
||||
if not all(0 <= start <= end <= size
|
||||
for start, end, size in safe_zip(start_indices, limit_indices, mat.shape)):
|
||||
raise ValueError(f"bcoo_slice: invalid indices. Got start_indices={start_indices}, "
|
||||
f"limit_indices={limit_indices} and shape={mat.shape}")
|
||||
|
||||
start_batch, start_sparse, start_dense = split_list(start_indices, [mat.n_batch, mat.n_sparse])
|
||||
end_batch, end_sparse, end_dense = split_list(limit_indices, [mat.n_batch, mat.n_sparse])
|
||||
|
||||
data_slices = []
|
||||
index_slices = []
|
||||
for i, (start, end) in enumerate(zip(start_batch, end_batch)):
|
||||
data_slices.append(slice(None) if mat.data.shape[i] != mat.shape[i] else slice(start, end))
|
||||
index_slices.append(slice(None) if mat.indices.shape[i] != mat.shape[i] else slice(start, end))
|
||||
data_slices.append(slice(None))
|
||||
index_slices.extend([slice(None), slice(None)])
|
||||
for i, (start, end) in enumerate(zip(start_dense, end_dense)):
|
||||
data_slices.append(slice(start, end))
|
||||
new_data = mat.data[tuple(data_slices)]
|
||||
new_indices = mat.indices[tuple(index_slices)]
|
||||
new_shape = [end - start for start, end in safe_zip(start_indices, limit_indices)]
|
||||
|
||||
if mat.n_sparse:
|
||||
starts = jnp.expand_dims(jnp.array(start_sparse, dtype=new_indices.dtype), range(mat.n_batch + 1))
|
||||
ends = jnp.expand_dims(jnp.array(end_sparse, dtype=new_indices.dtype), range(mat.n_batch + 1))
|
||||
sparse_shape = jnp.array(mat.shape[mat.n_batch: mat.n_batch + mat.n_sparse], dtype=new_indices.dtype)
|
||||
|
||||
keep = jnp.all((new_indices >= starts) & (new_indices < ends), -1, keepdims=True)
|
||||
new_indices = jnp.where(keep, new_indices - starts, sparse_shape)
|
||||
|
||||
keep_data = lax.expand_dims(keep[..., 0], range(mat.n_batch + 1, mat.n_batch + 1 + mat.n_dense))
|
||||
new_data = jnp.where(keep_data, new_data, 0)
|
||||
|
||||
return BCOO((new_data, new_indices), shape=new_shape)
|
||||
|
||||
def _tuple_replace(tup, ind, val):
|
||||
return tuple(val if i == ind else t for i, t in enumerate(tup))
|
||||
|
||||
|
@ -759,6 +759,13 @@ def _todense_sparse_rule(spenv, spvalue, *, tree):
|
||||
|
||||
sparse_rules[sparse.todense_p] = _todense_sparse_rule
|
||||
|
||||
def _slice_sparse_rule(spenv, *operands, **params):
|
||||
args = spvalues_to_arrays(spenv, operands)
|
||||
out = sparse.bcoo_slice(*args, **params)
|
||||
return arrays_to_spvalues(spenv, [out])
|
||||
|
||||
sparse_rules[lax.slice_p] = _slice_sparse_rule
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# BCOO methods derived from sparsify
|
||||
@ -775,6 +782,25 @@ def _reshape(self, *args, **kwargs):
|
||||
def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None, fill_value=None):
|
||||
# mirrors lax_numpy._rewriting_take.
|
||||
|
||||
# Handle some special cases, falling back if error messages might differ.
|
||||
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
|
||||
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
|
||||
if 0 <= idx < arr.shape[0]:
|
||||
return sparsify(lambda arr: lax.index_in_dim(arr, idx, keepdims=False))(arr)
|
||||
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
|
||||
isinstance(idx, slice) and
|
||||
(type(idx.start) is int or idx.start is None) and
|
||||
(type(idx.stop) is int or idx.stop is None) and
|
||||
(type(idx.step) is int or idx.step is None)):
|
||||
n = arr.shape[0]
|
||||
start = idx.start if idx.start is not None else 0
|
||||
stop = idx.stop if idx.stop is not None else n
|
||||
step = idx.step if idx.step is not None else 1
|
||||
if (0 <= start < n and 0 <= stop <= n and 0 < step and
|
||||
(start, stop, step) != (0, n, 1)):
|
||||
return sparsify(lambda arr: lax.slice_in_dim(arr, start, stop, step))(arr)
|
||||
|
||||
treedef, static_idx, dynamic_idx = lax_numpy._split_index_for_jit(idx, arr.shape)
|
||||
result = sparsify(
|
||||
lambda arr, idx: lax_numpy._gather(arr, treedef, static_idx, idx, indices_are_sorted,
|
||||
|
@ -36,6 +36,7 @@ from jax import lax
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.util import unzip2
|
||||
from jax import jit
|
||||
from jax import tree_util
|
||||
from jax import vmap
|
||||
@ -926,6 +927,50 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
||||
self.assertArraysEqual(trans(M), trans(Msp).todense())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
||||
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for dtype in jtu.dtypes.floating
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) + 1 - n_batch)))
|
||||
def test_bcoo_slice(self, shape, dtype, n_batch, n_dense):
|
||||
rng = self.rng()
|
||||
sprng = rand_sparse(rng)
|
||||
M = sprng(shape, dtype)
|
||||
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
||||
|
||||
rng = self.rng()
|
||||
slices = rng.randint(0, M.shape, (2, M.ndim)).T
|
||||
slices.sort(1)
|
||||
start_indices, limit_indices = unzip2(slices)
|
||||
strides = None # strides currently not implemented
|
||||
kwds = dict(start_indices=start_indices, limit_indices=limit_indices, strides=strides)
|
||||
|
||||
dense_result = lax.slice(M, **kwds)
|
||||
sparse_result = sparse.bcoo_slice(Msp, **kwds)
|
||||
|
||||
self.assertArraysEqual(dense_result, sparse_result.todense())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}_idx={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, idx),
|
||||
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense,
|
||||
"idx": idx}
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for dtype in jtu.dtypes.floating
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) + 1 - n_batch)
|
||||
for idx in [1, slice(1, 3)]))
|
||||
def test_bcoo_getitem(self, shape, dtype, n_batch, n_dense, idx):
|
||||
# Note: __getitem__ is currently only supported for simple slices and indexing
|
||||
rng = self.rng()
|
||||
sprng = rand_sparse(rng)
|
||||
M = sprng(shape, dtype)
|
||||
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
||||
self.assertArraysEqual(M[idx], Msp[idx].todense())
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
||||
|
Loading…
x
Reference in New Issue
Block a user