mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #22649 from ROCm:ci_jax_export_harness
PiperOrigin-RevId: 660096296
This commit is contained in:
commit
aec6efb44b
@ -935,6 +935,7 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
|
||||
# eigh on GPU
|
||||
"cusolver_syevj", "cusolver_syevd",
|
||||
"hipsolver_syevj", "hipsolver_syevd",
|
||||
# eigh on TPU
|
||||
"Eigh",
|
||||
# eig on CPU
|
||||
@ -948,6 +949,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
# qr on GPU
|
||||
"cusolver_geqrf", "cublas_geqrf_batched",
|
||||
"cusolver_orgqr",
|
||||
"hipsolver_geqrf", "hipblas_geqrf_batched",
|
||||
"hipsolver_orgqr",
|
||||
# qr and svd on TPU
|
||||
"Qr", "ProductOfElementaryHouseholderReflectors",
|
||||
# triangular_solve on CPU
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -34,9 +34,11 @@ from jax._src.internal_test_util import export_back_compat_test_util as bctu
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cpu_cholesky_lapack_potrf
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cpu_eig_lapack_geev
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_eigh_cusolver_syev
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import rocm_eigh_hipsolver_syev
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cpu_eigh_lapack_syev
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cpu_lu_lapack_getrf
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_qr_cusolver_geqrf
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import rocm_qr_hipsolver_geqrf
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cpu_qr_lapack_geqrf
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_lapack_gees
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd
|
||||
@ -123,6 +125,8 @@ class CompatTest(bctu.CompatTestBase):
|
||||
cuda_threefry2x32.data_2023_03_15, cuda_threefry2x32.data_2024_07_30,
|
||||
cpu_lu_lapack_getrf.data_2023_06_14,
|
||||
cuda_qr_cusolver_geqrf.data_2023_03_18, cuda_eigh_cusolver_syev.data_2023_03_17,
|
||||
rocm_qr_hipsolver_geqrf.data_2024_08_05,
|
||||
rocm_eigh_hipsolver_syev.data_2024_08_05,
|
||||
cpu_schur_lapack_gees.data_2023_07_16,
|
||||
cpu_svd_lapack_gesdd.data_2023_06_19,
|
||||
cpu_triangular_solve_blas_trsm.data_2023_07_16,
|
||||
@ -302,14 +306,19 @@ class CompatTest(bctu.CompatTestBase):
|
||||
for dtype_name in ("f32", "f64")
|
||||
# We use different custom calls for sizes <= 32
|
||||
for variant in ["syevj", "syevd"])
|
||||
def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"):
|
||||
def test_gpu_eigh_solver_syev(self, dtype_name="f32", variant="syevj"):
|
||||
if not config.enable_x64.value and dtype_name == "f64":
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
_is_required_cusolver_version_satisfied(11600)):
|
||||
# The underlying problem is that this test assumes the workspace size can be
|
||||
# queried from an older version of cuSOLVER and then be used in a newer one.
|
||||
self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized")
|
||||
if jtu.test_device_matches(["cuda"]):
|
||||
if _is_required_cusolver_version_satisfied(11600):
|
||||
# The underlying problem is that this test assumes the workspace size can be
|
||||
# queried from an older version of cuSOLVER and then be used in a newer one.
|
||||
self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized")
|
||||
data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"])
|
||||
elif jtu.test_device_matches(["rocm"]):
|
||||
data = self.load_testdata(rocm_eigh_hipsolver_syev.data_2024_08_05[f"{dtype_name}_{variant}"])
|
||||
else:
|
||||
self.skipTest("Unsupported platform")
|
||||
# For lax.linalg.eigh
|
||||
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
|
||||
size = dict(syevj=8, syevd=36)[variant]
|
||||
@ -317,7 +326,6 @@ class CompatTest(bctu.CompatTestBase):
|
||||
atol = dict(f32=1e-2, f64=1e-10)[dtype_name]
|
||||
operand = CompatTest.eigh_input((size, size), dtype)
|
||||
func = lambda: CompatTest.eigh_harness((size, size), dtype)
|
||||
data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_eigh_results, operand))
|
||||
|
||||
@ -359,15 +367,20 @@ class CompatTest(bctu.CompatTestBase):
|
||||
dict(testcase_name=f"_dtype={dtype_name}_{batched}",
|
||||
dtype_name=dtype_name, batched=batched)
|
||||
for dtype_name in ("f32",)
|
||||
# For batched qr we use cublas_geqrf_batched
|
||||
# For batched qr we use cublas_geqrf_batched/hipblas_geqrf_batched.
|
||||
for batched in ("batched", "unbatched"))
|
||||
def test_cuda_qr_cusolver_geqrf(self, dtype_name="f32", batched="unbatched"):
|
||||
def test_gpu_qr_solver_geqrf(self, dtype_name="f32", batched="unbatched"):
|
||||
if jtu.test_device_matches(["cuda"]):
|
||||
data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched])
|
||||
elif jtu.test_device_matches(["rocm"]):
|
||||
data = self.load_testdata(rocm_qr_hipsolver_geqrf.data_2024_08_05[batched])
|
||||
else:
|
||||
self.skipTest("Unsupported platform")
|
||||
# For lax.linalg.qr
|
||||
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
|
||||
rtol = dict(f32=1e-3, f64=1e-5)[dtype_name]
|
||||
shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched]
|
||||
func = lambda: CompatTest.qr_harness(shape, dtype)
|
||||
data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched])
|
||||
self.run_one_test(func, data, rtol=rtol)
|
||||
|
||||
def test_tpu_Qr(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user