[sparse] fix bug in sparse.eye under JIT

This commit is contained in:
Jake VanderPlas 2022-11-28 10:54:09 -08:00
parent cc1d2aaaed
commit dce6a9f8ce
4 changed files with 8 additions and 4 deletions

View File

@ -2468,7 +2468,7 @@ class BCOO(JAXSparse):
return cls.fromdense(jnp.eye(N, M, k, dtype=dtype),
n_batch=n_batch, n_dense=n_dense,
index_dtype=index_dtype)
k = jnp.asarray(k)
if n_batch == 0:
data = jnp.ones(diag_size, dtype=dtype)
idx = jnp.arange(diag_size, dtype=index_dtype)

View File

@ -110,7 +110,6 @@ class COO(JAXSparse):
# if k is out of range, return an empty matrix.
return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype)
k = jnp.asarray(k)
data = jnp.ones(diag_size, dtype=dtype)
idx = jnp.arange(diag_size, dtype=index_dtype)
zero = _const(idx, 0)

View File

@ -82,7 +82,6 @@ class CSR(JAXSparse):
# if k is out of range, return an empty matrix.
return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype)
k = jnp.asarray(k)
data = jnp.ones(diag_size, dtype=dtype)
idx = jnp.arange(diag_size, dtype=index_dtype)
zero = _const(idx, 0)

View File

@ -2493,14 +2493,20 @@ class SparseObjectTest(sptu.SparseTestCase):
for k in [-2, 0, 1])
def test_eye(self, cls, N, M, k):
sparse_format = cls.__name__.lower()
mat = sparse.eye(N, M, k, sparse_format=sparse_format)
func = partial(sparse.eye, N, M, k, sparse_format=sparse_format)
expected = jnp.eye(N, M, k)
expected_nse = jnp.count_nonzero(expected)
mat = func()
self.assertIsInstance(mat, cls)
self.assertArraysEqual(mat.todense(), expected)
self.assertEqual(mat.nse, expected_nse)
mat_jit = jit(func)()
self.assertIsInstance(mat_jit, cls)
self.assertArraysEqual(mat_jit.todense(), expected)
self.assertEqual(mat_jit.nse, expected_nse)
@parameterized.named_parameters(
{"testcase_name": f"{nse}_BCOO{shape}", "shape": shape, "nse": nse}
for shape in ([2, 5], [5, 3])