mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13160 from jakevdp:bcoo-squeeze
PiperOrigin-RevId: 487280563
This commit is contained in:
commit
63e3152764
@ -208,6 +208,7 @@ from jax.experimental.sparse.bcoo import (
|
||||
bcoo_sort_indices as bcoo_sort_indices,
|
||||
bcoo_sort_indices_p as bcoo_sort_indices_p,
|
||||
bcoo_spdot_general_p as bcoo_spdot_general_p,
|
||||
bcoo_squeeze as bcoo_squeeze,
|
||||
bcoo_sum_duplicates as bcoo_sum_duplicates,
|
||||
bcoo_sum_duplicates_p as bcoo_sum_duplicates_p,
|
||||
bcoo_todense as bcoo_todense,
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""BCOO (Bached coordinate format) matrix object and associated primitives."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
@ -42,10 +43,11 @@ from jax._src.lax.lax import (
|
||||
DotDimensionNumbers)
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.numpy.setops import _unique
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
from jax._src.lib import gpu_sparse
|
||||
|
||||
Dtype = Any
|
||||
Shape = Tuple[int, ...]
|
||||
@ -1868,6 +1870,34 @@ def bcoo_reshape(mat, *, new_sizes, dimensions):
|
||||
return BCOO((data, indices), shape=new_sizes)
|
||||
|
||||
|
||||
def bcoo_squeeze(arr: BCOO, *, dimensions: Sequence[int]) -> BCOO:
|
||||
"""Sparse implementation of {func}`jax.lax.squeeze`.
|
||||
|
||||
Squeeze any number of size 1 dimensions from an array.
|
||||
|
||||
Args:
|
||||
arr: BCOO array to be reshaped.
|
||||
dimensions: sequence of integers specifying dimensions to squeeze.
|
||||
|
||||
Returns:
|
||||
out: reshaped array.
|
||||
"""
|
||||
dimensions = tuple(canonicalize_axis(dim, arr.ndim) for dim in dimensions)
|
||||
if any(arr.shape[dim] != 1 for dim in dimensions):
|
||||
raise ValueError("cannot select an axis to squeeze out which has size not equal to one, "
|
||||
f"got shape={arr.shape} and dimensions={dimensions}")
|
||||
batch_dims = tuple(d for d in dimensions if d < arr.n_batch)
|
||||
sparse_dims = np.array([i for i in range(arr.n_sparse)
|
||||
if i + arr.n_batch not in dimensions], dtype=int)
|
||||
dense_dims = tuple(d - arr.n_sparse + 1 for d in dimensions
|
||||
if d >= arr.n_batch + arr.n_sparse)
|
||||
data_out = lax.squeeze(arr.data, batch_dims + dense_dims)
|
||||
indices_out = lax.squeeze(arr.indices[..., sparse_dims], batch_dims)
|
||||
out_shape = tuple(s for i, s in enumerate(arr.shape) if i not in dimensions)
|
||||
return BCOO((data_out, indices_out), shape=out_shape,
|
||||
indices_sorted=arr.indices_sorted, unique_indices=arr.unique_indices)
|
||||
|
||||
|
||||
def bcoo_slice(mat, *, start_indices: Sequence[int], limit_indices: Sequence[int],
|
||||
strides: Optional[Sequence[int]]=None):
|
||||
"""Sparse implementation of {func}`jax.lax.slice`.
|
||||
|
@ -65,7 +65,6 @@ from jax.util import safe_map, safe_zip, split_list
|
||||
from jax._src.config import config
|
||||
from jax._src.lax.control_flow import _check_tree_and_avals
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax.experimental import sparse
|
||||
from jax.experimental.sparse import BCOO
|
||||
|
||||
@ -622,22 +621,9 @@ def _concatenate_sparse(spenv, *spvalues, dimension):
|
||||
sparse_rules[lax.concatenate_p] = _concatenate_sparse
|
||||
|
||||
def _squeeze_sparse(spenv, *spvalues, dimensions):
|
||||
arr, = spvalues
|
||||
dimensions = tuple(canonicalize_axis(dim, arr.ndim) for dim in dimensions)
|
||||
if any(arr.shape[dim] != 1 for dim in dimensions):
|
||||
raise ValueError("cannot select an axis to squeeze out which has size not equal to one, "
|
||||
f"got shape={arr.shape} and dimensions={dimensions}")
|
||||
data = spenv.data(arr)
|
||||
indices = spenv.indices(arr)
|
||||
n_sparse = indices.shape[-1]
|
||||
n_batch = indices.ndim - 2
|
||||
batch_dims = tuple(d for d in dimensions if d < n_batch)
|
||||
sparse_dims = np.array([i for i in range(n_sparse) if i + n_batch not in dimensions], dtype=int)
|
||||
dense_dims = tuple(d - n_sparse + 1 for d in dimensions if d >= n_batch + n_sparse)
|
||||
data_out = lax.squeeze(data, batch_dims + dense_dims)
|
||||
indices_out = lax.squeeze(indices[..., sparse_dims], batch_dims)
|
||||
out_shape = tuple(s for i, s in enumerate(arr.shape) if i not in dimensions)
|
||||
return (spenv.sparse(out_shape, data_out, indices_out),)
|
||||
arr, = spvalues_to_arrays(spenv, spvalues)
|
||||
result = sparse.bcoo_squeeze(arr, dimensions=dimensions)
|
||||
return arrays_to_spvalues(spenv, (result,))
|
||||
|
||||
sparse_rules[lax.squeeze_p] = _squeeze_sparse
|
||||
|
||||
|
@ -1959,6 +1959,21 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
tol = {np.float32: 1E-6, np.float64: 1E-14}
|
||||
self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
|
||||
|
||||
def test_bcoo_squeeze(self):
|
||||
# more comprehensive tests in sparsify_test:testSparseSqueeze
|
||||
rng = rand_sparse(self.rng())
|
||||
shape = (1, 2, 1, 3, 4)
|
||||
dimensions = (0, 2)
|
||||
M = rng(shape, 'float32')
|
||||
M_bcoo = sparse.BCOO.fromdense(M)
|
||||
|
||||
M2 = lax.squeeze(M, dimensions=dimensions)
|
||||
M2_bcoo = sparse.bcoo_squeeze(M_bcoo, dimensions=dimensions)
|
||||
M2_bcoo_jit = jax.jit(partial(sparse.bcoo_squeeze, dimensions=dimensions))(M_bcoo)
|
||||
|
||||
self.assertArraysEqual(M2, M2_bcoo.todense())
|
||||
self.assertArraysEqual(M2, M2_bcoo_jit.todense())
|
||||
|
||||
def test_bcoo_reshape_error(self):
|
||||
x = sparse.BCOO.fromdense(jnp.ones((2, 2, 3)), n_batch=1)
|
||||
with self.assertRaisesRegex(ValueError, ".*cannot mix batch and sparse dimensions.*"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user