mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #21258 from olupton:skip-cusolver-test-with-cuda-12.4
PiperOrigin-RevId: 635845926
This commit is contained in:
commit
5350bc960d
@ -61,9 +61,14 @@ from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def _is_required_cusolver_version_satisfied(required_version):
|
||||
if cuda_versions is None:
|
||||
return False
|
||||
return cuda_versions.cusolver_get_version() >= required_version
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key="allow",
|
||||
jax_debug_key_reuse=False,
|
||||
@ -294,6 +299,11 @@ class CompatTest(bctu.CompatTestBase):
|
||||
def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"):
|
||||
if not config.enable_x64.value and dtype_name == "f64":
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
_is_required_cusolver_version_satisfied(11600)):
|
||||
# The underlying problem is that this test assumes the workspace size can be
|
||||
# queried from an older version of cuSOLVER and then be used in a newer one.
|
||||
self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized")
|
||||
# For lax.linalg.eigh
|
||||
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
|
||||
size = dict(syevj=8, syevd=36)[variant]
|
||||
|
Loading…
x
Reference in New Issue
Block a user