mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] fix bug in sparse.eye under JIT
This commit is contained in:
parent
cc1d2aaaed
commit
dce6a9f8ce
@ -2468,7 +2468,7 @@ class BCOO(JAXSparse):
|
||||
return cls.fromdense(jnp.eye(N, M, k, dtype=dtype),
|
||||
n_batch=n_batch, n_dense=n_dense,
|
||||
index_dtype=index_dtype)
|
||||
k = jnp.asarray(k)
|
||||
|
||||
if n_batch == 0:
|
||||
data = jnp.ones(diag_size, dtype=dtype)
|
||||
idx = jnp.arange(diag_size, dtype=index_dtype)
|
||||
|
@ -110,7 +110,6 @@ class COO(JAXSparse):
|
||||
# if k is out of range, return an empty matrix.
|
||||
return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype)
|
||||
|
||||
k = jnp.asarray(k)
|
||||
data = jnp.ones(diag_size, dtype=dtype)
|
||||
idx = jnp.arange(diag_size, dtype=index_dtype)
|
||||
zero = _const(idx, 0)
|
||||
|
@ -82,7 +82,6 @@ class CSR(JAXSparse):
|
||||
# if k is out of range, return an empty matrix.
|
||||
return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype)
|
||||
|
||||
k = jnp.asarray(k)
|
||||
data = jnp.ones(diag_size, dtype=dtype)
|
||||
idx = jnp.arange(diag_size, dtype=index_dtype)
|
||||
zero = _const(idx, 0)
|
||||
|
@ -2493,14 +2493,20 @@ class SparseObjectTest(sptu.SparseTestCase):
|
||||
for k in [-2, 0, 1])
|
||||
def test_eye(self, cls, N, M, k):
|
||||
sparse_format = cls.__name__.lower()
|
||||
mat = sparse.eye(N, M, k, sparse_format=sparse_format)
|
||||
func = partial(sparse.eye, N, M, k, sparse_format=sparse_format)
|
||||
expected = jnp.eye(N, M, k)
|
||||
expected_nse = jnp.count_nonzero(expected)
|
||||
|
||||
mat = func()
|
||||
self.assertIsInstance(mat, cls)
|
||||
self.assertArraysEqual(mat.todense(), expected)
|
||||
self.assertEqual(mat.nse, expected_nse)
|
||||
|
||||
mat_jit = jit(func)()
|
||||
self.assertIsInstance(mat_jit, cls)
|
||||
self.assertArraysEqual(mat_jit.todense(), expected)
|
||||
self.assertEqual(mat_jit.nse, expected_nse)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{nse}_BCOO{shape}", "shape": shape, "nse": nse}
|
||||
for shape in ([2, 5], [5, 3])
|
||||
|
Loading…
x
Reference in New Issue
Block a user