Merge pull request #13160 from jakevdp:bcoo-squeeze

PiperOrigin-RevId: 487280563
This commit is contained in:
jax authors 2022-11-09 10:18:22 -08:00
commit 63e3152764
4 changed files with 50 additions and 18 deletions

View File

@ -208,6 +208,7 @@ from jax.experimental.sparse.bcoo import (
bcoo_sort_indices as bcoo_sort_indices,
bcoo_sort_indices_p as bcoo_sort_indices_p,
bcoo_spdot_general_p as bcoo_spdot_general_p,
bcoo_squeeze as bcoo_squeeze,
bcoo_sum_duplicates as bcoo_sum_duplicates,
bcoo_sum_duplicates_p as bcoo_sum_duplicates_p,
bcoo_todense as bcoo_todense,

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""BCOO (Bached coordinate format) matrix object and associated primitives."""
from __future__ import annotations
import functools
from functools import partial
@ -42,10 +43,11 @@ from jax._src.lax.lax import (
DotDimensionNumbers)
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge
from jax._src.lib import gpu_sparse
from jax._src.lib.mlir.dialects import mhlo
from jax._src.numpy.setops import _unique
from jax._src.util import canonicalize_axis
from jax._src.lib import gpu_sparse
Dtype = Any
Shape = Tuple[int, ...]
@ -1868,6 +1870,34 @@ def bcoo_reshape(mat, *, new_sizes, dimensions):
return BCOO((data, indices), shape=new_sizes)
def bcoo_squeeze(arr: BCOO, *, dimensions: Sequence[int]) -> BCOO:
"""Sparse implementation of {func}`jax.lax.squeeze`.
Squeeze any number of size 1 dimensions from an array.
Args:
arr: BCOO array to be reshaped.
dimensions: sequence of integers specifying dimensions to squeeze.
Returns:
out: reshaped array.
"""
dimensions = tuple(canonicalize_axis(dim, arr.ndim) for dim in dimensions)
if any(arr.shape[dim] != 1 for dim in dimensions):
raise ValueError("cannot select an axis to squeeze out which has size not equal to one, "
f"got shape={arr.shape} and dimensions={dimensions}")
batch_dims = tuple(d for d in dimensions if d < arr.n_batch)
sparse_dims = np.array([i for i in range(arr.n_sparse)
if i + arr.n_batch not in dimensions], dtype=int)
dense_dims = tuple(d - arr.n_sparse + 1 for d in dimensions
if d >= arr.n_batch + arr.n_sparse)
data_out = lax.squeeze(arr.data, batch_dims + dense_dims)
indices_out = lax.squeeze(arr.indices[..., sparse_dims], batch_dims)
out_shape = tuple(s for i, s in enumerate(arr.shape) if i not in dimensions)
return BCOO((data_out, indices_out), shape=out_shape,
indices_sorted=arr.indices_sorted, unique_indices=arr.unique_indices)
def bcoo_slice(mat, *, start_indices: Sequence[int], limit_indices: Sequence[int],
strides: Optional[Sequence[int]]=None):
"""Sparse implementation of {func}`jax.lax.slice`.

View File

@ -65,7 +65,6 @@ from jax.util import safe_map, safe_zip, split_list
from jax._src.config import config
from jax._src.lax.control_flow import _check_tree_and_avals
from jax._src.numpy import lax_numpy
from jax._src.util import canonicalize_axis
from jax.experimental import sparse
from jax.experimental.sparse import BCOO
@ -622,22 +621,9 @@ def _concatenate_sparse(spenv, *spvalues, dimension):
sparse_rules[lax.concatenate_p] = _concatenate_sparse
def _squeeze_sparse(spenv, *spvalues, dimensions):
arr, = spvalues
dimensions = tuple(canonicalize_axis(dim, arr.ndim) for dim in dimensions)
if any(arr.shape[dim] != 1 for dim in dimensions):
raise ValueError("cannot select an axis to squeeze out which has size not equal to one, "
f"got shape={arr.shape} and dimensions={dimensions}")
data = spenv.data(arr)
indices = spenv.indices(arr)
n_sparse = indices.shape[-1]
n_batch = indices.ndim - 2
batch_dims = tuple(d for d in dimensions if d < n_batch)
sparse_dims = np.array([i for i in range(n_sparse) if i + n_batch not in dimensions], dtype=int)
dense_dims = tuple(d - n_sparse + 1 for d in dimensions if d >= n_batch + n_sparse)
data_out = lax.squeeze(data, batch_dims + dense_dims)
indices_out = lax.squeeze(indices[..., sparse_dims], batch_dims)
out_shape = tuple(s for i, s in enumerate(arr.shape) if i not in dimensions)
return (spenv.sparse(out_shape, data_out, indices_out),)
arr, = spvalues_to_arrays(spenv, spvalues)
result = sparse.bcoo_squeeze(arr, dimensions=dimensions)
return arrays_to_spvalues(spenv, (result,))
sparse_rules[lax.squeeze_p] = _squeeze_sparse

View File

@ -1959,6 +1959,21 @@ class BCOOTest(jtu.JaxTestCase):
tol = {np.float32: 1E-6, np.float64: 1E-14}
self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
def test_bcoo_squeeze(self):
# more comprehensive tests in sparsify_test:testSparseSqueeze
rng = rand_sparse(self.rng())
shape = (1, 2, 1, 3, 4)
dimensions = (0, 2)
M = rng(shape, 'float32')
M_bcoo = sparse.BCOO.fromdense(M)
M2 = lax.squeeze(M, dimensions=dimensions)
M2_bcoo = sparse.bcoo_squeeze(M_bcoo, dimensions=dimensions)
M2_bcoo_jit = jax.jit(partial(sparse.bcoo_squeeze, dimensions=dimensions))(M_bcoo)
self.assertArraysEqual(M2, M2_bcoo.todense())
self.assertArraysEqual(M2, M2_bcoo_jit.todense())
def test_bcoo_reshape_error(self):
x = sparse.BCOO.fromdense(jnp.ones((2, 2, 3)), n_batch=1)
with self.assertRaisesRegex(ValueError, ".*cannot mix batch and sparse dimensions.*"):