mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Skip a test when run with cuSolver >= 11.6
This version is shipped with CUDA 12.4. The test assumes that a workspace size baked in with an older version of cuSolver can be used with a newer version of cuSolver. This is not safe, and leads to an error when upgrading from 11.5 to 11.6.
This commit is contained in:
parent
a820387a79
commit
9ba77f8ecd
@ -61,9 +61,14 @@ from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def _is_required_cusolver_version_satisfied(required_version):
|
||||
if cuda_versions is None:
|
||||
return False
|
||||
return cuda_versions.cusolver_get_version() >= required_version
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key="allow",
|
||||
jax_debug_key_reuse=False,
|
||||
@ -294,6 +299,11 @@ class CompatTest(bctu.CompatTestBase):
|
||||
def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"):
|
||||
if not config.enable_x64.value and dtype_name == "f64":
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
_is_required_cusolver_version_satisfied(11600)):
|
||||
# The underlying problem is that this test assumes the workspace size can be
|
||||
# queried from an older version of cuSOLVER and then be used in a newer one.
|
||||
self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized")
|
||||
# For lax.linalg.eigh
|
||||
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
|
||||
size = dict(syevj=8, syevd=36)[variant]
|
||||
|
Loading…
x
Reference in New Issue
Block a user