mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13616 from jakevdp:fix-sparse-error
PiperOrigin-RevId: 494758906
This commit is contained in:
commit
b868cf7c07
@ -151,7 +151,7 @@ class cuSparseTest(sptu.SparseTestCase):
|
||||
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def gpu_matmul_warning_context(self, dtype):
|
||||
def gpu_matmul_dtype_warning_context(self, dtype):
|
||||
if jtu.device_under_test() == "gpu" and dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
|
||||
return contextlib.nullcontext()
|
||||
@ -318,7 +318,7 @@ class cuSparseTest(sptu.SparseTestCase):
|
||||
matvec = lambda *args: sparse.csr_matvec(*args, shape=M.shape, transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
|
||||
with self.gpu_matmul_warning_context(dtype):
|
||||
with self.gpu_matmul_dtype_warning_context(dtype):
|
||||
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -338,7 +338,7 @@ class cuSparseTest(sptu.SparseTestCase):
|
||||
matmat = lambda *args: sparse.csr_matmat(*args, shape=shape, transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
|
||||
with self.gpu_matmul_warning_context(dtype):
|
||||
with self.gpu_matmul_dtype_warning_context(dtype):
|
||||
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -397,7 +397,7 @@ class cuSparseTest(sptu.SparseTestCase):
|
||||
matvec = lambda *args: sparse_coo._coo_matvec(*args, spinfo=sparse_coo.COOInfo(shape=M.shape, rows_sorted=True), transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
|
||||
with self.gpu_matmul_warning_context(dtype):
|
||||
with self.gpu_matmul_dtype_warning_context(dtype):
|
||||
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -418,7 +418,7 @@ class cuSparseTest(sptu.SparseTestCase):
|
||||
matmat = lambda *args: sparse_coo._coo_matmat(*args, spinfo=sparse_coo.COOInfo(shape=shape, rows_sorted=True), transpose=transpose)
|
||||
|
||||
self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
|
||||
with self.gpu_matmul_warning_context(dtype):
|
||||
with self.gpu_matmul_dtype_warning_context(dtype):
|
||||
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
|
||||
|
||||
def test_coo_matmat_layout(self):
|
||||
@ -654,6 +654,11 @@ class cuSparseTest(sptu.SparseTestCase):
|
||||
|
||||
class BCOOTest(sptu.SparseTestCase):
|
||||
|
||||
def gpu_matmul_warning_context(self, msg):
|
||||
if GPU_LOWERING_ENABLED and config.jax_bcoo_cusparse_lowering:
|
||||
return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg)
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def test_vmappable(self):
|
||||
"""Test does not depend on batching rules of BCOO primitives."""
|
||||
M = jnp.arange(9).reshape((3, 3))
|
||||
@ -1166,10 +1171,8 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
else:
|
||||
lhs_bcoo, lhs, rhs = args_maker()
|
||||
matmat_expected = f_dense(lhs_bcoo, lhs, rhs)
|
||||
with self.assertWarnsRegex(
|
||||
sparse.CuSparseEfficiencyWarning,
|
||||
"bcoo_dot_general GPU lowering currently does not support this "
|
||||
"batch-mode computation.*"):
|
||||
with self.gpu_matmul_warning_context(
|
||||
"bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"):
|
||||
matmat_default_lowering_fallback = jit(f_sparse)(lhs_bcoo, lhs, rhs)
|
||||
self.assertAllClose(matmat_expected, matmat_default_lowering_fallback,
|
||||
atol=1E-6, rtol=1E-6)
|
||||
@ -1204,13 +1207,9 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general,
|
||||
dimension_numbers=dimension_numbers))
|
||||
|
||||
if config.jax_bcoo_cusparse_lowering:
|
||||
with self.assertWarnsRegex(
|
||||
sparse.CuSparseEfficiencyWarning,
|
||||
"bcoo_dot_general GPU lowering currently does not support this "
|
||||
"batch-mode computation.*"):
|
||||
matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs)
|
||||
|
||||
with self.gpu_matmul_warning_context(
|
||||
"bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"):
|
||||
matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs)
|
||||
self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback)
|
||||
|
||||
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
|
||||
@ -1236,14 +1235,11 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
|
||||
matmat_expected = lax.dot_general(lhs_mat_dense, rhs,
|
||||
dimension_numbers=dimension_numbers_2d)
|
||||
if config.jax_bcoo_cusparse_lowering:
|
||||
with self.assertWarnsRegex(
|
||||
sparse.CuSparseEfficiencyWarning,
|
||||
with self.subTest(msg="2D"):
|
||||
with self.gpu_matmul_warning_context(
|
||||
"bcoo_dot_general GPU lowering requires matrices with sorted indices*"):
|
||||
matmat_unsorted_fallback = sp_matmat(lhs_mat_bcoo_unsorted, rhs)
|
||||
|
||||
with self.subTest(msg="2D"):
|
||||
self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback)
|
||||
self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback)
|
||||
|
||||
lhs_vec_dense = jnp.array([0, 1, 0, 2, 0], dtype=jnp.float32)
|
||||
lhs_vec_bcoo = sparse.BCOO.fromdense(lhs_vec_dense, nse=5)
|
||||
@ -1260,14 +1256,11 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
vecmat_expected = lax.dot_general(lhs_vec_dense, rhs,
|
||||
dimension_numbers=dimension_numbers_1d)
|
||||
|
||||
if config.jax_bcoo_cusparse_lowering:
|
||||
with self.assertWarnsRegex(
|
||||
sparse.CuSparseEfficiencyWarning,
|
||||
with self.subTest(msg="1D"):
|
||||
with self.gpu_matmul_warning_context(
|
||||
"bcoo_dot_general GPU lowering requires matrices with sorted indices*"):
|
||||
vecmat_unsorted_fallback = sp_vecmat(lhs_vec_bcoo_unsorted, rhs)
|
||||
|
||||
with self.subTest(msg="1D"):
|
||||
self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback)
|
||||
self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback)
|
||||
|
||||
@jtu.sample_product(
|
||||
props=_generate_bcoo_dot_general_properties(
|
||||
|
Loading…
x
Reference in New Issue
Block a user