From 9ba77f8ecd9e679f61cccc66ccf4b6c09011a206 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Thu, 16 May 2024 02:01:35 -0700 Subject: [PATCH] Skip a test when run with cuSolver >= 11.6 This version is shipped with CUDA 12.4. The test assumes that a workspace size baked in with an older version of cuSolver can be used with a newer version of cuSolver. This is not safe, and leads to an error when upgrading from 11.5 to 11.6. --- tests/export_back_compat_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 67ddf0cee..acb7dec33 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -61,9 +61,14 @@ from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import test_util as jtu +from jax._src.lib import cuda_versions config.parse_flags_with_absl() +def _is_required_cusolver_version_satisfied(required_version): + if cuda_versions is None: + return False + return cuda_versions.cusolver_get_version() >= required_version @jtu.with_config(jax_legacy_prng_key="allow", jax_debug_key_reuse=False, @@ -294,6 +299,11 @@ class CompatTest(bctu.CompatTestBase): def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"): if not config.enable_x64.value and dtype_name == "f64": self.skipTest("Test disabled for x32 mode") + if (jtu.test_device_matches(["cuda"]) and + _is_required_cusolver_version_satisfied(11600)): + # The underlying problem is that this test assumes the workspace size can be + # queried from an older version of cuSOLVER and then be used in a newer one. + self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized") # For lax.linalg.eigh dtype = dict(f32=np.float32, f64=np.float64)[dtype_name] size = dict(syevj=8, syevd=36)[variant]