2022-09-22 12:26:48 -07:00
|
|
|
|
# Copyright 2021 The JAX Authors.
|
2021-06-25 10:45:16 -07:00
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
2021-09-13 17:24:44 -04:00
|
|
|
|
from functools import partial
|
2023-04-13 11:48:11 -07:00
|
|
|
|
import math
|
2021-06-25 10:45:16 -07:00
|
|
|
|
import operator
|
|
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
2022-06-14 12:45:26 -07:00
|
|
|
|
import jax
|
2024-04-11 13:23:27 -07:00
|
|
|
|
from jax import jit, lax
|
2021-06-25 10:45:16 -07:00
|
|
|
|
import jax.numpy as jnp
|
2021-09-24 07:02:08 -07:00
|
|
|
|
import jax._src.test_util as jtu
|
2023-07-26 13:00:24 -07:00
|
|
|
|
from jax.experimental.sparse import BCOO, BCSR, sparsify, todense, SparseTracer
|
2021-06-25 10:45:16 -07:00
|
|
|
|
from jax.experimental.sparse.transform import (
|
2022-03-07 12:48:03 -08:00
|
|
|
|
arrays_to_spvalues, spvalues_to_arrays, sparsify_raw, SparsifyValue, SparsifyEnv)
|
2022-05-25 21:06:52 -07:00
|
|
|
|
from jax.experimental.sparse.util import CuSparseEfficiencyWarning
|
2023-02-01 16:16:14 -08:00
|
|
|
|
from jax.experimental.sparse import test_util as sptu
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
2024-04-11 13:23:27 -07:00
|
|
|
|
jax.config.parse_flags_with_absl()
|
2021-06-28 10:43:07 -07:00
|
|
|
|
|
2022-06-03 14:50:44 -07:00
|
|
|
|
def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
|
|
|
|
|
def _rand_sparse(shape, dtype, nse=nse):
|
|
|
|
|
rand = rand_method(rng)
|
2023-04-13 11:48:11 -07:00
|
|
|
|
size = math.prod(shape)
|
2022-06-03 14:50:44 -07:00
|
|
|
|
if 0 <= nse < 1:
|
|
|
|
|
nse = nse * size
|
|
|
|
|
nse = min(size, int(nse))
|
|
|
|
|
M = rand(shape, dtype)
|
|
|
|
|
indices = rng.choice(size, size - nse, replace=False)
|
|
|
|
|
M.flat[indices] = 0
|
|
|
|
|
return post(M)
|
|
|
|
|
return _rand_sparse
|
|
|
|
|
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
class SparsifyTest(jtu.JaxTestCase):
|
2021-11-04 13:02:46 -07:00
|
|
|
|
@classmethod
|
|
|
|
|
def sparsify(cls, f):
|
|
|
|
|
return sparsify(f, use_tracer=False)
|
|
|
|
|
|
2022-04-25 16:22:19 -07:00
|
|
|
|
def testNotImplementedMessages(self):
|
|
|
|
|
x = BCOO.fromdense(jnp.arange(5.0))
|
|
|
|
|
# Test a densifying primitive
|
|
|
|
|
with self.assertRaisesRegex(NotImplementedError,
|
|
|
|
|
r"^sparse rule for cos is not implemented because it would result in dense output\."):
|
|
|
|
|
self.sparsify(lax.cos)(x)
|
|
|
|
|
|
|
|
|
|
# Test a generic not implemented primitive.
|
|
|
|
|
with self.assertRaisesRegex(NotImplementedError,
|
|
|
|
|
r"^sparse rule for complex is not implemented\.$"):
|
|
|
|
|
self.sparsify(lax.complex)(x, x)
|
|
|
|
|
|
2021-11-04 13:02:46 -07:00
|
|
|
|
def testTracerIsInstanceCheck(self):
|
|
|
|
|
@self.sparsify
|
|
|
|
|
def f(x):
|
|
|
|
|
self.assertNotIsInstance(x, SparseTracer)
|
|
|
|
|
f(jnp.arange(5))
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
def assertBcooIdentical(self, x, y):
|
|
|
|
|
self.assertIsInstance(x, BCOO)
|
|
|
|
|
self.assertIsInstance(y, BCOO)
|
|
|
|
|
self.assertEqual(x.shape, y.shape)
|
|
|
|
|
self.assertArraysEqual(x.data, y.data)
|
|
|
|
|
self.assertArraysEqual(x.indices, y.indices)
|
|
|
|
|
|
2022-03-07 12:48:03 -08:00
|
|
|
|
def testSparsifyValue(self):
|
2021-06-25 10:45:16 -07:00
|
|
|
|
X = jnp.arange(5)
|
|
|
|
|
X_BCOO = BCOO.fromdense(X)
|
|
|
|
|
|
|
|
|
|
args = (X, X_BCOO, X_BCOO)
|
|
|
|
|
|
|
|
|
|
# Independent index
|
2022-03-07 12:48:03 -08:00
|
|
|
|
spenv = SparsifyEnv()
|
|
|
|
|
spvalues = arrays_to_spvalues(spenv, args)
|
|
|
|
|
self.assertEqual(len(spvalues), len(args))
|
|
|
|
|
self.assertLen(spenv._buffers, 5)
|
|
|
|
|
self.assertEqual(spvalues,
|
2022-06-03 14:50:44 -07:00
|
|
|
|
(SparsifyValue(X.shape, 0, None, indices_sorted=False,
|
|
|
|
|
unique_indices=False),
|
|
|
|
|
SparsifyValue(X.shape, 1, 2, indices_sorted=True,
|
|
|
|
|
unique_indices=True),
|
|
|
|
|
SparsifyValue(X.shape, 3, 4, indices_sorted=True,
|
|
|
|
|
unique_indices=True)))
|
2022-03-07 12:48:03 -08:00
|
|
|
|
|
|
|
|
|
args_out = spvalues_to_arrays(spenv, spvalues)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
self.assertEqual(len(args_out), len(args))
|
|
|
|
|
self.assertArraysEqual(args[0], args_out[0])
|
|
|
|
|
self.assertBcooIdentical(args[1], args_out[1])
|
|
|
|
|
self.assertBcooIdentical(args[2], args_out[2])
|
|
|
|
|
|
|
|
|
|
# Shared index
|
2022-03-07 12:48:03 -08:00
|
|
|
|
spvalues = (SparsifyValue(X.shape, 0, None), SparsifyValue(X.shape, 1, 2), SparsifyValue(X.shape, 3, 2))
|
|
|
|
|
spenv = SparsifyEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data])
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
2022-03-07 12:48:03 -08:00
|
|
|
|
args_out = spvalues_to_arrays(spenv, spvalues)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
self.assertEqual(len(args_out), len(args))
|
|
|
|
|
self.assertArraysEqual(args[0], args_out[0])
|
|
|
|
|
self.assertBcooIdentical(args[1], args_out[1])
|
|
|
|
|
self.assertBcooIdentical(args[2], args_out[2])
|
|
|
|
|
|
2021-09-07 17:59:38 -07:00
|
|
|
|
def testDropvar(self):
|
|
|
|
|
def inner(x):
|
|
|
|
|
return x * 2, x * 3
|
|
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
|
_, y = jit(inner)(x)
|
|
|
|
|
return y * 4
|
|
|
|
|
|
|
|
|
|
x_dense = jnp.arange(5)
|
|
|
|
|
x_sparse = BCOO.fromdense(x_dense)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
self.assertArraysEqual(self.sparsify(f)(x_sparse).todense(), f(x_dense))
|
2021-09-07 17:59:38 -07:00
|
|
|
|
|
2021-08-10 10:31:16 -07:00
|
|
|
|
def testPytreeInput(self):
|
2021-11-04 13:02:46 -07:00
|
|
|
|
f = self.sparsify(lambda x: x)
|
2021-08-10 10:31:16 -07:00
|
|
|
|
args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4)))
|
|
|
|
|
out = f(args)
|
|
|
|
|
self.assertLen(out, 2)
|
|
|
|
|
self.assertArraysEqual(args[0], out[0])
|
|
|
|
|
self.assertBcooIdentical(args[1], out[1])
|
|
|
|
|
|
2022-06-14 12:45:26 -07:00
|
|
|
|
@jax.numpy_dtype_promotion('standard') # explicitly exercises implicit dtype promotion.
|
2021-06-25 10:45:16 -07:00
|
|
|
|
def testSparsify(self):
|
|
|
|
|
M_dense = jnp.arange(24).reshape(4, 6)
|
|
|
|
|
M_sparse = BCOO.fromdense(M_dense)
|
|
|
|
|
v = jnp.arange(M_dense.shape[0])
|
|
|
|
|
|
2021-11-04 13:02:46 -07:00
|
|
|
|
@self.sparsify
|
2021-06-25 10:45:16 -07:00
|
|
|
|
def func(x, v):
|
|
|
|
|
return -jnp.sin(jnp.pi * x).T @ (v + 1)
|
|
|
|
|
|
2022-05-25 21:06:52 -07:00
|
|
|
|
with jtu.ignore_warning(
|
|
|
|
|
category=CuSparseEfficiencyWarning,
|
|
|
|
|
message="bcoo_dot_general GPU lowering requires matrices with sorted indices*"):
|
|
|
|
|
result_sparse = func(M_sparse, v)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
result_dense = func(M_dense, v)
|
|
|
|
|
self.assertAllClose(result_sparse, result_dense)
|
|
|
|
|
|
2021-09-07 21:08:03 -07:00
|
|
|
|
def testSparsifyWithConsts(self):
|
|
|
|
|
M_dense = jnp.arange(24).reshape(4, 6)
|
|
|
|
|
M_sparse = BCOO.fromdense(M_dense)
|
|
|
|
|
|
2021-11-04 13:02:46 -07:00
|
|
|
|
@self.sparsify
|
2021-09-07 21:08:03 -07:00
|
|
|
|
def func(x):
|
|
|
|
|
return jit(lambda x: jnp.sum(x, 1))(x)
|
|
|
|
|
|
|
|
|
|
result_dense = func(M_dense)
|
|
|
|
|
result_sparse = func(M_sparse)
|
|
|
|
|
|
|
|
|
|
self.assertAllClose(result_sparse.todense(), result_dense)
|
|
|
|
|
|
2023-07-14 18:23:34 -07:00
|
|
|
|
@jax.numpy_dtype_promotion('standard')
|
2021-07-09 06:00:05 -07:00
|
|
|
|
def testSparseMatmul(self):
|
2023-07-14 18:23:34 -07:00
|
|
|
|
X = jnp.arange(16.0, dtype='float32').reshape(4, 4)
|
2021-07-09 06:00:05 -07:00
|
|
|
|
Xsp = BCOO.fromdense(X)
|
2023-07-14 18:23:34 -07:00
|
|
|
|
Y = jnp.ones(4, dtype='int32')
|
2021-10-05 16:45:48 -07:00
|
|
|
|
Ysp = BCOO.fromdense(Y)
|
2021-07-09 06:00:05 -07:00
|
|
|
|
|
2023-07-14 18:23:34 -07:00
|
|
|
|
# Note: deliberately testing with mixed precision
|
|
|
|
|
assert Xsp.dtype != Ysp.dtype
|
2022-04-26 13:58:27 -07:00
|
|
|
|
|
2021-07-09 06:00:05 -07:00
|
|
|
|
# dot_general
|
2023-07-14 18:23:34 -07:00
|
|
|
|
result_sparse = self.sparsify(lax.dot)(Xsp, Y)
|
|
|
|
|
result_dense = lax.dot(X, Y)
|
2021-07-09 06:00:05 -07:00
|
|
|
|
self.assertAllClose(result_sparse, result_dense)
|
|
|
|
|
|
|
|
|
|
# rdot_general
|
2023-07-14 18:23:34 -07:00
|
|
|
|
result_sparse = self.sparsify(lax.dot)(Y, Xsp)
|
|
|
|
|
result_dense = lax.dot(Y, X)
|
2021-07-09 06:00:05 -07:00
|
|
|
|
self.assertAllClose(result_sparse, result_dense)
|
|
|
|
|
|
2023-07-14 18:23:34 -07:00
|
|
|
|
# spdot_general
|
|
|
|
|
result_sparse = self.sparsify(lax.dot)(Xsp, Ysp)
|
|
|
|
|
result_dense = lax.dot(X, Y)
|
2021-10-05 16:45:48 -07:00
|
|
|
|
self.assertAllClose(result_sparse.todense(), result_dense)
|
|
|
|
|
|
2021-06-25 10:45:16 -07:00
|
|
|
|
def testSparseAdd(self):
|
|
|
|
|
x = BCOO.fromdense(jnp.arange(5))
|
|
|
|
|
y = BCOO.fromdense(2 * jnp.arange(5))
|
|
|
|
|
|
|
|
|
|
# Distinct indices
|
2021-11-04 13:02:46 -07:00
|
|
|
|
out = self.sparsify(operator.add)(x, y)
|
2021-06-30 17:46:02 -07:00
|
|
|
|
self.assertEqual(out.nse, 8) # uses concatenation.
|
2021-06-25 10:45:16 -07:00
|
|
|
|
self.assertArraysEqual(out.todense(), 3 * jnp.arange(5))
|
|
|
|
|
|
|
|
|
|
# Shared indices – requires lower level call
|
2022-03-07 12:48:03 -08:00
|
|
|
|
spenv = SparsifyEnv([x.indices, x.data, y.data])
|
|
|
|
|
spvalues = [
|
|
|
|
|
spenv.sparse(x.shape, data_ref=1, indices_ref=0),
|
|
|
|
|
spenv.sparse(y.shape, data_ref=2, indices_ref=0)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
]
|
|
|
|
|
|
2022-03-07 12:48:03 -08:00
|
|
|
|
result = sparsify_raw(operator.add)(spenv, *spvalues)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
args_out, _ = result
|
2022-03-07 12:48:03 -08:00
|
|
|
|
out, = spvalues_to_arrays(spenv, args_out)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
self.assertAllClose(out.todense(), x.todense() + y.todense())
|
|
|
|
|
|
2023-04-05 10:39:43 -07:00
|
|
|
|
# Sparse + dense: supported
|
|
|
|
|
x = BCOO.fromdense(jnp.arange(6.)).reshape(2, 3)
|
|
|
|
|
y = jnp.ones((2, 3))
|
|
|
|
|
|
|
|
|
|
out = self.sparsify(operator.add)(x, y)
|
|
|
|
|
self.assertAllClose(out, x.todense() + y)
|
|
|
|
|
|
|
|
|
|
out = self.sparsify(operator.add)(y, x)
|
|
|
|
|
self.assertAllClose(out, x.todense() + y)
|
|
|
|
|
|
|
|
|
|
# Sparse + dense: unsupported
|
|
|
|
|
msg = "Addition between a sparse array X and a dense array Y is not implemented"
|
|
|
|
|
with self.assertRaisesRegex(NotImplementedError, msg):
|
|
|
|
|
self.sparsify(operator.add)(x, 1.)
|
|
|
|
|
|
2022-10-12 13:51:11 +00:00
|
|
|
|
@jtu.sample_product(
|
|
|
|
|
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
|
2022-06-03 14:50:44 -07:00
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)
|
2022-10-12 13:51:11 +00:00
|
|
|
|
],
|
|
|
|
|
dtype=jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex,
|
|
|
|
|
unique_indices=[True, False],
|
|
|
|
|
)
|
2022-06-03 14:50:44 -07:00
|
|
|
|
def testSparseMul(self, shape, dtype, n_batch, n_dense, unique_indices):
|
|
|
|
|
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
|
|
|
|
|
x = BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch,
|
|
|
|
|
n_dense=n_dense)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
# Scalar multiplication
|
2022-06-03 14:50:44 -07:00
|
|
|
|
scalar = 2
|
|
|
|
|
y = self.sparsify(operator.mul)(x, scalar)
|
|
|
|
|
self.assertArraysEqual(x.todense() * scalar, y.todense())
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
# Shared indices – requires lower level call
|
2022-03-07 12:48:03 -08:00
|
|
|
|
spenv = SparsifyEnv([x.indices, x.data, y.data])
|
|
|
|
|
spvalues = [
|
2022-06-03 14:50:44 -07:00
|
|
|
|
spenv.sparse(x.shape, data_ref=1, indices_ref=0,
|
|
|
|
|
unique_indices=unique_indices),
|
|
|
|
|
spenv.sparse(y.shape, data_ref=2, indices_ref=0,
|
|
|
|
|
unique_indices=unique_indices)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
]
|
|
|
|
|
|
2022-03-07 12:48:03 -08:00
|
|
|
|
result = sparsify_raw(operator.mul)(spenv, *spvalues)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
args_out, _ = result
|
2022-03-07 12:48:03 -08:00
|
|
|
|
out, = spvalues_to_arrays(spenv, args_out)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
self.assertAllClose(out.todense(), x.todense() * y.todense())
|
|
|
|
|
|
2023-03-14 11:58:43 -07:00
|
|
|
|
@jtu.sample_product(
|
|
|
|
|
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
|
|
|
|
|
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
|
for n_dense in range(len(shape) + 1 - n_batch)
|
|
|
|
|
],
|
|
|
|
|
dtype=jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex,
|
|
|
|
|
)
|
|
|
|
|
def testSparseDiv(self, shape, dtype, n_batch, n_dense):
|
|
|
|
|
rng_dense = jtu.rand_nonzero(self.rng())
|
|
|
|
|
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
|
|
|
|
|
x = BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch,
|
|
|
|
|
n_dense=n_dense)
|
|
|
|
|
spdiv = self.sparsify(operator.truediv)
|
|
|
|
|
|
|
|
|
|
# Scalar division
|
|
|
|
|
divisor = 2
|
|
|
|
|
expected = x.todense() / divisor
|
|
|
|
|
self.assertAllClose(expected, spdiv(x, divisor).todense())
|
|
|
|
|
self.assertAllClose(expected, (x / divisor).todense())
|
|
|
|
|
|
|
|
|
|
# Array division
|
|
|
|
|
divisor = rng_dense(shape, dtype)
|
|
|
|
|
expected = x.todense() / divisor
|
|
|
|
|
self.assertAllClose(expected, spdiv(x, divisor).todense())
|
|
|
|
|
self.assertAllClose(expected, (x / divisor).todense())
|
|
|
|
|
|
2021-12-13 12:16:17 -08:00
|
|
|
|
def testSparseSubtract(self):
|
|
|
|
|
x = BCOO.fromdense(3 * jnp.arange(5))
|
|
|
|
|
y = BCOO.fromdense(jnp.arange(5))
|
|
|
|
|
|
|
|
|
|
# Distinct indices
|
|
|
|
|
out = self.sparsify(operator.sub)(x, y)
|
|
|
|
|
self.assertEqual(out.nse, 8) # uses concatenation.
|
|
|
|
|
self.assertArraysEqual(out.todense(), 2 * jnp.arange(5))
|
|
|
|
|
|
|
|
|
|
# Shared indices – requires lower level call
|
2022-03-07 12:48:03 -08:00
|
|
|
|
spenv = SparsifyEnv([x.indices, x.data, y.data])
|
|
|
|
|
spvalues = [
|
|
|
|
|
spenv.sparse(x.shape, data_ref=1, indices_ref=0),
|
|
|
|
|
spenv.sparse(y.shape, data_ref=2, indices_ref=0)
|
2021-12-13 12:16:17 -08:00
|
|
|
|
]
|
|
|
|
|
|
2022-03-07 12:48:03 -08:00
|
|
|
|
result = sparsify_raw(operator.sub)(spenv, *spvalues)
|
2021-12-13 12:16:17 -08:00
|
|
|
|
args_out, _ = result
|
2022-03-07 12:48:03 -08:00
|
|
|
|
out, = spvalues_to_arrays(spenv, args_out)
|
2021-12-13 12:16:17 -08:00
|
|
|
|
|
|
|
|
|
self.assertAllClose(out.todense(), x.todense() - y.todense())
|
|
|
|
|
|
2023-03-08 09:24:30 -08:00
|
|
|
|
def testSparsePow(self):
|
|
|
|
|
x = jnp.arange(20.0).reshape(4, 5)
|
|
|
|
|
xsp = BCOO.fromdense(x)
|
|
|
|
|
|
|
|
|
|
result_dense = x ** 2
|
|
|
|
|
result_sparse = xsp ** 2
|
|
|
|
|
|
|
|
|
|
self.assertAllClose(result_dense, result_sparse.todense())
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(NotImplementedError,
|
|
|
|
|
"sparse rule for integer_pow with non-positive exponent"):
|
|
|
|
|
_ = xsp ** -1
|
|
|
|
|
|
2021-06-25 10:45:16 -07:00
|
|
|
|
def testSparseSum(self):
|
|
|
|
|
x = jnp.arange(20).reshape(4, 5)
|
|
|
|
|
xsp = BCOO.fromdense(x)
|
|
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
|
return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1))
|
|
|
|
|
|
|
|
|
|
result_dense = f(x)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
result_sparse = self.sparsify(f)(xsp)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
assert len(result_dense) == len(result_sparse)
|
|
|
|
|
|
|
|
|
|
for res_dense, res_sparse in zip(result_dense, result_sparse):
|
|
|
|
|
if isinstance(res_sparse, BCOO):
|
|
|
|
|
res_sparse = res_sparse.todense()
|
|
|
|
|
self.assertArraysAllClose(res_dense, res_sparse)
|
|
|
|
|
|
2022-11-16 09:58:06 -08:00
|
|
|
|
def testSparseSqueeze(self):
|
|
|
|
|
# Note: more comprehensive tests in sparse_test.py:test_bcoo_squeeze
|
2021-06-25 10:45:16 -07:00
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
|
|
2022-11-16 09:58:06 -08:00
|
|
|
|
M_dense = rng((2, 3, 1, 4), np.float32)
|
|
|
|
|
M_sparse = BCOO.fromdense(M_dense)
|
|
|
|
|
func = self.sparsify(partial(lax.squeeze, dimensions=(2,)))
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
result_dense = func(M_dense)
|
|
|
|
|
result_sparse = func(M_sparse).todense()
|
|
|
|
|
|
|
|
|
|
self.assertAllClose(result_sparse, result_dense)
|
|
|
|
|
|
2023-02-01 15:43:47 -08:00
|
|
|
|
def testSparseRev(self):
|
|
|
|
|
# Note: more comprehensive tests in sparse_test.py:test_bcoo_rev
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
|
|
|
|
|
|
M_dense = rng((2, 3, 4), np.float32)
|
|
|
|
|
M_sparse = BCOO.fromdense(M_dense)
|
|
|
|
|
func = self.sparsify(partial(lax.rev, dimensions=(1, 2)))
|
|
|
|
|
|
|
|
|
|
result_dense = func(M_dense)
|
|
|
|
|
result_sparse = func(M_sparse).todense()
|
|
|
|
|
|
|
|
|
|
self.assertAllClose(result_sparse, result_dense)
|
|
|
|
|
|
2022-10-12 13:51:11 +00:00
|
|
|
|
@jtu.sample_product(
|
|
|
|
|
[dict(shapes=shapes, func=func, n_batch=n_batch)
|
2022-04-28 10:59:11 -07:00
|
|
|
|
for shapes, func, n_batch in [
|
|
|
|
|
([(4,), (4,)], "concatenate", 0),
|
|
|
|
|
([(4,), (4,)], "stack", 0),
|
|
|
|
|
([(4,), (4,)], "hstack", 0),
|
|
|
|
|
([(4,), (4,)], "vstack", 0),
|
|
|
|
|
([(4,), (4,)], "concatenate", 1),
|
|
|
|
|
([(4,), (4,)], "stack", 1),
|
|
|
|
|
([(4,), (4,)], "hstack", 1),
|
|
|
|
|
([(4,), (4,)], "vstack", 1),
|
|
|
|
|
([(2, 4), (2, 4)], "stack", 0),
|
|
|
|
|
([(2, 4), (3, 4)], "vstack", 0),
|
|
|
|
|
([(2, 4), (2, 5)], "hstack", 0),
|
|
|
|
|
([(2, 4), (3, 4)], "vstack", 1),
|
|
|
|
|
([(2, 4), (2, 5)], "hstack", 1),
|
|
|
|
|
([(2, 4), (3, 4)], "vstack", 2),
|
|
|
|
|
([(2, 4), (2, 5)], "hstack", 2),
|
|
|
|
|
([(2, 4), (4,), (3, 4)], "vstack", 0),
|
|
|
|
|
([(1, 4), (4,), (1, 4)], "vstack", 0),
|
2022-10-12 13:51:11 +00:00
|
|
|
|
]
|
2023-02-15 12:38:57 -08:00
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
def testSparseConcatenateBCOO(self, shapes, func, n_batch):
|
|
|
|
|
f = self.sparsify(getattr(jnp, func))
|
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
|
arrs = [rng(shape, 'int32') for shape in shapes]
|
|
|
|
|
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
|
|
|
|
|
self.assertArraysEqual(f(arrs), f(sparrs).todense())
|
|
|
|
|
|
|
|
|
|
@jtu.sample_product(
|
|
|
|
|
[dict(shapes=shapes, func=func, n_batch=n_batch)
|
|
|
|
|
for shapes, func, n_batch in [
|
|
|
|
|
([(2, 4), (2, 4)], "stack", 0),
|
|
|
|
|
([(2, 4), (3, 4)], "vstack", 0),
|
|
|
|
|
([(2, 4), (2, 5)], "hstack", 0),
|
|
|
|
|
([(2, 4), (3, 4)], "vstack", 1),
|
|
|
|
|
([(2, 4), (2, 5)], "hstack", 1),
|
|
|
|
|
([(2, 4), (3, 4)], "vstack", 2),
|
|
|
|
|
([(2, 4), (2, 5)], "hstack", 2),
|
|
|
|
|
]
|
|
|
|
|
]
|
2022-10-12 13:51:11 +00:00
|
|
|
|
)
|
2023-02-15 12:38:57 -08:00
|
|
|
|
def testSparseConcatenateBCSR(self, shapes, func, n_batch):
|
2022-04-28 10:59:11 -07:00
|
|
|
|
f = self.sparsify(getattr(jnp, func))
|
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
|
arrs = [rng(shape, 'int32') for shape in shapes]
|
|
|
|
|
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
|
|
|
|
|
self.assertArraysEqual(f(arrs), f(sparrs).todense())
|
|
|
|
|
|
2023-02-03 15:59:42 -08:00
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
|
|
|
def testSparseConvolve(self, lhs_shape=(10,), rhs_shape=(5,),
|
|
|
|
|
dtype='float32', mode='full'):
|
|
|
|
|
# Note: more comprehensive tests in sparse_test.py:test_bcoo_conv_general_dilated
|
|
|
|
|
dense_fun = partial(jnp.convolve, mode=mode)
|
|
|
|
|
sparse_fun = self.sparsify(dense_fun)
|
|
|
|
|
|
|
|
|
|
sprng = sptu.rand_bcoo(self.rng())
|
2023-02-01 16:16:14 -08:00
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
|
|
2023-02-03 15:59:42 -08:00
|
|
|
|
lhs = sprng(lhs_shape, dtype)
|
2023-02-01 16:16:14 -08:00
|
|
|
|
rhs = rng(rhs_shape, dtype)
|
|
|
|
|
|
2023-02-03 15:59:42 -08:00
|
|
|
|
expected = dense_fun(lhs.todense(), rhs)
|
|
|
|
|
actual = sparse_fun(lhs, rhs).todense()
|
|
|
|
|
|
2023-02-01 16:16:14 -08:00
|
|
|
|
tol = {np.float32: 1E-5, np.complex64: 1E-5, np.float64: 1E-14, np.complex128: 1E-14}
|
2023-02-03 15:59:42 -08:00
|
|
|
|
self.assertAllClose(expected, actual, atol=tol, rtol=tol)
|
2023-02-01 16:16:14 -08:00
|
|
|
|
|
2022-11-14 13:16:57 -08:00
|
|
|
|
def testSparseReshapeMethod(self):
|
|
|
|
|
# Note: this is more fully tested in sparse_test.py:test_bcoo_reshape
|
|
|
|
|
shape = (2, 3, 4)
|
|
|
|
|
new_shape = (2, 6, 2)
|
|
|
|
|
|
2022-05-02 09:11:55 -07:00
|
|
|
|
rng = jtu.rand_some_zero(self.rng())
|
|
|
|
|
arr = rng(shape, 'int32')
|
2022-11-14 13:16:57 -08:00
|
|
|
|
arr_sparse = BCOO.fromdense(arr, n_batch=1)
|
2022-05-02 09:11:55 -07:00
|
|
|
|
|
|
|
|
|
arr2 = arr.reshape(new_shape)
|
|
|
|
|
arr2_sparse = arr_sparse.reshape(new_shape)
|
2022-11-08 17:00:25 -08:00
|
|
|
|
arr2_sparse_jit = jax.jit(lambda x: x.reshape(new_shape))(arr_sparse)
|
2022-05-02 09:11:55 -07:00
|
|
|
|
|
|
|
|
|
self.assertArraysEqual(arr2, arr2_sparse.todense())
|
2022-11-08 17:00:25 -08:00
|
|
|
|
self.assertArraysEqual(arr2, arr2_sparse_jit.todense())
|
2022-05-02 09:11:55 -07:00
|
|
|
|
|
2021-07-12 16:53:24 -07:00
|
|
|
|
def testSparseWhileLoop(self):
|
|
|
|
|
def cond_fun(params):
|
|
|
|
|
i, A = params
|
|
|
|
|
return i < 5
|
|
|
|
|
|
|
|
|
|
def body_fun(params):
|
|
|
|
|
i, A = params
|
|
|
|
|
return i + 1, 2 * A
|
|
|
|
|
|
|
|
|
|
def f(A):
|
|
|
|
|
return lax.while_loop(cond_fun, body_fun, (0, A))
|
|
|
|
|
|
|
|
|
|
A = jnp.arange(4)
|
|
|
|
|
out_dense = f(A)
|
|
|
|
|
|
|
|
|
|
Asp = BCOO.fromdense(A)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
out_sparse = self.sparsify(f)(Asp)
|
2021-07-12 16:53:24 -07:00
|
|
|
|
|
|
|
|
|
self.assertEqual(len(out_dense), 2)
|
|
|
|
|
self.assertEqual(len(out_sparse), 2)
|
|
|
|
|
self.assertArraysEqual(out_dense[0], out_dense[0])
|
|
|
|
|
self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
|
|
|
|
|
|
|
|
|
|
def testSparseWhileLoopDuplicateIndices(self):
|
|
|
|
|
def cond_fun(params):
|
|
|
|
|
i, A, B = params
|
|
|
|
|
return i < 5
|
|
|
|
|
|
|
|
|
|
def body_fun(params):
|
|
|
|
|
i, A, B = params
|
|
|
|
|
# TODO(jakevdp): track shared indices through while loop & use this
|
|
|
|
|
# version of the test, which requires shared indices in order for
|
|
|
|
|
# the nse of the result to remain the same.
|
|
|
|
|
# return i + 1, A, A + B
|
|
|
|
|
|
|
|
|
|
# This version is fine without shared indices, and tests that we're
|
|
|
|
|
# flattening non-shared indices consistently.
|
|
|
|
|
return i + 1, B, A
|
|
|
|
|
|
|
|
|
|
def f(A):
|
|
|
|
|
return lax.while_loop(cond_fun, body_fun, (0, A, A))
|
|
|
|
|
|
|
|
|
|
A = jnp.arange(4).reshape((2, 2))
|
|
|
|
|
out_dense = f(A)
|
|
|
|
|
|
|
|
|
|
Asp = BCOO.fromdense(A)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
out_sparse = self.sparsify(f)(Asp)
|
2021-07-12 16:53:24 -07:00
|
|
|
|
|
|
|
|
|
self.assertEqual(len(out_dense), 3)
|
|
|
|
|
self.assertEqual(len(out_sparse), 3)
|
|
|
|
|
self.assertArraysEqual(out_dense[0], out_dense[0])
|
|
|
|
|
self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
|
|
|
|
|
self.assertArraysEqual(out_dense[2], out_sparse[2].todense())
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
2021-07-13 15:23:14 -07:00
|
|
|
|
def testSparsifyDenseXlaCall(self):
|
|
|
|
|
# Test handling of dense xla_call within jaxpr interpreter.
|
2021-11-04 13:02:46 -07:00
|
|
|
|
out = self.sparsify(jit(lambda x: x + 1))(0.0)
|
2021-07-13 13:31:21 -07:00
|
|
|
|
self.assertEqual(out, 1.0)
|
|
|
|
|
|
2021-07-13 15:23:14 -07:00
|
|
|
|
def testSparsifySparseXlaCall(self):
|
|
|
|
|
# Test sparse lowering of XLA call
|
|
|
|
|
def func(M):
|
|
|
|
|
return 2 * M
|
|
|
|
|
|
|
|
|
|
M = jnp.arange(6).reshape(2, 3)
|
|
|
|
|
Msp = BCOO.fromdense(M)
|
|
|
|
|
|
|
|
|
|
out_dense = func(M)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
out_sparse = self.sparsify(jit(func))(Msp)
|
2021-07-13 15:23:14 -07:00
|
|
|
|
self.assertArraysEqual(out_dense, out_sparse.todense())
|
|
|
|
|
|
2021-08-05 15:19:43 -07:00
|
|
|
|
def testSparseForiLoop(self):
|
|
|
|
|
def func(M, x):
|
|
|
|
|
body_fun = lambda i, val: (M @ val) / M.shape[1]
|
|
|
|
|
return lax.fori_loop(0, 2, body_fun, x)
|
|
|
|
|
|
|
|
|
|
x = jnp.arange(5.0)
|
|
|
|
|
M = jnp.arange(25).reshape(5, 5)
|
|
|
|
|
M_bcoo = BCOO.fromdense(M)
|
|
|
|
|
|
2022-06-14 12:45:26 -07:00
|
|
|
|
with jax.numpy_dtype_promotion('standard'):
|
|
|
|
|
result_dense = func(M, x)
|
|
|
|
|
result_sparse = self.sparsify(func)(M_bcoo, x)
|
2021-08-05 15:19:43 -07:00
|
|
|
|
|
|
|
|
|
self.assertArraysAllClose(result_dense, result_sparse)
|
|
|
|
|
|
2021-08-06 13:32:07 -07:00
|
|
|
|
def testSparseCondSimple(self):
|
|
|
|
|
def func(x):
|
|
|
|
|
return lax.cond(False, lambda x: x, lambda x: 2 * x, x)
|
|
|
|
|
|
|
|
|
|
x = jnp.arange(5.0)
|
|
|
|
|
result_dense = func(x)
|
|
|
|
|
|
|
|
|
|
x_bcoo = BCOO.fromdense(x)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
result_sparse = self.sparsify(func)(x_bcoo)
|
2021-08-06 13:32:07 -07:00
|
|
|
|
|
|
|
|
|
self.assertArraysAllClose(result_dense, result_sparse.todense())
|
|
|
|
|
|
|
|
|
|
def testSparseCondMismatchError(self):
|
2021-11-04 13:02:46 -07:00
|
|
|
|
@self.sparsify
|
2021-08-06 13:32:07 -07:00
|
|
|
|
def func(x, y):
|
|
|
|
|
return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y))
|
|
|
|
|
|
|
|
|
|
x = jnp.arange(5.0)
|
|
|
|
|
y = jnp.arange(5.0)
|
|
|
|
|
|
|
|
|
|
x_bcoo = BCOO.fromdense(x)
|
|
|
|
|
y_bcoo = BCOO.fromdense(y)
|
|
|
|
|
|
|
|
|
|
func(x, y) # No error
|
|
|
|
|
func(x_bcoo, y_bcoo) # No error
|
|
|
|
|
|
[better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140
Previously, the following code:
```
def f(i, x):
return lax.switch(i, [lambda x: dict(a=x),
lambda x: dict(a=(x, x))], x)
f(0, 42)
```
resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```
With this change the error message is more specific where the
difference is in the pytree structure:
```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
* at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-09 14:15:39 +02:00
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
|
TypeError,
|
|
|
|
|
"sparsified true_fun output must have same type structure as sparsified false_fun output.*"):
|
2021-08-06 13:32:07 -07:00
|
|
|
|
func(x_bcoo, y)
|
|
|
|
|
|
2023-07-26 13:00:24 -07:00
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
|
{"testcase_name": f"_{fmt}", "fmt": fmt}
|
|
|
|
|
for fmt in ["BCSR", "BCOO"]
|
|
|
|
|
)
|
|
|
|
|
def testToDense(self, fmt):
|
|
|
|
|
M = jnp.arange(4).reshape(2, 2)
|
|
|
|
|
if fmt == "BCOO":
|
|
|
|
|
Msp = BCOO.fromdense(M)
|
|
|
|
|
elif fmt == "BCSR":
|
|
|
|
|
Msp = BCSR.fromdense(M)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unrecognized {fmt=}")
|
|
|
|
|
|
2021-11-04 13:02:46 -07:00
|
|
|
|
@self.sparsify
|
2021-10-26 13:52:48 -07:00
|
|
|
|
def func(M):
|
|
|
|
|
return todense(M) + 1
|
2023-07-26 13:00:24 -07:00
|
|
|
|
|
2021-10-26 13:52:48 -07:00
|
|
|
|
self.assertArraysEqual(func(M), M + 1)
|
|
|
|
|
self.assertArraysEqual(func(Msp), M + 1)
|
|
|
|
|
self.assertArraysEqual(jit(func)(M), M + 1)
|
|
|
|
|
self.assertArraysEqual(jit(func)(Msp), M + 1)
|
|
|
|
|
|
2022-09-01 13:42:02 -07:00
|
|
|
|
def testSparseSlice(self):
|
|
|
|
|
M = jnp.arange(24).reshape(2, 3, 4)
|
|
|
|
|
Msp = BCOO.fromdense(M)
|
|
|
|
|
@self.sparsify
|
|
|
|
|
def func(M):
|
|
|
|
|
return lax.slice(M, (0, 1, 2), (1, 3, 3))
|
|
|
|
|
expected = M[:1, 1:3, 2:3]
|
|
|
|
|
self.assertArraysEqual(func(M), expected)
|
|
|
|
|
self.assertArraysEqual(func(Msp).todense(), expected)
|
|
|
|
|
self.assertArraysEqual(jit(func)(M), expected)
|
|
|
|
|
self.assertArraysEqual(jit(func)(Msp).todense(), expected)
|
|
|
|
|
|
|
|
|
|
def testSparseDynamicSlice(self):
|
|
|
|
|
M = jnp.arange(24).reshape(2, 3, 4)
|
|
|
|
|
Msp = BCOO.fromdense(M)
|
|
|
|
|
@self.sparsify
|
|
|
|
|
def func(M):
|
|
|
|
|
return lax.dynamic_slice(M, (0, 1, 2), (1, 1, 3))
|
|
|
|
|
expected = M[:1, 1:2, 1:4]
|
|
|
|
|
self.assertArraysEqual(func(M), expected)
|
|
|
|
|
self.assertArraysEqual(func(Msp).todense(), expected)
|
|
|
|
|
self.assertArraysEqual(jit(func)(M), expected)
|
|
|
|
|
self.assertArraysEqual(jit(func)(Msp).todense(), expected)
|
|
|
|
|
|
2021-10-18 13:35:20 -07:00
|
|
|
|
def testWeakTypes(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# Regression test for https://github.com/jax-ml/jax/issues/8267
|
2021-10-18 13:35:20 -07:00
|
|
|
|
M = jnp.arange(12, dtype='int32').reshape(3, 4)
|
|
|
|
|
Msp = BCOO.fromdense(M)
|
|
|
|
|
self.assertArraysEqual(
|
|
|
|
|
operator.mul(2, M),
|
2021-11-04 13:02:46 -07:00
|
|
|
|
self.sparsify(operator.mul)(2, Msp).todense(),
|
2021-10-18 13:35:20 -07:00
|
|
|
|
check_dtypes=True,
|
|
|
|
|
)
|
|
|
|
|
|
2022-07-07 13:49:50 -07:00
|
|
|
|
@parameterized.named_parameters(
|
2023-02-02 16:23:01 -08:00
|
|
|
|
{"testcase_name": f"_{op.__name__}_{fmt}", "op": op, "dtype": dtype,
|
|
|
|
|
"kwds": kwds, "fmt": fmt}
|
|
|
|
|
for fmt in ["BCSR", "BCOO"]
|
2022-07-07 13:49:50 -07:00
|
|
|
|
for op, dtype, kwds in [
|
2022-11-17 16:31:46 -08:00
|
|
|
|
(jnp.copy, jnp.float32, {}),
|
2023-11-28 15:46:47 -08:00
|
|
|
|
(lax.conj, jnp.complex64, {}),
|
2022-07-07 13:49:50 -07:00
|
|
|
|
(lax.abs, jnp.float32, {}),
|
|
|
|
|
(lax.asin, jnp.float32, {}),
|
|
|
|
|
(lax.asinh, jnp.float32, {}),
|
|
|
|
|
(lax.atan, jnp.float32, {}),
|
|
|
|
|
(lax.atanh, jnp.float32, {}),
|
|
|
|
|
(lax.bessel_i1e, jnp.float32, {}),
|
|
|
|
|
(lax.expm1, jnp.float32, {}),
|
|
|
|
|
(lax.log1p, jnp.float32, {}),
|
|
|
|
|
(lax.neg, jnp.float32, {}),
|
|
|
|
|
(lax.real, jnp.complex64, {}),
|
|
|
|
|
(lax.imag, jnp.complex64, {}),
|
|
|
|
|
(lax.sign, jnp.float32, {}),
|
|
|
|
|
(lax.sin, jnp.float32, {}),
|
|
|
|
|
(lax.sinh, jnp.float32, {}),
|
|
|
|
|
(lax.sqrt, jnp.float32, {}),
|
|
|
|
|
(lax.tan, jnp.float32, {}),
|
|
|
|
|
(lax.tanh, jnp.float32, {}),
|
2023-03-08 09:24:30 -08:00
|
|
|
|
(lax.convert_element_type, jnp.float32, {"new_dtype": np.dtype('complex64')}),
|
|
|
|
|
(lax.integer_pow, jnp.float32, {'y': 2})])
|
2023-02-02 16:23:01 -08:00
|
|
|
|
def testUnaryOperationsNonUniqueIndices(self, fmt, op, dtype, kwds):
|
|
|
|
|
shape = (4, 5)
|
2022-07-07 13:49:50 -07:00
|
|
|
|
|
|
|
|
|
# Note: we deliberately test non-unique indices here.
|
2023-02-02 16:23:01 -08:00
|
|
|
|
if fmt == "BCOO":
|
|
|
|
|
rng = sptu.rand_bcoo(self.rng())
|
|
|
|
|
elif fmt == "BCSR":
|
|
|
|
|
rng = sptu.rand_bcsr(self.rng())
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unrecognized {fmt=}")
|
|
|
|
|
mat = rng(shape, dtype)
|
2022-07-07 13:49:50 -07:00
|
|
|
|
|
2022-11-17 16:31:46 -08:00
|
|
|
|
sparse_result = self.sparsify(partial(op, **kwds))(mat)
|
2022-07-07 13:49:50 -07:00
|
|
|
|
dense_result = op(mat.todense(), **kwds)
|
|
|
|
|
|
2022-11-17 16:31:46 -08:00
|
|
|
|
self.assertArraysAllClose(sparse_result.todense(), dense_result)
|
|
|
|
|
|
|
|
|
|
# Ops that commute with addition should not deduplicate indices.
|
|
|
|
|
if op in [jnp.copy, lax.neg, lax.real, lax.imag]:
|
2023-02-02 16:23:01 -08:00
|
|
|
|
self.assertArraysAllClose(sparse_result.indices, mat.indices)
|
|
|
|
|
if fmt == "BCSR":
|
|
|
|
|
self.assertArraysAllClose(sparse_result.indptr, mat.indptr)
|
2022-07-07 13:49:50 -07:00
|
|
|
|
|
2023-06-21 02:36:07 -07:00
|
|
|
|
def testCustomJVP(self):
|
|
|
|
|
square = jax.custom_derivatives.custom_jvp(lambda x: x ** 2)
|
|
|
|
|
square.defjvp(lambda p, t: (p[0] ** 2, 2 * t[0] * p[0]))
|
|
|
|
|
x = BCOO.fromdense(jnp.arange(5.0))
|
|
|
|
|
|
|
|
|
|
# Test calling the function itself.
|
|
|
|
|
result = self.sparsify(square)(x)
|
|
|
|
|
expected = self.sparsify(lambda x: x ** 2)(x)
|
|
|
|
|
self.assertArraysEqual(result.indices, expected.indices)
|
|
|
|
|
self.assertArraysAllClose(result.data, expected.data)
|
|
|
|
|
|
|
|
|
|
# Test evaluating the custom gradient.
|
|
|
|
|
grad_square_sum = jax.grad(lambda x: square(x).sum())
|
|
|
|
|
result = self.sparsify(grad_square_sum)(x)
|
|
|
|
|
expected = self.sparsify(jax.grad(lambda x: jnp.sum(x ** 2)))(x)
|
|
|
|
|
self.assertArraysEqual(result.indices, expected.indices)
|
|
|
|
|
self.assertArraysAllClose(result.data, expected.data)
|
|
|
|
|
|
2022-07-07 13:49:50 -07:00
|
|
|
|
|
2021-11-04 13:02:46 -07:00
|
|
|
|
class SparsifyTracerTest(SparsifyTest):
|
|
|
|
|
@classmethod
|
|
|
|
|
def sparsify(cls, f):
|
|
|
|
|
return sparsify(f, use_tracer=True)
|
|
|
|
|
|
|
|
|
|
def testTracerIsInstanceCheck(self):
|
|
|
|
|
@self.sparsify
|
|
|
|
|
def f(x):
|
|
|
|
|
self.assertIsInstance(x, SparseTracer)
|
|
|
|
|
f(jnp.arange(5))
|
2021-08-05 15:19:43 -07:00
|
|
|
|
|
2021-06-25 10:45:16 -07:00
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|