2021-04-15 10:10:40 -07:00
|
|
|
# Copyright 2021 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2021-06-08 15:46:35 -07:00
|
|
|
from functools import partial
|
2021-04-26 17:40:17 -07:00
|
|
|
import itertools
|
2021-06-22 13:33:52 -07:00
|
|
|
from jax._src.api import vmap
|
2021-04-15 10:10:40 -07:00
|
|
|
import unittest
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
2021-05-14 13:43:53 -07:00
|
|
|
|
2021-06-08 15:46:35 -07:00
|
|
|
import jax
|
2021-05-14 13:43:53 -07:00
|
|
|
from jax import api
|
2021-04-15 10:10:40 -07:00
|
|
|
from jax import config
|
2021-05-14 13:43:53 -07:00
|
|
|
from jax import dtypes
|
2021-06-16 16:15:56 -07:00
|
|
|
from jax.experimental import sparse
|
|
|
|
from jax.experimental.sparse.ops import _bcoo_nse, _dedupe_bcoo
|
2021-06-08 15:46:35 -07:00
|
|
|
from jax import lax
|
2021-04-27 15:42:39 -07:00
|
|
|
from jax.lib import cusparse
|
2021-04-15 10:10:40 -07:00
|
|
|
from jax.lib import xla_bridge
|
|
|
|
from jax import jit
|
|
|
|
from jax import test_util as jtu
|
|
|
|
from jax import xla
|
|
|
|
import jax.numpy as jnp
|
2021-05-01 12:09:03 +00:00
|
|
|
from jax import jvp
|
2021-04-15 10:10:40 -07:00
|
|
|
import numpy as np
|
2021-06-16 16:15:56 -07:00
|
|
|
import scipy.sparse
|
2021-04-15 10:10:40 -07:00
|
|
|
config.parse_flags_with_absl()
|
|
|
|
FLAGS = config.FLAGS
|
|
|
|
|
2021-04-26 17:40:17 -07:00
|
|
|
MATMUL_TOL = {
|
|
|
|
np.float32: 1E-5,
|
|
|
|
np.float64: 1E-10,
|
|
|
|
np.complex64: 1e-5,
|
|
|
|
np.complex128: 1E-10,
|
|
|
|
}
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2021-06-14 15:45:47 -07:00
|
|
|
all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
|
2021-04-26 17:40:17 -07:00
|
|
|
|
|
|
|
def rand_sparse(rng, nnz=0.5, post=lambda x: x):
|
2021-04-15 10:10:40 -07:00
|
|
|
def _rand_sparse(shape, dtype, nnz=nnz):
|
|
|
|
rand = jtu.rand_default(rng)
|
|
|
|
size = np.prod(shape)
|
|
|
|
if 0 <= nnz < 1:
|
|
|
|
nnz = nnz * size
|
|
|
|
nnz = min(size, int(nnz))
|
|
|
|
M = rand(shape, dtype)
|
|
|
|
indices = rng.choice(size, size - nnz, replace=False)
|
|
|
|
M.flat[indices] = 0
|
|
|
|
return post(M)
|
|
|
|
return _rand_sparse
|
|
|
|
|
|
|
|
|
|
|
|
class cuSparseTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
|
|
def test_csr_todense(self, shape, dtype):
|
2021-06-16 16:15:56 -07:00
|
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix)
|
2021-04-15 10:10:40 -07:00
|
|
|
M = rng(shape, dtype)
|
|
|
|
|
|
|
|
args = (M.data, M.indices, M.indptr)
|
2021-06-16 16:15:56 -07:00
|
|
|
todense = lambda *args: sparse.csr_todense(*args, shape=M.shape)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
|
|
|
self.assertArraysEqual(M.toarray(), todense(*args))
|
|
|
|
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
|
|
def test_csr_fromdense(self, shape, dtype):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
M_csr = scipy.sparse.csr_matrix(M)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
|
|
|
nnz = M_csr.nnz
|
|
|
|
index_dtype = jnp.int32
|
2021-06-16 16:15:56 -07:00
|
|
|
fromdense = lambda M: sparse.csr_fromdense(M, nnz=nnz, index_dtype=jnp.int32)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
|
|
|
data, indices, indptr = fromdense(M)
|
|
|
|
self.assertArraysEqual(data, M_csr.data.astype(dtype))
|
|
|
|
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
|
|
|
|
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
|
|
|
|
|
|
|
|
data, indices, indptr = jit(fromdense)(M)
|
|
|
|
self.assertArraysEqual(data, M_csr.data.astype(dtype))
|
|
|
|
self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype))
|
|
|
|
self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype))
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
|
|
|
"shape": shape, "dtype": dtype, "transpose": transpose}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
for transpose in [True, False]))
|
|
|
|
def test_csr_matvec(self, shape, dtype, transpose):
|
|
|
|
op = lambda M: M.T if transpose else M
|
|
|
|
|
|
|
|
v_rng = jtu.rand_default(self.rng())
|
2021-06-16 16:15:56 -07:00
|
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix)
|
2021-04-15 10:10:40 -07:00
|
|
|
M = rng(shape, dtype)
|
|
|
|
v = v_rng(op(M).shape[1], dtype)
|
|
|
|
|
|
|
|
args = (M.data, M.indices, M.indptr, v)
|
2021-06-16 16:15:56 -07:00
|
|
|
matvec = lambda *args: sparse.csr_matvec(*args, shape=M.shape, transpose=transpose)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2021-04-26 17:40:17 -07:00
|
|
|
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
|
|
|
|
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
|
|
|
"shape": shape, "dtype": dtype, "transpose": transpose}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
for transpose in [True, False]))
|
|
|
|
def test_csr_matmat(self, shape, dtype, transpose):
|
|
|
|
op = lambda M: M.T if transpose else M
|
|
|
|
|
|
|
|
B_rng = jtu.rand_default(self.rng())
|
2021-06-16 16:15:56 -07:00
|
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.csr_matrix)
|
2021-04-15 10:10:40 -07:00
|
|
|
M = rng(shape, dtype)
|
|
|
|
B = B_rng((op(M).shape[1], 4), dtype)
|
|
|
|
|
|
|
|
args = (M.data, M.indices, M.indptr, B)
|
2021-06-16 16:15:56 -07:00
|
|
|
matmat = lambda *args: sparse.csr_matmat(*args, shape=shape, transpose=transpose)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2021-04-26 17:40:17 -07:00
|
|
|
self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
|
|
|
|
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
|
|
def test_coo_todense(self, shape, dtype):
|
2021-06-16 16:15:56 -07:00
|
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix)
|
2021-04-15 10:10:40 -07:00
|
|
|
M = rng(shape, dtype)
|
|
|
|
|
|
|
|
args = (M.data, M.row, M.col)
|
2021-06-16 16:15:56 -07:00
|
|
|
todense = lambda *args: sparse.coo_todense(*args, shape=M.shape)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
|
|
|
self.assertArraysEqual(M.toarray(), todense(*args))
|
|
|
|
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
|
|
def test_coo_fromdense(self, shape, dtype):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
M_coo = scipy.sparse.coo_matrix(M)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
|
|
|
nnz = M_coo.nnz
|
|
|
|
index_dtype = jnp.int32
|
2021-06-16 16:15:56 -07:00
|
|
|
fromdense = lambda M: sparse.coo_fromdense(M, nnz=nnz, index_dtype=jnp.int32)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
|
|
|
data, row, col = fromdense(M)
|
|
|
|
self.assertArraysEqual(data, M_coo.data.astype(dtype))
|
|
|
|
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
|
|
|
|
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
|
|
|
|
|
|
|
|
data, indices, indptr = jit(fromdense)(M)
|
|
|
|
self.assertArraysEqual(data, M_coo.data.astype(dtype))
|
|
|
|
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
|
|
|
|
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
|
|
|
"shape": shape, "dtype": dtype, "transpose": transpose}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
for transpose in [True, False]))
|
|
|
|
def test_coo_matvec(self, shape, dtype, transpose):
|
|
|
|
op = lambda M: M.T if transpose else M
|
|
|
|
|
|
|
|
v_rng = jtu.rand_default(self.rng())
|
2021-06-16 16:15:56 -07:00
|
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix)
|
2021-04-15 10:10:40 -07:00
|
|
|
M = rng(shape, dtype)
|
|
|
|
v = v_rng(op(M).shape[1], dtype)
|
|
|
|
|
|
|
|
args = (M.data, M.row, M.col, v)
|
2021-06-16 16:15:56 -07:00
|
|
|
matvec = lambda *args: sparse.coo_matvec(*args, shape=M.shape, transpose=transpose)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2021-04-26 17:40:17 -07:00
|
|
|
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
|
|
|
|
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2021-05-01 12:09:03 +00:00
|
|
|
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
2021-04-15 10:10:40 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
|
|
|
"shape": shape, "dtype": dtype, "transpose": transpose}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
for transpose in [True, False]))
|
|
|
|
def test_coo_matmat(self, shape, dtype, transpose):
|
|
|
|
op = lambda M: M.T if transpose else M
|
|
|
|
|
|
|
|
B_rng = jtu.rand_default(self.rng())
|
2021-06-16 16:15:56 -07:00
|
|
|
rng = rand_sparse(self.rng(), post=scipy.sparse.coo_matrix)
|
2021-04-15 10:10:40 -07:00
|
|
|
M = rng(shape, dtype)
|
|
|
|
B = B_rng((op(M).shape[1], 4), dtype)
|
|
|
|
|
|
|
|
args = (M.data, M.row, M.col, B)
|
2021-06-16 16:15:56 -07:00
|
|
|
matmat = lambda *args: sparse.coo_matmat(*args, shape=shape, transpose=transpose)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2021-04-26 17:40:17 -07:00
|
|
|
self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
|
|
|
|
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2021-06-16 16:15:56 -07:00
|
|
|
y, dy = jvp(lambda x: sparse.coo_matmat(M.data, M.row, M.col, x, shape=shape, transpose=transpose).sum(), (B, ), (jnp.ones_like(B), ))
|
2021-05-01 12:09:03 +00:00
|
|
|
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
|
|
|
|
|
2021-06-16 16:15:56 -07:00
|
|
|
y, dy = jvp(lambda x: sparse.coo_matmat(x, M.row, M.col, B, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
|
2021-05-01 12:09:03 +00:00
|
|
|
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
|
2021-05-14 13:43:53 -07:00
|
|
|
|
2021-04-15 10:10:40 -07:00
|
|
|
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
|
|
|
def test_gpu_translation_rule(self):
|
|
|
|
version = xla_bridge.get_backend().platform_version
|
|
|
|
cuda_version = None if version == "<unknown>" else int(version.split()[-1])
|
|
|
|
if cuda_version is None or cuda_version < 11000:
|
2021-04-27 15:42:39 -07:00
|
|
|
self.assertFalse(cusparse and cusparse.is_supported)
|
2021-06-16 16:15:56 -07:00
|
|
|
self.assertNotIn(sparse.csr_todense_p, xla.backend_specific_translations["gpu"])
|
2021-04-15 10:10:40 -07:00
|
|
|
else:
|
2021-04-27 15:42:39 -07:00
|
|
|
self.assertTrue(cusparse and cusparse.is_supported)
|
2021-06-16 16:15:56 -07:00
|
|
|
self.assertIn(sparse.csr_todense_p, xla.backend_specific_translations["gpu"])
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2021-04-23 10:36:28 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_{}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), mat_type),
|
|
|
|
"shape": shape, "dtype": dtype, "mat_type": mat_type}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
for mat_type in ['csr', 'coo']))
|
|
|
|
def test_extra_nnz(self, shape, dtype, mat_type):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
|
|
|
nnz = (M != 0).sum() + 5
|
2021-06-16 16:15:56 -07:00
|
|
|
fromdense = getattr(sparse, f"{mat_type}_fromdense")
|
|
|
|
todense = getattr(sparse, f"{mat_type}_todense")
|
2021-04-23 10:36:28 -07:00
|
|
|
args = fromdense(M, nnz=nnz, index_dtype=jnp.int32)
|
|
|
|
M_out = todense(*args, shape=M.shape)
|
|
|
|
self.assertArraysEqual(M, M_out)
|
|
|
|
|
2021-05-14 13:43:53 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
|
|
def test_coo_todense_ad(self, shape, dtype):
|
|
|
|
rng = rand_sparse(self.rng(), post=jnp.array)
|
|
|
|
M = rng(shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, row, col = sparse.coo_fromdense(M, nnz=(M != 0).sum())
|
|
|
|
f = lambda data: sparse.coo_todense(data, row, col, shape=M.shape)
|
2021-05-14 13:43:53 -07:00
|
|
|
|
|
|
|
# Forward-mode
|
|
|
|
primals, tangents = api.jvp(f, [data], [jnp.ones_like(data)])
|
|
|
|
self.assertArraysEqual(primals, f(data))
|
|
|
|
self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1))
|
|
|
|
|
|
|
|
# Reverse-mode
|
|
|
|
primals, vjp_fun = api.vjp(f, data)
|
|
|
|
data_out, = vjp_fun(primals)
|
|
|
|
self.assertArraysEqual(primals, f(data))
|
|
|
|
self.assertArraysEqual(data_out, data)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
|
|
def test_coo_fromdense_ad(self, shape, dtype):
|
|
|
|
rng = rand_sparse(self.rng(), post=jnp.array)
|
|
|
|
M = rng(shape, dtype)
|
|
|
|
nnz = (M != 0).sum()
|
2021-06-16 16:15:56 -07:00
|
|
|
f = lambda M: sparse.coo_fromdense(M, nnz=nnz)
|
2021-05-14 13:43:53 -07:00
|
|
|
|
|
|
|
# Forward-mode
|
|
|
|
primals, tangents = api.jvp(f, [M], [jnp.ones_like(M)])
|
|
|
|
self.assertArraysEqual(primals[0], f(M)[0])
|
|
|
|
self.assertArraysEqual(primals[1], f(M)[1])
|
|
|
|
self.assertArraysEqual(primals[2], f(M)[2])
|
|
|
|
self.assertArraysEqual(tangents[0], jnp.ones(nnz, dtype=dtype))
|
|
|
|
self.assertEqual(tangents[1].dtype, dtypes.float0)
|
|
|
|
self.assertEqual(tangents[2].dtype, dtypes.float0)
|
|
|
|
|
|
|
|
# Reverse-mode
|
|
|
|
primals, vjp_fun = api.vjp(f, M)
|
|
|
|
M_out, = vjp_fun(primals)
|
|
|
|
self.assertArraysEqual(primals[0], f(M)[0])
|
|
|
|
self.assertArraysEqual(primals[1], f(M)[1])
|
|
|
|
self.assertArraysEqual(primals[2], f(M)[2])
|
|
|
|
self.assertArraysEqual(M_out, M)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_{}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(bshape, dtype)),
|
|
|
|
"shape": shape, "dtype": dtype, "bshape": bshape}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for bshape in [shape[-1:] + s for s in [()]] # TODO: matmul autodiff
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) # TODO: other types
|
|
|
|
|
|
|
|
def test_coo_matvec_ad(self, shape, dtype, bshape):
|
2021-05-14 14:18:12 -07:00
|
|
|
tol = {np.float32: 1E-6, np.float64: 1E-13, np.complex64: 1E-6, np.complex128: 1E-13}
|
|
|
|
|
2021-05-14 13:43:53 -07:00
|
|
|
rng = rand_sparse(self.rng(), post=jnp.array)
|
|
|
|
rng_b = jtu.rand_default(self.rng())
|
|
|
|
|
|
|
|
M = rng(shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, row, col = sparse.coo_fromdense(M, nnz=(M != 0).sum())
|
2021-05-14 13:43:53 -07:00
|
|
|
x = rng_b(bshape, dtype)
|
|
|
|
xdot = rng_b(bshape, dtype)
|
|
|
|
|
|
|
|
# Forward-mode with respect to the vector
|
|
|
|
f_dense = lambda x: M @ x
|
2021-06-16 16:15:56 -07:00
|
|
|
f_sparse = lambda x: sparse.coo_matvec(data, row, col, x, shape=M.shape)
|
2021-05-14 13:43:53 -07:00
|
|
|
v_sparse, t_sparse = api.jvp(f_sparse, [x], [xdot])
|
|
|
|
v_dense, t_dense = api.jvp(f_dense, [x], [xdot])
|
2021-05-14 14:18:12 -07:00
|
|
|
self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
|
|
|
|
self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)
|
2021-05-14 13:43:53 -07:00
|
|
|
|
|
|
|
# Reverse-mode with respect to the vector
|
|
|
|
primals_dense, vjp_dense = api.vjp(f_dense, x)
|
|
|
|
primals_sparse, vjp_sparse = api.vjp(f_sparse, x)
|
|
|
|
out_dense, = vjp_dense(primals_dense)
|
|
|
|
out_sparse, = vjp_sparse(primals_sparse)
|
2021-05-14 14:18:12 -07:00
|
|
|
self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
|
|
|
|
self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
|
2021-05-14 13:43:53 -07:00
|
|
|
|
|
|
|
# Forward-mode with respect to nonzero elements of the matrix
|
2021-06-16 16:15:56 -07:00
|
|
|
f_sparse = lambda data: sparse.coo_matvec(data, row, col, x, shape=M.shape)
|
|
|
|
f_dense = lambda data: sparse.coo_todense(data, row, col, shape=M.shape) @ x
|
2021-05-14 13:43:53 -07:00
|
|
|
data = rng((len(data),), data.dtype)
|
|
|
|
data_dot = rng((len(data),), data.dtype)
|
|
|
|
v_sparse, t_sparse = api.jvp(f_sparse, [data], [data_dot])
|
|
|
|
v_dense, t_dense = api.jvp(f_dense, [data], [data_dot])
|
2021-05-14 14:18:12 -07:00
|
|
|
|
|
|
|
self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
|
|
|
|
self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)
|
2021-05-14 13:43:53 -07:00
|
|
|
|
|
|
|
# Reverse-mode with respect to nonzero elements of the matrix
|
|
|
|
primals_dense, vjp_dense = api.vjp(f_dense, data)
|
|
|
|
primals_sparse, vjp_sparse = api.vjp(f_sparse, data)
|
|
|
|
out_dense, = vjp_dense(primals_dense)
|
|
|
|
out_sparse, = vjp_sparse(primals_sparse)
|
2021-05-14 14:18:12 -07:00
|
|
|
self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
|
|
|
|
self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
|
2021-05-14 13:43:53 -07:00
|
|
|
|
2021-06-08 15:46:35 -07:00
|
|
|
class BCOOTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
|
|
def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
|
|
|
n_sparse = M.ndim - n_batch - n_dense
|
2021-06-16 16:15:56 -07:00
|
|
|
nse = int(_bcoo_nse(M, n_batch=n_batch, n_dense=n_dense))
|
|
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
2021-06-08 15:46:35 -07:00
|
|
|
# TODO: test fromdense JIT
|
|
|
|
|
|
|
|
assert data.dtype == dtype
|
|
|
|
assert data.shape == shape[:n_batch] + (nse,) + shape[n_batch + n_sparse:]
|
|
|
|
assert indices.dtype == jnp.int32 # TODO: test passing this arg
|
|
|
|
assert indices.shape == shape[:n_batch] + (n_sparse, nse)
|
|
|
|
|
2021-06-16 16:15:56 -07:00
|
|
|
todense = partial(sparse.bcoo_todense, shape=shape)
|
2021-06-08 15:46:35 -07:00
|
|
|
self.assertArraysEqual(M, todense(data, indices))
|
|
|
|
self.assertArraysEqual(M, jit(todense)(data, indices))
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
for dtype in jtu.dtypes.floating
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
|
|
def test_bcoo_todense_ad(self, shape, dtype, n_batch, n_dense):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
2021-06-08 15:46:35 -07:00
|
|
|
|
2021-06-16 16:15:56 -07:00
|
|
|
todense = partial(sparse.bcoo_todense, indices=indices, shape=shape)
|
2021-06-08 15:46:35 -07:00
|
|
|
j1 = jax.jacfwd(todense)(data)
|
|
|
|
j2 = jax.jacrev(todense)(data)
|
|
|
|
hess = jax.hessian(todense)(data)
|
|
|
|
self.assertArraysAllClose(j1, j2)
|
|
|
|
self.assertEqual(j1.shape, M.shape + data.shape)
|
|
|
|
self.assertEqual(hess.shape, M.shape + 2 * data.shape)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
for dtype in jtu.dtypes.floating
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
|
|
def test_bcoo_fromdense_ad(self, shape, dtype, n_batch, n_dense):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
nse = int(_bcoo_nse(M, n_batch=n_batch, n_dense=n_dense))
|
2021-06-08 15:46:35 -07:00
|
|
|
|
|
|
|
def fromdense(M):
|
2021-06-16 16:15:56 -07:00
|
|
|
return sparse.bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense)[0]
|
2021-06-08 15:46:35 -07:00
|
|
|
data = fromdense(M)
|
|
|
|
|
|
|
|
j1 = jax.jacfwd(fromdense)(M)
|
|
|
|
j2 = jax.jacrev(fromdense)(M)
|
|
|
|
hess = jax.hessian(fromdense)(M)
|
|
|
|
self.assertArraysAllClose(j1, j2)
|
|
|
|
self.assertEqual(j1.shape, data.shape + M.shape)
|
|
|
|
self.assertEqual(hess.shape, data.shape + 2 * M.shape)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
|
|
def test_bcoo_dense_round_trip_batched(self, shape, dtype, n_batch, n_dense):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
|
|
|
n_sparse = M.ndim - n_batch - n_dense
|
2021-06-16 16:15:56 -07:00
|
|
|
nse = int(_bcoo_nse(M, n_batch=n_batch, n_dense=n_dense))
|
2021-06-08 15:46:35 -07:00
|
|
|
|
2021-06-16 16:15:56 -07:00
|
|
|
fromdense = partial(sparse.bcoo_fromdense, nse=nse, n_dense=n_dense)
|
|
|
|
todense = partial(sparse.bcoo_todense, shape=shape[n_batch:])
|
2021-06-08 15:46:35 -07:00
|
|
|
for i in range(n_batch):
|
|
|
|
fromdense = jax.vmap(fromdense)
|
|
|
|
todense = jax.vmap(todense)
|
|
|
|
|
|
|
|
data, indices = fromdense(M)
|
|
|
|
|
|
|
|
assert data.dtype == dtype
|
|
|
|
assert data.shape == shape[:n_batch] + (nse,) + shape[n_batch + n_sparse:]
|
|
|
|
assert indices.dtype == jnp.int32 # TODO: test passing this arg
|
|
|
|
assert indices.shape == shape[:n_batch] + (n_sparse, nse)
|
|
|
|
|
|
|
|
self.assertArraysEqual(M, todense(data, indices))
|
|
|
|
self.assertArraysEqual(M, jit(todense)(data, indices))
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
|
|
def test_bcoo_extract(self, shape, dtype, n_batch, n_dense):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, indices = sparse.bcoo_fromdense(M)
|
|
|
|
data2 = sparse.bcoo_extract(indices, M)
|
2021-06-08 15:46:35 -07:00
|
|
|
self.assertArraysEqual(data, data2)
|
2021-06-16 16:15:56 -07:00
|
|
|
data3 = jit(sparse.bcoo_extract)(indices, M)
|
2021-06-08 15:46:35 -07:00
|
|
|
self.assertArraysEqual(data, data3)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
for dtype in jtu.dtypes.floating
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
|
|
def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
2021-06-08 15:46:35 -07:00
|
|
|
|
2021-06-16 16:15:56 -07:00
|
|
|
extract = partial(sparse.bcoo_extract, indices)
|
2021-06-08 15:46:35 -07:00
|
|
|
j1 = jax.jacfwd(extract)(M)
|
|
|
|
j2 = jax.jacrev(extract)(M)
|
|
|
|
hess = jax.hessian(extract)(M)
|
|
|
|
self.assertArraysAllClose(j1, j2)
|
|
|
|
self.assertEqual(j1.shape, data.shape + M.shape)
|
|
|
|
self.assertEqual(hess.shape, data.shape + 2 * M.shape)
|
|
|
|
|
2021-06-23 10:04:27 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
for dtype in jtu.dtypes.floating
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
|
|
def test_bcoo_transpose(self, shape, dtype, n_batch, n_dense):
|
|
|
|
n_sparse = len(shape) - n_batch - n_dense
|
|
|
|
rng = self.rng()
|
|
|
|
sprng = rand_sparse(rng)
|
|
|
|
M = sprng(shape, dtype)
|
|
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
|
|
|
|
permutation = np.concatenate([
|
|
|
|
rng.permutation(range(n_batch)),
|
|
|
|
rng.permutation(range(n_batch, n_batch + n_sparse)),
|
|
|
|
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)
|
|
|
|
|
|
|
|
M_T = M.transpose(permutation)
|
|
|
|
trans = partial(sparse.bcoo_transpose, shape=shape, permutation=permutation)
|
|
|
|
self.assertArraysEqual(M_T, sparse.bcoo_todense(*trans(data, indices), shape=M_T.shape))
|
|
|
|
self.assertArraysEqual(M_T, sparse.bcoo_todense(*jit(trans)(data, indices), shape=M_T.shape))
|
|
|
|
|
|
|
|
# test batched
|
|
|
|
def trans(M):
|
|
|
|
return M.transpose([p - n_batch for p in permutation[n_batch:]])
|
|
|
|
for _ in range(n_batch):
|
|
|
|
trans = jax.vmap(trans)
|
|
|
|
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
self.assertArraysEqual(trans(M), trans(Msp).todense())
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
for dtype in jtu.dtypes.floating
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
|
|
def test_bcoo_transpose_ad(self, shape, dtype, n_batch, n_dense):
|
|
|
|
n_sparse = len(shape) - n_batch - n_dense
|
|
|
|
rng = self.rng()
|
|
|
|
sprng = rand_sparse(self.rng())
|
|
|
|
|
|
|
|
M = sprng(shape, dtype)
|
|
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
|
|
|
|
permutation = np.concatenate([
|
|
|
|
rng.permutation(range(n_batch)),
|
|
|
|
rng.permutation(range(n_batch, n_batch + n_sparse)),
|
|
|
|
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)
|
|
|
|
|
|
|
|
def f_sparse(data):
|
|
|
|
return sparse.bcoo_transpose(data, indices, shape=shape, permutation=permutation)[0]
|
|
|
|
|
|
|
|
jf_sparse = jax.jacfwd(f_sparse)(data)
|
|
|
|
jr_sparse = jax.jacrev(f_sparse)(data)
|
|
|
|
|
|
|
|
tol = {}
|
|
|
|
if jtu.device_under_test() == "tpu":
|
|
|
|
tol = {np.float32: 5E-3}
|
|
|
|
|
|
|
|
# TODO(jakevdp) also test against dense version?
|
|
|
|
self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
|
|
|
|
|
2021-06-08 15:46:35 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
for n_batch in range(1, len(shape) + 1)
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
|
|
def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
2021-06-08 15:46:35 -07:00
|
|
|
|
2021-06-16 16:15:56 -07:00
|
|
|
M1 = sparse.bcoo_todense(data, indices[:1], shape=M.shape)
|
|
|
|
M2 = sparse.bcoo_todense(data, jnp.stack(shape[0] * [indices[0]]), shape=M.shape)
|
2021-06-09 10:25:28 -07:00
|
|
|
self.assertAllClose(M1, M2)
|
2021-06-08 15:46:35 -07:00
|
|
|
|
2021-06-16 16:15:56 -07:00
|
|
|
M3 = sparse.bcoo_todense(data[:1], indices, shape=M.shape)
|
|
|
|
M4 = sparse.bcoo_todense(jnp.stack(shape[0] * [data[0]]), indices, shape=M.shape)
|
2021-06-09 10:25:28 -07:00
|
|
|
self.assertAllClose(M3, M4)
|
2021-06-08 15:46:35 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_n_dense={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
lhs_contracting, rhs_contracting, n_dense),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting,
|
|
|
|
"n_dense": n_dense}
|
|
|
|
for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
|
|
|
|
[(5,), (6,), [], []],
|
|
|
|
[(5,), (5,), [0], [0]],
|
|
|
|
[(5, 7), (5,), [0], [0]],
|
|
|
|
[(7, 5), (5,), [1], [0]],
|
|
|
|
[(3, 5), (2, 5), [1], [1]],
|
|
|
|
[(5, 3), (5, 2), [0], [0]],
|
|
|
|
[(5, 3, 2), (5, 2, 4), [0], [0]],
|
|
|
|
[(5, 3, 2), (5, 2, 4), [0,2], [0,1]],
|
|
|
|
[(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]],
|
|
|
|
[(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
|
|
|
|
[(3, 2), (2, 4), [1], [0]],
|
|
|
|
]
|
|
|
|
for n_dense in range(len(lhs_shape) - max(lhs_contracting, default=0))
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
|
|
def test_bcoo_dot_general_contract_only(self, lhs_shape, rhs_shape, dtype,
|
2021-06-16 08:54:30 -07:00
|
|
|
lhs_contracting, rhs_contracting, n_dense):
|
2021-06-08 15:46:35 -07:00
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
rng_sparse = rand_sparse(self.rng())
|
|
|
|
def args_maker():
|
|
|
|
lhs = rng_sparse(lhs_shape, dtype)
|
|
|
|
rhs = rng(rhs_shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, indices = sparse.bcoo_fromdense(lhs, n_dense=n_dense)
|
2021-06-08 15:46:35 -07:00
|
|
|
return data, indices, lhs, rhs
|
|
|
|
dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
|
|
|
|
|
|
|
|
def f_dense(data, indices, lhs, rhs):
|
|
|
|
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
|
|
|
|
|
|
|
def f_sparse(data, indices, lhs, rhs):
|
2021-06-16 16:15:56 -07:00
|
|
|
return sparse.bcoo_dot_general(data, indices, rhs,
|
2021-06-08 15:46:35 -07:00
|
|
|
lhs_shape=lhs.shape,
|
|
|
|
dimension_numbers=dimension_numbers)
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
|
|
|
|
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
|
2021-06-16 08:54:30 -07:00
|
|
|
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
|
2021-06-08 15:46:35 -07:00
|
|
|
# self._CompileAndCheck(f_sparse, args_maker)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
dimension_numbers, n_batch, n_dense),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"dimension_numbers": dimension_numbers,
|
|
|
|
"n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0),
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0),
|
|
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0),
|
|
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0),
|
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0),
|
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1),
|
|
|
|
]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
|
|
def test_bcoo_dot_general_contract_and_batch(self, lhs_shape, rhs_shape, dtype,
|
|
|
|
dimension_numbers, n_batch, n_dense):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
rng_sparse = rand_sparse(self.rng())
|
|
|
|
def args_maker():
|
|
|
|
lhs = rng_sparse(lhs_shape, dtype)
|
|
|
|
rhs = rng(rhs_shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, indices = sparse.bcoo_fromdense(lhs, n_batch=n_batch, n_dense=n_dense)
|
2021-06-08 15:46:35 -07:00
|
|
|
return data, indices, lhs, rhs
|
|
|
|
|
|
|
|
def f_dense(data, indices, lhs, rhs):
|
|
|
|
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
|
|
|
|
|
|
|
def f_sparse(data, indices, lhs, rhs):
|
2021-06-16 16:15:56 -07:00
|
|
|
return sparse.bcoo_dot_general(data, indices, rhs,
|
2021-06-08 15:46:35 -07:00
|
|
|
lhs_shape=lhs.shape,
|
|
|
|
dimension_numbers=dimension_numbers)
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
|
|
|
|
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
|
2021-06-16 08:54:30 -07:00
|
|
|
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
|
|
|
|
# self._CompileAndCheck(f_sparse, args_maker)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
dimension_numbers, n_batch, n_dense),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"dimension_numbers": dimension_numbers,
|
|
|
|
"n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
|
|
|
|
((3, 2, 4), (3, 3, 2), (([1], [2]), ([0], [0])), 1, 0),
|
|
|
|
((3, 2, 4), (3, 3, 2), (([1], [2]), ([0], [0])), 2, 0),
|
|
|
|
((2, 3, 4), (3, 3, 2), (([0], [2]), ([1], [0])), 1, 0),
|
|
|
|
((2, 3, 4), (3, 3, 2), (([0], [2]), ([1], [0])), 2, 0),
|
|
|
|
((3, 4, 3, 2), (3, 4, 2, 4), (([3], [2]), ([0], [0])), 1, 0),
|
|
|
|
((3, 4, 3, 2), (3, 4, 2, 4), (([3], [2]), ([0, 1], [0, 1])), 2, 0),
|
|
|
|
((3, 4, 3, 2), (3, 4, 2, 4), (([3], [2]), ([0, 1], [0, 1])), 2, 1),
|
|
|
|
]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
|
|
def test_bcoo_rdot_general_contract_and_batch(self, lhs_shape, rhs_shape, dtype,
|
|
|
|
dimension_numbers, n_batch, n_dense):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
rng_sparse = rand_sparse(self.rng())
|
|
|
|
def args_maker():
|
|
|
|
lhs = rng(lhs_shape, dtype)
|
|
|
|
rhs = rng_sparse(rhs_shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, indices = sparse.bcoo_fromdense(rhs, n_batch=n_batch, n_dense=n_dense)
|
2021-06-16 08:54:30 -07:00
|
|
|
return data, indices, lhs, rhs
|
|
|
|
|
|
|
|
def f_dense(data, indices, lhs, rhs):
|
|
|
|
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
|
|
|
|
|
|
|
def f_sparse(data, indices, lhs, rhs):
|
2021-06-16 16:15:56 -07:00
|
|
|
return sparse.bcoo_rdot_general(lhs, data, indices,
|
2021-06-16 08:54:30 -07:00
|
|
|
rhs_shape=rhs.shape,
|
|
|
|
dimension_numbers=dimension_numbers)
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
|
|
|
|
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
|
|
|
|
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
|
2021-06-08 15:46:35 -07:00
|
|
|
# self._CompileAndCheck(f_sparse, args_maker)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
dimension_numbers, n_batch, n_dense),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"dimension_numbers": dimension_numbers,
|
|
|
|
"n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0),
|
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0),
|
|
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0),
|
|
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0),
|
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0),
|
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1),
|
|
|
|
]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
|
|
|
def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype,
|
|
|
|
dimension_numbers, n_batch, n_dense):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
rng_sparse = rand_sparse(self.rng())
|
|
|
|
|
|
|
|
X = rng_sparse(lhs_shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense)
|
2021-06-08 15:46:35 -07:00
|
|
|
Y = rng(rhs_shape, dtype)
|
|
|
|
|
|
|
|
def f_dense(X, Y):
|
|
|
|
return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)
|
|
|
|
|
|
|
|
def f_sparse(data, indices, Y):
|
2021-06-16 16:15:56 -07:00
|
|
|
return sparse.bcoo_dot_general(data, indices, Y, lhs_shape=X.shape,
|
2021-06-08 15:46:35 -07:00
|
|
|
dimension_numbers=dimension_numbers)
|
|
|
|
|
|
|
|
for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]):
|
2021-06-16 16:15:56 -07:00
|
|
|
X = sparse.bcoo_todense(data, indices, shape=X.shape)
|
2021-06-08 15:46:35 -07:00
|
|
|
self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y))
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
|
|
|
|
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
|
|
|
dimension_numbers, n_batch, n_dense),
|
|
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
|
|
|
"dimension_numbers": dimension_numbers,
|
|
|
|
"n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
|
2021-06-18 16:05:35 -07:00
|
|
|
((4, 5), (5, 3), (([1], [0]), ([], [])), 0, 0),
|
|
|
|
((2, 4, 5), (2, 5, 3), (([2], [1]), ([0], [0])), 1, 0),
|
2021-06-08 15:46:35 -07:00
|
|
|
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 1, 0),
|
|
|
|
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 1, 0),
|
|
|
|
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 0),
|
|
|
|
# These require contraction over batch & dense dimensions
|
|
|
|
# which is not yet implemented:
|
|
|
|
# ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0])), 2, 0),
|
|
|
|
# ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1])), 2, 0),
|
|
|
|
# ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1])), 2, 1),
|
|
|
|
]
|
|
|
|
for dtype in jtu.dtypes.floating))
|
|
|
|
def test_bcoo_dot_general_ad(self, lhs_shape, rhs_shape, dtype,
|
|
|
|
dimension_numbers, n_batch, n_dense):
|
|
|
|
rng = jtu.rand_small(self.rng())
|
|
|
|
rng_sparse = rand_sparse(self.rng())
|
|
|
|
|
|
|
|
X = rng_sparse(lhs_shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, indices = sparse.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense)
|
2021-06-08 15:46:35 -07:00
|
|
|
Y = rng(rhs_shape, dtype)
|
|
|
|
|
2021-06-18 16:05:35 -07:00
|
|
|
# gradient with respect to rhs
|
2021-06-08 15:46:35 -07:00
|
|
|
def f_dense(Y):
|
|
|
|
return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)
|
|
|
|
|
|
|
|
def f_sparse(Y):
|
2021-06-16 16:15:56 -07:00
|
|
|
return sparse.bcoo_dot_general(data, indices, Y, lhs_shape=X.shape,
|
2021-06-18 16:05:35 -07:00
|
|
|
dimension_numbers=dimension_numbers)
|
2021-06-08 15:46:35 -07:00
|
|
|
|
|
|
|
jf_dense = jax.jacfwd(f_dense)(Y)
|
|
|
|
jr_dense = jax.jacrev(f_dense)(Y)
|
|
|
|
jf_sparse = jax.jacfwd(f_sparse)(Y)
|
|
|
|
jr_sparse = jax.jacrev(f_sparse)(Y)
|
|
|
|
|
2021-06-09 10:25:28 -07:00
|
|
|
tol = {}
|
|
|
|
if jtu.device_under_test() == "tpu":
|
|
|
|
tol = {np.float32: 5E-3}
|
|
|
|
|
|
|
|
self.assertAllClose(jf_dense, jf_sparse, rtol=tol)
|
|
|
|
self.assertAllClose(jr_dense, jr_sparse, rtol=tol)
|
|
|
|
self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
|
2021-06-08 15:46:35 -07:00
|
|
|
|
2021-06-18 16:05:35 -07:00
|
|
|
# gradient with respect to lhs
|
|
|
|
def g_dense(X):
|
|
|
|
return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)
|
|
|
|
|
|
|
|
def g_sparse(data):
|
|
|
|
return sparse.bcoo_dot_general(data, indices, Y, lhs_shape=X.shape,
|
|
|
|
dimension_numbers=dimension_numbers)
|
|
|
|
|
|
|
|
jf_dense = jax.jacfwd(g_dense)(X)
|
|
|
|
jr_dense = jax.jacrev(g_dense)(X)
|
|
|
|
jf_sparse = jax.jacfwd(g_sparse)(data)
|
|
|
|
jr_sparse = jax.jacrev(g_sparse)(data)
|
|
|
|
|
|
|
|
tol = {}
|
|
|
|
if jtu.device_under_test() == "tpu":
|
|
|
|
tol = {np.float32: 5E-3}
|
|
|
|
|
|
|
|
self.assertAllClose(jf_dense, jr_dense, rtol=tol)
|
|
|
|
self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
|
|
|
|
|
|
|
|
# Extract the sparse jacobian from the dense & compare.
|
|
|
|
def extract(X):
|
|
|
|
return sparse.bcoo_extract(indices, X)
|
|
|
|
for i in range(g_dense(X).ndim):
|
|
|
|
extract = jax.vmap(extract)
|
|
|
|
self.assertAllClose(extract(jf_dense), jf_sparse, rtol=tol)
|
|
|
|
|
2021-06-08 15:46:35 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
|
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)))
|
|
|
|
def test_bcoo_dedupe(self, shape, dtype, n_batch, n_dense):
|
|
|
|
rng = self.rng()
|
|
|
|
rng_sparse = rand_sparse(self.rng())
|
|
|
|
M = rng_sparse(shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
2021-06-08 15:46:35 -07:00
|
|
|
for i, s in enumerate(shape[n_batch:len(shape) - n_dense]):
|
|
|
|
indices = indices.at[..., i, :].set(rng.randint(0, s, size=indices.shape[-1]))
|
2021-06-16 16:15:56 -07:00
|
|
|
data2, indices2 = _dedupe_bcoo(data, indices)
|
|
|
|
M1 = sparse.bcoo_todense(data, indices, shape=shape)
|
|
|
|
M2 = sparse.bcoo_todense(data2, indices2, shape=shape)
|
2021-06-08 15:46:35 -07:00
|
|
|
|
|
|
|
self.assertAllClose(M1, M2)
|
|
|
|
|
2021-06-11 13:19:54 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_nbatch={}_ndense={}_axes={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, axes),
|
|
|
|
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "axes": axes}
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)
|
|
|
|
for naxes in range(len(shape))
|
|
|
|
for axes in itertools.combinations(range(len(shape)), naxes)))
|
|
|
|
def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
2021-06-16 16:15:56 -07:00
|
|
|
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
data_out, indices_out, shape_out = sparse.bcoo_reduce_sum(data, indices, shape=shape, axes=axes)
|
2021-06-11 13:19:54 -07:00
|
|
|
result_dense = M.sum(axes)
|
2021-06-16 16:15:56 -07:00
|
|
|
result_sparse = sparse.bcoo_todense(data_out, indices_out, shape=shape_out)
|
2021-06-11 13:19:54 -07:00
|
|
|
tol = {np.float32: 1E-6, np.float64: 1E-14}
|
|
|
|
self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)
|
|
|
|
|
2021-06-14 15:45:47 -07:00
|
|
|
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_{}".format(
|
|
|
|
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
|
|
|
|
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
|
|
|
|
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
|
|
|
|
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
|
|
|
|
}
|
|
|
|
for lhs_shape, rhs_shape in [[(3,), (3,)],
|
|
|
|
[(3, 4), (4,)],
|
|
|
|
[(4,), (4, 5)],
|
|
|
|
[(3, 4), (4, 5)]]
|
|
|
|
for lhs_dtype in all_dtypes
|
|
|
|
for rhs_dtype in all_dtypes))
|
|
|
|
def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
lhs = jnp.array(rng(lhs_shape, lhs_dtype))
|
|
|
|
rhs = jnp.array(rng(rhs_shape, rhs_dtype))
|
|
|
|
|
|
|
|
out1 = lhs @ rhs
|
2021-06-16 16:15:56 -07:00
|
|
|
out2 = sparse.BCOO.fromdense(lhs) @ rhs
|
|
|
|
out3 = lhs @ sparse.BCOO.fromdense(rhs)
|
2021-06-14 15:45:47 -07:00
|
|
|
|
|
|
|
tol = {np.float64: 1E-13, np.complex128: 1E-13,
|
|
|
|
np.float32: 1E-6, np.complex64: 1E-6}
|
|
|
|
self.assertAllClose(out1, out2, rtol=tol)
|
|
|
|
self.assertAllClose(out1, out3, rtol=tol)
|
|
|
|
|
2021-06-22 13:33:52 -07:00
|
|
|
def test_bcoo_vmap_shape(self, shape=(2, 3, 4, 5), dtype=np.float32):
|
|
|
|
# This test checks that BCOO shape metadata interacts correctly with vmap.
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
|
|
|
|
|
|
|
def make_bcoo(M):
|
|
|
|
return sparse.BCOO.fromdense(M, nnz=np.prod(M.shape[:-1], dtype=int), n_dense=1)
|
|
|
|
|
|
|
|
for _ in range(3):
|
|
|
|
make_bcoo = vmap(make_bcoo)
|
|
|
|
Msp = make_bcoo(M)
|
|
|
|
self.assertEqual(Msp.shape, M.shape)
|
|
|
|
self.assertArraysEqual(Msp.todense(), M)
|
2021-04-15 10:10:40 -07:00
|
|
|
|
2021-06-18 08:18:56 -07:00
|
|
|
class SparseGradTest(jtu.JaxTestCase):
|
|
|
|
def test_sparse_grad(self):
|
|
|
|
rng_sparse = rand_sparse(self.rng())
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
|
|
|
|
y = rng(5, "float32")
|
|
|
|
X = rng_sparse((10, 5), "float32")
|
|
|
|
Xsp = sparse.BCOO.fromdense(X)
|
|
|
|
|
|
|
|
def f(X, y):
|
|
|
|
return jnp.sum(X @ y)
|
|
|
|
|
|
|
|
grad_dense = api.grad(f, argnums=0)(X, y)
|
|
|
|
grad_sparse = sparse.grad(f, argnums=0)(Xsp, y)
|
|
|
|
|
|
|
|
# extract sparse gradient from dense gradient
|
|
|
|
indices = tuple(Xsp.indices)
|
|
|
|
grad_sparse_from_dense = jnp.zeros_like(grad_dense).at[indices].set(grad_dense[indices])
|
|
|
|
|
|
|
|
self.assertArraysEqual(grad_sparse.todense(), grad_sparse_from_dense)
|
|
|
|
|
|
|
|
|
2021-04-26 17:40:17 -07:00
|
|
|
class SparseObjectTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
|
2021-06-16 16:15:56 -07:00
|
|
|
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])
|
2021-04-26 17:40:17 -07:00
|
|
|
def test_attrs(self, Obj, shape=(5, 8), dtype=np.float16):
|
|
|
|
rng = rand_sparse(self.rng(), post=Obj.fromdense)
|
|
|
|
M = rng(shape, dtype)
|
|
|
|
|
|
|
|
assert isinstance(M, Obj)
|
|
|
|
assert M.shape == shape
|
|
|
|
assert M.dtype == dtype
|
|
|
|
assert M.nnz == (M.todense() != 0).sum()
|
|
|
|
assert M.data.dtype == dtype
|
|
|
|
|
2021-06-16 16:15:56 -07:00
|
|
|
if isinstance(M, sparse.CSR):
|
2021-04-26 17:40:17 -07:00
|
|
|
assert len(M.data) == len(M.indices)
|
|
|
|
assert len(M.indptr) == M.shape[0] + 1
|
2021-06-16 16:15:56 -07:00
|
|
|
elif isinstance(M, sparse.CSC):
|
2021-04-26 17:40:17 -07:00
|
|
|
assert len(M.data) == len(M.indices)
|
|
|
|
assert len(M.indptr) == M.shape[1] + 1
|
2021-06-16 16:15:56 -07:00
|
|
|
elif isinstance(M, sparse.COO):
|
2021-04-26 17:40:17 -07:00
|
|
|
assert len(M.data) == len(M.row) == len(M.col)
|
2021-06-16 16:15:56 -07:00
|
|
|
elif isinstance(M, sparse.BCOO):
|
2021-06-08 15:46:35 -07:00
|
|
|
assert M.data.shape[M.n_batch] == M.indices.shape[-1]
|
|
|
|
assert M.indices.shape[-2] == M.n_sparse
|
2021-04-26 17:40:17 -07:00
|
|
|
else:
|
|
|
|
raise ValueError("Obj={Obj} not expected.")
|
|
|
|
|
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_Obj={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), Obj.__name__),
|
|
|
|
"shape": shape, "dtype": dtype, "Obj": Obj}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
|
2021-06-16 16:15:56 -07:00
|
|
|
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]))
|
2021-04-26 17:40:17 -07:00
|
|
|
def test_dense_round_trip(self, shape, dtype, Obj):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
|
|
|
Msparse = Obj.fromdense(M)
|
|
|
|
self.assertArraysEqual(M, Msparse.todense())
|
|
|
|
|
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_Obj={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), Obj.__name__),
|
|
|
|
"shape": shape, "dtype": dtype, "Obj": Obj}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
|
2021-06-16 16:15:56 -07:00
|
|
|
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]))
|
2021-04-26 17:40:17 -07:00
|
|
|
def test_transpose(self, shape, dtype, Obj):
|
|
|
|
rng = rand_sparse(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
|
|
|
Msparse = Obj.fromdense(M)
|
|
|
|
self.assertArraysEqual(M.T, Msparse.T.todense())
|
|
|
|
|
|
|
|
@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
|
|
|
|
@parameterized.named_parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_{}_Obj={}_bshape={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), Obj.__name__, bshape),
|
|
|
|
"shape": shape, "dtype": dtype, "Obj": Obj, "bshape": bshape}
|
|
|
|
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
|
|
|
for bshape in [shape[-1:] + s for s in [(), (3,), (4,)]]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.complex)
|
2021-06-16 16:15:56 -07:00
|
|
|
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]))
|
2021-04-26 17:40:17 -07:00
|
|
|
def test_matmul(self, shape, dtype, Obj, bshape):
|
|
|
|
rng = rand_sparse(self.rng(), post=jnp.array)
|
|
|
|
rng_b = jtu.rand_default(self.rng())
|
|
|
|
M = rng(shape, dtype)
|
|
|
|
Msp = Obj.fromdense(M)
|
|
|
|
x = rng_b(bshape, dtype)
|
|
|
|
x = jnp.asarray(x)
|
|
|
|
|
|
|
|
self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)
|
|
|
|
|
|
|
|
|
2021-04-15 10:10:40 -07:00
|
|
|
if __name__ == "__main__":
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|