Merge pull request #22649 from ROCm:ci_jax_export_harness

PiperOrigin-RevId: 660096296
This commit is contained in:
jax authors 2024-08-06 14:27:13 -07:00
commit aec6efb44b
4 changed files with 1674 additions and 10 deletions

View File

@ -935,6 +935,7 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
# eigh on GPU
"cusolver_syevj", "cusolver_syevd",
"hipsolver_syevj", "hipsolver_syevd",
# eigh on TPU
"Eigh",
# eig on CPU
@ -948,6 +949,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
# qr on GPU
"cusolver_geqrf", "cublas_geqrf_batched",
"cusolver_orgqr",
"hipsolver_geqrf", "hipblas_geqrf_batched",
"hipsolver_orgqr",
# qr and svd on TPU
"Qr", "ProductOfElementaryHouseholderReflectors",
# triangular_solve on CPU

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -34,9 +34,11 @@ from jax._src.internal_test_util import export_back_compat_test_util as bctu
from jax._src.internal_test_util.export_back_compat_test_data import cpu_cholesky_lapack_potrf
from jax._src.internal_test_util.export_back_compat_test_data import cpu_eig_lapack_geev
from jax._src.internal_test_util.export_back_compat_test_data import cuda_eigh_cusolver_syev
from jax._src.internal_test_util.export_back_compat_test_data import rocm_eigh_hipsolver_syev
from jax._src.internal_test_util.export_back_compat_test_data import cpu_eigh_lapack_syev
from jax._src.internal_test_util.export_back_compat_test_data import cpu_lu_lapack_getrf
from jax._src.internal_test_util.export_back_compat_test_data import cuda_qr_cusolver_geqrf
from jax._src.internal_test_util.export_back_compat_test_data import rocm_qr_hipsolver_geqrf
from jax._src.internal_test_util.export_back_compat_test_data import cpu_qr_lapack_geqrf
from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_lapack_gees
from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd
@ -123,6 +125,8 @@ class CompatTest(bctu.CompatTestBase):
cuda_threefry2x32.data_2023_03_15, cuda_threefry2x32.data_2024_07_30,
cpu_lu_lapack_getrf.data_2023_06_14,
cuda_qr_cusolver_geqrf.data_2023_03_18, cuda_eigh_cusolver_syev.data_2023_03_17,
rocm_qr_hipsolver_geqrf.data_2024_08_05,
rocm_eigh_hipsolver_syev.data_2024_08_05,
cpu_schur_lapack_gees.data_2023_07_16,
cpu_svd_lapack_gesdd.data_2023_06_19,
cpu_triangular_solve_blas_trsm.data_2023_07_16,
@ -302,14 +306,19 @@ class CompatTest(bctu.CompatTestBase):
for dtype_name in ("f32", "f64")
# We use different custom calls for sizes <= 32
for variant in ["syevj", "syevd"])
def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"):
def test_gpu_eigh_solver_syev(self, dtype_name="f32", variant="syevj"):
if not config.enable_x64.value and dtype_name == "f64":
self.skipTest("Test disabled for x32 mode")
if (jtu.test_device_matches(["cuda"]) and
_is_required_cusolver_version_satisfied(11600)):
# The underlying problem is that this test assumes the workspace size can be
# queried from an older version of cuSOLVER and then be used in a newer one.
self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized")
if jtu.test_device_matches(["cuda"]):
if _is_required_cusolver_version_satisfied(11600):
# The underlying problem is that this test assumes the workspace size can be
# queried from an older version of cuSOLVER and then be used in a newer one.
self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized")
data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"])
elif jtu.test_device_matches(["rocm"]):
data = self.load_testdata(rocm_eigh_hipsolver_syev.data_2024_08_05[f"{dtype_name}_{variant}"])
else:
self.skipTest("Unsupported platform")
# For lax.linalg.eigh
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
size = dict(syevj=8, syevd=36)[variant]
@ -317,7 +326,6 @@ class CompatTest(bctu.CompatTestBase):
atol = dict(f32=1e-2, f64=1e-10)[dtype_name]
operand = CompatTest.eigh_input((size, size), dtype)
func = lambda: CompatTest.eigh_harness((size, size), dtype)
data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))
@ -359,15 +367,20 @@ class CompatTest(bctu.CompatTestBase):
dict(testcase_name=f"_dtype={dtype_name}_{batched}",
dtype_name=dtype_name, batched=batched)
for dtype_name in ("f32",)
# For batched qr we use cublas_geqrf_batched
# For batched qr we use cublas_geqrf_batched/hipblas_geqrf_batched.
for batched in ("batched", "unbatched"))
def test_cuda_qr_cusolver_geqrf(self, dtype_name="f32", batched="unbatched"):
def test_gpu_qr_solver_geqrf(self, dtype_name="f32", batched="unbatched"):
if jtu.test_device_matches(["cuda"]):
data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched])
elif jtu.test_device_matches(["rocm"]):
data = self.load_testdata(rocm_qr_hipsolver_geqrf.data_2024_08_05[batched])
else:
self.skipTest("Unsupported platform")
# For lax.linalg.qr
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
rtol = dict(f32=1e-3, f64=1e-5)[dtype_name]
shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched]
func = lambda: CompatTest.qr_harness(shape, dtype)
data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched])
self.run_one_test(func, data, rtol=rtol)
def test_tpu_Qr(self):