Merge pull request #13869 from jakevdp:bcoo-extract-api

PiperOrigin-RevId: 501842123
This commit is contained in:
jax authors 2023-01-13 07:27:19 -08:00
commit d1593289a0
4 changed files with 97 additions and 67 deletions

View File

@ -85,12 +85,13 @@ def _todense_transpose(ct, *bufs, tree):
standin = object()
obj = tree_util.tree_unflatten(tree, [standin] * len(bufs))
from jax.experimental.sparse import BCOO, bcoo_extract
from jax.experimental.sparse import BCOO
from jax.experimental.sparse.bcoo import _bcoo_extract
if obj is standin:
return (ct,)
elif isinstance(obj, BCOO):
_, indices = bufs
return bcoo_extract(indices, ct), indices
return _bcoo_extract(indices, ct), indices
elif isinstance(obj, COO):
_, row, col = bufs
return _coo_extract(row, col, ct), row, col

View File

@ -236,7 +236,7 @@ def _bcoo_todense_transpose(ct, data, indices, *, spinfo):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.shape == shape
assert ct.dtype == data.aval.dtype
return bcoo_extract(indices, ct), indices
return _bcoo_extract(indices, ct), indices
def _bcoo_todense_batching_rule(batched_args, batch_dims, *, spinfo):
data, indices, spinfo = _bcoo_batch_dims_to_front(batched_args, batch_dims, spinfo)
@ -320,7 +320,7 @@ def _bcoo_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
indices = jnp.zeros(mask.shape[:n_batch] + (nse, 0), index_dtype)
else:
indices = jnp.moveaxis(jnp.array(indices, index_dtype), 0, n_batch + 1)
data = bcoo_extract(indices, mat)
data = _bcoo_extract(indices, mat)
true_nse = mask.sum(list(range(n_batch, mask.ndim)))[..., None]
true_nonzeros = lax.broadcasted_iota(true_nse.dtype, (1,) * n_batch + (nse,), n_batch) < true_nse
@ -346,7 +346,7 @@ def _bcoo_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype
if type(Mdot) is ad.Zero:
data_dot = ad.Zero.from_value(data)
else:
data_dot = bcoo_extract(indices, Mdot)
data_dot = _bcoo_extract(indices, Mdot)
tangents_out = (data_dot, ad.Zero.from_value(indices))
@ -381,44 +381,70 @@ mlir.register_lowering(bcoo_fromdense_p, mlir.lower_fun(
bcoo_extract_p = core.Primitive('bcoo_extract')
def bcoo_extract(indices: Array, mat: Array, *, assume_unique=True) -> Array:
"""Extract BCOO data values from a dense matrix at given BCOO indices.
def bcoo_extract(sparr: BCOO, arr: ArrayLike, *, assume_unique: Optional[bool] = None) -> BCOO:
"""Extract values from a dense array according to the sparse array's indices.
Args:
sparr : BCOO array whose indices will be used for the output.
arr : ArrayLike with shape equal to self.shape
assume_unique : bool, defaults to sparr.unique_indices
If True, extract values for every index, even if index contains duplicates.
If False, duplicate indices will have their values summed and returned in
the position of the first index.
Returns:
extracted : a BCOO array with the same sparsity pattern as self.
"""
if not isinstance(sparr, BCOO):
raise ValueError(f"First argument to bcoo_extract should be a BCOO array. Got {type(sparr)=}")
arr = jnp.asarray(arr)
if arr.shape != sparr.shape:
raise ValueError(f"shape mismatch: {sparr.shape=} {arr.shape=}")
if assume_unique is None:
assume_unique = sparr.unique_indices
data = _bcoo_extract(sparr.indices, arr, assume_unique=assume_unique)
return BCOO((data, sparr.indices), **sparr._info._asdict())
def _bcoo_extract(indices: Array, arr: Array, *, assume_unique=True) -> Array:
"""Extract BCOO data values from a dense array at given BCOO indices.
Args:
indices: An ndarray; see BCOO indices.
mat: A dense matrix.
arr: A dense array.
assume_unique: bool, default=True
If True, then indices will be assumed unique and a value will be extracted
from mat for each index. Otherwise, extra work will be done to de-duplicate
from arr for each index. Otherwise, extra work will be done to de-duplicate
indices to zero-out duplicate extracted values.
Returns:
An ndarray; see BCOO data.
"""
return bcoo_extract_p.bind(indices, mat, assume_unique=assume_unique)
return bcoo_extract_p.bind(indices, arr, assume_unique=assume_unique)
@bcoo_extract_p.def_impl
def _bcoo_extract_impl(indices, mat, *, assume_unique):
mat = jnp.asarray(mat)
props = _validate_bcoo_indices(indices, mat.shape)
def _bcoo_extract_impl(indices, arr, *, assume_unique):
arr = jnp.asarray(arr)
props = _validate_bcoo_indices(indices, arr.shape)
if not assume_unique:
indices, sort_ind = _unique_indices(indices, shape=mat.shape, return_index=True)
indices, sort_ind = _unique_indices(indices, shape=arr.shape, return_index=True)
original_props = props
props = _validate_bcoo_indices(indices, mat.shape)
props = _validate_bcoo_indices(indices, arr.shape)
ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
for s, i_s in zip(mat.shape[:props.n_batch], indices.shape[:props.n_batch]))
for s, i_s in zip(arr.shape[:props.n_batch], indices.shape[:props.n_batch]))
grid = tuple(np.meshgrid(*ind_slices, indexing='ij', sparse=True))
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(props.n_sparse))
batch_slices = tuple(np.arange(s) for s in mat.shape[:props.n_batch])
batch_slices = tuple(np.arange(s) for s in arr.shape[:props.n_batch])
grid = np.meshgrid(*batch_slices, np.arange(1), indexing='ij', sparse=True)
batch_ind = tuple(grid)[:-1]
if not sparse_ind + batch_ind:
result = mat[None]
result = arr[None]
else:
result = mat.at[batch_ind + sparse_ind].get(mode='fill', fill_value=0)
result = arr.at[batch_ind + sparse_ind].get(mode='fill', fill_value=0)
if props.n_sparse == 0 and props.nse != 1:
if assume_unique:
result = lax.broadcast_in_dim(
@ -437,27 +463,27 @@ def _bcoo_extract_impl(indices, mat, *, assume_unique):
return result
@bcoo_extract_p.def_abstract_eval
def _bcoo_extract_abstract_eval(indices, mat, *, assume_unique):
def _bcoo_extract_abstract_eval(indices, arr, *, assume_unique):
_ = bool(assume_unique)
n_batch, _, n_dense, nse = _validate_bcoo_indices(indices, mat.shape)
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
return core.ShapedArray(out_shape, mat.dtype)
n_batch, _, n_dense, nse = _validate_bcoo_indices(indices, arr.shape)
out_shape = arr.shape[:n_batch] + (nse,) + arr.shape[arr.ndim - n_dense:]
return core.ShapedArray(out_shape, arr.dtype)
def _bcoo_extract_jvp(mat_dot, indices, mat, *, assume_unique):
assert mat_dot.shape == mat.shape
return bcoo_extract(indices, mat_dot, assume_unique=assume_unique)
def _bcoo_extract_jvp(arr_dot, indices, arr, *, assume_unique):
assert arr_dot.shape == arr.shape
return _bcoo_extract(indices, arr_dot, assume_unique=assume_unique)
def _bcoo_extract_transpose(ct, indices, mat, *, assume_unique):
def _bcoo_extract_transpose(ct, indices, arr, *, assume_unique):
if not assume_unique:
raise NotImplementedError("transpose of bcoo_extract with assume_unique=False")
assert ad.is_undefined_primal(mat)
assert ad.is_undefined_primal(arr)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.dtype == mat.aval.dtype
return indices, _bcoo_todense(ct, indices, spinfo=SparseInfo(mat.aval.shape))
assert ct.dtype == arr.aval.dtype
return indices, _bcoo_todense(ct, indices, spinfo=SparseInfo(arr.aval.shape))
def _bcoo_extract_batching_rule(batched_args, batch_dims, *, assume_unique):
indices, mat = batched_args
indices, arr = batched_args
assert any(b is not None for b in batch_dims)
if batch_dims[0] is None:
bdim = batch_dims[1]
@ -465,9 +491,9 @@ def _bcoo_extract_batching_rule(batched_args, batch_dims, *, assume_unique):
elif batch_dims[1] is None:
# TODO(jakevdp) can we handle this case without explicit broadcasting?
bdim = batch_dims[0]
result_shape = list(mat.shape)
result_shape = list(arr.shape)
result_shape.insert(bdim, indices.shape[bdim])
mat = lax.broadcast_in_dim(mat, result_shape, (bdim,))
arr = lax.broadcast_in_dim(arr, result_shape, (bdim,))
else:
if batch_dims[0] != batch_dims[1]:
raise NotImplementedError("bcoo_extract with unequal batch dimensions.")
@ -475,7 +501,7 @@ def _bcoo_extract_batching_rule(batched_args, batch_dims, *, assume_unique):
n_batch = indices.ndim - 2
if bdim >= n_batch:
raise ValueError(f"{batch_dims=} out of range for indices with {n_batch=}")
return bcoo_extract(indices, mat, assume_unique=assume_unique), bdim
return _bcoo_extract(indices, arr, assume_unique=assume_unique), bdim
ad.defjvp(bcoo_extract_p, None, _bcoo_extract_jvp)
ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
@ -1004,7 +1030,7 @@ def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_num
# Fallback to direct approach when above is not possible.
out_dense_T = lax.dot_general(ct, rhs, dimension_numbers=dims)
out_dense = lax.transpose(out_dense_T, out_axes)
result = bcoo_extract(lhs_indices, out_dense)
result = _bcoo_extract(lhs_indices, out_dense)
return result, lhs_indices, rhs
else:
dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch)) # type: ignore[assignment]
@ -1078,12 +1104,12 @@ def bcoo_dot_general_sampled(A: Array, B: Array, indices: Array, *, dimension_nu
def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers):
# TODO(jakevdp): use a more efficient implementation that avoids the full dot product.
dense_result = lax.dot_general(A, B, dimension_numbers=dimension_numbers)
return bcoo_extract(indices, dense_result)
return _bcoo_extract(indices, dense_result)
@bcoo_dot_general_sampled_p.def_abstract_eval
def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers):
dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B)
sparse_result, = pe.abstract_eval_fun(lambda *args: [bcoo_extract(*args)], indices, dense_result)
sparse_result, = pe.abstract_eval_fun(lambda *args: [_bcoo_extract(*args)], indices, dense_result)
return sparse_result
def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers):
@ -2221,7 +2247,7 @@ def _bcoo_multiply_dense(data: Array, indices: Array, v: Array, *, spinfo: Spars
return lax.mul(data, v)
if shape == v.shape:
# Note: due to distributive property, no deduplication necessary!
return lax.mul(data, bcoo_extract(indices, v))
return lax.mul(data, _bcoo_extract(indices, v))
if lax.broadcast_shapes(v.shape, shape) != shape:
raise NotImplementedError(

View File

@ -291,7 +291,7 @@ def bcsr_extract(indices: ArrayLike, indptr: ArrayLike, mat: ArrayLike) -> Array
def _bcsr_extract_impl(indices, indptr, mat):
mat = jnp.asarray(mat)
bcoo_indices = _bcsr_to_bcoo(indices, indptr, shape=mat.shape)
return bcoo.bcoo_extract(bcoo_indices, mat)
return bcoo._bcoo_extract(bcoo_indices, mat)
@bcsr_extract_p.def_abstract_eval

View File

@ -804,21 +804,24 @@ class BCOOTest(sptu.SparseTestCase):
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for layout in iter_sparse_layouts(shape)],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
assume_unique=[True, False]
assume_unique=[True, False, None]
)
def test_bcoo_extract(self, shape, dtype, n_batch, n_dense, assume_unique):
rng = rand_sparse(self.rng())
M = rng(shape, dtype)
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
n_dense=n_dense)
data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse)
bcoo_extract = partial(sparse.bcoo_extract, assume_unique=assume_unique)
data2 = bcoo_extract(indices, M)
self.assertArraysEqual(data, data2)
def args_maker():
x = rng(shape, dtype)
x_bcoo = sparse.bcoo_fromdense(x, n_batch=n_batch, n_dense=n_dense)
# Unique indices are required for this test when assume_unique == True.
self.assertTrue(x_bcoo.unique_indices)
return x_bcoo, x
data3 = jit(bcoo_extract)(indices, M)
self.assertArraysEqual(data, data3)
dense_op = lambda _, x: x
sparse_op = partial(sparse.bcoo_extract, assume_unique=assume_unique)
self._CheckAgainstDense(dense_op, sparse_op, args_maker)
self._CompileAndCheckSparse(sparse_op, args_maker)
self._CheckBatchingSparse(dense_op, sparse_op, args_maker, bdims=2 * self._random_bdims(n_batch))
def test_bcoo_extract_duplicate_indices(self):
data = jnp.array([1, 3, 9, 27, 81, 243])
@ -826,10 +829,10 @@ class BCOOTest(sptu.SparseTestCase):
shape = (6,)
mat = sparse.BCOO((data, indices), shape=shape).todense()
data1 = sparse.bcoo_extract(indices, mat, assume_unique=True)
data1 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=True)
self.assertArraysEqual(data1, jnp.array([10, 3, 10, 270, 81, 270]))
data2 = sparse.bcoo_extract(indices, mat, assume_unique=False)
data2 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=False)
self.assertArraysEqual(data2, jnp.array([10, 3, 0, 270, 81, 0]))
def test_bcoo_extract_duplicate_indices_n_sparse_0(self):
@ -838,10 +841,10 @@ class BCOOTest(sptu.SparseTestCase):
shape = (3,)
mat = sparse.BCOO((data, indices), shape=shape).todense()
data1 = sparse.bcoo_extract(indices, mat, assume_unique=True)
data1 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=True)
self.assertArraysEqual(data1, jnp.array([[1, 1], [5, 5], [9, 9]]))
data2 = sparse.bcoo_extract(indices, mat, assume_unique=False)
data2 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=False)
self.assertArraysEqual(data2, jnp.array([[1, 0], [5, 0], [9, 0]]))
def test_bcoo_extract_batching(self):
@ -850,18 +853,18 @@ class BCOOTest(sptu.SparseTestCase):
mat = jnp.arange(4.).reshape((4, 1))
# in_axes = (0, None)
expected = jnp.vstack([sparse.bcoo_extract(i, mat[0]) for i in indices])
actual = vmap(sparse.bcoo_extract, in_axes=(0, None))(indices, mat[0])
expected = jnp.vstack([sparse_bcoo._bcoo_extract(i, mat[0]) for i in indices])
actual = vmap(sparse_bcoo._bcoo_extract, in_axes=(0, None))(indices, mat[0])
self.assertArraysEqual(expected, actual)
# in_axes = (None, 0)
expected = jnp.vstack([sparse.bcoo_extract(indices[0], m) for m in mat])
actual = vmap(sparse.bcoo_extract, in_axes=(None, 0))(indices[0], mat)
expected = jnp.vstack([sparse_bcoo._bcoo_extract(indices[0], m) for m in mat])
actual = vmap(sparse_bcoo._bcoo_extract, in_axes=(None, 0))(indices[0], mat)
self.assertArraysEqual(expected, actual)
# in_axes = (0, 0)
expected = jnp.vstack([sparse.bcoo_extract(i, m) for i, m in zip(indices, mat)])
actual = vmap(sparse.bcoo_extract, in_axes=0)(indices, mat)
expected = jnp.vstack([sparse_bcoo._bcoo_extract(i, m) for i, m in zip(indices, mat)])
actual = vmap(sparse_bcoo._bcoo_extract, in_axes=0)(indices, mat)
self.assertArraysEqual(expected, actual)
@jtu.sample_product(
@ -877,7 +880,7 @@ class BCOOTest(sptu.SparseTestCase):
n_dense=n_dense)
data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense)
extract = partial(sparse.bcoo_extract, indices)
extract = partial(sparse_bcoo._bcoo_extract, indices)
j1 = jax.jacfwd(extract)(M)
j2 = jax.jacrev(extract)(M)
hess = jax.hessian(extract)(M)
@ -890,11 +893,11 @@ class BCOOTest(sptu.SparseTestCase):
# (n_batch, n_sparse, n_dense) = (1, 0, 0), nse = 2
args_maker = lambda: (jnp.zeros((3, 2, 0), dtype='int32'), jnp.arange(3))
self._CompileAndCheck(sparse.bcoo_extract, args_maker)
self._CompileAndCheck(sparse_bcoo._bcoo_extract, args_maker)
# (n_batch, n_sparse, n_dense) = (0, 0, 1), nse = 2
args_maker = lambda: (jnp.zeros((2, 0), dtype='int32'), jnp.arange(3))
self._CompileAndCheck(sparse.bcoo_extract, args_maker)
self._CompileAndCheck(sparse_bcoo._bcoo_extract, args_maker)
@jtu.sample_product(
[dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense)
@ -1252,7 +1255,7 @@ class BCOOTest(sptu.SparseTestCase):
def dense_fun(lhs, rhs, indices):
AB = lax.dot_general(lhs, rhs, dimension_numbers=props.dimension_numbers)
return sparse.bcoo_extract(indices, AB)
return sparse_bcoo._bcoo_extract(indices, AB)
def sparse_fun(lhs, rhs, indices):
return sparse.bcoo_dot_general_sampled(
lhs, rhs, indices, dimension_numbers=props.dimension_numbers)
@ -1295,7 +1298,7 @@ class BCOOTest(sptu.SparseTestCase):
def dense_fun(lhs, rhs, indices):
AB = lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
return sparse.bcoo_extract(indices, AB)
return sparse_bcoo._bcoo_extract(indices, AB)
def sparse_fun(lhs, rhs, indices):
return sparse.bcoo_dot_general_sampled(
lhs, rhs, indices, dimension_numbers=dimension_numbers)
@ -2171,7 +2174,7 @@ class SparseGradTest(sptu.SparseTestCase):
val_sp, grad_sp = sparse.value_and_grad(f, argnums=0, has_aux=has_aux)(Xsp, y)
self.assertIsInstance(grad_sp, sparse.BCOO)
self.assertAllClose(val_de, val_sp)
self.assertAllClose(grad_sp.data, sparse.bcoo_extract(grad_sp.indices, grad_de))
self.assertAllClose(grad_sp.data, sparse_bcoo._bcoo_extract(grad_sp.indices, grad_de))
with self.subTest("wrt dense"):
self.assertAllClose(jax.value_and_grad(f, argnums=1, has_aux=has_aux)(X, y),
@ -2199,7 +2202,7 @@ class SparseGradTest(sptu.SparseTestCase):
grad_sp, aux_sp = grad_sp
self.assertAllClose(aux_de, aux_sp)
self.assertIsInstance(grad_sp, sparse.BCOO)
self.assertAllClose(grad_sp.data, sparse.bcoo_extract(grad_sp.indices, grad_de))
self.assertAllClose(grad_sp.data, sparse_bcoo._bcoo_extract(grad_sp.indices, grad_de))
with self.subTest("wrt dense"):
self.assertAllClose(jax.grad(f, argnums=1, has_aux=has_aux)(X, y),
@ -2233,7 +2236,7 @@ class SparseGradTest(sptu.SparseTestCase):
grad_sp, aux_sp = grad_sp
self.assertAllClose(aux_de, aux_sp)
self.assertIsInstance(grad_sp, sparse.BCOO)
self.assertAllClose(grad_sp.data, sparse.bcoo_extract(grad_sp.indices, grad_de))
self.assertAllClose(grad_sp.data, sparse_bcoo._bcoo_extract(grad_sp.indices, grad_de))
with self.subTest("wrt dense"):
rtol = 0.01 if jtu.device_under_test() == 'tpu' else None