Merge pull request #6574 from jakevdp:gpu-test

PiperOrigin-RevId: 370820776
This commit is contained in:
jax authors 2021-04-27 20:50:13 -07:00
commit 2c7556e014

View File

@ -19,6 +19,7 @@ from absl.testing import absltest
from absl.testing import parameterized
from jax import config
from jax.experimental import sparse_ops
from jax.lib import cusparse
from jax.lib import xla_bridge
from jax import jit
from jax import test_util as jtu
@ -217,8 +218,10 @@ class cuSparseTest(jtu.JaxTestCase):
version = xla_bridge.get_backend().platform_version
cuda_version = None if version == "<unknown>" else int(version.split()[-1])
if cuda_version is None or cuda_version < 11000:
self.assertFalse(cusparse and cusparse.is_supported)
self.assertNotIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
else:
self.assertTrue(cusparse and cusparse.is_supported)
self.assertIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
@parameterized.named_parameters(jtu.cases_from_list(