From 1c37f5091c3e45dcc9e6a5923afe6f5f5bdecedc Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 4 Oct 2023 19:56:04 -0700 Subject: [PATCH] sparse_test: Split into two so that each target is small enough to fit within a medium timeout. PiperOrigin-RevId: 570882867 --- jax/experimental/sparse/test_util.py | 81 +- tests/BUILD | 35 +- tests/sparse_bcoo_bcsr_test.py | 1958 ++++++++++++++++++++++++++ tests/sparse_test.py | 1936 ++----------------------- 4 files changed, 2194 insertions(+), 1816 deletions(-) create mode 100644 tests/sparse_bcoo_bcsr_test.py diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 131b350f2..3f4df1feb 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -13,26 +13,55 @@ # limitations under the License. """Sparse test utilities.""" -from collections.abc import Sequence +from collections.abc import Iterable, Iterator, Sequence import functools +import itertools +import math from typing import Any, Callable, Union - -import numpy as np +from typing import NamedTuple import jax from jax import lax -from jax._src import test_util as jtu -from jax._src.typing import DTypeLike from jax import tree_util -from jax.util import safe_zip, split_list +from jax._src import test_util as jtu +from jax._src.lax.lax import DotDimensionNumbers +from jax._src.lib import gpu_sparse +from jax._src.typing import DTypeLike from jax.experimental import sparse import jax.numpy as jnp +from jax.util import safe_zip, split_list +import numpy as np + +MATMUL_TOL = { + np.float32: 1e-5, + np.float64: 1e-10, + np.complex64: 1e-5, + np.complex128: 1e-10, +} + +GPU_LOWERING_ENABLED = gpu_sparse and ( + gpu_sparse.cuda_is_supported or gpu_sparse.rocm_is_supported +) def is_sparse(x): return isinstance(x, sparse.JAXSparse) +class BatchedDotGeneralProperties(NamedTuple): + lhs_shape: tuple[int, ...] + rhs_shape: tuple[int, ...] + n_batch: int + n_dense: int + dimension_numbers: DotDimensionNumbers + + +class SparseLayout(NamedTuple): + n_batch: int + n_dense: int + n_sparse: int + + class SparseTestCase(jtu.JaxTestCase): def assertSparseArraysEquivalent(self, x, y, *, check_dtypes=True, atol=None, rtol=None, canonicalize_dtypes=True, err_msg=''): @@ -170,3 +199,43 @@ def rand_bcsr(rng: np.random.RandomState, return functools.partial(_rand_sparse, rng=rng, rand_method=rand_method, nse=nse, n_batch=n_batch, n_dense=n_dense, sparse_format='bcsr') + + +def iter_subsets(s: Sequence) -> Iterable[tuple]: + """Return an iterator over all subsets of a sequence s""" + return itertools.chain.from_iterable( + itertools.combinations(s, n) for n in range(len(s) + 1) + ) + + +def iter_sparse_layouts( + shape: Sequence[int], min_n_batch=0 +) -> Iterator[SparseLayout]: + for n_batch in range(min_n_batch, len(shape) + 1): + for n_dense in range(len(shape) + 1 - n_batch): + n_sparse = len(shape) - n_batch - n_dense + yield SparseLayout(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense) + + +def iter_bcsr_layouts( + shape: Sequence[int], min_n_batch=0 +) -> Iterator[SparseLayout]: + n_sparse = 2 + for n_batch in range(min_n_batch, len(shape) - 1): + n_dense = len(shape) - n_sparse - n_batch + yield SparseLayout(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense) + + +def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default): + def _rand_sparse(shape, dtype, nse=nse): + rand = rand_method(rng) + size = math.prod(shape) + if 0 <= nse < 1: + nse = nse * size + nse = min(size, int(nse)) + M = rand(shape, dtype) + indices = rng.choice(size, size - nse, replace=False) + M.flat[indices] = 0 + return post(M) + + return _rand_sparse diff --git a/tests/BUILD b/tests/BUILD index f8de3810d..6eae1fc28 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -834,7 +834,40 @@ jax_test( backend_variant_args = { "cpu": ["--jax_num_generated_cases=40"], "cpu_x32": ["--jax_num_generated_cases=40"], - "cpu_no_jax_array": ["--jax_num_generated_cases=40"], + "gpu": ["--jax_num_generated_cases=40"], + }, + shard_count = { + "cpu": 50, + "gpu": 50, + "tpu": 50, + "iree": 10, + }, + tags = [ + "noasan", + "nomsan", + "notsan", + ], # Test times out under asan/msan/tsan. + deps = [ + "//jax:experimental_sparse", + "//jax:sparse_test_util", + ] + py_deps("scipy"), +) + +jax_test( + name = "sparse_bcoo_bcsr_test", + srcs = ["sparse_bcoo_bcsr_test.py"], + args = ["--jax_bcoo_cusparse_lowering=true"], + backend_tags = { + "cpu": [ + "nomsan", # Times out + "notsan", # Times out + ], + "tpu": ["optonly"], + }, + # Use fewer cases to prevent timeouts. + backend_variant_args = { + "cpu": ["--jax_num_generated_cases=40"], + "cpu_x32": ["--jax_num_generated_cases=40"], "gpu": ["--jax_num_generated_cases=40"], }, shard_count = { diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py new file mode 100644 index 000000000..8ba5b22c0 --- /dev/null +++ b/tests/sparse_bcoo_bcsr_test.py @@ -0,0 +1,1958 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +from functools import partial +import itertools +import math +import operator +import random +import unittest + +from absl.testing import absltest +import jax +from jax import config +from jax import jit +from jax import lax +from jax import vmap +from jax._src import test_util as jtu +from jax._src import xla_bridge +from jax._src.lax.lax import remaining +from jax._src.util import unzip2 +from jax.experimental import sparse +from jax.experimental.sparse import bcoo as sparse_bcoo +from jax.experimental.sparse import bcsr as sparse_bcsr +from jax.experimental.sparse import test_util as sptu +from jax.experimental.sparse import util as sparse_util +import jax.numpy as jnp +import jax.random +from jax.util import split_list +import numpy as np + +config.parse_flags_with_absl() +FLAGS = config.FLAGS + +COMPATIBLE_SHAPE_PAIRS = [ + [(), ()], + [(), (1,)], + [(3,), (1, 3)], + [(3, 1), (3,)], + [(6,), (2, 3)], + [(3, 2), (6,)], + [(2, 3), (1, 6)], + [(2, 4), (4, 1, 2)], + [(3, 4, 5), (2, 6, 5)], + [(2,), (2,)], +] + + +def _generate_batched_dot_general_properties( + shapes=((5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)), sparse_format="bcoo" +) -> sptu.BatchedDotGeneralProperties: + """Generator of properties for bcoo_dot_general tests.""" + rng = random.Random(0) + + if sparse_format not in ['bcoo', 'bcsr']: + raise ValueError(f"Sparse format {sparse_format} not supported.") + + for shape in shapes: + for layout in sptu.iter_sparse_layouts(shape): + if sparse_format == "bcsr" and layout.n_sparse != 2: + continue + subsets = split_list(range(len(shape)), [layout.n_batch, layout.n_sparse]) + for batch_dims in sptu.iter_subsets(range(layout.n_batch)): + for contracting_dims in sptu.iter_subsets( + remaining(range(layout.n_batch + layout.n_sparse), batch_dims) + ): + # We want coverage of permutations without generating hundreds of thousands of test cases; + # we do this by deterministic pseudo-random sampling instead of iterating. + rhs_permute = rng.sample(range(len(shape)), len(shape)) + lhs_permute = list( + itertools.chain.from_iterable( + rng.sample(subset, len(subset)) for subset in subsets + ) + ) + yield sptu.BatchedDotGeneralProperties( + lhs_shape=tuple(shape[p] for p in lhs_permute), + rhs_shape=tuple(shape[p] for p in rhs_permute), + n_batch=layout.n_batch, + n_dense=layout.n_dense, + dimension_numbers=( + ( + [lhs_permute.index(d) for d in contracting_dims], + [rhs_permute.index(d) for d in contracting_dims], + ), + ( + [lhs_permute.index(d) for d in batch_dims], + [rhs_permute.index(d) for d in batch_dims], + ), + ), + ) + + +def _generate_bcoo_dot_general_sampled_properties( + shapes=((5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)) +) -> sptu.BatchedDotGeneralProperties: + """Generator of properties for bcoo_dot_general_sampled tests.""" + rng = random.Random(0) + + for shape in shapes: + for batch_dims in sptu.iter_subsets(range(len(shape))): + for contracting_dims in sptu.iter_subsets( + remaining(range(len(shape)), batch_dims) + ): + # We want coverage of permutations without generating hundreds of thousands of test cases; + # we do this by deterministic pseudo-random sampling instead of iterating. + lhs_permute = rng.sample(range(len(shape)), len(shape)) + rhs_permute = rng.sample(range(len(shape)), len(shape)) + lhs_shape = tuple(shape[p] for p in lhs_permute) + rhs_shape = tuple(shape[p] for p in rhs_permute) + dimension_numbers = ( + ( + [lhs_permute.index(d) for d in contracting_dims], + [rhs_permute.index(d) for d in contracting_dims], + ), + ( + [lhs_permute.index(d) for d in batch_dims], + [rhs_permute.index(d) for d in batch_dims], + ), + ) + out = jax.eval_shape(partial(lax.dot_general, dimension_numbers=dimension_numbers), + jax.ShapeDtypeStruct(lhs_shape, 'float32'), jax.ShapeDtypeStruct(rhs_shape, 'float32')) + for layout in sptu.iter_sparse_layouts(out.shape): + yield sptu.BatchedDotGeneralProperties( + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + n_batch=layout.n_batch, + n_dense=layout.n_dense, + dimension_numbers=dimension_numbers, + ) + + +all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex + +def _is_required_cuda_version_satisfied(cuda_version): + version = xla_bridge.get_backend().platform_version + if version == "" or version.split()[0] == "rocm": + return False + else: + return int(version.split()[-1]) >= cuda_version + +class BCOOTest(sptu.SparseTestCase): + + def gpu_matmul_warning_context(self, msg): + if sptu.GPU_LOWERING_ENABLED and config.jax_bcoo_cusparse_lowering: + return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg) + return contextlib.nullcontext() + + def test_repr(self): + x = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32')) + self.assertEqual(repr(x), "BCOO(float32[5], nse=4)") + + y = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3), n_batch=1) + self.assertEqual(repr(y), "BCOO(float32[2, 3], nse=3, n_batch=1)") + + y = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3), n_batch=1, n_dense=1) + self.assertEqual(repr(y), "BCOO(float32[2, 3], nse=1, n_batch=1, n_dense=1)") + + M_invalid = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3)) + M_invalid.indices = jnp.array([]) + self.assertEqual(repr(M_invalid), "BCOO()") + + @jit + def f(x): + self.assertEqual(repr(x), "DynamicJaxprTracer[BCOO(float32[5], nse=4)]") + f(x) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=all_dtypes, + ) + def test_empty(self, shape, dtype, n_batch, n_dense): + M = sparse.empty(shape, dtype=dtype, n_batch=n_batch, n_dense=n_dense) + self.assertIsInstance(M, sparse.BCOO) + self.assertEqual(M.nse, 0) + self.assertEqual(M.n_batch, n_batch) + self.assertEqual(M.n_dense, n_dense) + self.assertEqual(M.dtype, dtype) + self.assertArraysEqual(M.todense(), jnp.empty(shape, dtype)) + + @jtu.sample_product( + [ + dict(n_batch=layout.n_batch, n_dense=layout.n_dense) + for layout in sptu.iter_sparse_layouts((3, 3)) + ], + N=[3, 5], + M=[None, 4], + k=[-3, -1, 0, 2, 4], + dtype=all_dtypes, + ) + def test_eye(self, N, M, k, dtype, n_batch, n_dense): + mat = sparse.eye(N, M, k, dtype=dtype, n_batch=n_batch, n_dense=n_dense) + expected = jnp.eye(N, M, k, dtype=dtype) + expected_nse = sparse.BCOO.fromdense(expected, n_batch=n_batch, n_dense=n_dense).nse + + self.assertIsInstance(mat, sparse.BCOO) + self.assertEqual(mat.n_batch, n_batch) + self.assertEqual(mat.n_dense, n_dense) + self.assertEqual(mat.dtype, dtype) + self.assertEqual(mat.nse, expected_nse) + self.assertArraysEqual(mat.todense(), expected) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=all_dtypes, + ) + def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense): + n_sparse = len(shape) - n_batch - n_dense + rng = sptu.rand_sparse(self.rng()) + M = rng(shape, dtype) + nse = sparse.util._count_stored_elements(M, n_batch=n_batch, n_dense=n_dense) + def round_trip(M): + return sparse.BCOO.fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense).todense() + args_maker = lambda: [M] + ident = lambda x: x + + self._CheckAgainstNumpy(ident, round_trip, args_maker) + self._CompileAndCheck(round_trip, args_maker) + self._CheckBatchingSparse(ident, round_trip, args_maker, bdims=self._random_bdims(n_batch)) + if jnp.issubdtype(dtype, jnp.floating): + # For n_sparse != 0, we can't use an identity because output zeros must not + # be dependent on input zeros. This mimics the code in count_stored_elements(). + def expected(M): + if n_sparse == 0: return M + mask = (M != 0).any(range(M.ndim - n_dense, M.ndim), keepdims=True) + return jnp.where(mask, M, 0) + self._CheckGradsSparse(expected, round_trip, args_maker) + + def test_bcoo_fromdense_sorted_and_unique_indices(self): + rng = self.rng() + rng_sparse = sptu.rand_sparse(rng) + mat = sparse.BCOO.fromdense(rng_sparse((5, 6), np.float32)) + perm = rng.permutation(mat.nse) + mat_unsorted = sparse.BCOO((mat.data[perm], mat.indices[perm]), + shape=mat.shape, + unique_indices=mat.unique_indices) + mat_resorted = mat_unsorted.sort_indices() + with self.subTest('sorted indices'): + self.assertArraysEqual(mat.indices, mat_resorted.indices) + self.assertArraysEqual(mat.data, mat_resorted.data) + + with self.subTest('unique indices'): + self.assertTrue(mat.unique_indices) + self.assertTrue(mat_unsorted.unique_indices) + self.assertTrue(mat_resorted.unique_indices) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + assume_unique=[True, False, None], + ) + def test_bcoo_extract(self, shape, dtype, n_batch, n_dense, assume_unique): + rng = sptu.rand_sparse(self.rng()) + + def args_maker(): + x = rng(shape, dtype) + x_bcoo = sparse.bcoo_fromdense(x, n_batch=n_batch, n_dense=n_dense) + # Unique indices are required for this test when assume_unique == True. + self.assertTrue(x_bcoo.unique_indices) + return x_bcoo, x + + dense_op = lambda _, x: x + sparse_op = partial(sparse.bcoo_extract, assume_unique=assume_unique) + + self._CheckAgainstDense(dense_op, sparse_op, args_maker) + self._CheckBatchingSparse(dense_op, sparse_op, args_maker, bdims=2 * self._random_bdims(n_batch)) + + def test_bcoo_extract_duplicate_indices(self): + data = jnp.array([1, 3, 9, 27, 81, 243]) + indices = jnp.array([[0], [5], [0], [3], [2], [3]]) + shape = (6,) + mat = sparse.BCOO((data, indices), shape=shape).todense() + + data1 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=True) + self.assertArraysEqual(data1, jnp.array([10, 3, 10, 270, 81, 270])) + + data2 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=False) + self.assertArraysEqual(data2, jnp.array([10, 3, 0, 270, 81, 0])) + + def test_bcoo_extract_duplicate_indices_n_sparse_0(self): + data = jnp.arange(6).reshape(3, 2) + indices = jnp.empty((3, 2, 0), dtype=int) + shape = (3,) + mat = sparse.BCOO((data, indices), shape=shape).todense() + + data1 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=True) + self.assertArraysEqual(data1, jnp.array([[1, 1], [5, 5], [9, 9]])) + + data2 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=False) + self.assertArraysEqual(data2, jnp.array([[1, 0], [5, 0], [9, 0]])) + + def test_bcoo_extract_batching(self): + # https://github.com/google/jax/issues/9431 + indices = jnp.zeros((4, 1, 1), dtype=int) + mat = jnp.arange(4.).reshape((4, 1)) + + # in_axes = (0, None) + expected = jnp.vstack([sparse_bcoo._bcoo_extract(i, mat[0]) for i in indices]) + actual = vmap(sparse_bcoo._bcoo_extract, in_axes=(0, None))(indices, mat[0]) + self.assertArraysEqual(expected, actual) + + # in_axes = (None, 0) + expected = jnp.vstack([sparse_bcoo._bcoo_extract(indices[0], m) for m in mat]) + actual = vmap(sparse_bcoo._bcoo_extract, in_axes=(None, 0))(indices[0], mat) + self.assertArraysEqual(expected, actual) + + # in_axes = (0, 0) + expected = jnp.vstack([sparse_bcoo._bcoo_extract(i, m) for i, m in zip(indices, mat)]) + actual = vmap(sparse_bcoo._bcoo_extract, in_axes=0)(indices, mat) + self.assertArraysEqual(expected, actual) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=jtu.dtypes.floating, + ) + def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense): + rng = sptu.rand_sparse(self.rng()) + M = rng(shape, dtype) + nse = sparse.util._count_stored_elements(M, n_batch=n_batch, + n_dense=n_dense) + data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense) + + extract = partial(sparse_bcoo._bcoo_extract, indices) + j1 = jax.jacfwd(extract)(M) + j2 = jax.jacrev(extract)(M) + hess = jax.hessian(extract)(M) + self.assertArraysAllClose(j1, j2) + self.assertEqual(j1.shape, data.shape + M.shape) + self.assertEqual(hess.shape, data.shape + 2 * M.shape) + + def test_bcoo_extract_zero_nse(self): + # Regression test for https://github.com/google/jax/issues/13653 + + # (n_batch, n_sparse, n_dense) = (1, 0, 0), nse = 2 + args_maker = lambda: (jnp.zeros((3, 2, 0), dtype='int32'), jnp.arange(3)) + self._CompileAndCheck(sparse_bcoo._bcoo_extract, args_maker) + + # (n_batch, n_sparse, n_dense) = (0, 0, 1), nse = 2 + args_maker = lambda: (jnp.zeros((2, 0), dtype='int32'), jnp.arange(3)) + self._CompileAndCheck(sparse_bcoo._bcoo_extract, args_maker) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=jtu.dtypes.numeric, + ) + def test_bcoo_transpose(self, shape, dtype, n_batch, n_dense): + n_sparse = len(shape) - n_batch - n_dense + rng = self.rng() + sprng = sptu.rand_bcoo(rng, n_batch=n_batch, n_dense=n_dense) + + permutation = np.concatenate([ + rng.permutation(range(n_batch)), + rng.permutation(range(n_batch, n_batch + n_sparse)), + rng.permutation(range(n_batch + n_sparse, len(shape))), + ]).astype(int) + + args_maker = lambda: [sprng(shape, dtype)] + dense_func = partial(lax.transpose, permutation=permutation) + sparse_func = partial(sparse.bcoo_transpose, permutation=permutation) + + self._CheckAgainstDense(dense_func, sparse_func, args_maker) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(dense_func, sparse_func, args_maker) + self._CheckBatchingSparse(dense_func, sparse_func, args_maker, bdims=self._random_bdims(n_batch)) + + def test_bcoo_transpose_indices_sorted(self): + rng = self.rng() + rng_sparse = sptu.rand_sparse(rng) + n_batch, n_dense = 2, 2 + shape = (2, 3, 4, 5, 6, 7, 8) + mat = sparse.BCOO.fromdense(rng_sparse(shape, np.float32), + n_dense=n_dense, n_batch=n_batch) + + permutations = (1, 0, 2, 3, 4, 6, 5) + mat_T_indices_sorted = mat.transpose(axes=permutations) + self.assertTrue(mat_T_indices_sorted.indices_sorted) + + permutations = (0, 1, 3, 2, 4, 5, 6) + mat_T_indices_unsorted = mat.transpose(axes=permutations) + self.assertFalse(mat_T_indices_unsorted.indices_sorted) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape, min_n_batch=1) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense): + rng = sptu.rand_sparse(self.rng()) + M = rng(shape, dtype) + nse = sparse.util._count_stored_elements(M, n_batch=n_batch, + n_dense=n_dense) + data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense) + + M1 = sparse_bcoo._bcoo_todense(data, indices[:1], spinfo=sparse_util.SparseInfo(M.shape)) + M2 = sparse_bcoo._bcoo_todense(data, jnp.stack(shape[0] * [indices[0]]), spinfo=sparse_util.SparseInfo(M.shape)) + self.assertAllClose(M1, M2) + + M3 = sparse_bcoo._bcoo_todense(data[:1], indices, spinfo=sparse_util.SparseInfo(M.shape)) + M4 = sparse_bcoo._bcoo_todense(jnp.stack(shape[0] * [data[0]]), indices, spinfo=sparse_util.SparseInfo(M.shape)) + self.assertAllClose(M3, M4) + + @jtu.sample_product( + props=_generate_batched_dot_general_properties(), + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + @jax.default_matmul_precision("float32") + def test_bcoo_dot_general( + self, dtype: np.dtype, props: sptu.BatchedDotGeneralProperties + ): + rng = jtu.rand_default(self.rng()) + sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense) + args_maker = lambda: [sprng(props.lhs_shape, dtype), + rng(props.rhs_shape, dtype)] + dense_fun = partial(lax.dot_general, dimension_numbers=props.dimension_numbers) + sparse_fun = partial(sparse.bcoo_dot_general, dimension_numbers=props.dimension_numbers) + + tol = {np.float64: 1E-12, np.complex128: 1E-12, + np.float32: 1E-5, np.complex64: 1E-5} + self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol) + if jnp.issubdtype(dtype, jnp.floating) and props.n_dense == 0: + # Dense dimensions not yet fully supported in reverse mode. + modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev'] + self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol) + self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol, + bdims=self._random_bdims(props.n_batch, len(props.rhs_shape))) + + @jtu.sample_product( + [ + dict( + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + lhs_contracting=lhs_contracting, + rhs_contracting=rhs_contracting, + ) + for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ + [(5,), (5,), [0], [0]], + [(5,), (5, 7), [0], [0]], + [(5,), (7, 5), [0], [1]], + [(5, 7), (5,), [0], [0]], + [(7, 5), (5,), [1], [0]], + [(3, 5), (2, 5), [1], [1]], + [(3, 5), (5, 2), [1], [0]], + [(5, 3), (2, 5), [0], [1]], + [(5, 3), (5, 2), [0], [0]], + ] + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + @jax.default_matmul_precision("float32") + @unittest.skipIf( + not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" + ) + @jtu.run_on_devices("gpu") + def test_bcoo_dot_general_cusparse( + self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting + ): + rng = jtu.rand_small(self.rng()) + rng_sparse = sptu.rand_sparse(self.rng()) + def args_maker(): + lhs = rng_sparse(lhs_shape, dtype) + rhs = rng(rhs_shape, dtype) + nse = sparse.util._count_stored_elements(lhs, n_batch=0, n_dense=0) + lhs_bcoo = sparse_bcoo.bcoo_fromdense(lhs, nse=nse, index_dtype=jnp.int32) + return lhs_bcoo, lhs, rhs + + dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) + + def f_dense(lhs_bcoo, lhs, rhs): + return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers) + + def f_sparse(lhs_bcoo, lhs, rhs): + return sparse_bcoo.bcoo_dot_general(lhs_bcoo, rhs, + dimension_numbers=dimension_numbers) + + self._CompileAndCheck(f_sparse, args_maker) + self._CheckAgainstNumpy(f_dense, f_sparse, args_maker) + + @jtu.sample_product( + [ + dict( + n_batch=n_batch, + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + lhs_contracting=lhs_contracting, + rhs_contracting=rhs_contracting, + ) + for n_batch, lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ + [1, (1, 2, 3), (3, 2), [2], [0]], + [1, (1, 3, 2), (3, 2), [1], [0]], + [1, (1, 3, 2), (4, 3), [1], [1]], + [1, (4, 2, 3), (3, 5), [2], [0]], + [1, (4, 2, 3), (2, 5), [1], [0]], + [1, (4, 2, 3), (5, 3), [2], [1]], + ] + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + @jax.default_matmul_precision("float32") + @unittest.skipIf( + not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" + ) + @jtu.run_on_devices("gpu") + def test_bcoo_batched_matmat_cusparse( + self, + n_batch, + lhs_shape, + rhs_shape, + dtype, + lhs_contracting, + rhs_contracting, + ): + rng = jtu.rand_small(self.rng()) + rng_sparse = sptu.rand_sparse(self.rng()) + def args_maker(): + lhs = rng_sparse(lhs_shape, dtype) + rhs = rng(rhs_shape, dtype) + nse = sparse.util._count_stored_elements(lhs, n_batch=n_batch, + n_dense=0) + lhs_bcoo = sparse_bcoo.bcoo_fromdense(lhs, n_batch=n_batch, nse=nse, + index_dtype=jnp.int32) + return lhs_bcoo, lhs, rhs + + dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) + + def f_dense(lhs_bcoo, lhs, rhs): + return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers) + + def f_sparse(lhs_bcoo, lhs, rhs): + return sparse_bcoo.bcoo_dot_general(lhs_bcoo, rhs, + dimension_numbers=dimension_numbers) + + # TODO(tianjianlu): In some cases, this fails python_should_be_executing. + # self._CompileAndCheck(f_sparse, args_maker) + self._CheckAgainstNumpy(f_dense, f_sparse, args_maker) + + @jtu.sample_product( + [ + dict( + n_batch=n_batch, + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + lhs_contracting=lhs_contracting, + rhs_contracting=rhs_contracting, + ) + for n_batch, lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ + [1, (1, 2, 3), (3), [2], [0]], + [1, (1, 2), (3, 2), [1], [1]], + ] + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + @unittest.skipIf( + not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" + ) + @jtu.run_on_devices("gpu") + def test_bcoo_batched_matmat_default_lowering( + self, + n_batch, + lhs_shape, + rhs_shape, + dtype, + lhs_contracting, + rhs_contracting, + ): + rng = jtu.rand_small(self.rng()) + rng_sparse = sptu.rand_sparse(self.rng()) + lhs = rng_sparse(lhs_shape, dtype) + rhs = rng(rhs_shape, dtype) + nse = sparse.util._count_stored_elements(lhs, n_batch=n_batch, + n_dense=0) + lhs_bcoo = sparse_bcoo.bcoo_fromdense( + lhs, n_batch=n_batch, nse=nse, index_dtype=jnp.int32 + ) + dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) + matmat_expected = lax.dot_general(lhs, rhs, + dimension_numbers=dimension_numbers) + sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general, + dimension_numbers=dimension_numbers)) + + # TODO(jakevdp): uncomment once batching is supported again. + # with self.gpu_matmul_warning_context( + # "bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"): + matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs) + self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback) + + @unittest.skipIf( + not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" + ) + @jtu.run_on_devices("gpu") + def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self): + """Tests bcoo dot general with out-of-bound and unsorted indices.""" + + rhs = jnp.ones((5, 3), dtype=jnp.float32) + + # It creates out-of-bound indices when nse > nnz. + lhs_mat_dense = jnp.array([[1, 0, 2, 3, 0], [0, 0, 0, 4, 0]], + dtype=jnp.float32) + lhs_mat_bcoo = sparse.BCOO.fromdense(lhs_mat_dense, nse=7) + rng = self.rng() + perm = rng.permutation(lhs_mat_bcoo.nse) + lhs_mat_bcoo_unsorted = sparse.BCOO( + (lhs_mat_bcoo.data[perm], lhs_mat_bcoo.indices[perm]), + shape=lhs_mat_dense.shape) + + dimension_numbers_2d = (([1], [0]), ([], [])) + sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general, + dimension_numbers=dimension_numbers_2d)) + + matmat_expected = lax.dot_general(lhs_mat_dense, rhs, + dimension_numbers=dimension_numbers_2d) + with self.subTest(msg="2D"): + with self.gpu_matmul_warning_context( + "bcoo_dot_general GPU lowering requires matrices with sorted indices*"): + matmat_unsorted_fallback = sp_matmat(lhs_mat_bcoo_unsorted, rhs) + self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback) + + lhs_vec_dense = jnp.array([0, 1, 0, 2, 0], dtype=jnp.float32) + lhs_vec_bcoo = sparse.BCOO.fromdense(lhs_vec_dense, nse=5) + rng = self.rng() + perm = rng.permutation(lhs_vec_bcoo.nse) + lhs_vec_bcoo_unsorted = sparse.BCOO( + (lhs_vec_bcoo.data[perm], lhs_vec_bcoo.indices[perm]), + shape=lhs_vec_dense.shape, indices_sorted=False) + + dimension_numbers_1d = (([0], [0]), ([], [])) + sp_vecmat = jit(partial(sparse_bcoo.bcoo_dot_general, + dimension_numbers=dimension_numbers_1d)) + + vecmat_expected = lax.dot_general(lhs_vec_dense, rhs, + dimension_numbers=dimension_numbers_1d) + + with self.subTest(msg="1D"): + with self.gpu_matmul_warning_context( + "bcoo_dot_general GPU lowering requires matrices with sorted indices*"): + vecmat_unsorted_fallback = sp_vecmat(lhs_vec_bcoo_unsorted, rhs) + self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback) + + @jtu.sample_product( + props=_generate_batched_dot_general_properties(), + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + @jax.default_matmul_precision("float32") + def test_bcoo_rdot_general( + self, dtype: np.dtype, props: sptu.BatchedDotGeneralProperties + ): + rng = jtu.rand_default(self.rng()) + sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense) + args_maker = lambda: [rng(props.rhs_shape, dtype), + sprng(props.lhs_shape, dtype)] + dimension_numbers = tuple(d[::-1] for d in props.dimension_numbers) + sparse_fun = partial(sparse.bcoo_dot_general, dimension_numbers=dimension_numbers) + dense_fun = partial(lax.dot_general, dimension_numbers=dimension_numbers) + + tol = {np.float64: 1E-12, np.complex128: 1E-12, + np.float32: 1E-5, np.complex64: 1E-5} + self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol) + if jnp.issubdtype(dtype, jnp.floating): + # Dense dimensions not yet fully supported in reverse mode. + modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev'] + self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol) + + @jtu.sample_product( + [ + dict( + n_batch=n_batch, + n_dense=n_dense, + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + dimension_numbers=dimension_numbers, + ) + for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [ + ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0), + ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0), + ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0), + ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0), + ( + (3, 4, 2, 4), + (3, 4, 3, 2), + (([2], [3]), ([0, 1], [0, 1])), + 2, + 0, + ), + ( + (3, 4, 2, 4), + (3, 4, 3, 2), + (([2], [3]), ([0, 1], [0, 1])), + 2, + 1, + ), + ] + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + @jax.default_matmul_precision("float32") + def test_bcoo_dot_general_partial_batch( + self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense + ): + rng = jtu.rand_small(self.rng()) + rng_sparse = sptu.rand_sparse(self.rng()) + + X = rng_sparse(lhs_shape, dtype) + nse = sparse.util._count_stored_elements(X, n_batch=n_batch, + n_dense=n_dense) + data, indices = sparse_bcoo._bcoo_fromdense(X, nse=nse, n_batch=n_batch, n_dense=n_dense) + Y = rng(rhs_shape, dtype) + + def f_dense(X, Y): + return lax.dot_general(X, Y, dimension_numbers=dimension_numbers) + + def f_sparse(data, indices, Y): + return sparse_bcoo._bcoo_dot_general(data, indices, Y, lhs_spinfo=sparse_util.SparseInfo(X.shape), + dimension_numbers=dimension_numbers, preferred_element_type=None) + + for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]): + X = sparse_bcoo._bcoo_todense(data, indices, spinfo=sparse_util.SparseInfo(X.shape)) + self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y)) + + @jtu.sample_product( + props=_generate_bcoo_dot_general_sampled_properties(), + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + @jax.default_matmul_precision("float32") + @jtu.skip_on_flag("jax_skip_slow_tests", True) + def test_bcoo_dot_general_sampled(self, props, dtype): + rng = jtu.rand_default(self.rng()) + sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense) + out = jax.eval_shape(partial(lax.dot_general, dimension_numbers=props.dimension_numbers), + jax.ShapeDtypeStruct(props.lhs_shape, dtype), + jax.ShapeDtypeStruct(props.rhs_shape, dtype)) + args_maker = lambda: [rng(props.lhs_shape, dtype), rng(props.rhs_shape, dtype), + sprng(out.shape, dtype).indices] + + def dense_fun(lhs, rhs, indices): + AB = lax.dot_general(lhs, rhs, dimension_numbers=props.dimension_numbers) + return sparse_bcoo._bcoo_extract(indices, AB) + def sparse_fun(lhs, rhs, indices): + return sparse.bcoo_dot_general_sampled( + lhs, rhs, indices, dimension_numbers=props.dimension_numbers) + + self._CheckAgainstDense(dense_fun, sparse_fun, args_maker) + if jnp.issubdtype(dtype, jnp.floating): + # Note: forward mode fails for some sparse layouts. + # TODO(jakevdp) fix forward-mode autodiff & enable tests here. + self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=['rev'], argnums=[0, 1]) + + @jtu.sample_product( + [ + { + "xshape": xshape, + "yshape": yshape, + "lhs_contract": lhs_contract, + "rhs_contract": rhs_contract, + } + for (xshape, yshape, lhs_contract, rhs_contract) in [ + [(4, 3), (4, 5), (0,), (0,)], + [(3, 4), (4, 5), (1,), (0,)], + [(4, 3), (5, 4), (0,), (1,)], + [(3, 4), (5, 4), (1,), (1,)], + [(3,), (3,), (), ()], + [(3,), (5,), (), ()], + [(5,), (3,), (), ()], + [(5,), (5,), (), ()], + ] + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + n_batch=[0, 1, 2], + ) + @jax.default_matmul_precision("float32") + def test_bcoo_dot_general_sampled_fast_cases( + self, xshape, yshape, lhs_contract, rhs_contract, n_batch, dtype): + rng = jtu.rand_default(self.rng()) + sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch) + dimension_numbers = ((lhs_contract, rhs_contract), ([], [])) + + out_shape = jax.eval_shape(partial(lax.dot_general, dimension_numbers=dimension_numbers), + jax.ShapeDtypeStruct(xshape, dtype), jax.ShapeDtypeStruct(yshape, dtype)) + + args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype), + sprng(out_shape.shape, out_shape.dtype).indices] + + def f1(x, y, indices): + mat_full = lax.dot_general(x, y, dimension_numbers=dimension_numbers) + return sparse_bcoo._bcoo_extract(indices, mat_full) + + def f2(x, y, indices): + return sparse.bcoo_dot_general_sampled(x, y, indices, dimension_numbers=dimension_numbers) + + self._CheckAgainstNumpy(f1, f2, args_maker, tol=sptu.MATMUL_TOL) + self._CompileAndCheck(f2, args_maker, tol=sptu.MATMUL_TOL) + + @jtu.sample_product( + [ + dict( + n_batch=n_batch, + n_dense=n_dense, + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + dimension_numbers=dimension_numbers, + ) + for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [ + ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0), + ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 1), + ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0), + ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0), + ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 1), + ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0), + ( + (3, 4, 2, 4), + (3, 4, 3, 2), + (([2], [3]), ([0, 1], [0, 1])), + 2, + 0, + ), + ( + (3, 4, 2, 4), + (3, 4, 3, 2), + (([2], [3]), ([0, 1], [0, 1])), + 2, + 1, + ), + ] + ], + dtype=jtu.dtypes.floating, + ) + @jax.default_matmul_precision("float32") + def test_bcoo_dot_general_sampled_ad(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): + rng = jtu.rand_default(self.rng()) + sprng = sptu.rand_sparse(self.rng()) + out_shape = lax.dot_general( + jnp.zeros(lhs_shape), + jnp.zeros(rhs_shape), + dimension_numbers=dimension_numbers, + ).shape + + lhs = rng(lhs_shape, dtype) + rhs = rng(rhs_shape, dtype) + indices = sparse.BCOO.fromdense(sprng(out_shape, dtype), + n_batch=n_batch, n_dense=n_dense).indices + + def dense_fun(lhs, rhs, indices): + AB = lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers) + return sparse_bcoo._bcoo_extract(indices, AB) + def sparse_fun(lhs, rhs, indices): + return sparse.bcoo_dot_general_sampled( + lhs, rhs, indices, dimension_numbers=dimension_numbers + ) + + jf_dense = jax.jacfwd(dense_fun)(lhs, rhs, indices) + jf_sparse = jax.jacfwd(sparse_fun)(lhs, rhs, indices) + jr_dense = jax.jacrev(dense_fun)(lhs, rhs, indices) + jr_sparse = jax.jacrev(sparse_fun)(lhs, rhs, indices) + + self.assertAllClose(jf_sparse, jf_dense) + self.assertAllClose(jr_sparse, jr_dense) + self.assertAllClose(jf_sparse, jr_sparse) + + @jtu.sample_product( + [ + dict( + lhs_n_batch=lhs_n_batch, + rhs_n_batch=rhs_n_batch, + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + dimension_numbers=dimension_numbers, + ) + for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dimension_numbers in [ + # (batched) outer products (no contraction) + ((5,), 0, (6,), 0, (([], []), ([], []))), + ((3, 5), 0, (2, 4), 0, (([], []), ([], []))), + ((3, 5), 1, (3, 4), 1, (([], []), ([0], [0]))), + # (batched) vector-vector products + ((5,), 0, (5,), 0, (([0], [0]), ([], []))), + ((7,), 0, (7,), 0, (([0], [0]), ([], []))), + ((5, 7), 1, (7,), 0, (([1], [0]), ([], []))), + ((2, 3, 4), 2, (2, 4), 1, (([2], [1]), ([0], [0]))), + ((2, 3, 4), 2, (2, 4), 1, (([2], [1]), ([], []))), + ((2, 3, 4), 2, (3, 4), 1, (([2], [1]), ([1], [0]))), + ((2, 3, 4), 2, (3, 4), 1, (([2], [1]), ([], []))), + # (batched) matrix-vector products + ((5, 7), 0, (7,), 0, (([1], [0]), ([], []))), + ((2, 3, 4), 1, (4,), 0, (([2], [0]), ([], []))), + ((2, 3, 4), 1, (2, 4), 1, (([2], [1]), ([0], [0]))), + ((3, 2, 4), 1, (3, 4), 1, (([2], [1]), ([0], [0]))), + ((2, 3, 4), 0, (2,), 0, (([0], [0]), ([], []))), + # (batched) matrix-matrix products + ((5, 7), 0, (7, 3), 0, (([1], [0]), ([], []))), + ((2, 3, 4), 1, (4, 3), 0, (([2], [0]), ([], []))), + ((2, 3, 4), 1, (2, 4, 3), 1, (([2], [1]), ([0], [0]))), + # more general operations + ( + (2, 3, 4, 3), + 1, + (2, 4, 3, 4), + 1, + (([2, 3], [1, 2]), ([0], [0])), + ), + ( + (2, 3, 4, 3, 1), + 2, + (3, 2, 3, 4), + 2, + (([2, 3], [3, 2]), ([0, 1], [1, 0])), + ), + ] + ], + swap=[True, False], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + @jax.default_matmul_precision("float32") + @jtu.skip_on_flag("jax_skip_slow_tests", True) + def test_bcoo_spdot_general(self, lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dtype, swap, dimension_numbers): + if swap: + dimension_numbers = tuple(d[::-1] for d in dimension_numbers) + lhs_shape, rhs_shape = rhs_shape, lhs_shape + lhs_n_batch, rhs_n_batch = rhs_n_batch, lhs_n_batch + + lhs_n_sparse = len(lhs_shape) - lhs_n_batch + rhs_batch = dimension_numbers[1][1] + lhs_contracting = dimension_numbers[0][0] + should_error = (rhs_n_batch > len(rhs_batch) and lhs_n_sparse > len(lhs_contracting)) + + sprng = sptu.rand_bcoo(self.rng()) + args_maker = lambda: [sprng(lhs_shape, dtype, n_batch=lhs_n_batch), + sprng(rhs_shape, dtype, n_batch=rhs_n_batch)] + + def f_dense(x, y): + return lax.dot_general(x, y, dimension_numbers=dimension_numbers) + + def f_sparse(xsp, ysp): + return sparse.bcoo_dot_general(xsp, ysp, dimension_numbers=dimension_numbers) + + if should_error: + with self.assertRaisesRegex(ValueError, ".*cannot have unused batch dims on rhs with unused sparse dims on lhs."): + f_sparse(*args_maker()) + else: + tol = {"float32": 1E-5, "complex64": 1E-5, "float64": 1E-14, "complex128": 1E-14} + self._CheckAgainstDense(f_dense, f_sparse, args_maker, tol=tol) + self._CheckBatchingSparse(f_dense, f_sparse, args_maker, tol=tol) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(f_dense, f_sparse, args_maker, modes=['fwd']) + + @jtu.sample_product(lhs_shape=[(5,), (4, 5)], rhs_shape=[(5,), (5, 4)]) + @jax.default_matmul_precision("float32") + def test_bcoo_spdot_general_nse(self, lhs_shape, rhs_shape): + rng = sptu.rand_bcoo(self.rng()) + dtype = jnp.float32 + lhs = rng(lhs_shape, dtype) + rhs = rng(rhs_shape, dtype) + out = lhs @ rhs + + expected_out = lhs.todense() @ rhs.todense() + expected_nse = min(lhs.nse * rhs.nse, out.size) + + self.assertArraysAllClose(out.todense(), expected_out) + self.assertEqual(out.nse, expected_nse) + + def test_bcoo_spdot_general_ad_bug(self): + # Regression test for https://github.com/google/jax/issues/10163 + A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]]) + A_values = jnp.array([-2.0, 1.0, -1.0, 0.5, 2.0]) + A_shape = (2, 3) + + B_indices = jnp.array([[0, 2], [2, 1], [0, 3], [1, 3], [1, 0], [0, 0]]) + B_values = jnp.array([10.0, 100.0, 1000.0, -5.0, -50.0, -500.0]) + B_shape = (3, 4) + + def sp_sp_product(v1, v2): + A = sparse.BCOO((v1, A_indices), shape=A_shape) + B = sparse.BCOO((v2, B_indices), shape=B_shape) + return (A @ B).todense() + + def sp_de_product(v1, v2): + A = sparse.BCOO((v1, A_indices), shape=A_shape) + B = sparse.BCOO((v2, B_indices), shape=B_shape).todense() + return A @ B + + def de_de_product(v1, v2): + sparse1 = sparse.BCOO((v1, A_indices), shape=A_shape).todense() + dense2 = sparse.BCOO((v2, B_indices), shape=B_shape).todense() + return sparse1 @ dense2 + + sp_sp_jac = jax.jacfwd(sp_sp_product, argnums=1)(A_values, B_values) + sp_de_jac = jax.jacfwd(sp_de_product, argnums=1)(A_values, B_values) + de_de_jac = jax.jacfwd(de_de_product, argnums=1)(A_values, B_values) + + self.assertAllClose(sp_sp_jac, de_de_jac) + self.assertAllClose(sp_de_jac, de_de_jac) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(), (5,), (5, 8), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=jtu.dtypes.numeric, + ) + def test_bcoo_slice(self, shape, dtype, n_batch, n_dense): + rng = self.rng() + sprng = sptu.rand_bcoo(rng, n_batch=n_batch, n_dense=n_dense) + args_maker = lambda: [sprng(shape, dtype)] + + slices = rng.randint(0, np.array(shape) + 1, (2, len(shape))).T + slices.sort(1) + start_indices, limit_indices = unzip2(slices) + strides = list(rng.randint(1, 4, len(shape))) + kwds = dict(start_indices=start_indices, limit_indices=limit_indices, strides=strides) + + dense_func = partial(lax.slice, **kwds) + sparse_func = partial(sparse.bcoo_slice, **kwds) + + self._CheckAgainstDense(dense_func, sparse_func, args_maker) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(dense_func, sparse_func, args_maker) + + mat, = args_maker() + out = sparse_func(mat) + + # Array layout is the same + self.assertEqual(mat.n_batch, out.n_batch) + self.assertEqual(mat.n_sparse, out.n_sparse) + self.assertEqual(mat.n_dense, out.n_dense) + + # Unnecessary padding eliminated + max_nse = math.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse]) + self.assertLessEqual(out.nse, max_nse) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(), (5,), (5, 8), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=jtu.dtypes.numeric, + ) + def test_bcoo_dynamic_slice(self, shape, dtype, n_batch, n_dense): + rng = self.rng() + sprng = sptu.rand_bcoo(rng, n_batch=n_batch, n_dense=n_dense) + args_maker = lambda: [sprng(shape, dtype)] + + rng = self.rng() + # Note: test out-of-range start indices + start_indices = rng.randint(-max(shape, default=0), max(shape, default=0), len(shape)) + slice_sizes = rng.randint(0, shape, len(shape)) + kwds = dict(start_indices=start_indices, slice_sizes=slice_sizes) + dense_func = partial(lax.dynamic_slice, **kwds) + sparse_func = partial(sparse.bcoo_dynamic_slice, **kwds) + + self._CheckAgainstDense(dense_func, sparse_func, args_maker) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(dense_func, sparse_func, args_maker) + + mat, = args_maker() + out = sparse_func(mat) + + # Array layout is the same + self.assertEqual(mat.n_batch, out.n_batch) + self.assertEqual(mat.n_sparse, out.n_sparse) + self.assertEqual(mat.n_dense, out.n_dense) + + # Unnecessary padding eliminated + max_nse = math.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse]) + self.assertLessEqual(out.nse, max_nse) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=n_batch, n_dense=n_dense, idx=idx) + for shape, idx in [ + [(5,), np.index_exp[:]], + [(5,), np.index_exp[4]], + [(5,), np.index_exp[::2]], + [(5,), np.index_exp[1::2]], + [(5,), 1], + [(3, 4), np.index_exp[1]], + [(3, 4), np.index_exp[1, 2]], + [(3, 4), np.index_exp[np.array([1, 2])]], + [(3, 4), np.index_exp[np.array([[1], [2]]), 0]], + [(3, 4), np.index_exp[np.array([[1], [2]]), 1:]], + [(3, 4), np.index_exp[np.array([True, False, True])]], + [(3, 4), np.index_exp[:2, np.array([True, False, True, False])]], + [(3, 4), np.index_exp[None, 0, np.array([[2]])]], + [(3, 4, 5), np.index_exp[2]], + [(3, 4, 5), np.index_exp[:, 2]], + ] + for n_batch in range(len(shape) + 1) + for n_dense in [0] # TODO(jakevdp): add tests with n_dense + ], + dtype=jtu.dtypes.numeric, + ) + def test_bcoo_getitem(self, shape, dtype, n_batch, n_dense, idx): + sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) + args_maker = lambda: [sprng(shape, dtype)] + + fun = lambda x: x[idx] + + self._CheckAgainstDense(fun, fun, args_maker) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(fun, fun, args_maker) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=n_batch, n_dense=n_dense) + for shape in [(2,), (3, 4), (5, 6, 2)] + for n_batch in range(len(shape) + 1) + for n_dense in [0] # TODO(jakevdp): add tests with n_dense + ], + dtype=jtu.dtypes.numeric, + ) + def test_bcoo_iter(self, shape, dtype, n_batch, n_dense): + sprng = sptu.rand_sparse(self.rng()) + args_maker = lambda: [sprng(shape, dtype)] + + self._CheckAgainstDense(list, list, args_maker) + + @jtu.sample_product( + [ + dict( + shape=shape, + n_batch=layout.n_batch, + n_dense=layout.n_dense, + nse=nse, + ) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + for nse in [None, math.prod(shape) - 1] + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + remove_zeros=[True, False], + ) + @jtu.skip_on_flag("jax_skip_slow_tests", True) + def test_bcoo_sum_duplicates(self, shape, dtype, n_batch, n_dense, nse, remove_zeros): + sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) + + def args_maker(): + # Create a matrix with duplicate indices + M = sprng(shape, dtype) + new_indices = jnp.concatenate([M.indices, M.indices], axis=n_batch) + new_data = jnp.concatenate([M.data, M.data], axis=n_batch) + return [sparse.BCOO((new_data, new_indices), shape=M.shape)] + + dense_fun = lambda x: x + def sparse_fun(x): + out = x.sum_duplicates(nse=nse, remove_zeros=remove_zeros) + self.assertTrue(out.unique_indices) + if nse: + self.assertEqual(out.nse, nse) + return out + self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, check_jit=(nse is not None)) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(dense_fun, sparse_fun, args_maker) + if nse is not None: + self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + def test_bcoo_sort_indices(self, shape, dtype, n_batch, n_dense): + rng_sparse = sptu.rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) + M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense) + M.indices = M.indices[..., ::-1, :] + M.indices_sorted = False + + M_sorted = M.sort_indices() + self.assertArraysEqual(M.todense(), M_sorted.todense()) + self.assertEqual(M.unique_indices, M_sorted.unique_indices) + self.assertEqual(True, M_sorted.indices_sorted) + + indices = M_sorted.indices + if indices.size > 0: + flatind = indices.reshape(-1, *indices.shape[-2:]).transpose(0, 2, 1) + sorted = jax.vmap(jnp.lexsort)(flatind[:, ::-1]) + self.assertArraysEqual(sorted, lax.broadcasted_iota(sorted.dtype, sorted.shape, sorted.ndim - 1)) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape, min_n_batch=1) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + def test_bcoo_sort_indices_batching(self, shape, dtype, n_batch, n_dense): + rng_sparse = sptu.rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) + M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense) + M.indices = M.indices[..., ::-1, :] + M.indices_sorted = False + + identity = lambda M: M + sort_ind = lambda M: M.sort_indices() + for b in range(n_batch): + identity = jax.vmap(identity, in_axes=b) + sort_ind = jax.vmap(sort_ind, in_axes=b) + M_sorted = sort_ind(M) + M_expected = identity(M) + self.assertArraysEqual(M_expected.todense(), M_sorted.todense()) + self.assertEqual(M.unique_indices, M_sorted.unique_indices) + self.assertEqual(True, M_sorted.indices_sorted) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=jtu.dtypes.floating, + ) + def test_bcoo_sort_indices_ad(self, shape, dtype, n_batch, n_dense): + rng_sparse = sptu.rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) + M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense) + M.indices = M.indices[..., ::-1, :] + + def sort_indices(data): + return sparse.BCOO((data, M.indices), shape=M.shape).sort_indices().data + + data_dot_fwd = jax.jacfwd(sort_indices)(M.data) + data_dot_rev = jax.jacrev(sort_indices)(M.data) + + self.assertAllClose(data_dot_fwd, data_dot_rev) + + def test_bcoo_sort_indices_broadcasted(self): + rng_index = jtu.rand_int(self.rng(), low=0, high=10) + rng_data = jtu.rand_default(self.rng()) + + # Construct matrix with three broadcasted batch dimensions. + indices = rng_index((1, 3, 1, 10, 2), dtype='int32') + data = rng_data((1, 1, 4, 10, 3), dtype='int32') + shape = (2, 3, 4, 5, 4, 3) + mat = sparse.BCOO((data, indices), shape=shape) + + indices_shape_out = indices.shape + data_shape_out = (*map(max, indices.shape[:3], data.shape[:3]), *data.shape[3:]) + + mat_sorted = sparse.bcoo_sort_indices(mat) + assert mat_sorted.indices.shape == indices_shape_out + assert mat_sorted.data.shape == data_shape_out + self.assertArraysEqual(mat.todense(), mat_sorted.todense()) + + mat_sorted_jit = jit(sparse.bcoo_sort_indices)(mat) + assert mat_sorted_jit.indices.shape == indices_shape_out + assert mat_sorted_jit.data.shape == data_shape_out + self.assertArraysEqual(mat.todense(), mat_sorted_jit.todense()) + + def test_bcoo_sum_duplicates_inferred_nse(self): + x = sparse.BCOO.fromdense(jnp.diag(jnp.arange(4))) + self.assertEqual(x.nse, 3) + y = x + x.T + self.assertEqual(y.nse, 6) + y2 = y.sum_duplicates() + self.assertEqual(y2.nse, 3) + self.assertArraysEqual(y.todense(), y2.todense()) + + def test_bcoo_sum_duplicates_remove_zeros(self): + data = jnp.array([0, 1, 0, 0]) + indices = jnp.array([[0], [1], [2], [3]]) + x = sparse.BCOO((data, indices), shape=(4,)) + self.assertEqual(x.nse, 4) + + y1 = x.sum_duplicates(remove_zeros=True) + self.assertArraysEqual(x.todense(), y1.todense()) + self.assertEqual(y1.nse, 1) + + y2 = x.sum_duplicates(remove_zeros=False) + self.assertArraysEqual(x.todense(), y2.todense()) + self.assertEqual(y2.nse, x.nse) + + def test_bcoo_sum_duplicates_padding(self): + # Regression test for https://github.com/google/jax/issues/8163 + size = 3 + data = jnp.array([1, 0, 0]) + indices = jnp.array([1, size, size])[:, None] + x = sparse.BCOO((data, indices), shape=(3,)) + y = x.sum_duplicates(nse=x.nse) + self.assertArraysEqual(x.todense(), y.todense()) + self.assertArraysEqual(x.indices, y.indices) + self.assertArraysEqual(x.data, y.data) + + @jtu.sample_product( + [ + dict( + shape=shape, + n_batch=layout.n_batch, + n_dense=layout.n_dense, + axes=axes, + ) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + for naxes in range(len(shape)) + for axes in itertools.combinations(range(len(shape)), naxes) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes): + sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) + args_maker = lambda: [sprng(shape, dtype)] + sparse_fun = partial(sparse.bcoo_reduce_sum, axes=axes) + dense_fun = partial(lambda x: x.sum(axes)) + + tol = {np.float64: 1E-14} + self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(dense_fun, sparse_fun, args_maker) + + @jtu.sample_product( + [ + dict( + shape=shape, + dimensions=dimensions, + n_batch=layout.n_batch, + n_dense=layout.n_dense, + ) + for shape, dimensions in [ + [(1,), (0,)], + [(1,), (-1,)], + [(2, 1, 4), (1,)], + [(2, 1, 3, 1), (1,)], + [(2, 1, 3, 1), (1, 3)], + [(2, 1, 3, 1), (3,)], + ] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=jtu.dtypes.numeric, + ) + def test_bcoo_squeeze(self, shape, dtype, dimensions, n_batch, n_dense): + sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) + args_maker = lambda: [sprng(shape, dtype)] + dense_func = partial(lax.squeeze, dimensions=dimensions) + sparse_func = partial(sparse.bcoo_squeeze, dimensions=dimensions) + + self._CheckAgainstDense(dense_func, sparse_func, args_maker) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(dense_func, sparse_func, args_maker) + + @jtu.sample_product( + [ + dict(batch_shapes=shapes, batch_perm=perm) + for shapes in COMPATIBLE_SHAPE_PAIRS + for perm in itertools.permutations(range(len(shapes[0]))) + ], + [ + dict(sparse_shapes=shapes, sparse_perm=perm) + for shapes in COMPATIBLE_SHAPE_PAIRS + for perm in itertools.permutations(range(len(shapes[0]))) + ], + [ + dict(dense_shapes=shapes, dense_perm=perm) + for shapes in [[(), ()]] # TODO(jakevdp) add support for dense shapes + for perm in itertools.permutations(range(len(shapes[0]))) + ], + dtype=jtu.dtypes.numeric, + ) + def test_bcoo_reshape( + self, + batch_shapes, + sparse_shapes, + dense_shapes, + batch_perm, + sparse_perm, + dense_perm, + dtype, + ): + # Sparse reshapes cannot mix between sparse, dense, and batch dimensions. + shape = (*batch_shapes[0], *sparse_shapes[0], *dense_shapes[0]) + new_sizes = (*batch_shapes[1], *sparse_shapes[1], *dense_shapes[1]) + n_batch = len(batch_shapes[0]) + n_sparse = len(sparse_shapes[0]) + n_dense = len(dense_shapes[0]) + dimensions = ( + *batch_perm, + *(dim + n_batch for dim in sparse_perm), + *(dim + n_batch + n_sparse for dim in dense_perm), + ) + + rng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) + args_maker = lambda: [rng(shape, dtype)] + + sparse_func = partial(sparse.bcoo_reshape, new_sizes=new_sizes, dimensions=dimensions) + dense_func = partial(lax.reshape, new_sizes=new_sizes, dimensions=dimensions) + + self._CheckAgainstDense(dense_func, sparse_func, args_maker) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(dense_func, sparse_func, args_maker) + + def test_bcoo_reshape_error(self): + x = sparse.BCOO.fromdense(jnp.ones((2, 2, 3)), n_batch=1) + with self.assertRaisesRegex(ValueError, ".*cannot mix batch and sparse dimensions.*"): + x.reshape(3, 2, 2) + y = sparse.BCOO((x.data[:1], x.indices), shape=x.shape) + with self.assertRaisesRegex(NotImplementedError, "reshape of arrays with broadcasted batch dimensions."): + y.reshape(2, 3, 2) + + @jtu.sample_product( + [ + dict( + shape=shape, + dimensions=dimensions, + n_batch=layout.n_batch, + n_dense=layout.n_dense, + ) + for shape in [(3,), (3, 4), (3, 4, 5)] + for dimensions in sptu.iter_subsets(range(len(shape))) + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=jtu.dtypes.numeric, + ) + def test_bcoo_rev(self, shape, dtype, n_batch, n_dense, dimensions): + sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) + args_maker = lambda: [sprng(shape, dtype)] + dense_func = partial(lax.rev, dimensions=dimensions) + sparse_func = partial(sparse.bcoo_rev, dimensions=dimensions) + + self._CheckAgainstDense(dense_func, sparse_func, args_maker) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(dense_func, sparse_func, args_maker) + + def test_bcsr_matmul_with_out_of_bounds_data(self): + # Simple regression test of a failure mode for cuSparse. + data = jnp.array([1, 2, 3, 4], dtype='float32') + indices = jnp.array([0, 1, 2, 3]) + indptr = jnp.array([0, 1, 3, 3]) + M = sparse.BCSR((data, indices, indptr), shape=(3, 4)) + x = jnp.array([1, 2, 3, 4], dtype='float32') + + sparse_result = jax.jit(operator.matmul)(M, x) + dense_result = jax.jit(operator.matmul)(M.todense(), x) + self.assertAllClose(sparse_result, dense_result) + + @jtu.sample_product( + [ + dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) + for lhs_shape, rhs_shape in [ + [(3, 4), (4,)], + [(3, 4), (4, 5)], + [(3, 4), (2, 4, 5)], + ] + ], + lhs_dtype=all_dtypes, + rhs_dtype=all_dtypes, + ) + @jax.default_matmul_precision("float32") + @jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning) + def test_bcsr_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + # Note: currently, batch dimensions in matmul must correspond to batch + # dimensions in the sparse representation. + n_batch_lhs = max(0, len(lhs_shape) - 2) + + rng = jtu.rand_default(self.rng()) + sprng = sptu.rand_bcsr(self.rng()) + args_maker = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=n_batch_lhs), + jnp.array(rng(rhs_shape, rhs_dtype))] + + tol = {np.float64: 1E-7, np.complex128: 1E-6, + np.float32: 2E-6, np.complex64: 2E-6} + + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker, + tol=tol) + + @jtu.sample_product( + [ + dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) + for lhs_shape, rhs_shape in [ + [(3,), (3,)], + [(3, 4), (4,)], + [(4,), (4, 5)], + [(3, 4), (4, 5)], + [(3, 4), (2, 4, 5)], + [(2, 3, 4), (4, 5)], + [(2, 3, 4), (2, 4, 5)], + ] + ], + lhs_dtype=all_dtypes, + rhs_dtype=all_dtypes, + ) + @jax.default_matmul_precision("float32") + @jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning) + def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + if (jtu.test_device_matches(["cuda"]) and + _is_required_cuda_version_satisfied(12000)): + raise unittest.SkipTest("Triggers a bug in cuda-12 b/287344632") + + # Note: currently, batch dimensions in matmul must correspond to batch + # dimensions in the sparse representation. + n_batch_lhs = max(0, len(lhs_shape) - 2) + n_batch_rhs = max(0, len(rhs_shape) - 2) + + rng = jtu.rand_default(self.rng()) + sprng = sptu.rand_bcoo(self.rng()) + args_maker_de_sp = lambda: [jnp.array(rng(lhs_shape, lhs_dtype)), + sprng(rhs_shape, rhs_dtype, n_batch=n_batch_rhs)] + args_maker_sp_de = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=n_batch_lhs), + jnp.array(rng(rhs_shape, rhs_dtype))] + + tol = {np.float64: 1E-7, np.complex128: 1E-7, + np.float32: 1E-6, np.complex64: 1E-6} + + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker_de_sp, tol=tol) + self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker_sp_de, tol=tol) + + @jtu.sample_product( + [ + dict( + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + n_batch=layout.n_batch, + n_dense=layout.n_dense, + ) + for lhs_shape, rhs_shape in [ + [(3,), ()], + [(3,), (1,)], + [(3,), (3,)], + [(3, 4), ()], + [(3, 4), (4,)], + [(3, 4), (3, 1)], + [(3, 4), (3, 4)], + [(3, 4, 5), (4, 5)], + [(3, 4, 5), (3, 1, 1)], + [(3, 4, 5), (1, 4, 1)], + ] + for layout in sptu.iter_sparse_layouts(lhs_shape) + ], + lhs_dtype=all_dtypes, + rhs_dtype=all_dtypes, + ) + @jax.numpy_rank_promotion( + "allow" + ) # This test explicitly exercises implicit rank promotion. + def test_bcoo_mul_dense(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense): + rng = jtu.rand_default(self.rng()) + sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) + + args_maker_sp_de = lambda: [sprng(lhs_shape, lhs_dtype), jnp.array(rng(rhs_shape, rhs_dtype))] + args_maker_de_sp = lambda: [jnp.array(rng(rhs_shape, rhs_dtype)), sprng(lhs_shape, lhs_dtype)] + + tol = {np.float64: 1E-13, np.complex128: 1E-13, + np.float32: 1E-6, np.complex64: 1E-6} + + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstDense(operator.mul, operator.mul, args_maker_de_sp, tol=tol) + self._CheckAgainstDense(operator.mul, operator.mul, args_maker_sp_de, tol=tol) + + @jtu.sample_product( + [ + dict( + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + lhs_n_batch=lhs_n_batch, + rhs_n_batch=rhs_n_batch, + n_dense=n_dense, + ) + # TODO(jakevdp): add broadcasted shapes (from bcoo_mul_dense) once sparse-sparse mul + # supports inputs of differing rank. + for lhs_shape, rhs_shape in [ + [(3,), (1,)], + [(3,), (3,)], + [(3, 4), (1, 1)], + [(3, 4), (1, 4)], + [(3, 4), (3, 1)], + [(3, 4), (3, 4)], + [(3, 4, 5), (1, 4, 5)], + [(3, 4, 5), (3, 1, 1)], + [(3, 4, 5), (1, 4, 1)], + ] + # TODO(jakevdp): add tests for batch & dense dimensions. + for lhs_n_batch in range(len(lhs_shape) + 1) + for rhs_n_batch in range(len(lhs_shape) + 1) + for n_dense in range( + len(lhs_shape) + 1 - max(lhs_n_batch, rhs_n_batch) + ) + ], + lhs_dtype=all_dtypes, + rhs_dtype=all_dtypes, + ) + def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, lhs_n_batch, rhs_n_batch, n_dense): + sprng = sptu.rand_bcoo(self.rng(), n_dense=n_dense) + args_maker = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=lhs_n_batch), + sprng(rhs_shape, rhs_dtype, n_batch=rhs_n_batch)] + + tol = {np.float64: 1E-13, np.complex128: 1E-13, + np.float32: 1E-5, np.complex64: 1E-5} + + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstDense(operator.mul, operator.mul, args_maker, tol=tol) + + def test_bcoo_mul_sparse_with_duplicates(self): + # Regression test for https://github.com/google/jax/issues/8888 + indices = jnp.array([[0, 1, 0, 0, 1, 1], + [1, 0, 1, 2, 0, 2]]).T + data = jnp.array([1, 2, 3, 4, 5, 6]) + mat = sparse.BCOO((data, indices), shape=(3, 3)) + self.assertArraysEqual((mat * mat).todense(), mat.todense() * mat.todense()) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(), (3,), (3, 5), (3, 5, 4)] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=all_dtypes, + ) + def test_bcoo_broadcast_in_dim(self, shape, dtype, n_batch, n_dense): + rng = sptu.rand_sparse(self.rng()) + x = jnp.array(rng(shape, dtype)) + xsp = sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense) + + self.assertEqual(xsp[None].n_batch, xsp.n_batch + 1) + self.assertArraysEqual(xsp[None].todense(), x[None]) + + if len(shape) >= 1: + self.assertEqual(xsp[:, None].n_batch, xsp.n_batch if xsp.n_batch < 1 else xsp.n_batch + 1) + self.assertArraysEqual(xsp[:, None].todense(), x[:, None]) + self.assertArraysEqual(xsp[:, None, None].todense(), x[:, None, None]) + if len(shape) >= 2: + self.assertEqual(xsp[:, :, None].n_batch, xsp.n_batch if xsp.n_batch < 2 else xsp.n_batch + 1) + self.assertArraysEqual(xsp[:, :, None].todense(), x[:, :, None]) + self.assertArraysEqual(xsp[:, None, :, None].todense(), x[:, None, :, None]) + + @jtu.sample_product( + [ + dict( + shape=shape, + n_batch=layout.n_batch, + n_dense=layout.n_dense, + dimension=dimension, + ) + for shape in [(3,), (3, 5), (3, 5, 4)] + for layout in sptu.iter_sparse_layouts(shape) + for dimension in range( + len(shape) - layout.n_dense + ) # Concatenation of dense dimensions not implemented. + ], + dtype=all_dtypes, + ) + def test_bcoo_concatenate(self, shape, dtype, n_batch, n_dense, dimension): + sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) + args_maker = lambda: [[sprng(shape, dtype) for i in range(3)]] + dense_func = partial(lax.concatenate, dimension=dimension) + sparse_func = partial(sparse.bcoo_concatenate, dimension=dimension) + + self._CheckAgainstDense(dense_func, sparse_func, args_maker) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(dense_func, sparse_func, args_maker) + + @jtu.sample_product( + lhs_shape=[(1, 1, 5), (1, 1, 10)], + rhs_shape=[(1, 1, 5), (1, 1, 10)], + padding=["SAME", "VALID", [(3, 3)]], + dtype=jtu.dtypes.inexact, + format=["sp-de", "de-sp", "sp-sp"], + ) + @jax.default_matmul_precision("float32") + @jtu.skip_on_flag("jax_skip_slow_tests", True) + def test_bcoo_conv_general_dilated(self, lhs_shape, rhs_shape, dtype, padding, format): + kwds = dict(window_strides=(1,), padding=padding) + sparse_fun = partial(sparse.bcoo_conv_general_dilated, **kwds) + dense_fun = partial(lax.conv_general_dilated, **kwds) + sprng = sptu.rand_bcoo(self.rng(), n_batch=2, n_dense=0) + rng = jtu.rand_default(self.rng()) + + def args_maker(): + lhs = (sprng if format.startswith('sp') else rng)(lhs_shape, dtype) + rhs = (sprng if format.endswith('sp') else rng)(rhs_shape, dtype) + return lhs, rhs + + tol = {np.float32: 1E-5, np.complex64: 1E-5, np.float64: 1E-14, np.complex128: 1E-14} + self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol) + + def test_bcoo_vmap_shape(self, shape=(2, 3, 4, 5), dtype=np.float32): + # This test checks that BCOO shape metadata interacts correctly with vmap. + rng = sptu.rand_sparse(self.rng()) + M = rng(shape, dtype) + + def make_bcoo(M): + return sparse_bcoo._bcoo_fromdense(M, nse=math.prod(M.shape[:-1]), n_dense=1) + + todense = partial(sparse_bcoo._bcoo_todense, spinfo=sparse_util.SparseInfo(shape)) + + for _ in range(3): + make_bcoo = jax.vmap(make_bcoo) + Msp_data, Msp_indices = make_bcoo(M) + Msp_dense = todense(Msp_data, Msp_indices) + self.assertEqual(Msp_dense.shape, M.shape) + self.assertArraysEqual(Msp_dense, M) + + @jtu.sample_product( + [ + dict( + shape=shape, + n_batch=layout.n_batch, + n_dense=layout.n_dense, + n_batch_out=layout_out.n_batch, + n_dense_out=layout_out.n_dense, + ) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + for layout_out in sptu.iter_sparse_layouts(shape) + ], + dtype=jtu.dtypes.integer, + ) + def test_bcoo_update_layout(self, shape, dtype, n_batch, n_batch_out, n_dense, n_dense_out): + rng = sptu.rand_sparse(self.rng()) + mat = sparse.BCOO.fromdense(rng(shape, dtype), n_batch=n_batch, n_dense=n_dense) + kwds = dict(n_batch=n_batch_out, n_dense=n_dense_out) + # TODO(jakevdp): in case of length-0 or length-1 shapes errors/warnings will not be raised. + if n_dense_out > n_dense or n_batch_out > n_batch: + with self.assertRaises(sparse.SparseEfficiencyError): + sparse.bcoo_update_layout(mat, **kwds) + with self.assertRaises(sparse.SparseEfficiencyError): + sparse.bcoo_update_layout(mat, **kwds, on_inefficient='error') + with self.assertWarns(sparse.SparseEfficiencyWarning): + sparse.bcoo_update_layout(mat, **kwds, on_inefficient='warn') + kwds['on_inefficient'] = None + mat_new = sparse.bcoo_update_layout(mat, **kwds) + self.assertEqual(mat_new.n_batch, n_batch_out) + self.assertEqual(mat_new.n_dense, n_dense_out) + self.assertArraysEqual(mat.todense(), mat_new.todense()) + + def test_bcoo_update_layout_method(self, shape=(2, 3, 4)): + # simple test to make sure update_layout method properly forwards. + rng = sptu.rand_sparse(self.rng()) + mat = sparse.BCOO.fromdense(rng((2, 3, 4), 'float32'), n_batch=1, n_dense=1) + mat_new = mat.update_layout(n_batch=0, n_dense=0) + self.assertEqual(mat_new.n_batch, 0) + self.assertEqual(mat_new.n_dense, 0) + self.assertArraysEqual(mat.todense(), mat_new.todense()) + + def test_bcoo_bad_fillvals(self): + # Extra values have 100 rather than zero. This lets us check that logic is + # properly ignoring these indices. + data = jnp.array([1, 2, 3, 100, 100]) + indices = jnp.array([1, 2, 3, 5, 5])[:, None] + x_sp = sparse.BCOO((data, indices), shape=(5,)) + x_de = x_sp.todense() + + data = jnp.array([3, 2, 100, 100]) + indices = jnp.array([2, 3, 5, 5])[:, None] + y_sp = sparse.BCOO((data, indices), shape=(5,)) + y_de = y_sp.todense() + + self.assertArraysEqual(x_de, jnp.array([0, 1, 2, 3, 0])) + self.assertArraysEqual(y_de, jnp.array([0, 0, 3, 2, 0])) + + self.assertArraysEqual(x_sp.sum_duplicates().todense(), x_de) + self.assertArraysEqual(y_sp.sum_duplicates().todense(), y_de) + + # reduce_sum: + self.assertArraysEqual(x_sp.sum(), x_de.sum()) + + # bcoo_dot_general + self.assertArraysEqual(x_sp @ y_de, x_de @ y_de) + + # bcoo_rdot_general + self.assertArraysEqual(x_de @ y_sp, x_de @ y_de) + + # bcoo_spdot_general + self.assertArraysEqual((x_sp @ y_sp).todense(), x_de @ y_de) + self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de) + + +# TODO(tianjianlu): Unify the testing for BCOOTest and BCSRTest. +class BCSRTest(sptu.SparseTestCase): + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_bcsr_layouts(shape) + ], + dtype=all_dtypes, + ) + def test_bcsr_dense_round_trip(self, shape, dtype, n_batch, n_dense): + n_sparse = len(shape) - n_batch - n_dense + rng = sptu.rand_sparse(self.rng()) + M = rng(shape, dtype) + nse = sparse.util._count_stored_elements(M, n_batch=n_batch, n_dense=n_dense) + def round_trip(M): + return sparse.BCSR.fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense).todense() + args_maker = lambda: [M] + ident = lambda x: x + + self._CheckAgainstNumpy(ident, round_trip, args_maker) + self._CompileAndCheck(round_trip, args_maker) + self._CheckBatchingSparse(ident, round_trip, args_maker, bdims=self._random_bdims(n_batch)) + if jnp.issubdtype(dtype, jnp.floating): + # For n_sparse != 0, we can't use an identity because output zeros must not + # be dependent on input zeros. This mimics the code in count_stored_elements(). + def expected(M): + if n_sparse == 0: return M + mask = (M != 0).any(range(M.ndim - n_dense, M.ndim), keepdims=True) + return jnp.where(mask, M, 0) + self._CheckGradsSparse(expected, round_trip, args_maker) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=n_batch) + for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for n_batch in range(len(shape) - 1) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + def test_bcsr_bcoo_round_trip(self, shape, n_batch, dtype): + n_sparse = 2 + n_dense = len(shape) - n_sparse - n_batch + rng = self.rng() + sprng = sptu.rand_bcsr(rng, n_batch=n_batch, n_dense=n_dense) + + M_bcsr = sprng(shape, dtype) + self.assertIsInstance(M_bcsr, sparse.BCSR) + + M_dense = M_bcsr.todense() + M_bcoo = M_bcsr.to_bcoo() + self.assertIsInstance(M_bcoo, sparse.BCOO) + self.assertAllClose(M_dense, M_bcoo.todense()) + + M_bcsr2 = sparse.BCSR.from_bcoo(M_bcoo) + self.assertAllClose(M_dense, M_bcsr2.todense()) + self.assertArraysEqual(M_bcsr.indptr, M_bcsr2.indptr) + + # TODO(jakevdp): This will only be true in general when M_bcsr.indices is sorted. + # self.assertSparseArraysEquivalent(M_bcsr, M_bcsr2) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=n_batch) + for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for n_batch in range(len(shape) - 1) + ], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + def test_bcsr_extract(self, shape, dtype, n_batch): + n_dense = len(shape) - n_batch - 2 + rng = sptu.rand_sparse(self.rng()) + M = rng(shape, dtype) + nse = sparse.util._count_stored_elements(M, n_batch=n_batch, + n_dense=n_dense) + data, indices, indptr = sparse_bcsr._bcsr_fromdense( + M, nse=nse, n_batch=n_batch, n_dense=n_dense) + data2 = sparse.bcsr_extract(indices, indptr, M) + self.assertArraysEqual(data, data2) + args_maker_bcsr_extract = lambda: [indices, indptr, M] + self._CompileAndCheck(sparse.bcsr_extract, args_maker_bcsr_extract) + + @jtu.sample_product( + props=_generate_batched_dot_general_properties( + shapes=((2, 3), (2, 3, 4), (2, 3, 4, 4)), sparse_format="bcsr" + ), + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + ) + @jax.default_matmul_precision("float32") + def test_bcsr_dot_general( + self, dtype: np.dtype, props: sptu.BatchedDotGeneralProperties + ): + rng = jtu.rand_default(self.rng()) + sprng = sptu.rand_bcsr(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense) + args_maker = lambda: [sprng(props.lhs_shape, dtype), + rng(props.rhs_shape, dtype)] + dense_fun = partial(lax.dot_general, + dimension_numbers=props.dimension_numbers) + sparse_fun = partial(sparse.bcsr_dot_general, + dimension_numbers=props.dimension_numbers) + + tol = {np.float64: 1E-12, np.complex128: 1E-12, + np.float32: 1E-5, np.complex64: 1E-5} + + self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol) + if jnp.issubdtype(dtype, jnp.floating) and props.n_dense == 0: + # Dense dimensions not yet fully supported in reverse mode. + modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev'] + self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol) + self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol, + bdims=self._random_bdims(props.n_batch, len(props.rhs_shape))) + + @jtu.sample_product( + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(3, 5), (3, 5, 4)] + for layout in sptu.iter_bcsr_layouts(shape) + ], + dtype=all_dtypes, + ) + def test_bcsr_broadcast_in_dim(self, shape, dtype, n_batch, n_dense): + rng = sptu.rand_sparse(self.rng()) + x = jnp.array(rng(shape, dtype)) + xsp = sparse.BCSR.fromdense(x, n_batch=n_batch, n_dense=n_dense) + + self.assertEqual(xsp[None].n_batch, xsp.n_batch + 1) + self.assertArraysEqual(xsp[None].todense(), x[None]) + + if n_batch == 1: + self.assertEqual(xsp[:, None].n_batch, xsp.n_batch + 1) + self.assertArraysEqual(xsp[:, None].todense(), x[:, None]) + + @jtu.sample_product( + [ + dict( + shape=shape, + n_batch=layout.n_batch, + n_dense=layout.n_dense, + dimension=dimension, + ) + for shape in [(3, 5), (3, 5, 4)] + for layout in sptu.iter_sparse_layouts(shape) + for dimension in range( + len(shape) - layout.n_dense + ) # Concatenation of dense dimensions not implemented. + ], + dtype=all_dtypes, + ) + def test_bcsr_concatenate(self, shape, dtype, n_batch, n_dense, dimension): + sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) + args_maker = lambda: [[sprng(shape, dtype) for i in range(3)]] + dense_func = partial(lax.concatenate, dimension=dimension) + sparse_func = partial(sparse.bcoo_concatenate, dimension=dimension) + + self._CheckAgainstDense(dense_func, sparse_func, args_maker) + if jnp.issubdtype(dtype, jnp.floating): + self._CheckGradsSparse(dense_func, sparse_func, args_maker) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 17c55a7df..6498fe492 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -12,15 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable, Iterator, Sequence import contextlib from functools import partial import itertools import math -import operator -import random import unittest -from typing import NamedTuple from absl.testing import absltest from absl.testing import parameterized @@ -37,15 +33,12 @@ from jax.experimental.sparse import bcsr as sparse_bcsr from jax.experimental.sparse import util as sparse_util from jax.experimental.sparse import test_util as sptu from jax.experimental.sparse import _lowerings -from jax import lax from jax._src import xla_bridge from jax._src.lib import gpu_sparse -from jax._src.util import unzip2 from jax import jit from jax import tree_util from jax import vmap from jax._src import test_util as jtu -from jax._src.lax.lax import remaining, DotDimensionNumbers from jax.interpreters import mlir import jax.numpy as jnp from jax.util import split_list @@ -55,146 +48,8 @@ import scipy.sparse config.parse_flags_with_absl() FLAGS = config.FLAGS -MATMUL_TOL = { - np.float32: 1E-5, - np.float64: 1E-10, - np.complex64: 1e-5, - np.complex128: 1E-10, -} - -GPU_LOWERING_ENABLED = gpu_sparse and (gpu_sparse.cuda_is_supported or - gpu_sparse.rocm_is_supported) - -COMPATIBLE_SHAPE_PAIRS = [ - [(), ()], - [(), (1,)], - [(3,), (1, 3)], - [(3, 1), (3,)], - [(6,), (2, 3)], - [(3, 2), (6,)], - [(2, 3), (1, 6)], - [(2, 4), (4, 1, 2)], - [(3, 4, 5), (2, 6, 5)], - [(2,), (2,)] -] - - -class BatchedDotGeneralProperties(NamedTuple): - lhs_shape: tuple[int, ...] - rhs_shape: tuple[int, ...] - n_batch: int - n_dense: int - dimension_numbers: DotDimensionNumbers - - -def _iter_subsets(s: Sequence) -> Iterable[tuple]: - """Return an iterator over all subsets of a sequence s""" - return itertools.chain.from_iterable(itertools.combinations(s, n) for n in range(len(s) + 1)) - - -class SparseLayout(NamedTuple): - n_batch: int - n_dense: int - n_sparse: int - - -def iter_sparse_layouts(shape: Sequence[int], min_n_batch=0) -> Iterator[SparseLayout]: - for n_batch in range(min_n_batch, len(shape) + 1): - for n_dense in range(len(shape) + 1 - n_batch): - n_sparse = len(shape) - n_batch - n_dense - yield SparseLayout(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense) - -def iter_bcsr_layouts(shape: Sequence[int], min_n_batch=0) -> Iterator[SparseLayout]: - n_sparse = 2 - for n_batch in range(min_n_batch, len(shape) - 1): - n_dense = len(shape) - n_sparse - n_batch - yield SparseLayout(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense) - - -def _generate_batched_dot_general_properties( - shapes=((5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)), - sparse_format='bcoo') -> BatchedDotGeneralProperties: - """Generator of properties for bcoo_dot_general tests.""" - rng = random.Random(0) - - if sparse_format not in ['bcoo', 'bcsr']: - raise ValueError(f"Sparse format {sparse_format} not supported.") - - for shape in shapes: - for layout in iter_sparse_layouts(shape): - if sparse_format == "bcsr" and layout.n_sparse != 2: - continue - subsets = split_list(range(len(shape)), [layout.n_batch, layout.n_sparse]) - for batch_dims in _iter_subsets(range(layout.n_batch)): - for contracting_dims in _iter_subsets(remaining(range(layout.n_batch + layout.n_sparse), batch_dims)): - # We want coverage of permutations without generating hundreds of thousands of test cases; - # we do this by deterministic pseudo-random sampling instead of iterating. - rhs_permute = rng.sample(range(len(shape)), len(shape)) - lhs_permute = list(itertools.chain.from_iterable( - rng.sample(subset, len(subset)) for subset in subsets)) - yield BatchedDotGeneralProperties( - lhs_shape=tuple(shape[p] for p in lhs_permute), - rhs_shape=tuple(shape[p] for p in rhs_permute), - n_batch=layout.n_batch, - n_dense=layout.n_dense, - dimension_numbers=( - ([lhs_permute.index(d) for d in contracting_dims], [rhs_permute.index(d) for d in contracting_dims]), - ([lhs_permute.index(d) for d in batch_dims], [rhs_permute.index(d) for d in batch_dims]) - ), - ) - - -def _generate_bcoo_dot_general_sampled_properties(shapes=((5,), (2, 3), (2, 3, 4), (2, 3, 4, 4))) -> BatchedDotGeneralProperties: - """Generator of properties for bcoo_dot_general_sampled tests.""" - rng = random.Random(0) - - for shape in shapes: - for batch_dims in _iter_subsets(range(len(shape))): - for contracting_dims in _iter_subsets(remaining(range(len(shape)), batch_dims)): - # We want coverage of permutations without generating hundreds of thousands of test cases; - # we do this by deterministic pseudo-random sampling instead of iterating. - lhs_permute = rng.sample(range(len(shape)), len(shape)) - rhs_permute = rng.sample(range(len(shape)), len(shape)) - lhs_shape = tuple(shape[p] for p in lhs_permute) - rhs_shape = tuple(shape[p] for p in rhs_permute) - dimension_numbers = ( - ([lhs_permute.index(d) for d in contracting_dims], [rhs_permute.index(d) for d in contracting_dims]), - ([lhs_permute.index(d) for d in batch_dims], [rhs_permute.index(d) for d in batch_dims]) - ) - out = jax.eval_shape(partial(lax.dot_general, dimension_numbers=dimension_numbers), - jax.ShapeDtypeStruct(lhs_shape, 'float32'), jax.ShapeDtypeStruct(rhs_shape, 'float32')) - for layout in iter_sparse_layouts(out.shape): - yield BatchedDotGeneralProperties( - lhs_shape=lhs_shape, rhs_shape=rhs_shape, - n_batch=layout.n_batch, n_dense=layout.n_dense, - dimension_numbers=dimension_numbers) - - all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex - -def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default): - def _rand_sparse(shape, dtype, nse=nse): - rand = rand_method(rng) - size = math.prod(shape) - if 0 <= nse < 1: - nse = nse * size - nse = min(size, int(nse)) - M = rand(shape, dtype) - indices = rng.choice(size, size - nse, replace=False) - M.flat[indices] = 0 - return post(M) - return _rand_sparse - - -def _is_required_cuda_version_satisfied(cuda_version): - version = xla_bridge.get_backend().platform_version - if version == "" or version.split()[0] == "rocm": - return False - else: - return int(version.split()[-1]) >= cuda_version - - class cuSparseTest(sptu.SparseTestCase): def gpu_dense_conversion_warning_context(self, dtype): if jtu.test_device_matches(["gpu"]) and np.issubdtype(dtype, np.integer): @@ -207,11 +62,11 @@ class cuSparseTest(sptu.SparseTestCase): return contextlib.nullcontext() @jtu.sample_product( - shape=[(5, 8), (8, 5), (5, 5), (8, 8)], - dtype=all_dtypes, + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, ) def test_csr_todense(self, shape, dtype): - rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix) + rng = sptu.rand_sparse(self.rng(), post=scipy.sparse.csr_matrix) M = rng(shape, dtype) args = (M.data, M.indices, M.indptr) @@ -226,7 +81,7 @@ class cuSparseTest(sptu.SparseTestCase): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) def test_csr_todense_ad(self, shape, dtype): - rng = rand_sparse(self.rng(), post=jnp.array) + rng = sptu.rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) data, indices, indptr = sparse_csr._csr_fromdense(M, nse=(M != 0).sum()) row, col = sparse_util._csr_to_coo(indices, indptr) @@ -248,7 +103,7 @@ class cuSparseTest(sptu.SparseTestCase): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) def test_csr_fromdense_ad(self, shape, dtype): - rng = rand_sparse(self.rng(), post=jnp.array) + rng = sptu.rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) nse = (M != 0).sum() f = lambda M: sparse_csr._csr_fromdense(M, nse=nse) @@ -283,7 +138,7 @@ class cuSparseTest(sptu.SparseTestCase): tol = {np.float32: 2E-5, np.float64: 1E-12, np.complex64: 1E-5, np.complex128: 1E-12} - rng = rand_sparse(self.rng(), post=jnp.array) + rng = sptu.rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) M = rng(shape, dtype) @@ -327,11 +182,11 @@ class cuSparseTest(sptu.SparseTestCase): self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol) @jtu.sample_product( - shape=[(5, 8), (8, 5), (5, 5), (8, 8)], - dtype=all_dtypes, + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, ) def test_csr_fromdense(self, shape, dtype): - rng = rand_sparse(self.rng()) + rng = sptu.rand_sparse(self.rng()) M = rng(shape, dtype) M_csr = scipy.sparse.csr_matrix(M) @@ -360,7 +215,7 @@ class cuSparseTest(sptu.SparseTestCase): op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) - rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix) + rng = sptu.rand_sparse(self.rng(), post=scipy.sparse.csr_matrix) M = rng(shape, dtype) v = v_rng(op(M).shape[1], dtype) @@ -368,19 +223,19 @@ class cuSparseTest(sptu.SparseTestCase): matvec = lambda *args: sparse_csr._csr_matvec(*args, shape=M.shape, transpose=transpose) with self.gpu_matmul_dtype_warning_context(dtype): - self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL) - self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL) + self.assertAllClose(op(M) @ v, matvec(*args), rtol=sptu.MATMUL_TOL) + self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=sptu.MATMUL_TOL) @jtu.sample_product( - shape=[(5, 8), (8, 5), (5, 5), (8, 8)], - dtype=all_dtypes, - transpose=[True, False], + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, + transpose=[True, False], ) def test_csr_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) - rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix) + rng = sptu.rand_sparse(self.rng(), post=scipy.sparse.csr_matrix) M = rng(shape, dtype) B = B_rng((op(M).shape[1], 4), dtype) @@ -388,15 +243,15 @@ class cuSparseTest(sptu.SparseTestCase): matmat = lambda *args: sparse_csr._csr_matmat(*args, shape=shape, transpose=transpose) with self.gpu_matmul_dtype_warning_context(dtype): - self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL) - self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL) + self.assertAllClose(op(M) @ B, matmat(*args), rtol=sptu.MATMUL_TOL) + self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=sptu.MATMUL_TOL) @jtu.sample_product( - shape=[(5, 8), (8, 5), (5, 5), (8, 8)], - dtype=all_dtypes, + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, ) def test_coo_todense(self, shape, dtype): - rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix) + rng = sptu.rand_sparse(self.rng(), post=scipy.sparse.coo_matrix) M = rng(shape, dtype) args = (M.data, M.row, M.col) @@ -407,11 +262,11 @@ class cuSparseTest(sptu.SparseTestCase): self.assertArraysEqual(M.toarray(), jit(todense)(*args)) @jtu.sample_product( - shape=[(5, 8), (8, 5), (5, 5), (8, 8)], - dtype=all_dtypes, + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, ) def test_coo_fromdense(self, shape, dtype): - rng = rand_sparse(self.rng()) + rng = sptu.rand_sparse(self.rng()) M = rng(shape, dtype) M_coo = scipy.sparse.coo_matrix(M) @@ -432,15 +287,15 @@ class cuSparseTest(sptu.SparseTestCase): self.assertArraysEqual(col, M_coo.col.astype(index_dtype)) @jtu.sample_product( - shape=[(5, 8), (8, 5), (5, 5), (8, 8)], - dtype=all_dtypes, - transpose=[True, False], + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, + transpose=[True, False], ) def test_coo_matvec(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) - rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix) + rng = sptu.rand_sparse(self.rng(), post=scipy.sparse.coo_matrix) M = rng(shape, dtype) v = v_rng(op(M).shape[1], dtype) @@ -448,19 +303,19 @@ class cuSparseTest(sptu.SparseTestCase): matvec = lambda *args: sparse_coo._coo_matvec(*args, spinfo=sparse_coo.COOInfo(shape=M.shape, rows_sorted=True), transpose=transpose) with self.gpu_matmul_dtype_warning_context(dtype): - self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL) - self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL) + self.assertAllClose(op(M) @ v, matvec(*args), rtol=sptu.MATMUL_TOL) + self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=sptu.MATMUL_TOL) @jtu.sample_product( - shape=[(5, 8), (8, 5), (5, 5), (8, 8)], - dtype=all_dtypes, - transpose=[True, False], + shape=[(5, 8), (8, 5), (5, 5), (8, 8)], + dtype=all_dtypes, + transpose=[True, False], ) def test_coo_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) - rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix) + rng = sptu.rand_sparse(self.rng(), post=scipy.sparse.coo_matrix) M = rng(shape, dtype) B = B_rng((op(M).shape[1], 4), dtype) @@ -468,8 +323,8 @@ class cuSparseTest(sptu.SparseTestCase): matmat = lambda *args: sparse_coo._coo_matmat(*args, spinfo=sparse_coo.COOInfo(shape=shape, rows_sorted=True), transpose=transpose) with self.gpu_matmul_dtype_warning_context(dtype): - self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL) - self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL) + self.assertAllClose(op(M) @ B, matmat(*args), rtol=sptu.MATMUL_TOL) + self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=sptu.MATMUL_TOL) def test_coo_matmat_layout(self): # Regression test for https://github.com/google/jax/issues/7533 @@ -490,7 +345,7 @@ class cuSparseTest(sptu.SparseTestCase): def test_coo_sorted_indices(self): rng = self.rng() - sprng = rand_sparse(rng) + sprng = sptu.rand_sparse(rng) mat = sparse.COO.fromdense(sprng((5, 6), np.float32)) perm = rng.permutation(mat.nse) @@ -498,7 +353,9 @@ class cuSparseTest(sptu.SparseTestCase): mat_resorted = mat_unsorted._sort_indices() self.assertArraysEqual(mat.todense(), mat_resorted.todense()) - @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") + @unittest.skipIf( + not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" + ) @jtu.run_on_devices("gpu") def test_coo_sorted_indices_gpu_lowerings(self): dtype = jnp.float32 @@ -585,7 +442,7 @@ class cuSparseTest(sptu.SparseTestCase): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) def test_coo_todense_ad(self, shape, dtype): - rng = rand_sparse(self.rng(), post=jnp.array) + rng = sptu.rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) data, row, col = sparse_coo._coo_fromdense(M, nse=(M != 0).sum()) f = lambda data: sparse_coo._coo_todense(data, row, col, spinfo=sparse_coo.COOInfo(shape=M.shape, rows_sorted=True)) @@ -606,7 +463,7 @@ class cuSparseTest(sptu.SparseTestCase): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) def test_coo_fromdense_ad(self, shape, dtype): - rng = rand_sparse(self.rng(), post=jnp.array) + rng = sptu.rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) nse = (M != 0).sum() f = lambda M: sparse_coo._coo_fromdense(M, nse=nse) @@ -640,7 +497,7 @@ class cuSparseTest(sptu.SparseTestCase): coo_matmul = sparse_coo._coo_matvec if len(bshape) == 1 else sparse_coo._coo_matmat tol = {np.float32: 1E-5, np.float64: 1E-12, np.complex64: 1E-5, np.complex128: 1E-12} - rng = rand_sparse(self.rng(), post=jnp.array) + rng = sptu.rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) M = rng(shape, dtype) @@ -685,13 +542,15 @@ class cuSparseTest(sptu.SparseTestCase): self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol) @jtu.sample_product( - shape=[(4, 5), (3, 4), (5, 4)], - dtype=_lowerings.SUPPORTED_DATA_DTYPES, - transpose=[True, False], + shape=[(4, 5), (3, 4), (5, 4)], + dtype=_lowerings.SUPPORTED_DATA_DTYPES, + transpose=[True, False], + ) + @unittest.skipIf( + not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" ) - @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") def test_coo_spmv(self, shape, dtype, transpose): - rng_sparse = rand_sparse(self.rng()) + rng_sparse = sptu.rand_sparse(self.rng()) rng_dense = jtu.rand_default(self.rng()) mat = rng_sparse(shape, dtype) @@ -708,13 +567,15 @@ class cuSparseTest(sptu.SparseTestCase): self.assertArraysAllClose(actual, expected) @jtu.sample_product( - shape=[(4, 5), (3, 4), (5, 4)], - dtype=_lowerings.SUPPORTED_DATA_DTYPES, - transpose=[True, False], + shape=[(4, 5), (3, 4), (5, 4)], + dtype=_lowerings.SUPPORTED_DATA_DTYPES, + transpose=[True, False], + ) + @unittest.skipIf( + not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" ) - @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") def test_coo_spmm(self, shape, dtype, transpose): - rng_sparse = rand_sparse(self.rng()) + rng_sparse = sptu.rand_sparse(self.rng()) rng_dense = jtu.rand_default(self.rng()) mat = rng_sparse(shape, dtype) @@ -731,13 +592,15 @@ class cuSparseTest(sptu.SparseTestCase): self.assertArraysAllClose(actual, expected) @jtu.sample_product( - shape=[(4, 5), (3, 4), (5, 4)], - dtype=_lowerings.SUPPORTED_DATA_DTYPES, - transpose=[True, False], + shape=[(4, 5), (3, 4), (5, 4)], + dtype=_lowerings.SUPPORTED_DATA_DTYPES, + transpose=[True, False], + ) + @unittest.skipIf( + not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" ) - @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") def test_csr_spmv(self, shape, dtype, transpose): - rng_sparse = rand_sparse(self.rng()) + rng_sparse = sptu.rand_sparse(self.rng()) rng_dense = jtu.rand_default(self.rng()) mat = rng_sparse(shape, dtype) @@ -752,13 +615,15 @@ class cuSparseTest(sptu.SparseTestCase): self.assertArraysAllClose(actual, expected) @jtu.sample_product( - shape=[(4, 5), (3, 4), (5, 4)], - dtype=_lowerings.SUPPORTED_DATA_DTYPES, - transpose=[True, False], + shape=[(4, 5), (3, 4), (5, 4)], + dtype=_lowerings.SUPPORTED_DATA_DTYPES, + transpose=[True, False], + ) + @unittest.skipIf( + not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse" ) - @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") def test_csr_spmm(self, shape, dtype, transpose): - rng_sparse = rand_sparse(self.rng()) + rng_sparse = sptu.rand_sparse(self.rng()) rng_dense = jtu.rand_default(self.rng()) mat = rng_sparse(shape, dtype) @@ -772,1574 +637,10 @@ class cuSparseTest(sptu.SparseTestCase): shape=mat.shape) self.assertArraysAllClose(actual, expected) - -class BCOOTest(sptu.SparseTestCase): - - def gpu_matmul_warning_context(self, msg): - if GPU_LOWERING_ENABLED and config.jax_bcoo_cusparse_lowering: - return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg) - return contextlib.nullcontext() - - def test_repr(self): - x = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32')) - self.assertEqual(repr(x), "BCOO(float32[5], nse=4)") - - y = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3), n_batch=1) - self.assertEqual(repr(y), "BCOO(float32[2, 3], nse=3, n_batch=1)") - - y = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3), n_batch=1, n_dense=1) - self.assertEqual(repr(y), "BCOO(float32[2, 3], nse=1, n_batch=1, n_dense=1)") - - M_invalid = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3)) - M_invalid.indices = jnp.array([]) - self.assertEqual(repr(M_invalid), "BCOO()") - - @jit - def f(x): - self.assertEqual(repr(x), "DynamicJaxprTracer[BCOO(float32[5], nse=4)]") - f(x) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape)], - dtype=all_dtypes, - ) - def test_empty(self, shape, dtype, n_batch, n_dense): - M = sparse.empty(shape, dtype=dtype, n_batch=n_batch, n_dense=n_dense) - self.assertIsInstance(M, sparse.BCOO) - self.assertEqual(M.nse, 0) - self.assertEqual(M.n_batch, n_batch) - self.assertEqual(M.n_dense, n_dense) - self.assertEqual(M.dtype, dtype) - self.assertArraysEqual(M.todense(), jnp.empty(shape, dtype)) - - @jtu.sample_product( - [dict(n_batch=layout.n_batch, n_dense=layout.n_dense) - for layout in iter_sparse_layouts((3, 3))], - N=[3, 5], - M=[None, 4], - k=[-3, -1, 0, 2, 4], - dtype=all_dtypes, - ) - def test_eye(self, N, M, k, dtype, n_batch, n_dense): - mat = sparse.eye(N, M, k, dtype=dtype, n_batch=n_batch, n_dense=n_dense) - expected = jnp.eye(N, M, k, dtype=dtype) - expected_nse = sparse.BCOO.fromdense(expected, n_batch=n_batch, n_dense=n_dense).nse - - self.assertIsInstance(mat, sparse.BCOO) - self.assertEqual(mat.n_batch, n_batch) - self.assertEqual(mat.n_dense, n_dense) - self.assertEqual(mat.dtype, dtype) - self.assertEqual(mat.nse, expected_nse) - self.assertArraysEqual(mat.todense(), expected) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape)], - dtype=all_dtypes, - ) - def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense): - n_sparse = len(shape) - n_batch - n_dense - rng = rand_sparse(self.rng()) - M = rng(shape, dtype) - nse = sparse.util._count_stored_elements(M, n_batch=n_batch, n_dense=n_dense) - def round_trip(M): - return sparse.BCOO.fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense).todense() - args_maker = lambda: [M] - ident = lambda x: x - - self._CheckAgainstNumpy(ident, round_trip, args_maker) - self._CompileAndCheck(round_trip, args_maker) - self._CheckBatchingSparse(ident, round_trip, args_maker, bdims=self._random_bdims(n_batch)) - if jnp.issubdtype(dtype, jnp.floating): - # For n_sparse != 0, we can't use an identity because output zeros must not - # be dependent on input zeros. This mimics the code in count_stored_elements(). - def expected(M): - if n_sparse == 0: return M - mask = (M != 0).any(range(M.ndim - n_dense, M.ndim), keepdims=True) - return jnp.where(mask, M, 0) - self._CheckGradsSparse(expected, round_trip, args_maker) - - def test_bcoo_fromdense_sorted_and_unique_indices(self): - rng = self.rng() - rng_sparse = rand_sparse(rng) - mat = sparse.BCOO.fromdense(rng_sparse((5, 6), np.float32)) - perm = rng.permutation(mat.nse) - mat_unsorted = sparse.BCOO((mat.data[perm], mat.indices[perm]), - shape=mat.shape, - unique_indices=mat.unique_indices) - mat_resorted = mat_unsorted.sort_indices() - with self.subTest('sorted indices'): - self.assertArraysEqual(mat.indices, mat_resorted.indices) - self.assertArraysEqual(mat.data, mat_resorted.data) - - with self.subTest('unique indices'): - self.assertTrue(mat.unique_indices) - self.assertTrue(mat_unsorted.unique_indices) - self.assertTrue(mat_resorted.unique_indices) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape)], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - assume_unique=[True, False, None] - ) - def test_bcoo_extract(self, shape, dtype, n_batch, n_dense, assume_unique): - rng = rand_sparse(self.rng()) - - def args_maker(): - x = rng(shape, dtype) - x_bcoo = sparse.bcoo_fromdense(x, n_batch=n_batch, n_dense=n_dense) - # Unique indices are required for this test when assume_unique == True. - self.assertTrue(x_bcoo.unique_indices) - return x_bcoo, x - - dense_op = lambda _, x: x - sparse_op = partial(sparse.bcoo_extract, assume_unique=assume_unique) - - self._CheckAgainstDense(dense_op, sparse_op, args_maker) - self._CheckBatchingSparse(dense_op, sparse_op, args_maker, bdims=2 * self._random_bdims(n_batch)) - - def test_bcoo_extract_duplicate_indices(self): - data = jnp.array([1, 3, 9, 27, 81, 243]) - indices = jnp.array([[0], [5], [0], [3], [2], [3]]) - shape = (6,) - mat = sparse.BCOO((data, indices), shape=shape).todense() - - data1 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=True) - self.assertArraysEqual(data1, jnp.array([10, 3, 10, 270, 81, 270])) - - data2 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=False) - self.assertArraysEqual(data2, jnp.array([10, 3, 0, 270, 81, 0])) - - def test_bcoo_extract_duplicate_indices_n_sparse_0(self): - data = jnp.arange(6).reshape(3, 2) - indices = jnp.empty((3, 2, 0), dtype=int) - shape = (3,) - mat = sparse.BCOO((data, indices), shape=shape).todense() - - data1 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=True) - self.assertArraysEqual(data1, jnp.array([[1, 1], [5, 5], [9, 9]])) - - data2 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=False) - self.assertArraysEqual(data2, jnp.array([[1, 0], [5, 0], [9, 0]])) - - def test_bcoo_extract_batching(self): - # https://github.com/google/jax/issues/9431 - indices = jnp.zeros((4, 1, 1), dtype=int) - mat = jnp.arange(4.).reshape((4, 1)) - - # in_axes = (0, None) - expected = jnp.vstack([sparse_bcoo._bcoo_extract(i, mat[0]) for i in indices]) - actual = vmap(sparse_bcoo._bcoo_extract, in_axes=(0, None))(indices, mat[0]) - self.assertArraysEqual(expected, actual) - - # in_axes = (None, 0) - expected = jnp.vstack([sparse_bcoo._bcoo_extract(indices[0], m) for m in mat]) - actual = vmap(sparse_bcoo._bcoo_extract, in_axes=(None, 0))(indices[0], mat) - self.assertArraysEqual(expected, actual) - - # in_axes = (0, 0) - expected = jnp.vstack([sparse_bcoo._bcoo_extract(i, m) for i, m in zip(indices, mat)]) - actual = vmap(sparse_bcoo._bcoo_extract, in_axes=0)(indices, mat) - self.assertArraysEqual(expected, actual) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape)], - dtype=jtu.dtypes.floating, - ) - def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense): - rng = rand_sparse(self.rng()) - M = rng(shape, dtype) - nse = sparse.util._count_stored_elements(M, n_batch=n_batch, - n_dense=n_dense) - data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense) - - extract = partial(sparse_bcoo._bcoo_extract, indices) - j1 = jax.jacfwd(extract)(M) - j2 = jax.jacrev(extract)(M) - hess = jax.hessian(extract)(M) - self.assertArraysAllClose(j1, j2) - self.assertEqual(j1.shape, data.shape + M.shape) - self.assertEqual(hess.shape, data.shape + 2 * M.shape) - - def test_bcoo_extract_zero_nse(self): - # Regression test for https://github.com/google/jax/issues/13653 - - # (n_batch, n_sparse, n_dense) = (1, 0, 0), nse = 2 - args_maker = lambda: (jnp.zeros((3, 2, 0), dtype='int32'), jnp.arange(3)) - self._CompileAndCheck(sparse_bcoo._bcoo_extract, args_maker) - - # (n_batch, n_sparse, n_dense) = (0, 0, 1), nse = 2 - args_maker = lambda: (jnp.zeros((2, 0), dtype='int32'), jnp.arange(3)) - self._CompileAndCheck(sparse_bcoo._bcoo_extract, args_maker) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape)], - dtype=jtu.dtypes.numeric, - ) - def test_bcoo_transpose(self, shape, dtype, n_batch, n_dense): - n_sparse = len(shape) - n_batch - n_dense - rng = self.rng() - sprng = sptu.rand_bcoo(rng, n_batch=n_batch, n_dense=n_dense) - - permutation = np.concatenate([ - rng.permutation(range(n_batch)), - rng.permutation(range(n_batch, n_batch + n_sparse)), - rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int) - - args_maker = lambda: [sprng(shape, dtype)] - dense_func = partial(lax.transpose, permutation=permutation) - sparse_func = partial(sparse.bcoo_transpose, permutation=permutation) - - self._CheckAgainstDense(dense_func, sparse_func, args_maker) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(dense_func, sparse_func, args_maker) - self._CheckBatchingSparse(dense_func, sparse_func, args_maker, bdims=self._random_bdims(n_batch)) - - def test_bcoo_transpose_indices_sorted(self): - rng = self.rng() - rng_sparse = rand_sparse(rng) - n_batch, n_dense = 2, 2 - shape = (2, 3, 4, 5, 6, 7, 8) - mat = sparse.BCOO.fromdense(rng_sparse(shape, np.float32), - n_dense=n_dense, n_batch=n_batch) - - permutations = (1, 0, 2, 3, 4, 6, 5) - mat_T_indices_sorted = mat.transpose(axes=permutations) - self.assertTrue(mat_T_indices_sorted.indices_sorted) - - permutations = (0, 1, 3, 2, 4, 5, 6) - mat_T_indices_unsorted = mat.transpose(axes=permutations) - self.assertFalse(mat_T_indices_unsorted.indices_sorted) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape, min_n_batch=1) - ], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense): - rng = rand_sparse(self.rng()) - M = rng(shape, dtype) - nse = sparse.util._count_stored_elements(M, n_batch=n_batch, - n_dense=n_dense) - data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense) - - M1 = sparse_bcoo._bcoo_todense(data, indices[:1], spinfo=sparse_util.SparseInfo(M.shape)) - M2 = sparse_bcoo._bcoo_todense(data, jnp.stack(shape[0] * [indices[0]]), spinfo=sparse_util.SparseInfo(M.shape)) - self.assertAllClose(M1, M2) - - M3 = sparse_bcoo._bcoo_todense(data[:1], indices, spinfo=sparse_util.SparseInfo(M.shape)) - M4 = sparse_bcoo._bcoo_todense(jnp.stack(shape[0] * [data[0]]), indices, spinfo=sparse_util.SparseInfo(M.shape)) - self.assertAllClose(M3, M4) - - @jtu.sample_product( - props=_generate_batched_dot_general_properties(), - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - @jax.default_matmul_precision("float32") - def test_bcoo_dot_general(self, dtype: np.dtype, props: BatchedDotGeneralProperties): - rng = jtu.rand_default(self.rng()) - sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense) - args_maker = lambda: [sprng(props.lhs_shape, dtype), - rng(props.rhs_shape, dtype)] - dense_fun = partial(lax.dot_general, dimension_numbers=props.dimension_numbers) - sparse_fun = partial(sparse.bcoo_dot_general, dimension_numbers=props.dimension_numbers) - - tol = {np.float64: 1E-12, np.complex128: 1E-12, - np.float32: 1E-5, np.complex64: 1E-5} - self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol) - if jnp.issubdtype(dtype, jnp.floating) and props.n_dense == 0: - # Dense dimensions not yet fully supported in reverse mode. - modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev'] - self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol) - self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol, - bdims=self._random_bdims(props.n_batch, len(props.rhs_shape))) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, - lhs_contracting=lhs_contracting, rhs_contracting=rhs_contracting) - for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ - [(5,), (5,), [0], [0]], - [(5,), (5, 7), [0], [0]], - [(5,), (7, 5), [0], [1]], - [(5, 7), (5,), [0], [0]], - [(7, 5), (5,), [1], [0]], - [(3, 5), (2, 5), [1], [1]], - [(3, 5), (5, 2), [1], [0]], - [(5, 3), (2, 5), [0], [1]], - [(5, 3), (5, 2), [0], [0]], - ] - ], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - @jax.default_matmul_precision("float32") - @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") - @jtu.run_on_devices("gpu") - def test_bcoo_dot_general_cusparse( - self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting): - rng = jtu.rand_small(self.rng()) - rng_sparse = rand_sparse(self.rng()) - def args_maker(): - lhs = rng_sparse(lhs_shape, dtype) - rhs = rng(rhs_shape, dtype) - nse = sparse.util._count_stored_elements(lhs, n_batch=0, n_dense=0) - lhs_bcoo = sparse_bcoo.bcoo_fromdense(lhs, nse=nse, index_dtype=jnp.int32) - return lhs_bcoo, lhs, rhs - - dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) - - def f_dense(lhs_bcoo, lhs, rhs): - return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers) - - def f_sparse(lhs_bcoo, lhs, rhs): - return sparse_bcoo.bcoo_dot_general(lhs_bcoo, rhs, - dimension_numbers=dimension_numbers) - - self._CompileAndCheck(f_sparse, args_maker) - self._CheckAgainstNumpy(f_dense, f_sparse, args_maker) - - @jtu.sample_product( - [dict(n_batch=n_batch, lhs_shape=lhs_shape, rhs_shape=rhs_shape, - lhs_contracting=lhs_contracting, rhs_contracting=rhs_contracting) - for n_batch, lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ - [1, (1, 2, 3), (3, 2), [2], [0]], - [1, (1, 3, 2), (3, 2), [1], [0]], - [1, (1, 3, 2), (4, 3), [1], [1]], - [1, (4, 2, 3), (3, 5), [2], [0]], - [1, (4, 2, 3), (2, 5), [1], [0]], - [1, (4, 2, 3), (5, 3), [2], [1]], - ] - ], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - @jax.default_matmul_precision("float32") - @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") - @jtu.run_on_devices("gpu") - def test_bcoo_batched_matmat_cusparse( - self, n_batch, lhs_shape, rhs_shape, dtype, lhs_contracting, - rhs_contracting): - rng = jtu.rand_small(self.rng()) - rng_sparse = rand_sparse(self.rng()) - def args_maker(): - lhs = rng_sparse(lhs_shape, dtype) - rhs = rng(rhs_shape, dtype) - nse = sparse.util._count_stored_elements(lhs, n_batch=n_batch, - n_dense=0) - lhs_bcoo = sparse_bcoo.bcoo_fromdense(lhs, n_batch=n_batch, nse=nse, - index_dtype=jnp.int32) - return lhs_bcoo, lhs, rhs - - dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) - - def f_dense(lhs_bcoo, lhs, rhs): - return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers) - - def f_sparse(lhs_bcoo, lhs, rhs): - return sparse_bcoo.bcoo_dot_general(lhs_bcoo, rhs, - dimension_numbers=dimension_numbers) - - # TODO(tianjianlu): In some cases, this fails python_should_be_executing. - # self._CompileAndCheck(f_sparse, args_maker) - self._CheckAgainstNumpy(f_dense, f_sparse, args_maker) - - @jtu.sample_product( - [dict(n_batch=n_batch, lhs_shape=lhs_shape, rhs_shape=rhs_shape, - lhs_contracting=lhs_contracting, rhs_contracting=rhs_contracting) - for n_batch, lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ - [1, (1, 2, 3), (3), [2], [0]], - [1, (1, 2), (3, 2), [1], [1]], - ] - ], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") - @jtu.run_on_devices("gpu") - def test_bcoo_batched_matmat_default_lowering( - self, n_batch, lhs_shape, rhs_shape, dtype, lhs_contracting, - rhs_contracting): - rng = jtu.rand_small(self.rng()) - rng_sparse = rand_sparse(self.rng()) - lhs = rng_sparse(lhs_shape, dtype) - rhs = rng(rhs_shape, dtype) - nse = sparse.util._count_stored_elements(lhs, n_batch=n_batch, - n_dense=0) - lhs_bcoo = sparse_bcoo.bcoo_fromdense(lhs, n_batch=n_batch, nse=nse, - index_dtype=jnp.int32) - dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) - matmat_expected = lax.dot_general(lhs, rhs, - dimension_numbers=dimension_numbers) - sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general, - dimension_numbers=dimension_numbers)) - - # TODO(jakevdp): uncomment once batching is supported again. - # with self.gpu_matmul_warning_context( - # "bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"): - matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs) - self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback) - - @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") - @jtu.run_on_devices("gpu") - def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self): - """Tests bcoo dot general with out-of-bound and unsorted indices.""" - - rhs = jnp.ones((5, 3), dtype=jnp.float32) - - # It creates out-of-bound indices when nse > nnz. - lhs_mat_dense = jnp.array([[1, 0, 2, 3, 0], [0, 0, 0, 4, 0]], - dtype=jnp.float32) - lhs_mat_bcoo = sparse.BCOO.fromdense(lhs_mat_dense, nse=7) - rng = self.rng() - perm = rng.permutation(lhs_mat_bcoo.nse) - lhs_mat_bcoo_unsorted = sparse.BCOO( - (lhs_mat_bcoo.data[perm], lhs_mat_bcoo.indices[perm]), - shape=lhs_mat_dense.shape) - - dimension_numbers_2d = (([1], [0]), ([], [])) - sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general, - dimension_numbers=dimension_numbers_2d)) - - matmat_expected = lax.dot_general(lhs_mat_dense, rhs, - dimension_numbers=dimension_numbers_2d) - with self.subTest(msg="2D"): - with self.gpu_matmul_warning_context( - "bcoo_dot_general GPU lowering requires matrices with sorted indices*"): - matmat_unsorted_fallback = sp_matmat(lhs_mat_bcoo_unsorted, rhs) - self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback) - - lhs_vec_dense = jnp.array([0, 1, 0, 2, 0], dtype=jnp.float32) - lhs_vec_bcoo = sparse.BCOO.fromdense(lhs_vec_dense, nse=5) - rng = self.rng() - perm = rng.permutation(lhs_vec_bcoo.nse) - lhs_vec_bcoo_unsorted = sparse.BCOO( - (lhs_vec_bcoo.data[perm], lhs_vec_bcoo.indices[perm]), - shape=lhs_vec_dense.shape, indices_sorted=False) - - dimension_numbers_1d = (([0], [0]), ([], [])) - sp_vecmat = jit(partial(sparse_bcoo.bcoo_dot_general, - dimension_numbers=dimension_numbers_1d)) - - vecmat_expected = lax.dot_general(lhs_vec_dense, rhs, - dimension_numbers=dimension_numbers_1d) - - with self.subTest(msg="1D"): - with self.gpu_matmul_warning_context( - "bcoo_dot_general GPU lowering requires matrices with sorted indices*"): - vecmat_unsorted_fallback = sp_vecmat(lhs_vec_bcoo_unsorted, rhs) - self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback) - - @jtu.sample_product( - props=_generate_batched_dot_general_properties(), - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - @jax.default_matmul_precision("float32") - def test_bcoo_rdot_general(self, dtype: np.dtype, props: BatchedDotGeneralProperties): - rng = jtu.rand_default(self.rng()) - sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense) - args_maker = lambda: [rng(props.rhs_shape, dtype), - sprng(props.lhs_shape, dtype)] - dimension_numbers = tuple(d[::-1] for d in props.dimension_numbers) - sparse_fun = partial(sparse.bcoo_dot_general, dimension_numbers=dimension_numbers) - dense_fun = partial(lax.dot_general, dimension_numbers=dimension_numbers) - - tol = {np.float64: 1E-12, np.complex128: 1E-12, - np.float32: 1E-5, np.complex64: 1E-5} - self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol) - if jnp.issubdtype(dtype, jnp.floating): - # Dense dimensions not yet fully supported in reverse mode. - modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev'] - self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol) - - @jtu.sample_product( - [dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape, - rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) - for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [ - ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0), - ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0), - ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0), - ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0), - ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0), - ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1), - ] - ], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - @jax.default_matmul_precision("float32") - def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype, - dimension_numbers, n_batch, n_dense): - rng = jtu.rand_small(self.rng()) - rng_sparse = rand_sparse(self.rng()) - - X = rng_sparse(lhs_shape, dtype) - nse = sparse.util._count_stored_elements(X, n_batch=n_batch, - n_dense=n_dense) - data, indices = sparse_bcoo._bcoo_fromdense(X, nse=nse, n_batch=n_batch, n_dense=n_dense) - Y = rng(rhs_shape, dtype) - - def f_dense(X, Y): - return lax.dot_general(X, Y, dimension_numbers=dimension_numbers) - - def f_sparse(data, indices, Y): - return sparse_bcoo._bcoo_dot_general(data, indices, Y, lhs_spinfo=sparse_util.SparseInfo(X.shape), - dimension_numbers=dimension_numbers, preferred_element_type=None) - - for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]): - X = sparse_bcoo._bcoo_todense(data, indices, spinfo=sparse_util.SparseInfo(X.shape)) - self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y)) - - @jtu.sample_product( - props=_generate_bcoo_dot_general_sampled_properties(), - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - @jax.default_matmul_precision("float32") - @jtu.skip_on_flag("jax_skip_slow_tests", True) - def test_bcoo_dot_general_sampled(self, props, dtype): - rng = jtu.rand_default(self.rng()) - sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense) - out = jax.eval_shape(partial(lax.dot_general, dimension_numbers=props.dimension_numbers), - jax.ShapeDtypeStruct(props.lhs_shape, dtype), - jax.ShapeDtypeStruct(props.rhs_shape, dtype)) - args_maker = lambda: [rng(props.lhs_shape, dtype), rng(props.rhs_shape, dtype), - sprng(out.shape, dtype).indices] - - def dense_fun(lhs, rhs, indices): - AB = lax.dot_general(lhs, rhs, dimension_numbers=props.dimension_numbers) - return sparse_bcoo._bcoo_extract(indices, AB) - def sparse_fun(lhs, rhs, indices): - return sparse.bcoo_dot_general_sampled( - lhs, rhs, indices, dimension_numbers=props.dimension_numbers) - - self._CheckAgainstDense(dense_fun, sparse_fun, args_maker) - if jnp.issubdtype(dtype, jnp.floating): - # Note: forward mode fails for some sparse layouts. - # TODO(jakevdp) fix forward-mode autodiff & enable tests here. - self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=['rev'], argnums=[0, 1]) - - @jtu.sample_product( - [{'xshape': xshape, 'yshape': yshape, 'lhs_contract': lhs_contract, 'rhs_contract': rhs_contract} - for (xshape, yshape, lhs_contract, rhs_contract) in [ - [(4, 3), (4, 5), (0,), (0,)], - [(3, 4), (4, 5), (1,), (0,)], - [(4, 3), (5, 4), (0,), (1,)], - [(3, 4), (5, 4), (1,), (1,)], - [(3,), (3,), (), ()], - [(3,), (5,), (), ()], - [(5,), (3,), (), ()], - [(5,), (5,), (), ()], - ]], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - n_batch=[0, 1, 2], - ) - @jax.default_matmul_precision("float32") - def test_bcoo_dot_general_sampled_fast_cases( - self, xshape, yshape, lhs_contract, rhs_contract, n_batch, dtype): - rng = jtu.rand_default(self.rng()) - sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch) - dimension_numbers = ((lhs_contract, rhs_contract), ([], [])) - - out_shape = jax.eval_shape(partial(lax.dot_general, dimension_numbers=dimension_numbers), - jax.ShapeDtypeStruct(xshape, dtype), jax.ShapeDtypeStruct(yshape, dtype)) - - args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype), - sprng(out_shape.shape, out_shape.dtype).indices] - - def f1(x, y, indices): - mat_full = lax.dot_general(x, y, dimension_numbers=dimension_numbers) - return sparse_bcoo._bcoo_extract(indices, mat_full) - - def f2(x, y, indices): - return sparse.bcoo_dot_general_sampled(x, y, indices, dimension_numbers=dimension_numbers) - - self._CheckAgainstNumpy(f1, f2, args_maker, tol=MATMUL_TOL) - self._CompileAndCheck(f2, args_maker, tol=MATMUL_TOL) - - @jtu.sample_product( - [dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape, - rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) - for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [ - ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0), - ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 1), - ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0), - ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0), - ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 1), - ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0), - ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0), - ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1), - ] - ], - dtype=jtu.dtypes.floating, - ) - @jax.default_matmul_precision("float32") - def test_bcoo_dot_general_sampled_ad(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): - rng = jtu.rand_default(self.rng()) - sprng = rand_sparse(self.rng()) - out_shape = lax.dot_general( - jnp.zeros(lhs_shape), jnp.zeros(rhs_shape), - dimension_numbers=dimension_numbers).shape - - lhs = rng(lhs_shape, dtype) - rhs = rng(rhs_shape, dtype) - indices = sparse.BCOO.fromdense(sprng(out_shape, dtype), - n_batch=n_batch, n_dense=n_dense).indices - - def dense_fun(lhs, rhs, indices): - AB = lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers) - return sparse_bcoo._bcoo_extract(indices, AB) - def sparse_fun(lhs, rhs, indices): - return sparse.bcoo_dot_general_sampled( - lhs, rhs, indices, dimension_numbers=dimension_numbers) - - jf_dense = jax.jacfwd(dense_fun)(lhs, rhs, indices) - jf_sparse = jax.jacfwd(sparse_fun)(lhs, rhs, indices) - jr_dense = jax.jacrev(dense_fun)(lhs, rhs, indices) - jr_sparse = jax.jacrev(sparse_fun)(lhs, rhs, indices) - - self.assertAllClose(jf_sparse, jf_dense) - self.assertAllClose(jr_sparse, jr_dense) - self.assertAllClose(jf_sparse, jr_sparse) - - @jtu.sample_product( - [dict(lhs_n_batch=lhs_n_batch, rhs_n_batch=rhs_n_batch, lhs_shape=lhs_shape, - rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) - for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dimension_numbers in [ - # (batched) outer products (no contraction) - ((5,), 0, (6,), 0, (([], []), ([], []))), - ((3, 5), 0, (2, 4), 0, (([], []), ([], []))), - ((3, 5), 1, (3, 4), 1, (([], []), ([0], [0]))), - # (batched) vector-vector products - ((5,), 0, (5,), 0, (([0], [0]), ([], []))), - ((7,), 0, (7,), 0, (([0], [0]), ([], []))), - ((5, 7), 1, (7,), 0, (([1], [0]), ([], []))), - ((2, 3, 4), 2, (2, 4), 1, (([2], [1]), ([0], [0]))), - ((2, 3, 4), 2, (2, 4), 1, (([2], [1]), ([], []))), - ((2, 3, 4), 2, (3, 4), 1, (([2], [1]), ([1], [0]))), - ((2, 3, 4), 2, (3, 4), 1, (([2], [1]), ([], []))), - # (batched) matrix-vector products - ((5, 7), 0, (7,), 0, (([1], [0]), ([], []))), - ((2, 3, 4), 1, (4,), 0, (([2], [0]), ([], []))), - ((2, 3, 4), 1, (2, 4), 1, (([2], [1]), ([0], [0]))), - ((3, 2, 4), 1, (3, 4), 1, (([2], [1]), ([0], [0]))), - ((2, 3, 4), 0, (2,), 0, (([0], [0]), ([], []))), - # (batched) matrix-matrix products - ((5, 7), 0, (7, 3), 0, (([1], [0]), ([], []))), - ((2, 3, 4), 1, (4, 3), 0, (([2], [0]), ([], []))), - ((2, 3, 4), 1, (2, 4, 3), 1, (([2], [1]), ([0], [0]))), - # more general operations - ((2, 3, 4, 3), 1, (2, 4, 3, 4), 1, (([2, 3], [1, 2]), ([0], [0]))), - ((2, 3, 4, 3, 1), 2, (3, 2, 3, 4), 2, (([2, 3], [3, 2]), ([0, 1], [1, 0]))), - ] - ], - swap=[True, False], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - @jax.default_matmul_precision("float32") - @jtu.skip_on_flag("jax_skip_slow_tests", True) - def test_bcoo_spdot_general(self, lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dtype, swap, dimension_numbers): - if swap: - dimension_numbers = tuple(d[::-1] for d in dimension_numbers) - lhs_shape, rhs_shape = rhs_shape, lhs_shape - lhs_n_batch, rhs_n_batch = rhs_n_batch, lhs_n_batch - - lhs_n_sparse = len(lhs_shape) - lhs_n_batch - rhs_batch = dimension_numbers[1][1] - lhs_contracting = dimension_numbers[0][0] - should_error = (rhs_n_batch > len(rhs_batch) and lhs_n_sparse > len(lhs_contracting)) - - sprng = sptu.rand_bcoo(self.rng()) - args_maker = lambda: [sprng(lhs_shape, dtype, n_batch=lhs_n_batch), - sprng(rhs_shape, dtype, n_batch=rhs_n_batch)] - - def f_dense(x, y): - return lax.dot_general(x, y, dimension_numbers=dimension_numbers) - - def f_sparse(xsp, ysp): - return sparse.bcoo_dot_general(xsp, ysp, dimension_numbers=dimension_numbers) - - if should_error: - with self.assertRaisesRegex(ValueError, ".*cannot have unused batch dims on rhs with unused sparse dims on lhs."): - f_sparse(*args_maker()) - else: - tol = {"float32": 1E-5, "complex64": 1E-5, "float64": 1E-14, "complex128": 1E-14} - self._CheckAgainstDense(f_dense, f_sparse, args_maker, tol=tol) - self._CheckBatchingSparse(f_dense, f_sparse, args_maker, tol=tol) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(f_dense, f_sparse, args_maker, modes=['fwd']) - - @jtu.sample_product( - lhs_shape=[(5,), (4, 5)], - rhs_shape=[(5,), (5, 4)]) - @jax.default_matmul_precision("float32") - def test_bcoo_spdot_general_nse(self, lhs_shape, rhs_shape): - rng = sptu.rand_bcoo(self.rng()) - dtype = jnp.float32 - lhs = rng(lhs_shape, dtype) - rhs = rng(rhs_shape, dtype) - out = lhs @ rhs - - expected_out = lhs.todense() @ rhs.todense() - expected_nse = min(lhs.nse * rhs.nse, out.size) - - self.assertArraysAllClose(out.todense(), expected_out) - self.assertEqual(out.nse, expected_nse) - - def test_bcoo_spdot_general_ad_bug(self): - # Regression test for https://github.com/google/jax/issues/10163 - A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]]) - A_values = jnp.array([-2.0, 1.0, -1.0, 0.5, 2.0]) - A_shape = (2, 3) - - B_indices = jnp.array([[0, 2], [2, 1], [0, 3], [1, 3], [1, 0], [0, 0]]) - B_values = jnp.array([10.0, 100.0, 1000.0, -5.0, -50.0, -500.0]) - B_shape = (3, 4) - - def sp_sp_product(v1, v2): - A = sparse.BCOO((v1, A_indices), shape=A_shape) - B = sparse.BCOO((v2, B_indices), shape=B_shape) - return (A @ B).todense() - - def sp_de_product(v1, v2): - A = sparse.BCOO((v1, A_indices), shape=A_shape) - B = sparse.BCOO((v2, B_indices), shape=B_shape).todense() - return A @ B - - def de_de_product(v1, v2): - sparse1 = sparse.BCOO((v1, A_indices), shape=A_shape).todense() - dense2 = sparse.BCOO((v2, B_indices), shape=B_shape).todense() - return sparse1 @ dense2 - - sp_sp_jac = jax.jacfwd(sp_sp_product, argnums=1)(A_values, B_values) - sp_de_jac = jax.jacfwd(sp_de_product, argnums=1)(A_values, B_values) - de_de_jac = jax.jacfwd(de_de_product, argnums=1)(A_values, B_values) - - self.assertAllClose(sp_sp_jac, de_de_jac) - self.assertAllClose(sp_de_jac, de_de_jac) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(), (5,), (5, 8), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape)], - dtype=jtu.dtypes.numeric, - ) - def test_bcoo_slice(self, shape, dtype, n_batch, n_dense): - rng = self.rng() - sprng = sptu.rand_bcoo(rng, n_batch=n_batch, n_dense=n_dense) - args_maker = lambda: [sprng(shape, dtype)] - - slices = rng.randint(0, np.array(shape) + 1, (2, len(shape))).T - slices.sort(1) - start_indices, limit_indices = unzip2(slices) - strides = list(rng.randint(1, 4, len(shape))) - kwds = dict(start_indices=start_indices, limit_indices=limit_indices, strides=strides) - - dense_func = partial(lax.slice, **kwds) - sparse_func = partial(sparse.bcoo_slice, **kwds) - - self._CheckAgainstDense(dense_func, sparse_func, args_maker) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(dense_func, sparse_func, args_maker) - - mat, = args_maker() - out = sparse_func(mat) - - # Array layout is the same - self.assertEqual(mat.n_batch, out.n_batch) - self.assertEqual(mat.n_sparse, out.n_sparse) - self.assertEqual(mat.n_dense, out.n_dense) - - # Unnecessary padding eliminated - max_nse = math.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse]) - self.assertLessEqual(out.nse, max_nse) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(), (5,), (5, 8), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape)], - dtype=jtu.dtypes.numeric, - ) - def test_bcoo_dynamic_slice(self, shape, dtype, n_batch, n_dense): - rng = self.rng() - sprng = sptu.rand_bcoo(rng, n_batch=n_batch, n_dense=n_dense) - args_maker = lambda: [sprng(shape, dtype)] - - rng = self.rng() - # Note: test out-of-range start indices - start_indices = rng.randint(-max(shape, default=0), max(shape, default=0), len(shape)) - slice_sizes = rng.randint(0, shape, len(shape)) - kwds = dict(start_indices=start_indices, slice_sizes=slice_sizes) - dense_func = partial(lax.dynamic_slice, **kwds) - sparse_func = partial(sparse.bcoo_dynamic_slice, **kwds) - - self._CheckAgainstDense(dense_func, sparse_func, args_maker) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(dense_func, sparse_func, args_maker) - - mat, = args_maker() - out = sparse_func(mat) - - # Array layout is the same - self.assertEqual(mat.n_batch, out.n_batch) - self.assertEqual(mat.n_sparse, out.n_sparse) - self.assertEqual(mat.n_dense, out.n_dense) - - # Unnecessary padding eliminated - max_nse = math.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse]) - self.assertLessEqual(out.nse, max_nse) - - @jtu.sample_product( - [dict(shape=shape, n_batch=n_batch, n_dense=n_dense, idx=idx) - for shape, idx in [ - [(5,), np.index_exp[:]], - [(5,), np.index_exp[4]], - [(5,), np.index_exp[::2]], - [(5,), np.index_exp[1::2]], - [(5,), 1], - [(3, 4), np.index_exp[1]], - [(3, 4), np.index_exp[1, 2]], - [(3, 4), np.index_exp[np.array([1, 2])]], - [(3, 4), np.index_exp[np.array([[1], [2]]), 0]], - [(3, 4), np.index_exp[np.array([[1], [2]]), 1:]], - [(3, 4), np.index_exp[np.array([True, False, True])]], - [(3, 4), np.index_exp[:2, np.array([True, False, True, False])]], - [(3, 4), np.index_exp[None, 0, np.array([[2]])]], - [(3, 4, 5), np.index_exp[2]], - [(3, 4, 5), np.index_exp[:, 2]] - ] - for n_batch in range(len(shape) + 1) - for n_dense in [0] # TODO(jakevdp): add tests with n_dense - ], - dtype=jtu.dtypes.numeric, - ) - def test_bcoo_getitem(self, shape, dtype, n_batch, n_dense, idx): - sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) - args_maker = lambda: [sprng(shape, dtype)] - - fun = lambda x: x[idx] - - self._CheckAgainstDense(fun, fun, args_maker) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(fun, fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, n_batch=n_batch, n_dense=n_dense) - for shape in [(2,), (3, 4), (5, 6, 2)] - for n_batch in range(len(shape) + 1) - for n_dense in [0] # TODO(jakevdp): add tests with n_dense - ], - dtype=jtu.dtypes.numeric, - ) - def test_bcoo_iter(self, shape, dtype, n_batch, n_dense): - sprng = rand_sparse(self.rng()) - args_maker = lambda: [sprng(shape, dtype)] - - self._CheckAgainstDense(list, list, args_maker) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, nse=nse) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape) - for nse in [None, math.prod(shape) - 1] - ], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - remove_zeros=[True, False], - ) - @jtu.skip_on_flag("jax_skip_slow_tests", True) - def test_bcoo_sum_duplicates(self, shape, dtype, n_batch, n_dense, nse, remove_zeros): - sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) - - def args_maker(): - # Create a matrix with duplicate indices - M = sprng(shape, dtype) - new_indices = jnp.concatenate([M.indices, M.indices], axis=n_batch) - new_data = jnp.concatenate([M.data, M.data], axis=n_batch) - return [sparse.BCOO((new_data, new_indices), shape=M.shape)] - - dense_fun = lambda x: x - def sparse_fun(x): - out = x.sum_duplicates(nse=nse, remove_zeros=remove_zeros) - self.assertTrue(out.unique_indices) - if nse: - self.assertEqual(out.nse, nse) - return out - self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, check_jit=(nse is not None)) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(dense_fun, sparse_fun, args_maker) - if nse is not None: - self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape)], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - def test_bcoo_sort_indices(self, shape, dtype, n_batch, n_dense): - rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) - M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense) - M.indices = M.indices[..., ::-1, :] - M.indices_sorted = False - - M_sorted = M.sort_indices() - self.assertArraysEqual(M.todense(), M_sorted.todense()) - self.assertEqual(M.unique_indices, M_sorted.unique_indices) - self.assertEqual(True, M_sorted.indices_sorted) - - indices = M_sorted.indices - if indices.size > 0: - flatind = indices.reshape(-1, *indices.shape[-2:]).transpose(0, 2, 1) - sorted = jax.vmap(jnp.lexsort)(flatind[:, ::-1]) - self.assertArraysEqual(sorted, lax.broadcasted_iota(sorted.dtype, sorted.shape, sorted.ndim - 1)) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape, min_n_batch=1)], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - def test_bcoo_sort_indices_batching(self, shape, dtype, n_batch, n_dense): - rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) - M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense) - M.indices = M.indices[..., ::-1, :] - M.indices_sorted = False - - identity = lambda M: M - sort_ind = lambda M: M.sort_indices() - for b in range(n_batch): - identity = jax.vmap(identity, in_axes=b) - sort_ind = jax.vmap(sort_ind, in_axes=b) - M_sorted = sort_ind(M) - M_expected = identity(M) - self.assertArraysEqual(M_expected.todense(), M_sorted.todense()) - self.assertEqual(M.unique_indices, M_sorted.unique_indices) - self.assertEqual(True, M_sorted.indices_sorted) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape)], - dtype=jtu.dtypes.floating, - ) - def test_bcoo_sort_indices_ad(self, shape, dtype, n_batch, n_dense): - rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) - M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense) - M.indices = M.indices[..., ::-1, :] - - def sort_indices(data): - return sparse.BCOO((data, M.indices), shape=M.shape).sort_indices().data - - data_dot_fwd = jax.jacfwd(sort_indices)(M.data) - data_dot_rev = jax.jacrev(sort_indices)(M.data) - - self.assertAllClose(data_dot_fwd, data_dot_rev) - - def test_bcoo_sort_indices_broadcasted(self): - rng_index = jtu.rand_int(self.rng(), low=0, high=10) - rng_data = jtu.rand_default(self.rng()) - - # Construct matrix with three broadcasted batch dimensions. - indices = rng_index((1, 3, 1, 10, 2), dtype='int32') - data = rng_data((1, 1, 4, 10, 3), dtype='int32') - shape = (2, 3, 4, 5, 4, 3) - mat = sparse.BCOO((data, indices), shape=shape) - - indices_shape_out = indices.shape - data_shape_out = (*map(max, indices.shape[:3], data.shape[:3]), *data.shape[3:]) - - mat_sorted = sparse.bcoo_sort_indices(mat) - assert mat_sorted.indices.shape == indices_shape_out - assert mat_sorted.data.shape == data_shape_out - self.assertArraysEqual(mat.todense(), mat_sorted.todense()) - - mat_sorted_jit = jit(sparse.bcoo_sort_indices)(mat) - assert mat_sorted_jit.indices.shape == indices_shape_out - assert mat_sorted_jit.data.shape == data_shape_out - self.assertArraysEqual(mat.todense(), mat_sorted_jit.todense()) - - def test_bcoo_sum_duplicates_inferred_nse(self): - x = sparse.BCOO.fromdense(jnp.diag(jnp.arange(4))) - self.assertEqual(x.nse, 3) - y = x + x.T - self.assertEqual(y.nse, 6) - y2 = y.sum_duplicates() - self.assertEqual(y2.nse, 3) - self.assertArraysEqual(y.todense(), y2.todense()) - - def test_bcoo_sum_duplicates_remove_zeros(self): - data = jnp.array([0, 1, 0, 0]) - indices = jnp.array([[0], [1], [2], [3]]) - x = sparse.BCOO((data, indices), shape=(4,)) - self.assertEqual(x.nse, 4) - - y1 = x.sum_duplicates(remove_zeros=True) - self.assertArraysEqual(x.todense(), y1.todense()) - self.assertEqual(y1.nse, 1) - - y2 = x.sum_duplicates(remove_zeros=False) - self.assertArraysEqual(x.todense(), y2.todense()) - self.assertEqual(y2.nse, x.nse) - - def test_bcoo_sum_duplicates_padding(self): - # Regression test for https://github.com/google/jax/issues/8163 - size = 3 - data = jnp.array([1, 0, 0]) - indices = jnp.array([1, size, size])[:, None] - x = sparse.BCOO((data, indices), shape=(3,)) - y = x.sum_duplicates(nse=x.nse) - self.assertArraysEqual(x.todense(), y.todense()) - self.assertArraysEqual(x.indices, y.indices) - self.assertArraysEqual(x.data, y.data) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, axes=axes) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape) - for naxes in range(len(shape)) - for axes in itertools.combinations(range(len(shape)), naxes) - ], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes): - sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) - args_maker = lambda: [sprng(shape, dtype)] - sparse_fun = partial(sparse.bcoo_reduce_sum, axes=axes) - dense_fun = partial(lambda x: x.sum(axes)) - - tol = {np.float64: 1E-14} - self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(dense_fun, sparse_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, dimensions=dimensions, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape, dimensions in [ - [(1,), (0,)], - [(1,), (-1,)], - [(2, 1, 4), (1,)], - [(2, 1, 3, 1), (1,)], - [(2, 1, 3, 1), (1, 3)], - [(2, 1, 3, 1), (3,)], - ] - for layout in iter_sparse_layouts(shape)], - dtype=jtu.dtypes.numeric, - ) - def test_bcoo_squeeze(self, shape, dtype, dimensions, n_batch, n_dense): - sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) - args_maker = lambda: [sprng(shape, dtype)] - dense_func = partial(lax.squeeze, dimensions=dimensions) - sparse_func = partial(sparse.bcoo_squeeze, dimensions=dimensions) - - self._CheckAgainstDense(dense_func, sparse_func, args_maker) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(dense_func, sparse_func, args_maker) - - @jtu.sample_product( - [dict(batch_shapes=shapes, batch_perm=perm) - for shapes in COMPATIBLE_SHAPE_PAIRS - for perm in itertools.permutations(range(len(shapes[0])))], - [dict(sparse_shapes=shapes, sparse_perm=perm) - for shapes in COMPATIBLE_SHAPE_PAIRS - for perm in itertools.permutations(range(len(shapes[0])))], - [dict(dense_shapes=shapes, dense_perm=perm) - for shapes in [[(),()]] # TODO(jakevdp) add support for dense shapes - for perm in itertools.permutations(range(len(shapes[0])))], - dtype=jtu.dtypes.numeric - ) - def test_bcoo_reshape(self, batch_shapes, sparse_shapes, dense_shapes, - batch_perm, sparse_perm, dense_perm, dtype): - # Sparse reshapes cannot mix between sparse, dense, and batch dimensions. - shape = (*batch_shapes[0], *sparse_shapes[0], *dense_shapes[0]) - new_sizes = (*batch_shapes[1], *sparse_shapes[1], *dense_shapes[1]) - n_batch = len(batch_shapes[0]) - n_sparse = len(sparse_shapes[0]) - n_dense = len(dense_shapes[0]) - dimensions = ( - *batch_perm, - *(dim + n_batch for dim in sparse_perm), - *(dim + n_batch + n_sparse for dim in dense_perm) - ) - - rng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) - args_maker = lambda: [rng(shape, dtype)] - - sparse_func = partial(sparse.bcoo_reshape, new_sizes=new_sizes, dimensions=dimensions) - dense_func = partial(lax.reshape, new_sizes=new_sizes, dimensions=dimensions) - - self._CheckAgainstDense(dense_func, sparse_func, args_maker) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(dense_func, sparse_func, args_maker) - - def test_bcoo_reshape_error(self): - x = sparse.BCOO.fromdense(jnp.ones((2, 2, 3)), n_batch=1) - with self.assertRaisesRegex(ValueError, ".*cannot mix batch and sparse dimensions.*"): - x.reshape(3, 2, 2) - y = sparse.BCOO((x.data[:1], x.indices), shape=x.shape) - with self.assertRaisesRegex(NotImplementedError, "reshape of arrays with broadcasted batch dimensions."): - y.reshape(2, 3, 2) - - @jtu.sample_product( - [dict(shape=shape, dimensions=dimensions, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(3,), (3, 4), (3, 4, 5)] - for dimensions in _iter_subsets(range(len(shape))) - for layout in iter_sparse_layouts(shape)], - dtype=jtu.dtypes.numeric, - ) - def test_bcoo_rev(self, shape, dtype, n_batch, n_dense, dimensions): - sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) - args_maker = lambda: [sprng(shape, dtype)] - dense_func = partial(lax.rev, dimensions=dimensions) - sparse_func = partial(sparse.bcoo_rev, dimensions=dimensions) - - self._CheckAgainstDense(dense_func, sparse_func, args_maker) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(dense_func, sparse_func, args_maker) - - def test_bcsr_matmul_with_out_of_bounds_data(self): - # Simple regression test of a failure mode for cuSparse. - data = jnp.array([1, 2, 3, 4], dtype='float32') - indices = jnp.array([0, 1, 2, 3]) - indptr = jnp.array([0, 1, 3, 3]) - M = sparse.BCSR((data, indices, indptr), shape=(3, 4)) - x = jnp.array([1, 2, 3, 4], dtype='float32') - - sparse_result = jax.jit(operator.matmul)(M, x) - dense_result = jax.jit(operator.matmul)(M.todense(), x) - self.assertAllClose(sparse_result, dense_result) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) - for lhs_shape, rhs_shape in [[(3, 4), (4,)], - [(3, 4), (4, 5)], - [(3, 4), (2, 4, 5)]] - ], - lhs_dtype=all_dtypes, - rhs_dtype=all_dtypes, - ) - @jax.default_matmul_precision("float32") - @jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning) - def test_bcsr_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - # Note: currently, batch dimensions in matmul must correspond to batch - # dimensions in the sparse representation. - n_batch_lhs = max(0, len(lhs_shape) - 2) - - rng = jtu.rand_default(self.rng()) - sprng = sptu.rand_bcsr(self.rng()) - args_maker = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=n_batch_lhs), - jnp.array(rng(rhs_shape, rhs_dtype))] - - tol = {np.float64: 1E-7, np.complex128: 1E-6, - np.float32: 2E-6, np.complex64: 2E-6} - - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker, - tol=tol) - - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) - for lhs_shape, rhs_shape in [[(3,), (3,)], - [(3, 4), (4,)], - [(4,), (4, 5)], - [(3, 4), (4, 5)], - [(3, 4), (2, 4, 5)], - [(2, 3, 4), (4, 5)], - [(2, 3, 4), (2, 4, 5)]] - ], - lhs_dtype=all_dtypes, - rhs_dtype=all_dtypes, - ) - @jax.default_matmul_precision("float32") - @jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning) - def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - if (jtu.test_device_matches(["cuda"]) and - _is_required_cuda_version_satisfied(12000)): - raise unittest.SkipTest("Triggers a bug in cuda-12 b/287344632") - - # Note: currently, batch dimensions in matmul must correspond to batch - # dimensions in the sparse representation. - n_batch_lhs = max(0, len(lhs_shape) - 2) - n_batch_rhs = max(0, len(rhs_shape) - 2) - - rng = jtu.rand_default(self.rng()) - sprng = sptu.rand_bcoo(self.rng()) - args_maker_de_sp = lambda: [jnp.array(rng(lhs_shape, lhs_dtype)), - sprng(rhs_shape, rhs_dtype, n_batch=n_batch_rhs)] - args_maker_sp_de = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=n_batch_lhs), - jnp.array(rng(rhs_shape, rhs_dtype))] - - tol = {np.float64: 1E-7, np.complex128: 1E-7, - np.float32: 1E-6, np.complex64: 1E-6} - - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker_de_sp, tol=tol) - self._CheckAgainstDense(operator.matmul, operator.matmul, args_maker_sp_de, tol=tol) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, n_batch=layout.n_batch, - n_dense=layout.n_dense) - for lhs_shape, rhs_shape in [[(3,), ()], [(3,), (1,)], [(3,), (3,)], - [(3, 4), ()], [(3, 4), (4,)], [(3, 4), (3, 1)], [(3, 4), (3, 4)], - [(3, 4, 5), (4, 5)], [(3, 4, 5), (3, 1, 1)], [(3, 4, 5), (1, 4, 1)]] - for layout in iter_sparse_layouts(lhs_shape) - ], - lhs_dtype=all_dtypes, - rhs_dtype=all_dtypes, - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def test_bcoo_mul_dense(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense): - rng = jtu.rand_default(self.rng()) - sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) - - args_maker_sp_de = lambda: [sprng(lhs_shape, lhs_dtype), jnp.array(rng(rhs_shape, rhs_dtype))] - args_maker_de_sp = lambda: [jnp.array(rng(rhs_shape, rhs_dtype)), sprng(lhs_shape, lhs_dtype)] - - tol = {np.float64: 1E-13, np.complex128: 1E-13, - np.float32: 1E-6, np.complex64: 1E-6} - - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstDense(operator.mul, operator.mul, args_maker_de_sp, tol=tol) - self._CheckAgainstDense(operator.mul, operator.mul, args_maker_sp_de, tol=tol) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, lhs_n_batch=lhs_n_batch, - rhs_n_batch=rhs_n_batch, n_dense=n_dense) - # TODO(jakevdp): add broadcasted shapes (from bcoo_mul_dense) once sparse-sparse mul - # supports inputs of differing rank. - for lhs_shape, rhs_shape in [[(3,), (1,)], [(3,), (3,)], - [(3, 4), (1, 1)], [(3, 4), (1, 4)], [(3, 4), (3, 1)], [(3, 4), (3, 4)], - [(3, 4, 5), (1, 4, 5)], [(3, 4, 5), (3, 1, 1)], [(3, 4, 5), (1, 4, 1)]] - # TODO(jakevdp): add tests for batch & dense dimensions. - for lhs_n_batch in range(len(lhs_shape) + 1) - for rhs_n_batch in range(len(lhs_shape) + 1) - for n_dense in range(len(lhs_shape) + 1 - max(lhs_n_batch, rhs_n_batch)) - ], - lhs_dtype=all_dtypes, - rhs_dtype=all_dtypes, - ) - def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, lhs_n_batch, rhs_n_batch, n_dense): - sprng = sptu.rand_bcoo(self.rng(), n_dense=n_dense) - args_maker = lambda: [sprng(lhs_shape, lhs_dtype, n_batch=lhs_n_batch), - sprng(rhs_shape, rhs_dtype, n_batch=rhs_n_batch)] - - tol = {np.float64: 1E-13, np.complex128: 1E-13, - np.float32: 1E-5, np.complex64: 1E-5} - - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstDense(operator.mul, operator.mul, args_maker, tol=tol) - - def test_bcoo_mul_sparse_with_duplicates(self): - # Regression test for https://github.com/google/jax/issues/8888 - indices = jnp.array([[0, 1, 0, 0, 1, 1], - [1, 0, 1, 2, 0, 2]]).T - data = jnp.array([1, 2, 3, 4, 5, 6]) - mat = sparse.BCOO((data, indices), shape=(3, 3)) - self.assertArraysEqual((mat * mat).todense(), mat.todense() * mat.todense()) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(), (3,), (3, 5), (3, 5, 4)] - for layout in iter_sparse_layouts(shape)], - dtype=all_dtypes, - ) - def test_bcoo_broadcast_in_dim(self, shape, dtype, n_batch, n_dense): - rng = rand_sparse(self.rng()) - x = jnp.array(rng(shape, dtype)) - xsp = sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense) - - self.assertEqual(xsp[None].n_batch, xsp.n_batch + 1) - self.assertArraysEqual(xsp[None].todense(), x[None]) - - if len(shape) >= 1: - self.assertEqual(xsp[:, None].n_batch, xsp.n_batch if xsp.n_batch < 1 else xsp.n_batch + 1) - self.assertArraysEqual(xsp[:, None].todense(), x[:, None]) - self.assertArraysEqual(xsp[:, None, None].todense(), x[:, None, None]) - if len(shape) >= 2: - self.assertEqual(xsp[:, :, None].n_batch, xsp.n_batch if xsp.n_batch < 2 else xsp.n_batch + 1) - self.assertArraysEqual(xsp[:, :, None].todense(), x[:, :, None]) - self.assertArraysEqual(xsp[:, None, :, None].todense(), x[:, None, :, None]) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, dimension=dimension) - for shape in [ (3,), (3, 5), (3, 5, 4)] - for layout in iter_sparse_layouts(shape) - for dimension in range(len(shape) - layout.n_dense) # Concatenation of dense dimensions not implemented. - ], - dtype=all_dtypes, - ) - def test_bcoo_concatenate(self, shape, dtype, n_batch, n_dense, dimension): - sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) - args_maker = lambda: [[sprng(shape, dtype) for i in range(3)]] - dense_func = partial(lax.concatenate, dimension=dimension) - sparse_func = partial(sparse.bcoo_concatenate, dimension=dimension) - - self._CheckAgainstDense(dense_func, sparse_func, args_maker) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(dense_func, sparse_func, args_maker) - - @jtu.sample_product( - lhs_shape=[(1, 1, 5), (1, 1, 10)], - rhs_shape=[(1, 1, 5), (1, 1, 10)], - padding=['SAME', 'VALID', [(3, 3)]], - dtype=jtu.dtypes.inexact, - format=['sp-de', 'de-sp', 'sp-sp'] - ) - @jax.default_matmul_precision("float32") - @jtu.skip_on_flag("jax_skip_slow_tests", True) - def test_bcoo_conv_general_dilated(self, lhs_shape, rhs_shape, dtype, padding, format): - kwds = dict(window_strides=(1,), padding=padding) - sparse_fun = partial(sparse.bcoo_conv_general_dilated, **kwds) - dense_fun = partial(lax.conv_general_dilated, **kwds) - sprng = sptu.rand_bcoo(self.rng(), n_batch=2, n_dense=0) - rng = jtu.rand_default(self.rng()) - - def args_maker(): - lhs = (sprng if format.startswith('sp') else rng)(lhs_shape, dtype) - rhs = (sprng if format.endswith('sp') else rng)(rhs_shape, dtype) - return lhs, rhs - - tol = {np.float32: 1E-5, np.complex64: 1E-5, np.float64: 1E-14, np.complex128: 1E-14} - self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol) - - def test_bcoo_vmap_shape(self, shape=(2, 3, 4, 5), dtype=np.float32): - # This test checks that BCOO shape metadata interacts correctly with vmap. - rng = rand_sparse(self.rng()) - M = rng(shape, dtype) - - def make_bcoo(M): - return sparse_bcoo._bcoo_fromdense(M, nse=math.prod(M.shape[:-1]), n_dense=1) - - todense = partial(sparse_bcoo._bcoo_todense, spinfo=sparse_util.SparseInfo(shape)) - - for _ in range(3): - make_bcoo = jax.vmap(make_bcoo) - Msp_data, Msp_indices = make_bcoo(M) - Msp_dense = todense(Msp_data, Msp_indices) - self.assertEqual(Msp_dense.shape, M.shape) - self.assertArraysEqual(Msp_dense, M) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, - n_batch_out=layout_out.n_batch, n_dense_out=layout_out.n_dense) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape) - for layout_out in iter_sparse_layouts(shape) - ], - dtype=jtu.dtypes.integer, - ) - def test_bcoo_update_layout(self, shape, dtype, n_batch, n_batch_out, n_dense, n_dense_out): - rng = rand_sparse(self.rng()) - mat = sparse.BCOO.fromdense(rng(shape, dtype), n_batch=n_batch, n_dense=n_dense) - kwds = dict(n_batch=n_batch_out, n_dense=n_dense_out) - # TODO(jakevdp): in case of length-0 or length-1 shapes errors/warnings will not be raised. - if n_dense_out > n_dense or n_batch_out > n_batch: - with self.assertRaises(sparse.SparseEfficiencyError): - sparse.bcoo_update_layout(mat, **kwds) - with self.assertRaises(sparse.SparseEfficiencyError): - sparse.bcoo_update_layout(mat, **kwds, on_inefficient='error') - with self.assertWarns(sparse.SparseEfficiencyWarning): - sparse.bcoo_update_layout(mat, **kwds, on_inefficient='warn') - kwds['on_inefficient'] = None - mat_new = sparse.bcoo_update_layout(mat, **kwds) - self.assertEqual(mat_new.n_batch, n_batch_out) - self.assertEqual(mat_new.n_dense, n_dense_out) - self.assertArraysEqual(mat.todense(), mat_new.todense()) - - def test_bcoo_update_layout_method(self, shape=(2, 3, 4)): - # simple test to make sure update_layout method properly forwards. - rng = rand_sparse(self.rng()) - mat = sparse.BCOO.fromdense(rng((2, 3, 4), 'float32'), n_batch=1, n_dense=1) - mat_new = mat.update_layout(n_batch=0, n_dense=0) - self.assertEqual(mat_new.n_batch, 0) - self.assertEqual(mat_new.n_dense, 0) - self.assertArraysEqual(mat.todense(), mat_new.todense()) - - def test_bcoo_bad_fillvals(self): - # Extra values have 100 rather than zero. This lets us check that logic is - # properly ignoring these indices. - data = jnp.array([1, 2, 3, 100, 100]) - indices = jnp.array([1, 2, 3, 5, 5])[:, None] - x_sp = sparse.BCOO((data, indices), shape=(5,)) - x_de = x_sp.todense() - - data = jnp.array([3, 2, 100, 100]) - indices = jnp.array([2, 3, 5, 5])[:, None] - y_sp = sparse.BCOO((data, indices), shape=(5,)) - y_de = y_sp.todense() - - self.assertArraysEqual(x_de, jnp.array([0, 1, 2, 3, 0])) - self.assertArraysEqual(y_de, jnp.array([0, 0, 3, 2, 0])) - - self.assertArraysEqual(x_sp.sum_duplicates().todense(), x_de) - self.assertArraysEqual(y_sp.sum_duplicates().todense(), y_de) - - # reduce_sum: - self.assertArraysEqual(x_sp.sum(), x_de.sum()) - - # bcoo_dot_general - self.assertArraysEqual(x_sp @ y_de, x_de @ y_de) - - # bcoo_rdot_general - self.assertArraysEqual(x_de @ y_sp, x_de @ y_de) - - # bcoo_spdot_general - self.assertArraysEqual((x_sp @ y_sp).todense(), x_de @ y_de) - self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de) - - -# TODO(tianjianlu): Unify the testing for BCOOTest and BCSRTest. -class BCSRTest(sptu.SparseTestCase): - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_bcsr_layouts(shape)], - dtype=all_dtypes, - ) - def test_bcsr_dense_round_trip(self, shape, dtype, n_batch, n_dense): - n_sparse = len(shape) - n_batch - n_dense - rng = rand_sparse(self.rng()) - M = rng(shape, dtype) - nse = sparse.util._count_stored_elements(M, n_batch=n_batch, n_dense=n_dense) - def round_trip(M): - return sparse.BCSR.fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense).todense() - args_maker = lambda: [M] - ident = lambda x: x - - self._CheckAgainstNumpy(ident, round_trip, args_maker) - self._CompileAndCheck(round_trip, args_maker) - self._CheckBatchingSparse(ident, round_trip, args_maker, bdims=self._random_bdims(n_batch)) - if jnp.issubdtype(dtype, jnp.floating): - # For n_sparse != 0, we can't use an identity because output zeros must not - # be dependent on input zeros. This mimics the code in count_stored_elements(). - def expected(M): - if n_sparse == 0: return M - mask = (M != 0).any(range(M.ndim - n_dense, M.ndim), keepdims=True) - return jnp.where(mask, M, 0) - self._CheckGradsSparse(expected, round_trip, args_maker) - - @jtu.sample_product( - [dict(shape=shape, n_batch=n_batch) - for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for n_batch in range(len(shape) - 1) - ], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - def test_bcsr_bcoo_round_trip(self, shape, n_batch, dtype): - n_sparse = 2 - n_dense = len(shape) - n_sparse - n_batch - rng = self.rng() - sprng = sptu.rand_bcsr(rng, n_batch=n_batch, n_dense=n_dense) - - M_bcsr = sprng(shape, dtype) - self.assertIsInstance(M_bcsr, sparse.BCSR) - - M_dense = M_bcsr.todense() - M_bcoo = M_bcsr.to_bcoo() - self.assertIsInstance(M_bcoo, sparse.BCOO) - self.assertAllClose(M_dense, M_bcoo.todense()) - - M_bcsr2 = sparse.BCSR.from_bcoo(M_bcoo) - self.assertAllClose(M_dense, M_bcsr2.todense()) - self.assertArraysEqual(M_bcsr.indptr, M_bcsr2.indptr) - - # TODO(jakevdp): This will only be true in general when M_bcsr.indices is sorted. - # self.assertSparseArraysEquivalent(M_bcsr, M_bcsr2) - - @jtu.sample_product( - [dict(shape=shape, n_batch=n_batch) - for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for n_batch in range(len(shape) - 1) - ], - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - def test_bcsr_extract(self, shape, dtype, n_batch): - n_dense = len(shape) - n_batch - 2 - rng = rand_sparse(self.rng()) - M = rng(shape, dtype) - nse = sparse.util._count_stored_elements(M, n_batch=n_batch, - n_dense=n_dense) - data, indices, indptr = sparse_bcsr._bcsr_fromdense( - M, nse=nse, n_batch=n_batch, n_dense=n_dense) - data2 = sparse.bcsr_extract(indices, indptr, M) - self.assertArraysEqual(data, data2) - args_maker_bcsr_extract = lambda: [indices, indptr, M] - self._CompileAndCheck(sparse.bcsr_extract, args_maker_bcsr_extract) - - @jtu.sample_product( - props=_generate_batched_dot_general_properties( - shapes=((2, 3), (2, 3, 4), (2, 3, 4, 4)), sparse_format='bcsr'), - dtype=jtu.dtypes.floating + jtu.dtypes.complex, - ) - @jax.default_matmul_precision("float32") - def test_bcsr_dot_general(self, dtype: np.dtype, props: BatchedDotGeneralProperties): - rng = jtu.rand_default(self.rng()) - sprng = sptu.rand_bcsr(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense) - args_maker = lambda: [sprng(props.lhs_shape, dtype), - rng(props.rhs_shape, dtype)] - dense_fun = partial(lax.dot_general, - dimension_numbers=props.dimension_numbers) - sparse_fun = partial(sparse.bcsr_dot_general, - dimension_numbers=props.dimension_numbers) - - tol = {np.float64: 1E-12, np.complex128: 1E-12, - np.float32: 1E-5, np.complex64: 1E-5} - - self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol) - if jnp.issubdtype(dtype, jnp.floating) and props.n_dense == 0: - # Dense dimensions not yet fully supported in reverse mode. - modes = ['fwd'] if props.n_dense != 0 else ['fwd', 'rev'] - self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=modes, atol=tol, rtol=tol) - self._CheckBatchingSparse(dense_fun, sparse_fun, args_maker, atol=tol, rtol=tol, - bdims=self._random_bdims(props.n_batch, len(props.rhs_shape))) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(3, 5), (3, 5, 4)] - for layout in iter_bcsr_layouts(shape)], - dtype=all_dtypes, - ) - def test_bcsr_broadcast_in_dim(self, shape, dtype, n_batch, n_dense): - rng = rand_sparse(self.rng()) - x = jnp.array(rng(shape, dtype)) - xsp = sparse.BCSR.fromdense(x, n_batch=n_batch, n_dense=n_dense) - - self.assertEqual(xsp[None].n_batch, xsp.n_batch + 1) - self.assertArraysEqual(xsp[None].todense(), x[None]) - - if n_batch == 1: - self.assertEqual(xsp[:, None].n_batch, xsp.n_batch + 1) - self.assertArraysEqual(xsp[:, None].todense(), x[:, None]) - - @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, dimension=dimension) - for shape in [(3, 5), (3, 5, 4)] - for layout in iter_sparse_layouts(shape) - for dimension in range(len(shape) - layout.n_dense) # Concatenation of dense dimensions not implemented. - ], - dtype=all_dtypes, - ) - def test_bcsr_concatenate(self, shape, dtype, n_batch, n_dense, dimension): - sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch, n_dense=n_dense) - args_maker = lambda: [[sprng(shape, dtype) for i in range(3)]] - dense_func = partial(lax.concatenate, dimension=dimension) - sparse_func = partial(sparse.bcoo_concatenate, dimension=dimension) - - self._CheckAgainstDense(dense_func, sparse_func, args_maker) - if jnp.issubdtype(dtype, jnp.floating): - self._CheckGradsSparse(dense_func, sparse_func, args_maker) - - class SparseGradTest(sptu.SparseTestCase): @jtu.sample_product(has_aux=[True, False]) def test_sparse_value_and_grad(self, has_aux): - rng_sparse = rand_sparse(self.rng()) + rng_sparse = sptu.rand_sparse(self.rng()) rng = jtu.rand_default(self.rng()) y = rng(5, "float32") @@ -2364,7 +665,7 @@ class SparseGradTest(sptu.SparseTestCase): @jtu.sample_product(has_aux=[True, False]) def test_sparse_grad(self, has_aux): - rng_sparse = rand_sparse(self.rng()) + rng_sparse = sptu.rand_sparse(self.rng()) rng = jtu.rand_default(self.rng()) y = rng(5, "float32") @@ -2399,7 +700,7 @@ class SparseGradTest(sptu.SparseTestCase): jac_dense = getattr(jax, transform) jac_sparse = getattr(sparse, transform) - rng_sparse = rand_sparse(self.rng()) + rng_sparse = sptu.rand_sparse(self.rng()) rng = jtu.rand_default(self.rng()) y = rng(5, "float32") @@ -2538,7 +839,7 @@ class SparseObjectTest(sptu.SparseTestCase): {"testcase_name": f"_{Obj.__name__}", "Obj": Obj} for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]) def test_block_until_ready(self, Obj, shape=(5, 8), dtype=np.float32): - rng = rand_sparse(self.rng(), post=Obj.fromdense) + rng = sptu.rand_sparse(self.rng(), post=Obj.fromdense) M = rng(shape, dtype) self.assertEqual(M.shape, M.block_until_ready().shape) self.assertArraysEqual(M.data, M.block_until_ready().data) @@ -2548,7 +849,7 @@ class SparseObjectTest(sptu.SparseTestCase): {"testcase_name": f"_{Obj.__name__}", "Obj": Obj} for Obj in [jnp.array, sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]) def test_todense(self, Obj, shape=(5, 8), dtype=np.float32): - rng = rand_sparse(self.rng()) + rng = sptu.rand_sparse(self.rng()) M_dense = rng(shape, dtype) M = jnp.array(M_dense) if Obj is jnp.array else Obj.fromdense(M_dense) self.assertArraysEqual(sparse.todense(M), M_dense) @@ -2562,7 +863,7 @@ class SparseObjectTest(sptu.SparseTestCase): {"testcase_name": f"_{Obj.__name__}", "Obj": Obj} for Obj in [jnp.array, sparse.BCOO]) def test_todense_batching(self, Obj, shape=(5, 8), dtype=np.float32): - rng = rand_sparse(self.rng()) + rng = sptu.rand_sparse(self.rng()) M_dense = rng(shape, dtype) if Obj is sparse.BCOO: M = sparse.BCOO.fromdense(M_dense, n_batch=1) @@ -2588,7 +889,7 @@ class SparseObjectTest(sptu.SparseTestCase): {"testcase_name": f"_{Obj.__name__}", "Obj": Obj} for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO, sparse.BCSR]) def test_attrs(self, Obj, shape=(5, 8), dtype=np.float32): - rng = rand_sparse(self.rng(), post=Obj.fromdense) + rng = sptu.rand_sparse(self.rng(), post=Obj.fromdense) M = rng(shape, dtype) self.assertIsInstance(M, Obj) @@ -2629,7 +930,7 @@ class SparseObjectTest(sptu.SparseTestCase): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) def test_dense_round_trip(self, shape, dtype, Obj): - rng = rand_sparse(self.rng()) + rng = sptu.rand_sparse(self.rng()) M = rng(shape, dtype) Msparse = Obj.fromdense(M) self.assertArraysEqual(M, Msparse.todense()) @@ -2642,7 +943,7 @@ class SparseObjectTest(sptu.SparseTestCase): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) def test_transpose(self, shape, dtype, Obj): - rng = rand_sparse(self.rng()) + rng = sptu.rand_sparse(self.rng()) M = rng(shape, dtype) Msparse = Obj.fromdense(M) self.assertArraysEqual(M.T, Msparse.T.todense()) @@ -2659,7 +960,7 @@ class SparseObjectTest(sptu.SparseTestCase): for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") def test_matmul(self, shape, dtype, Obj, bshape): - rng = rand_sparse(self.rng(), post=jnp.array) + rng = sptu.rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) M = rng(shape, dtype) Msp = Obj.fromdense(M) @@ -2667,13 +968,15 @@ class SparseObjectTest(sptu.SparseTestCase): # Test matching type x = rng_b(bshape, dtype) x = jnp.asarray(x) - self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL, atol=MATMUL_TOL) + self.assertAllClose( + M @ x, Msp @ x, rtol=sptu.MATMUL_TOL, atol=sptu.MATMUL_TOL + ) # Test mismatched type x = rng_b(bshape, np.int32) x = jnp.asarray(x) with jax.numpy_dtype_promotion('standard'): - self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL) + self.assertAllClose(M @ x, Msp @ x, rtol=sptu.MATMUL_TOL) @jtu.sample_product( cls=[sparse.BCOO, sparse.BCSR], @@ -2684,7 +987,7 @@ class SparseObjectTest(sptu.SparseTestCase): ) def test_bcoo_bcsr_from_scipy_sparse(self, cls, input_type, shape, dtype): """Test BCOO and BCSR from_scipy_sparse.""" - rng = rand_sparse(self.rng()) + rng = sptu.rand_sparse(self.rng()) M = rng(shape, dtype) M_scipy = input_type(M) M_jax = cls.from_scipy_sparse(M_scipy) @@ -2715,7 +1018,7 @@ class SparseObjectTest(sptu.SparseTestCase): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) def test_bcoo_to_bcsr_round_trip(self, shape, dtype, n_batch): - rng = rand_sparse(self.rng()) + rng = sptu.rand_sparse(self.rng()) M = rng(shape, dtype) n_dense = len(shape) - 2 - n_batch nse = sparse.util._count_stored_elements(M, n_batch=n_batch, @@ -2742,12 +1045,15 @@ class SparseObjectTest(sptu.SparseTestCase): class SparseRandomTest(sptu.SparseTestCase): + @jtu.sample_product( - [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) - for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for layout in iter_sparse_layouts(shape)], - dtype=jtu.dtypes.floating, - indices_dtype=jtu.dtypes.integer, + [ + dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense) + for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] + for layout in sptu.iter_sparse_layouts(shape) + ], + dtype=jtu.dtypes.floating, + indices_dtype=jtu.dtypes.integer, ) def test_random_bcoo(self, shape, dtype, indices_dtype, n_batch, n_dense): key = jax.random.PRNGKey(1701) @@ -2779,9 +1085,9 @@ class SparseSolverTest(sptu.SparseTestCase): ) @jtu.run_on_devices("cpu", "cuda") def test_sparse_qr_linear_solver(self, size, reorder, dtype): - if jtu.test_device_matches(["cuda"]) and not GPU_LOWERING_ENABLED: + if jtu.test_device_matches(["cuda"]) and not sptu.GPU_LOWERING_ENABLED: raise unittest.SkipTest('test requires cusparse/cusolver') - rng = rand_sparse(self.rng()) + rng = sptu.rand_sparse(self.rng()) a = rng((size, size), dtype) nse = (a != 0).sum() data, indices, indptr = sparse_csr._csr_fromdense(a, nse=nse) @@ -2806,9 +1112,9 @@ class SparseSolverTest(sptu.SparseTestCase): ) @jtu.run_on_devices("cpu", "cuda") def test_sparse_qr_linear_solver_grads(self, size, dtype): - if jtu.test_device_matches(["cuda"]) and not GPU_LOWERING_ENABLED: + if jtu.test_device_matches(["cuda"]) and not sptu.GPU_LOWERING_ENABLED: raise unittest.SkipTest('test requires cusparse/cusolver') - rng = rand_sparse(self.rng()) + rng = sptu.rand_sparse(self.rng()) a = rng((size, size), dtype) nse = (a != 0).sum() data, indices, indptr = sparse_csr._csr_fromdense(a, nse=nse) @@ -2825,11 +1131,18 @@ class SparseSolverTest(sptu.SparseTestCase): class SparseUtilTest(sptu.SparseTestCase): @jtu.sample_product( - [dict(n_batch=n_batch, n_dense=n_dense, expected_nse=expected_nse) - for n_batch, n_dense, expected_nse in - [(0, 0, 4), (1, 0, 2), (0, 1, 2), (2, 0, 1), (1, 1, 1), (0, 2, 1)] - ], - dtype=all_dtypes, + [ + dict(n_batch=n_batch, n_dense=n_dense, expected_nse=expected_nse) + for n_batch, n_dense, expected_nse in [ + (0, 0, 4), + (1, 0, 2), + (0, 1, 2), + (2, 0, 1), + (1, 1, 1), + (0, 2, 1), + ] + ], + dtype=all_dtypes, ) def test_count_stored_elements(self, dtype, n_batch, n_dense, expected_nse): """Test counting nse.""" @@ -2839,11 +1152,12 @@ class SparseUtilTest(sptu.SparseTestCase): self.assertEqual(expected_nse, actual_nse) @jtu.sample_product( - [dict(n_batch=n_batch, n_dense=n_dense) - for n_batch in range(3) - for n_dense in range(3 - n_batch) - ], - dtype=all_dtypes, + [ + dict(n_batch=n_batch, n_dense=n_dense) + for n_batch in range(3) + for n_dense in range(3 - n_batch) + ], + dtype=all_dtypes, ) def test_count_stored_elements_empty(self, dtype, n_batch, n_dense): mat = np.empty((0, 4), dtype=dtype) @@ -2852,12 +1166,16 @@ class SparseUtilTest(sptu.SparseTestCase): self.assertEqual(0, actual_nse) @jtu.sample_product( - [dict(n_batch=n_batch, n_dense=n_dense, expected_nse=expected_nse) - for n_batch, n_dense, expected_nse in - [(0, 0, 14), (1, 0, np.array([6, 8])), (0, 1, 9), - (2, 0, np.array([[3, 3], [4, 4]]))] - ], - dtype=all_dtypes + [ + dict(n_batch=n_batch, n_dense=n_dense, expected_nse=expected_nse) + for n_batch, n_dense, expected_nse in [ + (0, 0, 14), + (1, 0, np.array([6, 8])), + (0, 1, 9), + (2, 0, np.array([[3, 3], [4, 4]])), + ] + ], + dtype=all_dtypes, ) def test_count_stored_elements_per_batch(self, dtype, n_batch, n_dense, expected_nse):