Merge pull request #21258 from olupton:skip-cusolver-test-with-cuda-12.4

PiperOrigin-RevId: 635845926
This commit is contained in:
jax authors 2024-05-21 10:11:54 -07:00
commit 5350bc960d

View File

@ -61,9 +61,14 @@ from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
config.parse_flags_with_absl()
def _is_required_cusolver_version_satisfied(required_version):
if cuda_versions is None:
return False
return cuda_versions.cusolver_get_version() >= required_version
@jtu.with_config(jax_legacy_prng_key="allow",
jax_debug_key_reuse=False,
@ -294,6 +299,11 @@ class CompatTest(bctu.CompatTestBase):
def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"):
if not config.enable_x64.value and dtype_name == "f64":
self.skipTest("Test disabled for x32 mode")
if (jtu.test_device_matches(["cuda"]) and
_is_required_cusolver_version_satisfied(11600)):
# The underlying problem is that this test assumes the workspace size can be
# queried from an older version of cuSOLVER and then be used in a newer one.
self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized")
# For lax.linalg.eigh
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
size = dict(syevj=8, syevd=36)[variant]