mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
1951 lines
85 KiB
Python
1951 lines
85 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import contextlib
|
|
from functools import partial
|
|
import itertools
|
|
import operator
|
|
import random
|
|
import unittest
|
|
from typing import NamedTuple, Tuple
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import jax
|
|
import jax.random
|
|
from jax import config
|
|
from jax import dtypes
|
|
from jax.experimental import sparse
|
|
from jax.experimental.sparse.bcoo import BCOOInfo
|
|
from jax import lax
|
|
from jax._src.lib import cusparse
|
|
from jax._src.lib import hipsparse
|
|
from jax._src.lib import xla_bridge
|
|
from jax import jit
|
|
from jax import tree_util
|
|
from jax import vmap
|
|
from jax._src import test_util as jtu
|
|
from jax._src.lax.lax import remaining, DotDimensionNumbers
|
|
from jax import xla
|
|
import jax.numpy as jnp
|
|
from jax.util import split_list
|
|
import numpy as np
|
|
import scipy.sparse
|
|
|
|
config.parse_flags_with_absl()
|
|
FLAGS = config.FLAGS
|
|
|
|
MATMUL_TOL = {
|
|
np.float32: 1E-5,
|
|
np.float64: 1E-10,
|
|
np.complex64: 1e-5,
|
|
np.complex128: 1E-10,
|
|
}
|
|
|
|
class BcooDotGeneralProperties(NamedTuple):
|
|
lhs_shape: Tuple[int]
|
|
rhs_shape: Tuple[int]
|
|
dtype: np.dtype
|
|
n_batch: int
|
|
n_dense: int
|
|
dimension_numbers: DotDimensionNumbers
|
|
|
|
def testcase_name(self):
|
|
return "_{}_{}_nbatch={}_ndense={}_dimension_numbers={}".format(
|
|
jtu.format_shape_dtype_string(self.lhs_shape, self.dtype),
|
|
jtu.format_shape_dtype_string(self.rhs_shape, self.dtype),
|
|
self.n_batch, self.n_dense, self.dimension_numbers)
|
|
|
|
|
|
def _iter_subsets(s):
|
|
return itertools.chain.from_iterable(itertools.combinations(s, n) for n in range(len(s) + 1))
|
|
|
|
def _generate_bcoo_dot_general_properties(shapes, dtypes) -> BcooDotGeneralProperties:
|
|
"""Generator of properties for bcoo_dot_general tests."""
|
|
rng = random.Random(0)
|
|
|
|
for shape in shapes:
|
|
for n_batch in range(len(shape) + 1):
|
|
for n_dense in range(len(shape) + 1 - n_batch):
|
|
n_sparse = len(shape) - n_batch - n_dense
|
|
subsets = split_list(range(len(shape)), [n_batch, n_sparse])
|
|
for batch_dims in _iter_subsets(range(n_batch)):
|
|
for contracting_dims in _iter_subsets(remaining(range(n_batch + n_sparse), batch_dims)):
|
|
# We want coverage of permutations & dtypes without generating hundreds of thousands
|
|
# of test cases; we do this by deterministic pseudo-random sampling instead of iterating.
|
|
rhs_permute = rng.sample(range(len(shape)), len(shape))
|
|
lhs_permute = list(itertools.chain.from_iterable(
|
|
rng.sample(subset, len(subset)) for subset in subsets))
|
|
yield BcooDotGeneralProperties(
|
|
lhs_shape=tuple(shape[p] for p in lhs_permute),
|
|
rhs_shape=tuple(shape[p] for p in rhs_permute),
|
|
dtype=rng.choice(dtypes),
|
|
n_batch=n_batch,
|
|
n_dense=n_dense,
|
|
dimension_numbers=(
|
|
([lhs_permute.index(d) for d in contracting_dims], [rhs_permute.index(d) for d in contracting_dims]),
|
|
([lhs_permute.index(d) for d in batch_dims], [rhs_permute.index(d) for d in batch_dims])
|
|
),
|
|
)
|
|
|
|
|
|
all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
|
|
def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
|
|
def _rand_sparse(shape, dtype, nse=nse):
|
|
rand = rand_method(rng)
|
|
size = np.prod(shape).astype(int)
|
|
if 0 <= nse < 1:
|
|
nse = nse * size
|
|
nse = min(size, int(nse))
|
|
M = rand(shape, dtype)
|
|
indices = rng.choice(size, size - nse, replace=False)
|
|
M.flat[indices] = 0
|
|
return post(M)
|
|
return _rand_sparse
|
|
|
|
|
|
class cuSparseTest(jtu.JaxTestCase):
|
|
def gpu_dense_conversion_warning_context(self, dtype):
|
|
if jtu.device_under_test() == "gpu" and np.issubdtype(dtype, np.integer):
|
|
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
|
|
return contextlib.nullcontext()
|
|
|
|
def gpu_matmul_warning_context(self, dtype):
|
|
if jtu.device_under_test() == "gpu" and dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
|
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
|
|
return contextlib.nullcontext()
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in all_dtypes))
|
|
def test_csr_todense(self, shape, dtype):
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix)
|
|
M = rng(shape, dtype)
|
|
|
|
args = (M.data, M.indices, M.indptr)
|
|
todense = lambda *args: sparse.csr_todense(*args, shape=M.shape)
|
|
|
|
self.assertArraysEqual(M.toarray(), todense(*args))
|
|
with self.gpu_dense_conversion_warning_context(dtype):
|
|
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_csr_todense_ad(self, shape, dtype):
|
|
rng = rand_sparse(self.rng(), post=jnp.array)
|
|
M = rng(shape, dtype)
|
|
data, indices, indptr = sparse.csr_fromdense(M, nse=(M != 0).sum())
|
|
row, col = sparse.util._csr_to_coo(indices, indptr)
|
|
f = lambda data: sparse.csr_todense(data, indices, indptr, shape=M.shape)
|
|
|
|
# Forward-mode
|
|
primals, tangents = jax.jvp(f, [data], [jnp.ones_like(data)])
|
|
self.assertArraysEqual(primals, f(data))
|
|
self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1))
|
|
|
|
# Reverse-mode
|
|
primals, vjp_fun = jax.vjp(f, data)
|
|
data_out, = vjp_fun(primals)
|
|
self.assertArraysEqual(primals, f(data))
|
|
self.assertArraysEqual(data_out, data)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_csr_fromdense_ad(self, shape, dtype):
|
|
rng = rand_sparse(self.rng(), post=jnp.array)
|
|
M = rng(shape, dtype)
|
|
nse = (M != 0).sum()
|
|
f = lambda M: sparse.csr_fromdense(M, nse=nse)
|
|
|
|
# Forward-mode
|
|
primals, tangents = jax.jvp(f, [M], [jnp.ones_like(M)])
|
|
self.assertArraysEqual(primals[0], f(M)[0])
|
|
self.assertArraysEqual(primals[1], f(M)[1])
|
|
self.assertArraysEqual(primals[2], f(M)[2])
|
|
self.assertArraysEqual(tangents[0], jnp.ones(nse, dtype=dtype))
|
|
self.assertEqual(tangents[1].dtype, dtypes.float0)
|
|
self.assertEqual(tangents[2].dtype, dtypes.float0)
|
|
|
|
# Reverse-mode
|
|
primals, vjp_fun = jax.vjp(f, M)
|
|
M_out, = vjp_fun(primals)
|
|
self.assertArraysEqual(primals[0], f(M)[0])
|
|
self.assertArraysEqual(primals[1], f(M)[1])
|
|
self.assertArraysEqual(primals[2], f(M)[2])
|
|
self.assertArraysEqual(M_out, M)
|
|
|
|
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_{}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
jtu.format_shape_dtype_string(bshape, dtype)),
|
|
"shape": shape, "dtype": dtype, "bshape": bshape}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for bshape in [shape[-1:] + s for s in [(), (1,), (3,)]]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_csr_matmul_ad(self, shape, dtype, bshape):
|
|
csr_matmul = sparse.csr_matvec if len(bshape) == 1 else sparse.csr_matmat
|
|
tol = {np.float32: 1E-5, np.float64: 1E-12, np.complex64: 1E-5, np.complex128: 1E-12}
|
|
|
|
rng = rand_sparse(self.rng(), post=jnp.array)
|
|
rng_b = jtu.rand_default(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
|
data, indices, indptr = sparse.csr_fromdense(M, nse=(M != 0).sum())
|
|
x = rng_b(bshape, dtype)
|
|
xdot = rng_b(bshape, dtype)
|
|
|
|
# Forward-mode with respect to the vector
|
|
f_dense = lambda x: M @ x
|
|
f_sparse = lambda x: csr_matmul(data, indices, indptr, x, shape=M.shape)
|
|
v_sparse, t_sparse = jax.jvp(f_sparse, [x], [xdot])
|
|
v_dense, t_dense = jax.jvp(f_dense, [x], [xdot])
|
|
self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
|
|
self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)
|
|
|
|
# Reverse-mode with respect to the vector
|
|
primals_dense, vjp_dense = jax.vjp(f_dense, x)
|
|
primals_sparse, vjp_sparse = jax.vjp(f_sparse, x)
|
|
out_dense, = vjp_dense(primals_dense)
|
|
out_sparse, = vjp_sparse(primals_sparse)
|
|
self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
|
|
self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
|
|
|
|
# Forward-mode with respect to nonzero elements of the matrix
|
|
f_sparse = lambda data: csr_matmul(data, indices, indptr, x, shape=M.shape)
|
|
f_dense = lambda data: sparse.csr_todense(data, indices, indptr, shape=M.shape) @ x
|
|
data = rng((len(data),), data.dtype)
|
|
data_dot = rng((len(data),), data.dtype)
|
|
v_sparse, t_sparse = jax.jvp(f_sparse, [data], [data_dot])
|
|
v_dense, t_dense = jax.jvp(f_dense, [data], [data_dot])
|
|
|
|
self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
|
|
self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)
|
|
|
|
# Reverse-mode with respect to nonzero elements of the matrix
|
|
primals_dense, vjp_dense = jax.vjp(f_dense, data)
|
|
primals_sparse, vjp_sparse = jax.vjp(f_sparse, data)
|
|
out_dense, = vjp_dense(primals_dense)
|
|
out_sparse, = vjp_sparse(primals_sparse)
|
|
self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
|
|
self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in all_dtypes))
|
|
def test_csr_fromdense(self, shape, dtype):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
M_csr = scipy.sparse.csr_matrix(M)
|
|
|
|
nse = M_csr.nnz
|
|
index_dtype = jnp.int32
|
|
fromdense = lambda M: sparse.csr_fromdense(M, nse=nse, index_dtype=jnp.int32)
|
|
|
|
data, indices, indptr = fromdense(M)
|
|
self.assertArraysEqual(data, M_csr.data.astype(dtype))
|
|
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
|
|
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
|
|
|
|
with self.gpu_dense_conversion_warning_context(dtype):
|
|
data, indices, indptr = jit(fromdense)(M)
|
|
self.assertArraysEqual(data, M_csr.data.astype(dtype))
|
|
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
|
|
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
|
"shape": shape, "dtype": dtype, "transpose": transpose}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in all_dtypes
|
|
for transpose in [True, False]))
|
|
@jtu.skip_on_devices("rocm") # will be fixed in rocm-5.1
|
|
def test_csr_matvec(self, shape, dtype, transpose):
|
|
op = lambda M: M.T if transpose else M
|
|
|
|
v_rng = jtu.rand_default(self.rng())
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix)
|
|
M = rng(shape, dtype)
|
|
v = v_rng(op(M).shape[1], dtype)
|
|
|
|
args = (M.data, M.indices, M.indptr, v)
|
|
matvec = lambda *args: sparse.csr_matvec(*args, shape=M.shape, transpose=transpose)
|
|
|
|
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
|
|
with self.gpu_matmul_warning_context(dtype):
|
|
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
|
"shape": shape, "dtype": dtype, "transpose": transpose}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in all_dtypes
|
|
for transpose in [True, False]))
|
|
def test_csr_matmat(self, shape, dtype, transpose):
|
|
op = lambda M: M.T if transpose else M
|
|
|
|
B_rng = jtu.rand_default(self.rng())
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix)
|
|
M = rng(shape, dtype)
|
|
B = B_rng((op(M).shape[1], 4), dtype)
|
|
|
|
args = (M.data, M.indices, M.indptr, B)
|
|
matmat = lambda *args: sparse.csr_matmat(*args, shape=shape, transpose=transpose)
|
|
|
|
self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
|
|
with self.gpu_matmul_warning_context(dtype):
|
|
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in all_dtypes))
|
|
def test_coo_todense(self, shape, dtype):
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix)
|
|
M = rng(shape, dtype)
|
|
|
|
args = (M.data, M.row, M.col)
|
|
todense = lambda *args: sparse.coo_todense(*args, shape=M.shape)
|
|
|
|
self.assertArraysEqual(M.toarray(), todense(*args))
|
|
with self.gpu_dense_conversion_warning_context(dtype):
|
|
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in all_dtypes))
|
|
def test_coo_fromdense(self, shape, dtype):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
M_coo = scipy.sparse.coo_matrix(M)
|
|
|
|
nse = M_coo.nnz
|
|
index_dtype = jnp.int32
|
|
fromdense = lambda M: sparse.coo_fromdense(M, nse=nse, index_dtype=jnp.int32)
|
|
|
|
data, row, col = fromdense(M)
|
|
self.assertArraysEqual(data, M_coo.data.astype(dtype))
|
|
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
|
|
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
|
|
|
|
with self.gpu_dense_conversion_warning_context(dtype):
|
|
data, indices, indptr = jit(fromdense)(M)
|
|
self.assertArraysEqual(data, M_coo.data.astype(dtype))
|
|
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
|
|
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
|
"shape": shape, "dtype": dtype, "transpose": transpose}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in all_dtypes
|
|
for transpose in [True, False]))
|
|
def test_coo_matvec(self, shape, dtype, transpose):
|
|
op = lambda M: M.T if transpose else M
|
|
|
|
v_rng = jtu.rand_default(self.rng())
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix)
|
|
M = rng(shape, dtype)
|
|
v = v_rng(op(M).shape[1], dtype)
|
|
|
|
args = (M.data, M.row, M.col, v)
|
|
matvec = lambda *args: sparse.coo_matvec(*args, shape=M.shape, transpose=transpose)
|
|
|
|
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
|
|
with self.gpu_matmul_warning_context(dtype):
|
|
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
|
"shape": shape, "dtype": dtype, "transpose": transpose}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in all_dtypes
|
|
for transpose in [True, False]))
|
|
@jtu.skip_on_devices("rocm") # will be fixed in rocm-5.1
|
|
def test_coo_matmat(self, shape, dtype, transpose):
|
|
op = lambda M: M.T if transpose else M
|
|
|
|
B_rng = jtu.rand_default(self.rng())
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix)
|
|
M = rng(shape, dtype)
|
|
B = B_rng((op(M).shape[1], 4), dtype)
|
|
|
|
args = (M.data, M.row, M.col, B)
|
|
matmat = lambda *args: sparse.coo_matmat(*args, shape=shape, transpose=transpose)
|
|
|
|
self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
|
|
with self.gpu_matmul_warning_context(dtype):
|
|
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
|
|
|
|
def test_coo_matmat_layout(self):
|
|
# Regression test for https://github.com/google/jax/issues/7533
|
|
d = jnp.array([1.0, 2.0, 3.0, 4.0])
|
|
i = jnp.array([0, 0, 1, 2])
|
|
j = jnp.array([0, 2, 0, 0])
|
|
shape = (3, 3)
|
|
|
|
x = jnp.arange(9).reshape(3, 3).astype(d.dtype)
|
|
|
|
def f(x):
|
|
return sparse.coo_matmat(d, i, j, x.T, shape=shape)
|
|
|
|
result = f(x)
|
|
result_jit = jit(f)(x)
|
|
|
|
self.assertAllClose(result, result_jit)
|
|
|
|
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
|
def test_gpu_translation_rule(self):
|
|
version = xla_bridge.get_backend().platform_version
|
|
if version.split()[0] != "rocm":
|
|
cuda_version = None if version == "<unknown>" else int(
|
|
version.split()[-1])
|
|
if cuda_version is None or cuda_version < 11000:
|
|
self.assertFalse(cusparse and cusparse.is_supported)
|
|
self.assertNotIn(sparse.csr_todense_p,
|
|
xla._backend_specific_translations["gpu"])
|
|
else:
|
|
self.assertTrue(cusparse and cusparse.is_supported)
|
|
self.assertIn(sparse.csr_todense_p,
|
|
xla._backend_specific_translations["gpu"])
|
|
else:
|
|
self.assertTrue(hipsparse and hipsparse.is_supported)
|
|
self.assertIn(sparse.csr_todense_p,
|
|
xla._backend_specific_translations["gpu"])
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_{}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), mat_type),
|
|
"shape": shape, "dtype": dtype, "mat_type": mat_type}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
for mat_type in ['csr', 'coo']))
|
|
def test_extra_nse(self, shape, dtype, mat_type):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
nse = (M != 0).sum() + 5
|
|
fromdense = getattr(sparse, f"{mat_type}_fromdense")
|
|
todense = getattr(sparse, f"{mat_type}_todense")
|
|
args = fromdense(M, nse=nse, index_dtype=jnp.int32)
|
|
M_out = todense(*args, shape=M.shape)
|
|
self.assertArraysEqual(M, M_out)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_coo_todense_ad(self, shape, dtype):
|
|
rng = rand_sparse(self.rng(), post=jnp.array)
|
|
M = rng(shape, dtype)
|
|
data, row, col = sparse.coo_fromdense(M, nse=(M != 0).sum())
|
|
f = lambda data: sparse.coo_todense(data, row, col, shape=M.shape)
|
|
|
|
# Forward-mode
|
|
primals, tangents = jax.jvp(f, [data], [jnp.ones_like(data)])
|
|
self.assertArraysEqual(primals, f(data))
|
|
self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1))
|
|
|
|
# Reverse-mode
|
|
primals, vjp_fun = jax.vjp(f, data)
|
|
data_out, = vjp_fun(primals)
|
|
self.assertArraysEqual(primals, f(data))
|
|
self.assertArraysEqual(data_out, data)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
"shape": shape, "dtype": dtype}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_coo_fromdense_ad(self, shape, dtype):
|
|
rng = rand_sparse(self.rng(), post=jnp.array)
|
|
M = rng(shape, dtype)
|
|
nse = (M != 0).sum()
|
|
f = lambda M: sparse.coo_fromdense(M, nse=nse)
|
|
|
|
# Forward-mode
|
|
primals, tangents = jax.jvp(f, [M], [jnp.ones_like(M)])
|
|
self.assertArraysEqual(primals[0], f(M)[0])
|
|
self.assertArraysEqual(primals[1], f(M)[1])
|
|
self.assertArraysEqual(primals[2], f(M)[2])
|
|
self.assertArraysEqual(tangents[0], jnp.ones(nse, dtype=dtype))
|
|
self.assertEqual(tangents[1].dtype, dtypes.float0)
|
|
self.assertEqual(tangents[2].dtype, dtypes.float0)
|
|
|
|
# Reverse-mode
|
|
primals, vjp_fun = jax.vjp(f, M)
|
|
M_out, = vjp_fun(primals)
|
|
self.assertArraysEqual(primals[0], f(M)[0])
|
|
self.assertArraysEqual(primals[1], f(M)[1])
|
|
self.assertArraysEqual(primals[2], f(M)[2])
|
|
self.assertArraysEqual(M_out, M)
|
|
|
|
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_{}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
jtu.format_shape_dtype_string(bshape, dtype)),
|
|
"shape": shape, "dtype": dtype, "bshape": bshape}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for bshape in [shape[-1:] + s for s in [(), (1,), (3,)]]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_coo_matmul_ad(self, shape, dtype, bshape):
|
|
coo_matmul = sparse.coo_matvec if len(bshape) == 1 else sparse.coo_matmat
|
|
tol = {np.float32: 1E-5, np.float64: 1E-12, np.complex64: 1E-5, np.complex128: 1E-12}
|
|
|
|
rng = rand_sparse(self.rng(), post=jnp.array)
|
|
rng_b = jtu.rand_default(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
|
data, row, col = sparse.coo_fromdense(M, nse=(M != 0).sum())
|
|
x = rng_b(bshape, dtype)
|
|
xdot = rng_b(bshape, dtype)
|
|
|
|
# Forward-mode with respect to the vector
|
|
f_dense = lambda x: M @ x
|
|
f_sparse = lambda x: coo_matmul(data, row, col, x, shape=M.shape)
|
|
v_sparse, t_sparse = jax.jvp(f_sparse, [x], [xdot])
|
|
v_dense, t_dense = jax.jvp(f_dense, [x], [xdot])
|
|
self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
|
|
self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)
|
|
|
|
# Reverse-mode with respect to the vector
|
|
primals_dense, vjp_dense = jax.vjp(f_dense, x)
|
|
primals_sparse, vjp_sparse = jax.vjp(f_sparse, x)
|
|
out_dense, = vjp_dense(primals_dense)
|
|
out_sparse, = vjp_sparse(primals_sparse)
|
|
self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
|
|
self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
|
|
|
|
# Forward-mode with respect to nonzero elements of the matrix
|
|
f_sparse = lambda data: coo_matmul(data, row, col, x, shape=M.shape)
|
|
f_dense = lambda data: sparse.coo_todense(data, row, col, shape=M.shape) @ x
|
|
data = rng((len(data),), data.dtype)
|
|
data_dot = rng((len(data),), data.dtype)
|
|
v_sparse, t_sparse = jax.jvp(f_sparse, [data], [data_dot])
|
|
v_dense, t_dense = jax.jvp(f_dense, [data], [data_dot])
|
|
|
|
self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
|
|
self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)
|
|
|
|
# Reverse-mode with respect to nonzero elements of the matrix
|
|
primals_dense, vjp_dense = jax.vjp(f_dense, data)
|
|
primals_sparse, vjp_sparse = jax.vjp(f_sparse, data)
|
|
out_dense, = vjp_dense(primals_dense)
|
|
out_sparse, = vjp_sparse(primals_sparse)
|
|
self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
|
|
self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
|
|
|
|
|
|
class BCOOTest(jtu.JaxTestCase):
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in all_dtypes
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_empty(self, shape, dtype, n_batch, n_dense):
|
|
M = sparse.empty(shape, dtype=dtype, n_batch=n_batch, n_dense=n_dense)
|
|
self.assertIsInstance(M, sparse.BCOO)
|
|
self.assertEqual(M.nse, 0)
|
|
self.assertEqual(M.n_batch, n_batch)
|
|
self.assertEqual(M.n_dense, n_dense)
|
|
self.assertEqual(M.dtype, dtype)
|
|
self.assertArraysEqual(M.todense(), jnp.empty(shape, dtype))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in all_dtypes
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
n_sparse = M.ndim - n_batch - n_dense
|
|
nse = int(sparse.bcoo._bcoo_nse(M, n_batch=n_batch, n_dense=n_dense))
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
data_jit, indices_jit = jit(partial(sparse.bcoo_fromdense, nse=nse, n_batch=n_batch, n_dense=n_dense))(M)
|
|
self.assertArraysEqual(data, data_jit)
|
|
self.assertArraysEqual(indices, indices_jit)
|
|
|
|
assert data.dtype == dtype
|
|
assert data.shape == shape[:n_batch] + (nse,) + shape[n_batch + n_sparse:]
|
|
assert indices.dtype == jnp.int32 # TODO: test passing this arg
|
|
assert indices.shape == shape[:n_batch] + (nse, n_sparse)
|
|
|
|
todense = partial(sparse.bcoo_todense, spinfo=BCOOInfo(shape))
|
|
self.assertArraysEqual(M, todense(data, indices))
|
|
self.assertArraysEqual(M, jit(todense)(data, indices))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_todense_ad(self, shape, dtype, n_batch, n_dense):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
todense = partial(sparse.bcoo_todense, indices=indices, spinfo=BCOOInfo(shape))
|
|
j1 = jax.jacfwd(todense)(data)
|
|
j2 = jax.jacrev(todense)(data)
|
|
hess = jax.hessian(todense)(data)
|
|
self.assertArraysAllClose(j1, j2)
|
|
self.assertEqual(j1.shape, M.shape + data.shape)
|
|
self.assertEqual(hess.shape, M.shape + 2 * data.shape)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_fromdense_ad(self, shape, dtype, n_batch, n_dense):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
nse = int(sparse.bcoo._bcoo_nse(M, n_batch=n_batch, n_dense=n_dense))
|
|
|
|
def fromdense(M):
|
|
return sparse.bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense)[0]
|
|
data = fromdense(M)
|
|
|
|
j1 = jax.jacfwd(fromdense)(M)
|
|
j2 = jax.jacrev(fromdense)(M)
|
|
hess = jax.hessian(fromdense)(M)
|
|
self.assertArraysAllClose(j1, j2)
|
|
self.assertEqual(j1.shape, data.shape + M.shape)
|
|
self.assertEqual(hess.shape, data.shape + 2 * M.shape)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_dense_round_trip_batched(self, shape, dtype, n_batch, n_dense):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
n_sparse = M.ndim - n_batch - n_dense
|
|
nse = int(sparse.bcoo._bcoo_nse(M, n_batch=n_batch, n_dense=n_dense))
|
|
|
|
fromdense = partial(sparse.bcoo_fromdense, nse=nse, n_dense=n_dense)
|
|
todense = partial(sparse.bcoo_todense, spinfo=BCOOInfo(shape[n_batch:]))
|
|
for i in range(n_batch):
|
|
fromdense = jax.vmap(fromdense)
|
|
todense = jax.vmap(todense)
|
|
|
|
data, indices = fromdense(M)
|
|
|
|
assert data.dtype == dtype
|
|
assert data.shape == shape[:n_batch] + (nse,) + shape[n_batch + n_sparse:]
|
|
assert indices.dtype == jnp.int32 # TODO: test passing this arg
|
|
assert indices.shape == shape[:n_batch] + (nse, n_sparse)
|
|
|
|
self.assertArraysEqual(M, todense(data, indices))
|
|
self.assertArraysEqual(M, jit(todense)(data, indices))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_extract(self, shape, dtype, n_batch, n_dense):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
data, indices = sparse.bcoo_fromdense(M)
|
|
data2 = sparse.bcoo_extract(indices, M)
|
|
self.assertArraysEqual(data, data2)
|
|
data3 = jit(sparse.bcoo_extract)(indices, M)
|
|
self.assertArraysEqual(data, data3)
|
|
|
|
def test_bcoo_extract_batching(self):
|
|
# https://github.com/google/jax/issues/9431
|
|
indices = jnp.zeros((4, 1, 1), dtype=int)
|
|
mat = jnp.arange(4.).reshape((4, 1))
|
|
|
|
# in_axes = (0, None)
|
|
expected = jnp.vstack([sparse.bcoo_extract(i, mat[0]) for i in indices])
|
|
actual = vmap(sparse.bcoo_extract, in_axes=(0, None))(indices, mat[0])
|
|
self.assertArraysEqual(expected, actual)
|
|
|
|
# in_axes = (None, 0)
|
|
expected = jnp.vstack([sparse.bcoo_extract(indices[0], m) for m in mat])
|
|
actual = vmap(sparse.bcoo_extract, in_axes=(None, 0))(indices[0], mat)
|
|
self.assertArraysEqual(expected, actual)
|
|
|
|
# in_axes = (0, 0)
|
|
expected = jnp.vstack([sparse.bcoo_extract(i, m) for i, m in zip(indices, mat)])
|
|
actual = vmap(sparse.bcoo_extract, in_axes=0)(indices, mat)
|
|
self.assertArraysEqual(expected, actual)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
extract = partial(sparse.bcoo_extract, indices)
|
|
j1 = jax.jacfwd(extract)(M)
|
|
j2 = jax.jacrev(extract)(M)
|
|
hess = jax.hessian(extract)(M)
|
|
self.assertArraysAllClose(j1, j2)
|
|
self.assertEqual(j1.shape, data.shape + M.shape)
|
|
self.assertEqual(hess.shape, data.shape + 2 * M.shape)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_transpose(self, shape, dtype, n_batch, n_dense):
|
|
n_sparse = len(shape) - n_batch - n_dense
|
|
rng = self.rng()
|
|
sprng = rand_sparse(rng)
|
|
M = sprng(shape, dtype)
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
permutation = np.concatenate([
|
|
rng.permutation(range(n_batch)),
|
|
rng.permutation(range(n_batch, n_batch + n_sparse)),
|
|
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)
|
|
|
|
M_T = M.transpose(permutation)
|
|
trans = partial(sparse.bcoo_transpose, spinfo=BCOOInfo(shape), permutation=permutation)
|
|
self.assertArraysEqual(M_T, sparse.bcoo_todense(*trans(data, indices), spinfo=BCOOInfo(M_T.shape)))
|
|
self.assertArraysEqual(M_T, sparse.bcoo_todense(*jit(trans)(data, indices), spinfo=BCOOInfo(M_T.shape)))
|
|
|
|
# test batched
|
|
def trans(M):
|
|
return M.transpose([p - n_batch for p in permutation[n_batch:]])
|
|
for _ in range(n_batch):
|
|
trans = jax.vmap(trans)
|
|
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
self.assertArraysEqual(trans(M), trans(Msp).todense())
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_transpose_ad(self, shape, dtype, n_batch, n_dense):
|
|
n_sparse = len(shape) - n_batch - n_dense
|
|
rng = self.rng()
|
|
sprng = rand_sparse(self.rng())
|
|
|
|
M = sprng(shape, dtype)
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
permutation = np.concatenate([
|
|
rng.permutation(range(n_batch)),
|
|
rng.permutation(range(n_batch, n_batch + n_sparse)),
|
|
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)
|
|
|
|
def f_sparse(data):
|
|
return sparse.bcoo_transpose(data, indices, spinfo=BCOOInfo(shape), permutation=permutation)[0]
|
|
|
|
jf_sparse = jax.jacfwd(f_sparse)(data)
|
|
jr_sparse = jax.jacrev(f_sparse)(data)
|
|
|
|
tol = {}
|
|
if jtu.device_under_test() == "tpu":
|
|
tol = {np.float32: 5E-3}
|
|
|
|
# TODO(jakevdp) also test against dense version?
|
|
self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
for n_batch in range(1, len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
M1 = sparse.bcoo_todense(data, indices[:1], spinfo=BCOOInfo(M.shape))
|
|
M2 = sparse.bcoo_todense(data, jnp.stack(shape[0] * [indices[0]]), spinfo=BCOOInfo(M.shape))
|
|
self.assertAllClose(M1, M2)
|
|
|
|
M3 = sparse.bcoo_todense(data[:1], indices, spinfo=BCOOInfo(M.shape))
|
|
M4 = sparse.bcoo_todense(jnp.stack(shape[0] * [data[0]]), indices, spinfo=BCOOInfo(M.shape))
|
|
self.assertAllClose(M3, M4)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": props.testcase_name(), "props": props}
|
|
for props in _generate_bcoo_dot_general_properties(
|
|
shapes=[(5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)],
|
|
dtypes=jtu.dtypes.floating + jtu.dtypes.complex,
|
|
)))
|
|
def test_bcoo_dot_general(self, props: BcooDotGeneralProperties):
|
|
rng = jtu.rand_small(self.rng())
|
|
rng_sparse = rand_sparse(self.rng())
|
|
|
|
def args_maker():
|
|
lhs = rng_sparse(props.lhs_shape, props.dtype)
|
|
rhs = rng(props.rhs_shape, props.dtype)
|
|
data, indices = sparse.bcoo_fromdense(lhs, n_batch=props.n_batch, n_dense=props.n_dense)
|
|
return data, indices, lhs, rhs
|
|
|
|
def f_dense(data, indices, lhs, rhs):
|
|
return lax.dot_general(lhs, rhs, dimension_numbers=props.dimension_numbers)
|
|
|
|
def f_sparse(data, indices, lhs, rhs):
|
|
return sparse.bcoo_dot_general(data, indices, rhs, lhs_spinfo=BCOOInfo(lhs.shape),
|
|
dimension_numbers=props.dimension_numbers)
|
|
|
|
tol = {'float32': 3E-2} if jtu.device_under_test() == 'tpu' else {}
|
|
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker, tol=tol)
|
|
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker, tol=tol)
|
|
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
|
|
# self._CompileAndCheck(f_sparse, args_maker)
|
|
|
|
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}"
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
lhs_contracting, rhs_contracting),
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
"lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting}
|
|
for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
|
|
[(5,), (5,), [0], [0]],
|
|
[(5,), (5, 7), [0], [0]],
|
|
[(5,), (7, 5), [0], [1]],
|
|
[(5, 7), (5,), [0], [0]],
|
|
[(7, 5), (5,), [1], [0]],
|
|
[(3, 5), (2, 5), [1], [1]],
|
|
[(3, 5), (5, 2), [1], [0]],
|
|
[(5, 3), (2, 5), [0], [1]],
|
|
[(5, 3), (5, 2), [0], [0]],
|
|
]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_bcoo_dot_general_cusparse(
|
|
self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting):
|
|
rng = jtu.rand_small(self.rng())
|
|
rng_sparse = rand_sparse(self.rng())
|
|
def args_maker():
|
|
lhs = rng_sparse(lhs_shape, dtype)
|
|
rhs = rng(rhs_shape, dtype)
|
|
data, indices = sparse.bcoo_fromdense(lhs, index_dtype=jnp.int32)
|
|
return data, indices, lhs, rhs
|
|
dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
|
|
|
|
def f_dense(data, indices, lhs, rhs):
|
|
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
|
|
|
def f_sparse(data, indices, lhs, rhs):
|
|
return sparse.bcoo_dot_general(data, indices, rhs,
|
|
dimension_numbers=dimension_numbers,
|
|
lhs_spinfo=BCOOInfo(lhs.shape))
|
|
|
|
self._CompileAndCheck(f_sparse, args_maker)
|
|
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
|
|
|
|
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
|
def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self):
|
|
"""Tests bcoo dot general with out-of-bound and unsorted indices."""
|
|
|
|
rhs = jnp.ones((5, 3), dtype=jnp.float32)
|
|
|
|
# It creates out-of-bound indices when nse > nnz.
|
|
lhs_2d_dense = jnp.array([[1, 0, 2, 3, 0], [0, 0, 0, 4, 0]],
|
|
dtype=jnp.float32)
|
|
lhs_2d_sparse, lhs_sparse_2d_indicdes = sparse.bcoo_fromdense(
|
|
lhs_2d_dense, nse=7)
|
|
|
|
def create_unsorted_indices(data, indices):
|
|
data_to_shuffle = jnp.hstack((jnp.expand_dims(data, axis=1), indices))
|
|
key = jax.random.PRNGKey(1701)
|
|
data_after_shuffle = jax.random.permutation(key, data_to_shuffle)
|
|
return (data_after_shuffle[:, 0],
|
|
data_after_shuffle[:, 1:].astype(dtype=jnp.int32))
|
|
|
|
# Random permutate the indices to make them unsorted.
|
|
lhs_2d_sparse, lhs_sparse_2d_indicdes = create_unsorted_indices(
|
|
lhs_2d_sparse, lhs_sparse_2d_indicdes)
|
|
|
|
dimension_numbers_2d = (([1], [0]), ([], []))
|
|
|
|
def args_maker_2d():
|
|
return lhs_2d_sparse, lhs_sparse_2d_indicdes, lhs_2d_dense, rhs
|
|
|
|
def f_dense_2d(data, indices, lhs, rhs):
|
|
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers_2d)
|
|
|
|
def f_sparse_2d(data, indices, lhs, rhs):
|
|
return sparse.bcoo_dot_general(data, indices, rhs,
|
|
dimension_numbers=dimension_numbers_2d,
|
|
lhs_spinfo=BCOOInfo(lhs.shape))
|
|
with self.subTest(msg="2D"):
|
|
self._CompileAndCheck(f_sparse_2d, args_maker_2d)
|
|
self._CheckAgainstNumpy(f_dense_2d, f_sparse_2d, args_maker_2d)
|
|
|
|
lhs_1d_dense = jnp.array([0, 1, 0, 2, 0], dtype=jnp.float32)
|
|
lhs_1d_sparse, lhs_sparse_1d_indicdes = sparse.bcoo_fromdense(
|
|
lhs_1d_dense, nse=5)
|
|
|
|
# Random permutate the indices to make them unsorted.
|
|
lhs_1d_sparse, lhs_sparse_1d_indicdes = create_unsorted_indices(
|
|
lhs_1d_sparse, lhs_sparse_1d_indicdes)
|
|
|
|
dimension_numbers_1d = (([0], [0]), ([], []))
|
|
|
|
def args_maker_1d():
|
|
return lhs_1d_sparse, lhs_sparse_1d_indicdes, lhs_1d_dense, rhs
|
|
|
|
def f_dense_1d(data, indices, lhs, rhs):
|
|
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers_1d)
|
|
|
|
def f_sparse_1d(data, indices, lhs, rhs):
|
|
return sparse.bcoo_dot_general(data, indices, rhs,
|
|
dimension_numbers=dimension_numbers_1d,
|
|
lhs_spinfo=BCOOInfo(lhs.shape))
|
|
|
|
with self.subTest(msg="1D"):
|
|
self._CompileAndCheck(f_sparse_1d, args_maker_1d)
|
|
self._CheckAgainstNumpy(f_dense_1d, f_sparse_1d, args_maker_1d)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": props.testcase_name(), "props": props}
|
|
for props in _generate_bcoo_dot_general_properties(
|
|
shapes=[(5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)],
|
|
dtypes=jtu.dtypes.floating + jtu.dtypes.complex,
|
|
)))
|
|
def test_bcoo_rdot_general(self, props: BcooDotGeneralProperties):
|
|
rng = jtu.rand_small(self.rng())
|
|
rng_sparse = rand_sparse(self.rng())
|
|
|
|
lhs_shape, rhs_shape = props.rhs_shape, props.lhs_shape
|
|
dimension_numbers = tuple(d[::-1] for d in props.dimension_numbers)
|
|
|
|
def args_maker():
|
|
lhs = rng_sparse(lhs_shape, props.dtype)
|
|
rhs = rng(rhs_shape, props.dtype)
|
|
data, indices = sparse.bcoo_fromdense(rhs, n_batch=props.n_batch, n_dense=props.n_dense)
|
|
return data, indices, lhs, rhs
|
|
|
|
def f_dense(data, indices, lhs, rhs):
|
|
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
|
|
|
def f_sparse(data, indices, lhs, rhs):
|
|
return sparse.bcoo_rdot_general(lhs, data, indices, rhs_spinfo=BCOOInfo(rhs.shape),
|
|
dimension_numbers=dimension_numbers)
|
|
|
|
tol = {'float32': 3E-2} if jtu.device_under_test() == 'tpu' else {}
|
|
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker, tol=tol)
|
|
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker, tol=tol)
|
|
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
|
|
# self._CompileAndCheck(f_sparse, args_maker)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
dimension_numbers, n_batch, n_dense),
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
"dimension_numbers": dimension_numbers,
|
|
"n_batch": n_batch, "n_dense": n_dense}
|
|
for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0),
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0),
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0),
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0),
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0),
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1),
|
|
]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype,
|
|
dimension_numbers, n_batch, n_dense):
|
|
rng = jtu.rand_small(self.rng())
|
|
rng_sparse = rand_sparse(self.rng())
|
|
|
|
X = rng_sparse(lhs_shape, dtype)
|
|
data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense)
|
|
Y = rng(rhs_shape, dtype)
|
|
|
|
def f_dense(X, Y):
|
|
return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)
|
|
|
|
def f_sparse(data, indices, Y):
|
|
return sparse.bcoo_dot_general(data, indices, Y, lhs_spinfo=BCOOInfo(X.shape),
|
|
dimension_numbers=dimension_numbers)
|
|
|
|
for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]):
|
|
X = sparse.bcoo_todense(data, indices, spinfo=BCOOInfo(X.shape))
|
|
self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
dimension_numbers, n_batch, n_dense),
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
"dimension_numbers": dimension_numbers,
|
|
"n_batch": n_batch, "n_dense": n_dense}
|
|
for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
|
|
((4, 5), (5, 3), (([1], [0]), ([], [])), 0, 0),
|
|
((2, 4, 5), (2, 5, 3), (([2], [1]), ([0], [0])), 1, 0),
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0),
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0),
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0),
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0),
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0),
|
|
# This requires contraction over dense dimensions, which is not yet implemented:
|
|
# ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1),
|
|
]
|
|
for dtype in jtu.dtypes.floating))
|
|
def test_bcoo_dot_general_ad(self, lhs_shape, rhs_shape, dtype,
|
|
dimension_numbers, n_batch, n_dense):
|
|
rng = jtu.rand_small(self.rng())
|
|
rng_sparse = rand_sparse(self.rng())
|
|
|
|
X = rng_sparse(lhs_shape, dtype)
|
|
data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense)
|
|
Y = rng(rhs_shape, dtype)
|
|
|
|
# gradient with respect to rhs
|
|
def f_dense(Y):
|
|
return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)
|
|
|
|
def f_sparse(Y):
|
|
return sparse.bcoo_dot_general(data, indices, Y, lhs_spinfo=BCOOInfo(X.shape),
|
|
dimension_numbers=dimension_numbers)
|
|
|
|
jf_dense = jax.jacfwd(f_dense)(Y)
|
|
jr_dense = jax.jacrev(f_dense)(Y)
|
|
jf_sparse = jax.jacfwd(f_sparse)(Y)
|
|
jr_sparse = jax.jacrev(f_sparse)(Y)
|
|
|
|
tol = {}
|
|
if jtu.device_under_test() == "tpu":
|
|
tol = {np.float32: 5E-3}
|
|
|
|
self.assertAllClose(jf_dense, jf_sparse, rtol=tol)
|
|
self.assertAllClose(jr_dense, jr_sparse, rtol=tol)
|
|
self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
|
|
|
|
# gradient with respect to lhs
|
|
def g_dense(X):
|
|
return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)
|
|
|
|
def g_sparse(data):
|
|
return sparse.bcoo_dot_general(data, indices, Y, lhs_spinfo=BCOOInfo(X.shape),
|
|
dimension_numbers=dimension_numbers)
|
|
|
|
jf_dense = jax.jacfwd(g_dense)(X)
|
|
jr_dense = jax.jacrev(g_dense)(X)
|
|
jf_sparse = jax.jacfwd(g_sparse)(data)
|
|
jr_sparse = jax.jacrev(g_sparse)(data)
|
|
|
|
tol = {}
|
|
if jtu.device_under_test() == "tpu":
|
|
tol = {np.float32: 5E-3}
|
|
|
|
self.assertAllClose(jf_dense, jr_dense, rtol=tol)
|
|
self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
|
|
|
|
# Extract the sparse jacobian from the dense & compare.
|
|
def extract(X):
|
|
return sparse.bcoo_extract(indices, X)
|
|
for i in range(g_dense(X).ndim):
|
|
extract = jax.vmap(extract)
|
|
self.assertAllClose(extract(jf_dense), jf_sparse, rtol=tol)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
dimension_numbers, n_batch, n_dense),
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
"dimension_numbers": dimension_numbers,
|
|
"n_batch": n_batch, "n_dense": n_dense}
|
|
for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 0, 0),
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0),
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 0, 1),
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 0, 0),
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 1),
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 0, 0),
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 1, 2),
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1),
|
|
]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_bcoo_dot_general_sampled(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense):
|
|
rng = jtu.rand_default(self.rng())
|
|
sprng = rand_sparse(self.rng())
|
|
out_shape = lax.dot_general(
|
|
jnp.zeros(lhs_shape), jnp.zeros(rhs_shape),
|
|
dimension_numbers=dimension_numbers).shape
|
|
|
|
args_maker = lambda: [
|
|
rng(lhs_shape, dtype), rng(rhs_shape, dtype),
|
|
sparse.BCOO.fromdense(sprng(out_shape, dtype),
|
|
n_batch=n_batch, n_dense=n_dense).indices]
|
|
|
|
def dense_fun(lhs, rhs, indices):
|
|
AB = lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
|
return sparse.bcoo_extract(indices, AB)
|
|
def sparse_fun(lhs, rhs, indices):
|
|
return sparse.bcoo_dot_general_sampled(
|
|
lhs, rhs, indices, dimension_numbers=dimension_numbers)
|
|
|
|
tol = {}
|
|
if jtu.device_under_test() == "tpu":
|
|
tol = {np.float32: 5E-3}
|
|
|
|
self._CheckAgainstNumpy(dense_fun, sparse_fun, args_maker, tol=tol)
|
|
# TODO: python_should_be_executing check occasionally fails... why?
|
|
# self._CompileAndCheck(sparse_fun, args_maker)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
dimension_numbers, n_batch, n_dense),
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
"dimension_numbers": dimension_numbers,
|
|
"n_batch": n_batch, "n_dense": n_dense}
|
|
for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0),
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 1),
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0),
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0),
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 1),
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0),
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0),
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1),
|
|
]
|
|
for dtype in jtu.dtypes.floating))
|
|
def test_bcoo_dot_general_sampled_ad(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense):
|
|
rng = jtu.rand_default(self.rng())
|
|
sprng = rand_sparse(self.rng())
|
|
out_shape = lax.dot_general(
|
|
jnp.zeros(lhs_shape), jnp.zeros(rhs_shape),
|
|
dimension_numbers=dimension_numbers).shape
|
|
|
|
lhs = rng(lhs_shape, dtype)
|
|
rhs = rng(rhs_shape, dtype)
|
|
indices = sparse.BCOO.fromdense(sprng(out_shape, dtype),
|
|
n_batch=n_batch, n_dense=n_dense).indices
|
|
|
|
def dense_fun(lhs, rhs, indices):
|
|
AB = lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
|
return sparse.bcoo_extract(indices, AB)
|
|
def sparse_fun(lhs, rhs, indices):
|
|
return sparse.bcoo_dot_general_sampled(
|
|
lhs, rhs, indices, dimension_numbers=dimension_numbers)
|
|
|
|
jf_dense = jax.jacfwd(dense_fun)(lhs, rhs, indices)
|
|
jf_sparse = jax.jacfwd(sparse_fun)(lhs, rhs, indices)
|
|
jr_dense = jax.jacrev(dense_fun)(lhs, rhs, indices)
|
|
jr_sparse = jax.jacrev(sparse_fun)(lhs, rhs, indices)
|
|
|
|
tol = {}
|
|
if jtu.device_under_test() == "tpu":
|
|
tol = {np.float32: 5E-3}
|
|
|
|
self.assertAllClose(jf_sparse, jf_dense, atol=tol)
|
|
self.assertAllClose(jr_sparse, jr_dense, atol=tol)
|
|
self.assertAllClose(jf_sparse, jr_sparse, atol=tol)
|
|
|
|
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}[n_batch={}]_{}[n_batch={}]_swap={}_dims={}".format(
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype), lhs_n_batch,
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), rhs_n_batch,
|
|
swap, dimension_numbers),
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape,
|
|
"lhs_n_batch": lhs_n_batch, "rhs_n_batch": rhs_n_batch,
|
|
"dimension_numbers": dimension_numbers, "swap": swap, "dtype": dtype}
|
|
for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dimension_numbers in [
|
|
# (batched) outer products (no contraction)
|
|
((5,), 0, (6,), 0, (([], []), ([], []))),
|
|
((3, 5), 0, (2, 4), 0, (([], []), ([], []))),
|
|
((3, 5), 1, (3, 4), 1, (([], []), ([0], [0]))),
|
|
# (batched) vector-vector products
|
|
((5,), 0, (5,), 0, (([0], [0]), ([], []))),
|
|
((7,), 0, (7,), 0, (([0], [0]), ([], []))),
|
|
((5, 7), 1, (7,), 0, (([1], [0]), ([], []))),
|
|
((2, 3, 4), 2, (2, 4), 1, (([2], [1]), ([0], [0]))),
|
|
((2, 3, 4), 2, (2, 4), 1, (([2], [1]), ([], []))),
|
|
((2, 3, 4), 2, (3, 4), 1, (([2], [1]), ([1], [0]))),
|
|
((2, 3, 4), 2, (3, 4), 1, (([2], [1]), ([], []))),
|
|
# (batched) matrix-vector products
|
|
((5, 7), 0, (7,), 0, (([1], [0]), ([], []))),
|
|
((2, 3, 4), 1, (4,), 0, (([2], [0]), ([], []))),
|
|
((2, 3, 4), 1, (2, 4), 1, (([2], [1]), ([0], [0]))),
|
|
((3, 2, 4), 1, (3, 4), 1, (([2], [1]), ([0], [0]))),
|
|
((2, 3, 4), 0, (2,), 0, (([0], [0]), ([], []))),
|
|
# (batched) matrix-matrix products
|
|
((5, 7), 0, (7, 3), 0, (([1], [0]), ([], []))),
|
|
((2, 3, 4), 1, (4, 3), 0, (([2], [0]), ([], []))),
|
|
((2, 3, 4), 1, (2, 4, 3), 1, (([2], [1]), ([0], [0]))),
|
|
# more general operations
|
|
((2, 3, 4, 3), 1, (2, 4, 3, 4), 1, (([2, 3], [1, 2]), ([0], [0]))),
|
|
((2, 3, 4, 3, 1), 2, (3, 2, 3, 4), 2, (([2, 3], [3, 2]), ([0, 1], [1, 0]))),
|
|
]
|
|
for swap in [True, False]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_bcoo_spdot_general(self, lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dtype, swap, dimension_numbers):
|
|
if swap:
|
|
dimension_numbers = tuple(d[::-1] for d in dimension_numbers)
|
|
lhs_shape, rhs_shape = rhs_shape, lhs_shape
|
|
lhs_n_batch, rhs_n_batch = rhs_n_batch, lhs_n_batch
|
|
|
|
lhs_n_sparse = len(lhs_shape) - lhs_n_batch
|
|
rhs_batch = dimension_numbers[1][1]
|
|
lhs_contracting = dimension_numbers[0][0]
|
|
should_error = (rhs_n_batch > len(rhs_batch) and lhs_n_sparse > len(lhs_contracting))
|
|
|
|
sprng = rand_sparse(self.rng())
|
|
def args_maker():
|
|
x = sprng(lhs_shape, dtype)
|
|
y = sprng(rhs_shape, dtype)
|
|
xsp = sparse.BCOO.fromdense(x, n_batch=lhs_n_batch)
|
|
ysp = sparse.BCOO.fromdense(y, n_batch=rhs_n_batch)
|
|
return x, y, xsp, ysp
|
|
|
|
def f_dense(x, y, xsp, ysp):
|
|
return lax.dot_general(x, y, dimension_numbers=dimension_numbers)
|
|
|
|
def f_sparse(x, y, xsp, ysp):
|
|
shape = sparse.bcoo._dot_general_validated_shape(xsp.shape, ysp.shape, dimension_numbers)
|
|
data, indices = sparse.bcoo_spdot_general(xsp.data, xsp.indices, ysp.data, ysp.indices,
|
|
lhs_spinfo=xsp._info, rhs_spinfo=ysp._info,
|
|
dimension_numbers=dimension_numbers)
|
|
return sparse.bcoo_todense(data, indices, spinfo=BCOOInfo(shape))
|
|
|
|
tol = {"complex128": 1E-14}
|
|
if should_error:
|
|
with self.assertRaisesRegex(ValueError, ".*cannot have unused batch dims on rhs with unused sparse dims on lhs."):
|
|
f_sparse(*args_maker())
|
|
else:
|
|
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker, tol=tol)
|
|
self._CheckAgainstNumpy(jit(f_dense), jit(f_sparse), args_maker, tol=tol)
|
|
# TODO(jakevdp): This occasionally fails python_should_be_executing check. Why?
|
|
# self._CompileAndCheck(f_sparse, args_maker)
|
|
|
|
def test_bcoo_spdot_general_nse(self):
|
|
# vector-vector product -> nse=1
|
|
x = sparse.BCOO.fromdense(jnp.arange(3))
|
|
self.assertEqual((x @ x).nse, 1)
|
|
|
|
# matrix-vector product -> nse matches matrix
|
|
M = sparse.BCOO.fromdense(jnp.arange(6).reshape(2, 3))
|
|
self.assertEqual((M @ x).nse, M.nse)
|
|
|
|
# matrix-matrix product -> product of nse
|
|
N = sparse.BCOO.fromdense(jnp.arange(12).reshape(3, 4))
|
|
self.assertEqual((M @ N).nse, M.nse * N.nse)
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name":
|
|
"_lhs_shape={}[n_batch={}]_rhs_shape={}[n_batch={}]_dimension_numbers={}"
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype), lhs_n_batch,
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), rhs_n_batch,
|
|
dimension_numbers),
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
"dimension_numbers": dimension_numbers,
|
|
"lhs_n_batch": lhs_n_batch, "rhs_n_batch": rhs_n_batch}
|
|
for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dimension_numbers in [
|
|
((4, 5), 0, (5,), 0, (([1], [0]), ([], []))),
|
|
((2, 4, 5), 1, (5,), 0, (([2], [0]), ([], []))),
|
|
((4, 5), 0, (5, 3), 0, (([1], [0]), ([], []))),
|
|
((2, 4, 5), 1, (2, 5, 3), 1, (([2], [1]), ([0], [0]))),
|
|
]
|
|
for dtype in jtu.dtypes.floating))
|
|
def test_bcoo_spdot_general_ad(self, lhs_shape, rhs_shape, dtype,
|
|
dimension_numbers, lhs_n_batch, rhs_n_batch):
|
|
rng = rand_sparse(self.rng())
|
|
|
|
lhs = rng(lhs_shape, dtype)
|
|
rhs = rng(rhs_shape, dtype)
|
|
|
|
lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=lhs_n_batch)
|
|
rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=rhs_n_batch)
|
|
|
|
def f_dense(lhs_data, rhs_data):
|
|
lhs = sparse.BCOO((lhs_data, lhs_sp.indices), shape=lhs_sp.shape).todense()
|
|
rhs = sparse.BCOO((rhs_data, rhs_sp.indices), shape=rhs_sp.shape).todense()
|
|
return (lhs @ rhs).sum()
|
|
|
|
def f_sparse(lhs_data, rhs_data):
|
|
lhs = sparse.BCOO((lhs_data, lhs_sp.indices), shape=lhs_sp.shape)
|
|
rhs = sparse.BCOO((rhs_data, rhs_sp.indices), shape=rhs_sp.shape)
|
|
return (lhs @ rhs).sum()
|
|
|
|
tol = {}
|
|
if jtu.device_under_test() == "tpu":
|
|
tol = {np.float32: 5E-2}
|
|
|
|
jf_dense_0 = jax.jacfwd(f_dense, argnums=0)(lhs_sp.data, rhs_sp.data)
|
|
jf_sparse_0 = jax.jacfwd(f_sparse, argnums=0)(lhs_sp.data, rhs_sp.data)
|
|
self.assertAllClose(jf_dense_0, jf_sparse_0, rtol=tol)
|
|
|
|
jf_dense_1 = jax.jacfwd(f_dense, argnums=1)(lhs_sp.data, rhs_sp.data)
|
|
jf_sparse_1 = jax.jacfwd(f_sparse, argnums=1)(lhs_sp.data, rhs_sp.data)
|
|
self.assertAllClose(jf_dense_1, jf_sparse_1, rtol=tol)
|
|
|
|
jf_dense_0, jf_dense_1 = jax.jacfwd(f_dense, argnums=(0, 1))(lhs_sp.data, rhs_sp.data)
|
|
jf_sparse_0, jf_sparse_1 = jax.jacfwd(f_sparse, argnums=(0, 1))(lhs_sp.data, rhs_sp.data)
|
|
self.assertAllClose(jf_dense_0, jf_sparse_0, rtol=tol)
|
|
self.assertAllClose(jf_dense_1, jf_sparse_1, rtol=tol)
|
|
|
|
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}[n_batch={}]_{}[n_batch={}]_in_axes={}".format(
|
|
jtu.format_shape_dtype_string(lhs_shape, dtype), lhs_n_batch,
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype), rhs_n_batch,
|
|
in_axes),
|
|
"lhs_shape": lhs_shape, "lhs_n_batch": lhs_n_batch,
|
|
"rhs_shape": rhs_shape, "rhs_n_batch": rhs_n_batch,
|
|
"dtype": dtype, "in_axes": in_axes}
|
|
for lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, in_axes in [
|
|
((3, 5), 1, (3, 5), 1, 0),
|
|
((3, 4, 5), 1, (3, 5), 1, 0),
|
|
((3, 4, 5), 2, (3, 5), 1, 0),
|
|
# TODO(jakevdp): test these once unequal batches are implemented
|
|
# ((4, 5), 1, (5,), 0, (0, None)),
|
|
# ((3, 4, 5), 1, (5,), 0, (0, None)),
|
|
# ((4, 5), 0, (3, 5), 1, (None, 0)),
|
|
]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_bcoo_spmm_batched(self, lhs_shape, lhs_n_batch, rhs_shape, rhs_n_batch, dtype, in_axes):
|
|
sprng = rand_sparse(self.rng())
|
|
def args_maker():
|
|
x = sprng(lhs_shape, dtype)
|
|
y = sprng(rhs_shape, dtype)
|
|
xsp = sparse.BCOO.fromdense(x, n_batch=lhs_n_batch)
|
|
ysp = sparse.BCOO.fromdense(y, n_batch=rhs_n_batch)
|
|
return x, y, xsp, ysp
|
|
|
|
def f_dense(x, y, _, __):
|
|
return jax.vmap(operator.matmul, in_axes=in_axes)(x, y)
|
|
def f_sparse(_, __, x, y):
|
|
return jax.vmap(operator.matmul, in_axes=in_axes)(x, y)
|
|
|
|
args = args_maker()
|
|
result_dense = f_dense(*args)
|
|
result_sparse = f_sparse(*args)
|
|
self.assertAllClose(result_dense, result_sparse.todense())
|
|
result_sparse_jit = jax.jit(f_sparse)(*args)
|
|
self.assertAllClose(result_dense, result_sparse_jit.todense())
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}_nse={}_remove_zeros={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, nse, remove_zeros),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense,
|
|
"nse": nse, "remove_zeros": remove_zeros}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)
|
|
for nse in [None, np.prod(shape) - 1]
|
|
for remove_zeros in [True, False]))
|
|
def test_bcoo_sum_duplicates(self, shape, dtype, n_batch, n_dense, nse, remove_zeros):
|
|
rng = self.rng()
|
|
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
|
|
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
|
|
for i, s in enumerate(shape[n_batch:len(shape) - n_dense]):
|
|
M.indices = M.indices.at[..., i].set(rng.randint(0, s, size=M.nse))
|
|
dedupe = partial(M.sum_duplicates, nse=nse, remove_zeros=remove_zeros)
|
|
jit_dedupe = jax.jit(dedupe)
|
|
|
|
M_dedup = dedupe()
|
|
self.assertAllClose(M.todense(), M_dedup.todense())
|
|
if nse:
|
|
self.assertEqual(M_dedup.nse, nse)
|
|
|
|
if not nse:
|
|
with self.assertRaisesRegex(ValueError, ".*nse argument"):
|
|
jit_dedupe()
|
|
else:
|
|
M_dedup = jit_dedupe()
|
|
self.assertAllClose(M.todense(), M_dedup.todense())
|
|
self.assertEqual(M_dedup.nse, nse)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_sort_indices(self, shape, dtype, n_batch, n_dense):
|
|
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
|
|
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
|
|
M.indices = M.indices[..., ::-1, :]
|
|
|
|
M_sorted = M.sort_indices()
|
|
self.assertArraysEqual(M.todense(), M_sorted.todense())
|
|
|
|
indices = M_sorted.indices
|
|
if indices.size > 0:
|
|
flatind = indices.reshape(-1, *indices.shape[-2:]).transpose(0, 2, 1)
|
|
sorted = jax.vmap(jnp.lexsort)(flatind[:, ::-1])
|
|
self.assertArraysEqual(sorted, lax.broadcasted_iota(sorted.dtype, sorted.shape, sorted.ndim - 1))
|
|
|
|
def test_bcoo_sum_duplicates_inferred_nse(self):
|
|
x = sparse.BCOO.fromdense(jnp.diag(jnp.arange(4)))
|
|
self.assertEqual(x.nse, 3)
|
|
y = x + x.T
|
|
self.assertEqual(y.nse, 6)
|
|
y2 = y.sum_duplicates()
|
|
self.assertEqual(y2.nse, 3)
|
|
self.assertArraysEqual(y.todense(), y2.todense())
|
|
|
|
def test_bcoo_sum_duplicates_remove_zeros(self):
|
|
data = jnp.array([0, 1, 0, 0])
|
|
indices = jnp.array([[0], [1], [2], [3]])
|
|
x = sparse.BCOO((data, indices), shape=(4,))
|
|
self.assertEqual(x.nse, 4)
|
|
|
|
y1 = x.sum_duplicates(remove_zeros=True)
|
|
self.assertArraysEqual(x.todense(), y1.todense())
|
|
self.assertEqual(y1.nse, 1)
|
|
|
|
y2 = x.sum_duplicates(remove_zeros=False)
|
|
self.assertArraysEqual(x.todense(), y2.todense())
|
|
self.assertEqual(y2.nse, x.nse)
|
|
|
|
def test_bcoo_sum_duplicates_padding(self):
|
|
# Regression test for https://github.com/google/jax/issues/8163
|
|
size = 3
|
|
data = jnp.array([1, 0, 0])
|
|
indices = jnp.array([1, size, size])[:, None]
|
|
x = sparse.BCOO((data, indices), shape=(3,))
|
|
y = x.sum_duplicates(nse=x.nse)
|
|
self.assertArraysEqual(x.todense(), y.todense())
|
|
self.assertArraysEqual(x.indices, y.indices)
|
|
self.assertArraysEqual(x.data, y.data)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}_axes={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, axes),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "axes": axes}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)
|
|
for naxes in range(len(shape))
|
|
for axes in itertools.combinations(range(len(shape)), naxes)))
|
|
def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
data_out, indices_out, shape_out = sparse.bcoo_reduce_sum(data, indices, spinfo=BCOOInfo(shape), axes=axes)
|
|
result_dense = M.sum(axes)
|
|
result_sparse = sparse.bcoo_todense(data_out, indices_out, spinfo=BCOOInfo(shape_out))
|
|
tol = {np.float32: 1E-6, np.float64: 1E-14}
|
|
self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
|
|
|
|
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_{}".format(
|
|
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
|
|
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
|
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
|
}
|
|
for lhs_shape, rhs_shape in [[(3,), (3,)],
|
|
[(3, 4), (4,)],
|
|
[(4,), (4, 5)],
|
|
[(3, 4), (4, 5)],
|
|
[(3, 4), (2, 4, 5)],
|
|
[(2, 3, 4), (4, 5)],
|
|
[(2, 3, 4), (2, 4, 5)]]
|
|
for lhs_dtype in all_dtypes
|
|
for rhs_dtype in all_dtypes))
|
|
def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
|
|
rng = jtu.rand_default(self.rng())
|
|
lhs = jnp.array(rng(lhs_shape, lhs_dtype))
|
|
rhs = jnp.array(rng(rhs_shape, rhs_dtype))
|
|
|
|
# Note: currently, batch dimensions in matmul must correspond to batch
|
|
# dimensions in the sparse representation.
|
|
lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=max(0, len(lhs_shape) - 2))
|
|
rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=max(0, len(rhs_shape) - 2))
|
|
|
|
out1 = lhs @ rhs
|
|
out2 = lhs_sp @ rhs
|
|
out3 = lhs @ rhs_sp
|
|
|
|
tol = {np.float64: 1E-13, np.complex128: 1E-13,
|
|
np.float32: 1E-6, np.complex64: 1E-6}
|
|
self.assertAllClose(out1, out2, rtol=tol)
|
|
self.assertAllClose(out1, out3, rtol=tol)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_{}_n_batch={}_n_dense={}".format(
|
|
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
|
|
n_batch, n_dense),
|
|
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
|
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
|
"n_batch": n_batch, "n_dense": n_dense,
|
|
}
|
|
for lhs_shape, rhs_shape in [[(3,), ()], [(3,), (1,)], [(3,), (3,)],
|
|
[(3, 4), ()], [(3, 4), (4,)], [(3, 4), (3, 1)], [(3, 4), (3, 4)],
|
|
[(3, 4, 5), (4, 5)], [(3, 4, 5), (3, 1, 1)], [(3, 4, 5), (1, 4, 1)]]
|
|
for n_batch in range(len(lhs_shape) + 1)
|
|
for n_dense in range(len(lhs_shape) + 1 - n_batch)
|
|
for lhs_dtype in all_dtypes
|
|
for rhs_dtype in all_dtypes))
|
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
|
def test_bcoo_mul_dense(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense):
|
|
rng_lhs = rand_sparse(self.rng())
|
|
rng_rhs = jtu.rand_default(self.rng())
|
|
lhs = jnp.array(rng_lhs(lhs_shape, lhs_dtype))
|
|
rhs = jnp.array(rng_rhs(rhs_shape, rhs_dtype))
|
|
|
|
sp = lambda x: sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
out1 = lhs * rhs
|
|
out2 = (sp(lhs) * rhs).todense()
|
|
out3 = (rhs * sp(lhs)).todense()
|
|
|
|
tol = {np.float64: 1E-13, np.complex128: 1E-13,
|
|
np.float32: 1E-6, np.complex64: 1E-6}
|
|
self.assertAllClose(out1, out2, rtol=tol)
|
|
self.assertAllClose(out1, out3, rtol=tol)
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_{}_n_batch={}_n_dense={}".format(
|
|
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
|
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
|
|
n_batch, n_dense),
|
|
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
|
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
|
"n_batch": n_batch, "n_dense": n_dense,
|
|
}
|
|
# TODO(jakevdp): add broadcasted shapes (from bcoo_mul_dense) once sparse-sparse mul
|
|
# supports inputs of differing rank.
|
|
for lhs_shape, rhs_shape in [[(3,), (1,)], [(3,), (3,)],
|
|
[(3, 4), (1, 1)], [(3, 4), (1, 4)], [(3, 4), (3, 1)], [(3, 4), (3, 4)],
|
|
[(3, 4, 5), (1, 4, 5)], [(3, 4, 5), (3, 1, 1)], [(3, 4, 5), (1, 4, 1)]]
|
|
# TODO(jakevdp): add tests for batch & dense dimensions.
|
|
for n_batch in range(len(lhs_shape) + 1)
|
|
for n_dense in range(len(lhs_shape) + 1 - n_batch)
|
|
for lhs_dtype in all_dtypes
|
|
for rhs_dtype in all_dtypes))
|
|
def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense):
|
|
rng = rand_sparse(self.rng())
|
|
lhs = jnp.array(rng(lhs_shape, lhs_dtype))
|
|
rhs = jnp.array(rng(rhs_shape, rhs_dtype))
|
|
|
|
sp = lambda x: sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
out1 = lhs * rhs
|
|
out2 = (sp(lhs) * sp(rhs)).todense()
|
|
|
|
tol = {np.float64: 1E-13, np.complex128: 1E-13,
|
|
np.float32: 1E-6, np.complex64: 1E-6}
|
|
self.assertAllClose(out1, out2, rtol=tol)
|
|
|
|
def test_bcoo_mul_sparse_with_duplicates(self):
|
|
# Regression test for https://github.com/google/jax/issues/8888
|
|
indices = jnp.array([[0, 1, 0, 0, 1, 1],
|
|
[1, 0, 1, 2, 0, 2]]).T
|
|
data = jnp.array([1, 2, 3, 4, 5, 6])
|
|
mat = sparse.BCOO((data, indices), shape=(3, 3))
|
|
self.assertArraysEqual((mat * mat).todense(), mat.todense() * mat.todense())
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_n_batch={}_n_dense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(), (3,), (3, 5), (3, 5, 4)]
|
|
for dtype in all_dtypes
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_broadcast_in_dim(self, shape, dtype, n_batch, n_dense):
|
|
rng = rand_sparse(self.rng())
|
|
x = jnp.array(rng(shape, dtype))
|
|
xsp = sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
self.assertArraysEqual(xsp[None].todense(), x[None])
|
|
if len(shape) >= 1:
|
|
self.assertArraysEqual(xsp[:, None].todense(), x[:, None])
|
|
self.assertArraysEqual(xsp[:, None, None].todense(), x[:, None, None])
|
|
if len(shape) >= 2:
|
|
self.assertArraysEqual(xsp[:, :, None].todense(), x[:, :, None])
|
|
self.assertArraysEqual(xsp[:, None, :, None].todense(), x[:, None, :, None])
|
|
|
|
def test_bcoo_vmap_shape(self, shape=(2, 3, 4, 5), dtype=np.float32):
|
|
# This test checks that BCOO shape metadata interacts correctly with vmap.
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
|
|
def make_bcoo(M):
|
|
return sparse.BCOO.fromdense(M, nse=np.prod(M.shape[:-1], dtype=int), n_dense=1)
|
|
|
|
for _ in range(3):
|
|
make_bcoo = jax.vmap(make_bcoo)
|
|
Msp = make_bcoo(M)
|
|
self.assertEqual(Msp.shape, M.shape)
|
|
self.assertArraysEqual(Msp.todense(), M)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_bcoo_unbatch(self, shape, dtype, n_batch, n_dense):
|
|
rng_sparse = rand_sparse(self.rng())
|
|
M1 = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
|
|
M2 = M1._unbatch()
|
|
self.assertEqual(M2.n_batch, 0)
|
|
self.assertEqual(M1.n_dense, M2.n_dense)
|
|
self.assertEqual(M1.shape, M2.shape)
|
|
self.assertEqual(M1.dtype, M2.dtype)
|
|
self.assertArraysEqual(M1.todense(), M2.todense())
|
|
|
|
def test_bcoo_bad_fillvals(self):
|
|
# Extra values have 100 rather than zero. This lets us check that logic is
|
|
# properly ignoring these indices.
|
|
data = jnp.array([1, 2, 3, 100, 100])
|
|
indices = jnp.array([1, 2, 3, 5, 5])[:, None]
|
|
x_sp = sparse.BCOO((data, indices), shape=(5,))
|
|
x_de = x_sp.todense()
|
|
|
|
data = jnp.array([3, 2, 100, 100])
|
|
indices = jnp.array([2, 3, 5, 5])[:, None]
|
|
y_sp = sparse.BCOO((data, indices), shape=(5,))
|
|
y_de = y_sp.todense()
|
|
|
|
self.assertArraysEqual(x_de, jnp.array([0, 1, 2, 3, 0]))
|
|
self.assertArraysEqual(y_de, jnp.array([0, 0, 3, 2, 0]))
|
|
|
|
self.assertArraysEqual(x_sp.sum_duplicates().todense(), x_de)
|
|
self.assertArraysEqual(y_sp.sum_duplicates().todense(), y_de)
|
|
|
|
# reduce_sum:
|
|
self.assertArraysEqual(x_sp.sum(), x_de.sum())
|
|
|
|
# bcoo_dot_general
|
|
self.assertArraysEqual(x_sp @ y_de, x_de @ y_de)
|
|
|
|
# bcoo_spdot_general
|
|
self.assertArraysEqual((x_sp @ y_sp).todense(), x_de @ y_de)
|
|
self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de)
|
|
|
|
|
|
class SparseGradTest(jtu.JaxTestCase):
|
|
def test_sparse_grad(self):
|
|
rng_sparse = rand_sparse(self.rng())
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
y = rng(5, "float32")
|
|
X = rng_sparse((10, 5), "float32")
|
|
Xsp = sparse.BCOO.fromdense(X)
|
|
|
|
def f(X, y):
|
|
return jnp.sum(X @ y)
|
|
|
|
grad_dense = jax.grad(f, argnums=0)(X, y)
|
|
grad_sparse = sparse.grad(f, argnums=0)(Xsp, y)
|
|
|
|
# extract sparse gradient from dense gradient
|
|
indices = tuple(Xsp.indices.T)
|
|
grad_sparse_from_dense = jnp.zeros_like(grad_dense).at[indices].set(grad_dense[indices])
|
|
|
|
self.assertArraysEqual(grad_sparse.todense(), grad_sparse_from_dense)
|
|
|
|
|
|
class SparseObjectTest(jtu.JaxTestCase):
|
|
def test_repr(self):
|
|
M = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32'))
|
|
self.assertEqual(repr(M), "BCOO(float32[5], nse=4)")
|
|
|
|
M_invalid = sparse.BCOO(([], []), shape=(100,))
|
|
self.assertEqual(repr(M_invalid), "BCOO(<invalid>)")
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}{}".format(cls.__name__, shape), "cls": cls, "shape": shape}
|
|
for cls in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]
|
|
for shape in ([2, 5], [5, 3]))
|
|
def test_empty(self, cls, shape):
|
|
sparse_format = cls.__name__.lower()
|
|
M = sparse.empty(shape, sparse_format=sparse_format)
|
|
self.assertIsInstance(M, cls)
|
|
self.assertEqual(M.nse, 0)
|
|
self.assertArraysEqual(M.todense(), jnp.empty(shape))
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "{}_BCOO{}".format(nse, shape), "shape": shape, "nse": nse}
|
|
for shape in ([2, 5], [5, 3])
|
|
for nse in [0, 2])
|
|
def test_empty_nse(self, shape, nse=2):
|
|
M = sparse.empty(shape, nse=nse)
|
|
self.assertEqual(M.nse, nse)
|
|
self.assertArraysEqual(M.todense(), jnp.empty(shape))
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
|
|
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])
|
|
def test_block_until_ready(self, Obj, shape=(5, 8), dtype=np.float32):
|
|
rng = rand_sparse(self.rng(), post=Obj.fromdense)
|
|
M = rng(shape, dtype)
|
|
self.assertEqual(M.shape, M.block_until_ready().shape)
|
|
self.assertArraysEqual(M.data, M.block_until_ready().data)
|
|
self.assertArraysEqual(M.todense(), M.block_until_ready().todense())
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
|
|
for Obj in [jnp.array, sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])
|
|
def test_todense(self, Obj, shape=(5, 8), dtype=np.float32):
|
|
rng = rand_sparse(self.rng())
|
|
M_dense = rng(shape, dtype)
|
|
M = jnp.array(M_dense) if Obj is jnp.array else Obj.fromdense(M_dense)
|
|
self.assertArraysEqual(sparse.todense(M), M_dense)
|
|
self.assertArraysEqual(jit(sparse.todense)(M), M_dense)
|
|
|
|
def test_todense_scalar(self):
|
|
self.assertEqual(sparse.todense(1.0), 1.0)
|
|
self.assertEqual(jit(sparse.todense)(1.0), 1.0)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
|
|
for Obj in [jnp.array, sparse.BCOO])
|
|
def test_todense_batching(self, Obj, shape=(5, 8), dtype=np.float32):
|
|
rng = rand_sparse(self.rng())
|
|
M_dense = rng(shape, dtype)
|
|
if Obj is sparse.BCOO:
|
|
M = sparse.BCOO.fromdense(M_dense, n_batch=1)
|
|
else:
|
|
M = jnp.asarray(M_dense)
|
|
self.assertArraysEqual(vmap(sparse.todense)(M), M_dense)
|
|
self.assertArraysEqual(jit(vmap(sparse.todense))(M), M_dense)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
|
|
for Obj in [jnp.array, sparse.BCOO])
|
|
def test_todense_ad(self, Obj, shape=(3,), dtype=np.float32):
|
|
M_dense = jnp.array([1., 2., 3.])
|
|
M = M_dense if Obj is jnp.array else Obj.fromdense(M_dense)
|
|
bufs, tree = tree_util.tree_flatten(M)
|
|
jac = jnp.eye(M.shape[0], dtype=M.dtype)
|
|
jac1 = jax.jacfwd(lambda *bufs: sparse.todense_p.bind(*bufs, tree=tree))(*bufs)
|
|
jac2 = jax.jacrev(lambda *bufs: sparse.todense_p.bind(*bufs, tree=tree))(*bufs)
|
|
self.assertArraysEqual(jac1, jac2)
|
|
self.assertArraysEqual(jac, jac2)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
|
|
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])
|
|
def test_attrs(self, Obj, shape=(5, 8), dtype=np.float16):
|
|
rng = rand_sparse(self.rng(), post=Obj.fromdense)
|
|
M = rng(shape, dtype)
|
|
|
|
assert isinstance(M, Obj)
|
|
assert M.shape == shape
|
|
assert M.size == np.prod(shape)
|
|
assert M.ndim == len(shape)
|
|
assert M.dtype == dtype
|
|
assert M.nse == (M.todense() != 0).sum()
|
|
assert M.data.dtype == dtype
|
|
|
|
with self.assertRaises(TypeError):
|
|
hash(M)
|
|
|
|
if isinstance(M, sparse.CSR):
|
|
assert len(M.data) == len(M.indices)
|
|
assert len(M.indptr) == M.shape[0] + 1
|
|
elif isinstance(M, sparse.CSC):
|
|
assert len(M.data) == len(M.indices)
|
|
assert len(M.indptr) == M.shape[1] + 1
|
|
elif isinstance(M, sparse.COO):
|
|
assert len(M.data) == len(M.row) == len(M.col)
|
|
elif isinstance(M, sparse.BCOO):
|
|
assert M.data.shape[M.n_batch] == M.indices.shape[-2]
|
|
assert M.indices.shape[-1] == M.n_sparse
|
|
else:
|
|
raise ValueError("Obj={Obj} not expected.")
|
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
jtu.cases_from_list(
|
|
{"testcase_name": "_{}_Obj={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), Obj.__name__),
|
|
"shape": shape, "dtype": dtype, "Obj": Obj}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
|
|
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]))
|
|
def test_dense_round_trip(self, shape, dtype, Obj):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
Msparse = Obj.fromdense(M)
|
|
self.assertArraysEqual(M, Msparse.todense())
|
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
jtu.cases_from_list(
|
|
{"testcase_name": "_{}_Obj={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), Obj.__name__),
|
|
"shape": shape, "dtype": dtype, "Obj": Obj}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
|
|
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]))
|
|
def test_transpose(self, shape, dtype, Obj):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
Msparse = Obj.fromdense(M)
|
|
self.assertArraysEqual(M.T, Msparse.T.todense())
|
|
|
|
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
jtu.cases_from_list(
|
|
{"testcase_name": "_{}_Obj={}_bshape={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), Obj.__name__, bshape),
|
|
"shape": shape, "dtype": dtype, "Obj": Obj, "bshape": bshape}
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for bshape in [shape[-1:] + s for s in [(), (3,), (4,)]]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
|
|
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]))
|
|
def test_matmul(self, shape, dtype, Obj, bshape):
|
|
rng = rand_sparse(self.rng(), post=jnp.array)
|
|
rng_b = jtu.rand_default(self.rng())
|
|
M = rng(shape, dtype)
|
|
Msp = Obj.fromdense(M)
|
|
|
|
# Test matching type
|
|
x = rng_b(bshape, dtype)
|
|
x = jnp.asarray(x)
|
|
self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)
|
|
|
|
# Test mismatched type
|
|
x = rng_b(bshape, np.int32)
|
|
x = jnp.asarray(x)
|
|
self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}({})".format(
|
|
input_type.__name__,
|
|
jtu.format_shape_dtype_string(shape, dtype)),
|
|
"input_type": input_type, "shape": shape, "dtype": dtype}
|
|
for input_type in [scipy.sparse.coo_matrix, scipy.sparse.csr_matrix, scipy.sparse.csc_matrix]
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
def test_bcoo_from_scipy_sparse(self, input_type, shape, dtype):
|
|
rng = rand_sparse(self.rng())
|
|
M = rng(shape, dtype)
|
|
M_sparse = input_type(M)
|
|
M_bcoo = sparse.BCOO.from_scipy_sparse(M_sparse)
|
|
self.assertArraysEqual(M, M_bcoo.todense())
|
|
|
|
def test_bcoo_methods(self):
|
|
M = jnp.arange(12).reshape(3, 4)
|
|
Msp = sparse.BCOO.fromdense(M)
|
|
|
|
self.assertArraysEqual(-M, (-Msp).todense())
|
|
|
|
self.assertArraysEqual(2 * M, (2 * Msp).todense())
|
|
self.assertArraysEqual(M * 2, (Msp * 2).todense())
|
|
|
|
self.assertArraysEqual(M + M, (Msp + Msp).todense())
|
|
|
|
self.assertArraysEqual(M.sum(0), Msp.sum(0).todense())
|
|
self.assertArraysEqual(M.sum(1), Msp.sum(1).todense())
|
|
self.assertArraysEqual(M.sum(), Msp.sum())
|
|
|
|
|
|
class SparseRandomTest(jtu.JaxTestCase):
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_{}_indices_dtype={}_nbatch={}_ndense={}".format(
|
|
jtu.format_shape_dtype_string(shape, dtype), indices_dtype, n_batch, n_dense),
|
|
"shape": shape, "dtype": dtype, "indices_dtype": indices_dtype,
|
|
"n_batch": n_batch, "n_dense": n_dense}
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
for dtype in jtu.dtypes.floating
|
|
for indices_dtype in jtu.dtypes.integer
|
|
for n_batch in range(len(shape) + 1)
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
def test_random_bcoo(self, shape, dtype, indices_dtype, n_batch, n_dense):
|
|
key = jax.random.PRNGKey(1701)
|
|
mat = sparse.random_bcoo(
|
|
key, shape=shape, dtype=dtype, indices_dtype=indices_dtype,
|
|
n_batch=n_batch, n_dense=n_dense)
|
|
|
|
mat_dense = mat.todense()
|
|
self.assertEqual(mat_dense.shape, shape)
|
|
self.assertEqual(mat_dense.dtype, dtype)
|
|
self.assertEqual(mat.indices.dtype, indices_dtype)
|
|
|
|
n_sparse = len(shape) - n_batch - n_dense
|
|
batch_shape, sparse_shape, dense_shape = split_list(shape, [n_batch, n_sparse])
|
|
|
|
approx_expected_num_nonzero = (
|
|
np.ceil(0.2 * np.prod(sparse_shape))
|
|
* np.prod(batch_shape) * np.prod(dense_shape))
|
|
num_nonzero = (mat_dense != 0).sum()
|
|
self.assertAlmostEqual(num_nonzero, approx_expected_num_nonzero, delta=2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|