mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #6574 from jakevdp:gpu-test
PiperOrigin-RevId: 370820776
This commit is contained in:
commit
2c7556e014
@ -19,6 +19,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from jax import config
|
||||
from jax.experimental import sparse_ops
|
||||
from jax.lib import cusparse
|
||||
from jax.lib import xla_bridge
|
||||
from jax import jit
|
||||
from jax import test_util as jtu
|
||||
@ -217,8 +218,10 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
version = xla_bridge.get_backend().platform_version
|
||||
cuda_version = None if version == "<unknown>" else int(version.split()[-1])
|
||||
if cuda_version is None or cuda_version < 11000:
|
||||
self.assertFalse(cusparse and cusparse.is_supported)
|
||||
self.assertNotIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
|
||||
else:
|
||||
self.assertTrue(cusparse and cusparse.is_supported)
|
||||
self.assertIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user