2021-06-25 10:45:16 -07:00
|
|
|
|
# Copyright 2021 Google LLC
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
2021-09-13 17:24:44 -04:00
|
|
|
|
from functools import partial
|
2021-06-25 10:45:16 -07:00
|
|
|
|
import operator
|
|
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
2021-09-13 17:24:44 -04:00
|
|
|
|
from jax import config, core, jit, lax
|
2021-06-25 10:45:16 -07:00
|
|
|
|
import jax.numpy as jnp
|
2021-09-24 07:02:08 -07:00
|
|
|
|
import jax._src.test_util as jtu
|
2021-11-04 13:02:46 -07:00
|
|
|
|
from jax.experimental.sparse import BCOO, sparsify, todense, SparseTracer
|
2021-06-25 10:45:16 -07:00
|
|
|
|
from jax.experimental.sparse.transform import (
|
2022-03-07 12:48:03 -08:00
|
|
|
|
arrays_to_spvalues, spvalues_to_arrays, sparsify_raw, SparsifyValue, SparsifyEnv)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
2021-06-28 10:43:07 -07:00
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
class SparsifyTest(jtu.JaxTestCase):
|
2021-11-04 13:02:46 -07:00
|
|
|
|
@classmethod
|
|
|
|
|
def sparsify(cls, f):
|
|
|
|
|
return sparsify(f, use_tracer=False)
|
|
|
|
|
|
|
|
|
|
def testTracerIsInstanceCheck(self):
|
|
|
|
|
@self.sparsify
|
|
|
|
|
def f(x):
|
|
|
|
|
self.assertNotIsInstance(x, SparseTracer)
|
|
|
|
|
f(jnp.arange(5))
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
def assertBcooIdentical(self, x, y):
|
|
|
|
|
self.assertIsInstance(x, BCOO)
|
|
|
|
|
self.assertIsInstance(y, BCOO)
|
|
|
|
|
self.assertEqual(x.shape, y.shape)
|
|
|
|
|
self.assertArraysEqual(x.data, y.data)
|
|
|
|
|
self.assertArraysEqual(x.indices, y.indices)
|
|
|
|
|
|
2022-03-07 12:48:03 -08:00
|
|
|
|
def testSparsifyValue(self):
|
2021-06-25 10:45:16 -07:00
|
|
|
|
X = jnp.arange(5)
|
|
|
|
|
X_BCOO = BCOO.fromdense(X)
|
|
|
|
|
|
|
|
|
|
args = (X, X_BCOO, X_BCOO)
|
|
|
|
|
|
|
|
|
|
# Independent index
|
2022-03-07 12:48:03 -08:00
|
|
|
|
spenv = SparsifyEnv()
|
|
|
|
|
spvalues = arrays_to_spvalues(spenv, args)
|
|
|
|
|
self.assertEqual(len(spvalues), len(args))
|
|
|
|
|
self.assertLen(spenv._buffers, 5)
|
|
|
|
|
self.assertEqual(spvalues,
|
|
|
|
|
(SparsifyValue(X.shape, 0, None), SparsifyValue(X.shape, 1, 2), SparsifyValue(X.shape, 3, 4)))
|
|
|
|
|
|
|
|
|
|
args_out = spvalues_to_arrays(spenv, spvalues)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
self.assertEqual(len(args_out), len(args))
|
|
|
|
|
self.assertArraysEqual(args[0], args_out[0])
|
|
|
|
|
self.assertBcooIdentical(args[1], args_out[1])
|
|
|
|
|
self.assertBcooIdentical(args[2], args_out[2])
|
|
|
|
|
|
|
|
|
|
# Shared index
|
2022-03-07 12:48:03 -08:00
|
|
|
|
spvalues = (SparsifyValue(X.shape, 0, None), SparsifyValue(X.shape, 1, 2), SparsifyValue(X.shape, 3, 2))
|
|
|
|
|
spenv = SparsifyEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data])
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
2022-03-07 12:48:03 -08:00
|
|
|
|
args_out = spvalues_to_arrays(spenv, spvalues)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
self.assertEqual(len(args_out), len(args))
|
|
|
|
|
self.assertArraysEqual(args[0], args_out[0])
|
|
|
|
|
self.assertBcooIdentical(args[1], args_out[1])
|
|
|
|
|
self.assertBcooIdentical(args[2], args_out[2])
|
|
|
|
|
|
2021-08-09 09:15:08 -07:00
|
|
|
|
def testUnitHandling(self):
|
|
|
|
|
x = BCOO.fromdense(jnp.arange(5))
|
|
|
|
|
f = jit(lambda x, y: x)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
result = self.sparsify(jit(f))(x, core.unit)
|
2021-08-09 09:15:08 -07:00
|
|
|
|
self.assertBcooIdentical(result, x)
|
|
|
|
|
|
2021-09-07 17:59:38 -07:00
|
|
|
|
def testDropvar(self):
|
|
|
|
|
def inner(x):
|
|
|
|
|
return x * 2, x * 3
|
|
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
|
_, y = jit(inner)(x)
|
|
|
|
|
return y * 4
|
|
|
|
|
|
|
|
|
|
x_dense = jnp.arange(5)
|
|
|
|
|
x_sparse = BCOO.fromdense(x_dense)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
self.assertArraysEqual(self.sparsify(f)(x_sparse).todense(), f(x_dense))
|
2021-09-07 17:59:38 -07:00
|
|
|
|
|
2021-08-10 10:31:16 -07:00
|
|
|
|
def testPytreeInput(self):
|
2021-11-04 13:02:46 -07:00
|
|
|
|
f = self.sparsify(lambda x: x)
|
2021-08-10 10:31:16 -07:00
|
|
|
|
args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4)))
|
|
|
|
|
out = f(args)
|
|
|
|
|
self.assertLen(out, 2)
|
|
|
|
|
self.assertArraysEqual(args[0], out[0])
|
|
|
|
|
self.assertBcooIdentical(args[1], out[1])
|
|
|
|
|
|
2021-06-25 10:45:16 -07:00
|
|
|
|
def testSparsify(self):
|
|
|
|
|
M_dense = jnp.arange(24).reshape(4, 6)
|
|
|
|
|
M_sparse = BCOO.fromdense(M_dense)
|
|
|
|
|
v = jnp.arange(M_dense.shape[0])
|
|
|
|
|
|
2021-11-04 13:02:46 -07:00
|
|
|
|
@self.sparsify
|
2021-06-25 10:45:16 -07:00
|
|
|
|
def func(x, v):
|
|
|
|
|
return -jnp.sin(jnp.pi * x).T @ (v + 1)
|
|
|
|
|
|
|
|
|
|
result_dense = func(M_dense, v)
|
|
|
|
|
result_sparse = func(M_sparse, v)
|
|
|
|
|
|
|
|
|
|
self.assertAllClose(result_sparse, result_dense)
|
|
|
|
|
|
2021-09-07 21:08:03 -07:00
|
|
|
|
def testSparsifyWithConsts(self):
|
|
|
|
|
M_dense = jnp.arange(24).reshape(4, 6)
|
|
|
|
|
M_sparse = BCOO.fromdense(M_dense)
|
|
|
|
|
|
2021-11-04 13:02:46 -07:00
|
|
|
|
@self.sparsify
|
2021-09-07 21:08:03 -07:00
|
|
|
|
def func(x):
|
|
|
|
|
return jit(lambda x: jnp.sum(x, 1))(x)
|
|
|
|
|
|
|
|
|
|
result_dense = func(M_dense)
|
|
|
|
|
result_sparse = func(M_sparse)
|
|
|
|
|
|
|
|
|
|
self.assertAllClose(result_sparse.todense(), result_dense)
|
|
|
|
|
|
2021-07-09 06:00:05 -07:00
|
|
|
|
def testSparseMatmul(self):
|
|
|
|
|
X = jnp.arange(16).reshape(4, 4)
|
|
|
|
|
Xsp = BCOO.fromdense(X)
|
|
|
|
|
Y = jnp.ones(4)
|
2021-10-05 16:45:48 -07:00
|
|
|
|
Ysp = BCOO.fromdense(Y)
|
2021-07-09 06:00:05 -07:00
|
|
|
|
|
|
|
|
|
# dot_general
|
2021-11-04 13:02:46 -07:00
|
|
|
|
result_sparse = self.sparsify(operator.matmul)(Xsp, Y)
|
2021-07-09 06:00:05 -07:00
|
|
|
|
result_dense = operator.matmul(X, Y)
|
|
|
|
|
self.assertAllClose(result_sparse, result_dense)
|
|
|
|
|
|
|
|
|
|
# rdot_general
|
2021-11-04 13:02:46 -07:00
|
|
|
|
result_sparse = self.sparsify(operator.matmul)(Y, Xsp)
|
2021-07-09 06:00:05 -07:00
|
|
|
|
result_dense = operator.matmul(Y, X)
|
|
|
|
|
self.assertAllClose(result_sparse, result_dense)
|
|
|
|
|
|
2021-10-05 16:45:48 -07:00
|
|
|
|
# spdot_general
|
2021-11-04 13:02:46 -07:00
|
|
|
|
result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp)
|
2021-10-05 16:45:48 -07:00
|
|
|
|
result_dense = operator.matmul(X, Y)
|
|
|
|
|
self.assertAllClose(result_sparse.todense(), result_dense)
|
|
|
|
|
|
2021-06-25 10:45:16 -07:00
|
|
|
|
def testSparseAdd(self):
|
|
|
|
|
x = BCOO.fromdense(jnp.arange(5))
|
|
|
|
|
y = BCOO.fromdense(2 * jnp.arange(5))
|
|
|
|
|
|
|
|
|
|
# Distinct indices
|
2021-11-04 13:02:46 -07:00
|
|
|
|
out = self.sparsify(operator.add)(x, y)
|
2021-06-30 17:46:02 -07:00
|
|
|
|
self.assertEqual(out.nse, 8) # uses concatenation.
|
2021-06-25 10:45:16 -07:00
|
|
|
|
self.assertArraysEqual(out.todense(), 3 * jnp.arange(5))
|
|
|
|
|
|
|
|
|
|
# Shared indices – requires lower level call
|
2022-03-07 12:48:03 -08:00
|
|
|
|
spenv = SparsifyEnv([x.indices, x.data, y.data])
|
|
|
|
|
spvalues = [
|
|
|
|
|
spenv.sparse(x.shape, data_ref=1, indices_ref=0),
|
|
|
|
|
spenv.sparse(y.shape, data_ref=2, indices_ref=0)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
]
|
|
|
|
|
|
2022-03-07 12:48:03 -08:00
|
|
|
|
result = sparsify_raw(operator.add)(spenv, *spvalues)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
args_out, _ = result
|
2022-03-07 12:48:03 -08:00
|
|
|
|
out, = spvalues_to_arrays(spenv, args_out)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
self.assertAllClose(out.todense(), x.todense() + y.todense())
|
|
|
|
|
|
|
|
|
|
def testSparseMul(self):
|
|
|
|
|
x = BCOO.fromdense(jnp.arange(5))
|
|
|
|
|
y = BCOO.fromdense(2 * jnp.arange(5))
|
|
|
|
|
|
|
|
|
|
# Scalar multiplication
|
2021-11-04 13:02:46 -07:00
|
|
|
|
out = self.sparsify(operator.mul)(x, 2.5)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
self.assertArraysEqual(out.todense(), x.todense() * 2.5)
|
|
|
|
|
|
|
|
|
|
# Shared indices – requires lower level call
|
2022-03-07 12:48:03 -08:00
|
|
|
|
spenv = SparsifyEnv([x.indices, x.data, y.data])
|
|
|
|
|
spvalues = [
|
|
|
|
|
spenv.sparse(x.shape, data_ref=1, indices_ref=0),
|
|
|
|
|
spenv.sparse(y.shape, data_ref=2, indices_ref=0)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
]
|
|
|
|
|
|
2022-03-07 12:48:03 -08:00
|
|
|
|
result = sparsify_raw(operator.mul)(spenv, *spvalues)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
args_out, _ = result
|
2022-03-07 12:48:03 -08:00
|
|
|
|
out, = spvalues_to_arrays(spenv, args_out)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
self.assertAllClose(out.todense(), x.todense() * y.todense())
|
|
|
|
|
|
2021-12-13 12:16:17 -08:00
|
|
|
|
def testSparseSubtract(self):
|
|
|
|
|
x = BCOO.fromdense(3 * jnp.arange(5))
|
|
|
|
|
y = BCOO.fromdense(jnp.arange(5))
|
|
|
|
|
|
|
|
|
|
# Distinct indices
|
|
|
|
|
out = self.sparsify(operator.sub)(x, y)
|
|
|
|
|
self.assertEqual(out.nse, 8) # uses concatenation.
|
|
|
|
|
self.assertArraysEqual(out.todense(), 2 * jnp.arange(5))
|
|
|
|
|
|
|
|
|
|
# Shared indices – requires lower level call
|
2022-03-07 12:48:03 -08:00
|
|
|
|
spenv = SparsifyEnv([x.indices, x.data, y.data])
|
|
|
|
|
spvalues = [
|
|
|
|
|
spenv.sparse(x.shape, data_ref=1, indices_ref=0),
|
|
|
|
|
spenv.sparse(y.shape, data_ref=2, indices_ref=0)
|
2021-12-13 12:16:17 -08:00
|
|
|
|
]
|
|
|
|
|
|
2022-03-07 12:48:03 -08:00
|
|
|
|
result = sparsify_raw(operator.sub)(spenv, *spvalues)
|
2021-12-13 12:16:17 -08:00
|
|
|
|
args_out, _ = result
|
2022-03-07 12:48:03 -08:00
|
|
|
|
out, = spvalues_to_arrays(spenv, args_out)
|
2021-12-13 12:16:17 -08:00
|
|
|
|
|
|
|
|
|
self.assertAllClose(out.todense(), x.todense() - y.todense())
|
|
|
|
|
|
2021-06-25 10:45:16 -07:00
|
|
|
|
def testSparseSum(self):
|
|
|
|
|
x = jnp.arange(20).reshape(4, 5)
|
|
|
|
|
xsp = BCOO.fromdense(x)
|
|
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
|
return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1))
|
|
|
|
|
|
|
|
|
|
result_dense = f(x)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
result_sparse = self.sparsify(f)(xsp)
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
assert len(result_dense) == len(result_sparse)
|
|
|
|
|
|
|
|
|
|
for res_dense, res_sparse in zip(result_dense, result_sparse):
|
|
|
|
|
if isinstance(res_sparse, BCOO):
|
|
|
|
|
res_sparse = res_sparse.todense()
|
|
|
|
|
self.assertArraysAllClose(res_dense, res_sparse)
|
|
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2022-03-07 12:48:03 -08:00
|
|
|
|
{"testcase_name": "_shape={}_dimensions={}_nbatch={}_ndense={}".format(
|
2021-06-25 10:45:16 -07:00
|
|
|
|
jtu.format_shape_dtype_string(shape, np.float32), dimensions, n_batch, n_dense),
|
|
|
|
|
"shape": shape, "dimensions": dimensions, "n_batch": n_batch, "n_dense": n_dense}
|
|
|
|
|
for shape, dimensions in [
|
|
|
|
|
[(1,), (0,)],
|
|
|
|
|
[(1,), (-1,)],
|
|
|
|
|
[(2, 1, 4), (1,)],
|
|
|
|
|
[(2, 1, 3, 1), (1,)],
|
|
|
|
|
[(2, 1, 3, 1), (1, 3)],
|
|
|
|
|
[(2, 1, 3, 1), (3,)],
|
|
|
|
|
]
|
|
|
|
|
for n_batch in range(len(shape) + 1)
|
|
|
|
|
for n_dense in range(len(shape) - n_batch + 1)))
|
|
|
|
|
def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense):
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
|
|
|
|
|
|
M_dense = rng(shape, np.float32)
|
|
|
|
|
M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
func = self.sparsify(partial(lax.squeeze, dimensions=dimensions))
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
|
|
|
|
result_dense = func(M_dense)
|
|
|
|
|
result_sparse = func(M_sparse).todense()
|
|
|
|
|
|
|
|
|
|
self.assertAllClose(result_sparse, result_dense)
|
|
|
|
|
|
2021-07-12 16:53:24 -07:00
|
|
|
|
def testSparseWhileLoop(self):
|
|
|
|
|
def cond_fun(params):
|
|
|
|
|
i, A = params
|
|
|
|
|
return i < 5
|
|
|
|
|
|
|
|
|
|
def body_fun(params):
|
|
|
|
|
i, A = params
|
|
|
|
|
return i + 1, 2 * A
|
|
|
|
|
|
|
|
|
|
def f(A):
|
|
|
|
|
return lax.while_loop(cond_fun, body_fun, (0, A))
|
|
|
|
|
|
|
|
|
|
A = jnp.arange(4)
|
|
|
|
|
out_dense = f(A)
|
|
|
|
|
|
|
|
|
|
Asp = BCOO.fromdense(A)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
out_sparse = self.sparsify(f)(Asp)
|
2021-07-12 16:53:24 -07:00
|
|
|
|
|
|
|
|
|
self.assertEqual(len(out_dense), 2)
|
|
|
|
|
self.assertEqual(len(out_sparse), 2)
|
|
|
|
|
self.assertArraysEqual(out_dense[0], out_dense[0])
|
|
|
|
|
self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
|
|
|
|
|
|
|
|
|
|
def testSparseWhileLoopDuplicateIndices(self):
|
|
|
|
|
def cond_fun(params):
|
|
|
|
|
i, A, B = params
|
|
|
|
|
return i < 5
|
|
|
|
|
|
|
|
|
|
def body_fun(params):
|
|
|
|
|
i, A, B = params
|
|
|
|
|
# TODO(jakevdp): track shared indices through while loop & use this
|
|
|
|
|
# version of the test, which requires shared indices in order for
|
|
|
|
|
# the nse of the result to remain the same.
|
|
|
|
|
# return i + 1, A, A + B
|
|
|
|
|
|
|
|
|
|
# This version is fine without shared indices, and tests that we're
|
|
|
|
|
# flattening non-shared indices consistently.
|
|
|
|
|
return i + 1, B, A
|
|
|
|
|
|
|
|
|
|
def f(A):
|
|
|
|
|
return lax.while_loop(cond_fun, body_fun, (0, A, A))
|
|
|
|
|
|
|
|
|
|
A = jnp.arange(4).reshape((2, 2))
|
|
|
|
|
out_dense = f(A)
|
|
|
|
|
|
|
|
|
|
Asp = BCOO.fromdense(A)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
out_sparse = self.sparsify(f)(Asp)
|
2021-07-12 16:53:24 -07:00
|
|
|
|
|
|
|
|
|
self.assertEqual(len(out_dense), 3)
|
|
|
|
|
self.assertEqual(len(out_sparse), 3)
|
|
|
|
|
self.assertArraysEqual(out_dense[0], out_dense[0])
|
|
|
|
|
self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
|
|
|
|
|
self.assertArraysEqual(out_dense[2], out_sparse[2].todense())
|
2021-06-25 10:45:16 -07:00
|
|
|
|
|
2021-07-13 15:23:14 -07:00
|
|
|
|
def testSparsifyDenseXlaCall(self):
|
|
|
|
|
# Test handling of dense xla_call within jaxpr interpreter.
|
2021-11-04 13:02:46 -07:00
|
|
|
|
out = self.sparsify(jit(lambda x: x + 1))(0.0)
|
2021-07-13 13:31:21 -07:00
|
|
|
|
self.assertEqual(out, 1.0)
|
|
|
|
|
|
2021-07-13 15:23:14 -07:00
|
|
|
|
def testSparsifySparseXlaCall(self):
|
|
|
|
|
# Test sparse lowering of XLA call
|
|
|
|
|
def func(M):
|
|
|
|
|
return 2 * M
|
|
|
|
|
|
|
|
|
|
M = jnp.arange(6).reshape(2, 3)
|
|
|
|
|
Msp = BCOO.fromdense(M)
|
|
|
|
|
|
|
|
|
|
out_dense = func(M)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
out_sparse = self.sparsify(jit(func))(Msp)
|
2021-07-13 15:23:14 -07:00
|
|
|
|
self.assertArraysEqual(out_dense, out_sparse.todense())
|
|
|
|
|
|
2021-08-05 15:19:43 -07:00
|
|
|
|
def testSparseForiLoop(self):
|
|
|
|
|
def func(M, x):
|
|
|
|
|
body_fun = lambda i, val: (M @ val) / M.shape[1]
|
|
|
|
|
return lax.fori_loop(0, 2, body_fun, x)
|
|
|
|
|
|
|
|
|
|
x = jnp.arange(5.0)
|
|
|
|
|
M = jnp.arange(25).reshape(5, 5)
|
|
|
|
|
M_bcoo = BCOO.fromdense(M)
|
|
|
|
|
|
|
|
|
|
result_dense = func(M, x)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
result_sparse = self.sparsify(func)(M_bcoo, x)
|
2021-08-05 15:19:43 -07:00
|
|
|
|
|
|
|
|
|
self.assertArraysAllClose(result_dense, result_sparse)
|
|
|
|
|
|
2021-08-06 13:32:07 -07:00
|
|
|
|
def testSparseCondSimple(self):
|
|
|
|
|
def func(x):
|
|
|
|
|
return lax.cond(False, lambda x: x, lambda x: 2 * x, x)
|
|
|
|
|
|
|
|
|
|
x = jnp.arange(5.0)
|
|
|
|
|
result_dense = func(x)
|
|
|
|
|
|
|
|
|
|
x_bcoo = BCOO.fromdense(x)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
result_sparse = self.sparsify(func)(x_bcoo)
|
2021-08-06 13:32:07 -07:00
|
|
|
|
|
|
|
|
|
self.assertArraysAllClose(result_dense, result_sparse.todense())
|
|
|
|
|
|
|
|
|
|
def testSparseCondMismatchError(self):
|
2021-11-04 13:02:46 -07:00
|
|
|
|
@self.sparsify
|
2021-08-06 13:32:07 -07:00
|
|
|
|
def func(x, y):
|
|
|
|
|
return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y))
|
|
|
|
|
|
|
|
|
|
x = jnp.arange(5.0)
|
|
|
|
|
y = jnp.arange(5.0)
|
|
|
|
|
|
|
|
|
|
x_bcoo = BCOO.fromdense(x)
|
|
|
|
|
y_bcoo = BCOO.fromdense(y)
|
|
|
|
|
|
|
|
|
|
func(x, y) # No error
|
|
|
|
|
func(x_bcoo, y_bcoo) # No error
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"):
|
|
|
|
|
func(x_bcoo, y)
|
|
|
|
|
|
2021-10-26 13:52:48 -07:00
|
|
|
|
def testToDense(self):
|
|
|
|
|
M = jnp.arange(4)
|
|
|
|
|
Msp = BCOO.fromdense(M)
|
2021-11-04 13:02:46 -07:00
|
|
|
|
@self.sparsify
|
2021-10-26 13:52:48 -07:00
|
|
|
|
def func(M):
|
|
|
|
|
return todense(M) + 1
|
|
|
|
|
self.assertArraysEqual(func(M), M + 1)
|
|
|
|
|
self.assertArraysEqual(func(Msp), M + 1)
|
|
|
|
|
self.assertArraysEqual(jit(func)(M), M + 1)
|
|
|
|
|
self.assertArraysEqual(jit(func)(Msp), M + 1)
|
|
|
|
|
|
2021-10-18 13:35:20 -07:00
|
|
|
|
def testWeakTypes(self):
|
|
|
|
|
# Regression test for https://github.com/google/jax/issues/8267
|
|
|
|
|
M = jnp.arange(12, dtype='int32').reshape(3, 4)
|
|
|
|
|
Msp = BCOO.fromdense(M)
|
|
|
|
|
self.assertArraysEqual(
|
|
|
|
|
operator.mul(2, M),
|
2021-11-04 13:02:46 -07:00
|
|
|
|
self.sparsify(operator.mul)(2, Msp).todense(),
|
2021-10-18 13:35:20 -07:00
|
|
|
|
check_dtypes=True,
|
|
|
|
|
)
|
|
|
|
|
|
2021-11-04 13:02:46 -07:00
|
|
|
|
class SparsifyTracerTest(SparsifyTest):
|
|
|
|
|
@classmethod
|
|
|
|
|
def sparsify(cls, f):
|
|
|
|
|
return sparsify(f, use_tracer=True)
|
|
|
|
|
|
|
|
|
|
def testTracerIsInstanceCheck(self):
|
|
|
|
|
@self.sparsify
|
|
|
|
|
def f(x):
|
|
|
|
|
self.assertIsInstance(x, SparseTracer)
|
|
|
|
|
f(jnp.arange(5))
|
2021-08-05 15:19:43 -07:00
|
|
|
|
|
2021-06-25 10:45:16 -07:00
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|