mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

This tests saving a module with one set of axis names, but loading it with another set of axis names. This does also test the custom calls: - `@Sharding` - `@xla.sdy.GlobalToLocalShape` - `@xla.sdy.LocalToGlobalShape` But note that there are a bunch of other custom calls that will be tested in the Shardy and XLA codebases. The way the testing utils is tested here doesn't allow me to set `out_shardings` for example. So JAX can rely on the existence of those tests as stability guarantees just like for StableHLO. PiperOrigin-RevId: 732893432
1060 lines
48 KiB
Python
1060 lines
48 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for backwards compatibility of exporting code with custom calls.
|
|
|
|
See the export_back_compat_test_util module docstring for how to setup and
|
|
update these tests.
|
|
"""
|
|
import dataclasses
|
|
from functools import partial
|
|
import itertools
|
|
import math
|
|
|
|
from absl.testing import absltest, parameterized
|
|
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax._src.export import _export
|
|
|
|
from jax._src.internal_test_util import export_back_compat_test_util as bctu
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_cholesky_lapack_potrf
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_eig_lapack_geev
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_eigh_cusolver_syev
|
|
from jax._src.internal_test_util.export_back_compat_test_data import rocm_eigh_hipsolver_syev
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_eigh_lapack_syev
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_lu_lapack_getrf
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_qr_cusolver_geqrf
|
|
from jax._src.internal_test_util.export_back_compat_test_data import rocm_qr_hipsolver_geqrf
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_qr_lapack_geqrf
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_lapack_gees
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_hessenberg_lapack_gehrd
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_tridiagonal_lapack_sytrd_hetrd
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_tridiagonal_solve_lapack_gtsv
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_svd_cusolver_gesvd
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_tridiagonal_cusolver_sytrd
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Qr
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Sharding
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_stablehlo_dynamic_reduce_window
|
|
from jax._src.internal_test_util.export_back_compat_test_data import shardy_sharding_ops_with_different_meshes
|
|
from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_rng_bit_generator
|
|
from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_top_k
|
|
from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_approx_top_k
|
|
|
|
from jax.experimental import pjit
|
|
from jax.experimental.shard_map import shard_map
|
|
import jax.numpy as jnp
|
|
|
|
from jax.sharding import Mesh
|
|
from jax.sharding import PartitionSpec as P
|
|
from jax.sharding import NamedSharding as NS
|
|
|
|
from jax._src import config
|
|
from jax._src import test_util as jtu
|
|
from jax._src.lib import cuda_versions
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
def _is_required_cusolver_version_satisfied(required_version):
|
|
if cuda_versions is None:
|
|
return False
|
|
return cuda_versions.cusolver_get_version() >= required_version
|
|
|
|
|
|
@jtu.with_config(jax_legacy_prng_key="allow",
|
|
jax_debug_key_reuse=False,
|
|
jax_include_full_tracebacks_in_locations=False,
|
|
jax_threefry_gpu_kernel_lowering=True)
|
|
class CompatTest(bctu.CompatTestBase):
|
|
def test_dummy(self):
|
|
# Tests the testing mechanism. Let this test run on all platforms
|
|
dummy_data = self.load_testdata(bctu.dummy_data_dict)
|
|
platform_dummy_data = dataclasses.replace(
|
|
dummy_data, platform=self.default_jax_backend())
|
|
self.run_one_test(jnp.sin, platform_dummy_data)
|
|
|
|
def test_detect_different_output(self):
|
|
# Test the detection mechanism. Let this test run on all platforms
|
|
dummy_data = self.load_testdata(bctu.dummy_data_dict)
|
|
platform_dummy_data = dataclasses.replace(
|
|
dummy_data,
|
|
platform=self.default_jax_backend(),
|
|
expected_outputs=(np.array(2.0, dtype=np.float32),))
|
|
with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"):
|
|
self.run_one_test(jnp.sin, platform_dummy_data)
|
|
|
|
def test_detect_different_custom_calls(self):
|
|
# Test the detection mechanism. Let this test run on all platforms
|
|
dummy_data = self.load_testdata(bctu.dummy_data_dict)
|
|
platform_dummy_data = dataclasses.replace(
|
|
dummy_data,
|
|
platform=self.default_jax_backend(),
|
|
custom_call_targets=["missing"])
|
|
with self.assertRaisesRegex(AssertionError, "Element counts were not equal"):
|
|
self.run_one_test(jnp.sin, platform_dummy_data)
|
|
|
|
def test_custom_call_coverage(self):
|
|
"""Tests that the back compat tests cover all the targets declared stable."""
|
|
targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
|
|
cpu_ffi_testdatas = [
|
|
cpu_cholesky_lapack_potrf.data_2024_05_31,
|
|
cpu_qr_lapack_geqrf.data_2024_08_22,
|
|
cpu_eig_lapack_geev.data_2024_08_19,
|
|
cpu_eigh_lapack_syev.data_2024_08_19,
|
|
cpu_lu_lapack_getrf.data_2024_05_31,
|
|
cpu_schur_lapack_gees.data_2024_11_29,
|
|
cpu_triangular_solve_blas_trsm.data_2024_12_02,
|
|
cpu_svd_lapack_gesdd.data_2024_08_13,
|
|
cpu_hessenberg_lapack_gehrd.data_2024_08_31,
|
|
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01,
|
|
]
|
|
# Add here all the testdatas that should cover the targets guaranteed
|
|
# stable
|
|
covering_testdatas = [
|
|
*cpu_ffi_testdatas,
|
|
cpu_cholesky_lapack_potrf.data_2023_06_19,
|
|
cpu_eig_lapack_geev.data_2023_06_19,
|
|
cpu_eigh_lapack_syev.data_2023_03_17,
|
|
cpu_qr_lapack_geqrf.data_2023_03_17,
|
|
cuda_threefry2x32.data_2024_07_30,
|
|
cpu_lu_lapack_getrf.data_2023_06_14,
|
|
cuda_lu_pivots_to_permutation.data_2024_08_08,
|
|
cuda_lu_cusolver_getrf.data_2024_08_19,
|
|
cuda_qr_cusolver_geqrf.data_2024_09_26,
|
|
cuda_eigh_cusolver_syev.data_2024_09_30,
|
|
cuda_svd_cusolver_gesvd.data_2024_10_08,
|
|
cpu_tridiagonal_solve_lapack_gtsv.data_2025_01_09,
|
|
cuda_tridiagonal_cusolver_sytrd.data_2025_01_09,
|
|
rocm_qr_hipsolver_geqrf.data_2024_08_05,
|
|
rocm_eigh_hipsolver_syev.data_2024_08_05,
|
|
cpu_schur_lapack_gees.data_2023_07_16,
|
|
cpu_svd_lapack_gesdd.data_2023_06_19,
|
|
cpu_triangular_solve_blas_trsm.data_2023_07_16,
|
|
cpu_hessenberg_lapack_gehrd.data_2024_08_30,
|
|
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03,
|
|
tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17,
|
|
tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17,
|
|
tpu_ApproxTopK.data_2023_05_16,
|
|
tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17,
|
|
tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17,
|
|
stablehlo_dynamic_rng_bit_generator.data_2023_06_17,
|
|
stablehlo_dynamic_top_k.data_2023_07_16,
|
|
stablehlo_dynamic_top_k.data_2023_08_11, # with shape_assertion
|
|
stablehlo_dynamic_approx_top_k.data_2024_05_30,
|
|
]
|
|
# Some of the above are nested structures.
|
|
covering_testdatas = itertools.chain(
|
|
*[self.load_testdata_nested(d) for d in covering_testdatas])
|
|
covered_targets = set()
|
|
for data in covering_testdatas:
|
|
self.assertIsInstance(data, bctu.CompatTestData)
|
|
covered_targets = covered_targets.union(data.custom_call_targets)
|
|
|
|
covered_targets = covered_targets.union({
|
|
"tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py
|
|
"tpu_custom_call", # tested separately
|
|
"__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py
|
|
# The following require ROCm to test
|
|
"hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi",
|
|
"hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", "hipsolver_syevd_ffi",
|
|
"hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi",
|
|
})
|
|
not_covered = targets_to_cover.difference(covered_targets)
|
|
self.assertEmpty(not_covered,
|
|
msg=("The following custom call targets are declared "
|
|
"stable but are not covered by any tests: "
|
|
f"{not_covered}"))
|
|
|
|
def cholesky_input(self, shape, dtype):
|
|
a = jtu.rand_default(self.rng())(shape, dtype)
|
|
return np.matmul(a, np.conj(np.swapaxes(a, -1, -2)))
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
shape = (4, 4)
|
|
input = self.cholesky_input(shape, dtype)
|
|
del input # Input is in the testdata, here for readability
|
|
func = lax.linalg.cholesky
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
info = cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name]
|
|
data = self.load_testdata(info)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
|
|
|
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name])
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
def test_cpu_eig_lapack_geev(self, dtype_name="f32"):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
shape = (4, 4)
|
|
def func():
|
|
# Compute the inputs to simplify the harness
|
|
input = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape)
|
|
return lax.linalg.eig(input,
|
|
compute_left_eigenvectors=True,
|
|
compute_right_eigenvectors=True)
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
def check_eig_results(res_run, res_expected, *, rtol, atol):
|
|
# Test ported from tests.linlag_test.testEig
|
|
# Norm, adjusted for dimension and type.
|
|
inner_dimension = shape[-1]
|
|
operand = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
|
|
def norm(x):
|
|
norm = np.linalg.norm(x, axis=(-2, -1))
|
|
return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps)
|
|
|
|
def check_right_eigenvectors(a, w, vr):
|
|
self.assertTrue(
|
|
np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))
|
|
|
|
def check_left_eigenvectors(a, w, vl):
|
|
rank = len(a.shape)
|
|
aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
|
|
wC = jnp.conj(w)
|
|
check_right_eigenvectors(aH, wC, vl)
|
|
|
|
def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
|
|
closest_diff = min(abs(eigenvalues_array - eigenvalue))
|
|
self.assertAllClose(
|
|
closest_diff,
|
|
np.array(0., closest_diff.dtype),
|
|
atol=atol, rtol=rtol)
|
|
|
|
all_w_run, all_w_exp = res_run[0], res_expected[0]
|
|
for idx in itertools.product(*map(range, operand.shape[:-2])):
|
|
w_run, w_exp = all_w_run[idx], all_w_exp[idx]
|
|
for i in range(inner_dimension):
|
|
check_eigenvalue_is_in_array(w_run[i], w_exp)
|
|
check_eigenvalue_is_in_array(w_exp[i], w_run)
|
|
|
|
check_left_eigenvectors(operand, all_w_run, res_run[1])
|
|
check_right_eigenvectors(operand, all_w_run, res_run[2])
|
|
|
|
info = cpu_eig_lapack_geev.data_2024_08_19[dtype_name]
|
|
data = self.load_testdata(info)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=check_eig_results)
|
|
data = self.load_testdata(cpu_eig_lapack_geev.data_2023_06_19[dtype_name])
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=check_eig_results,
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
@staticmethod
|
|
def eigh_input(shape, dtype):
|
|
# In order to keep inputs small, we construct the input programmatically
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
|
# Make operand self-adjoint
|
|
operand = (operand + jnp.conj(jnp.swapaxes(operand, -1, -2))) / 2.
|
|
return operand
|
|
|
|
@staticmethod
|
|
def eigh_harness(shape, dtype):
|
|
operand = CompatTest.eigh_input(shape, dtype)
|
|
return lax.linalg.eigh(jnp.tril(operand), lower=True, symmetrize_input=False)
|
|
|
|
def check_eigh_results(self, operand, res_now, res_expected, *,
|
|
rtol, atol=None):
|
|
v_now, w_now = res_now
|
|
_, w_expected = res_expected
|
|
n, m = operand.shape
|
|
assert n == m
|
|
assert v_now.shape == operand.shape
|
|
assert w_now.shape == (n,)
|
|
self.assertLessEqual(
|
|
np.linalg.norm(np.eye(n) - np.matmul(np.conj(np.swapaxes(v_now, -1, -2)), v_now)),
|
|
rtol)
|
|
# w_now : f64[n] while v_now: c128[n, n]
|
|
w_now_like_v = w_now[np.newaxis, :].astype(v_now.dtype)
|
|
self.assertLessEqual(
|
|
np.linalg.norm(np.matmul(operand, v_now) - w_now_like_v * v_now),
|
|
rtol * np.linalg.norm(operand))
|
|
self.assertAllClose(w_expected, w_now, rtol=rtol, atol=atol)
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"):
|
|
# For lax.linalg.eigh
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
size = 8
|
|
operand = CompatTest.eigh_input((size, size), dtype)
|
|
func = lambda: CompatTest.eigh_harness((8, 8), dtype)
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
info = cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]
|
|
data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=partial(self.check_eigh_results, operand))
|
|
|
|
# Legacy custom call test
|
|
data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name])
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=partial(self.check_eigh_results, operand),
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}_{variant}",
|
|
dtype_name=dtype_name, variant=variant)
|
|
for dtype_name in ("f32", "f64")
|
|
# We use different custom calls for sizes <= 32
|
|
for variant in ["syevj", "syevd"])
|
|
def test_gpu_eigh_solver_syev_legacy(self, dtype_name="f32", variant="syevj"):
|
|
if not config.enable_x64.value and dtype_name == "f64":
|
|
self.skipTest("Test disabled for x32 mode")
|
|
if jtu.test_device_matches(["rocm"]):
|
|
data = self.load_testdata(rocm_eigh_hipsolver_syev.data_2024_08_05[f"{dtype_name}_{variant}"])
|
|
prefix = "hip"
|
|
elif jtu.test_device_matches(["cuda"]):
|
|
if _is_required_cusolver_version_satisfied(11600):
|
|
# The underlying problem is that this test assumes the workspace size can be
|
|
# queried from an older version of cuSOLVER and then be used in a newer one.
|
|
self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized")
|
|
data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"])
|
|
prefix = "cu"
|
|
else:
|
|
self.skipTest("Unsupported platform")
|
|
# For lax.linalg.eigh
|
|
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
|
|
size = dict(syevj=8, syevd=36)[variant]
|
|
rtol = dict(f32=1e-3, f64=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-2, f64=1e-10)[dtype_name]
|
|
operand = CompatTest.eigh_input((size, size), dtype)
|
|
func = lambda: CompatTest.eigh_harness((size, size), dtype)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=partial(self.check_eigh_results, operand),
|
|
expect_current_custom_calls=[f"{prefix}solver_syevd_ffi"])
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
def test_gpu_eigh_solver_syev(self, dtype_name="f32"):
|
|
if not jtu.test_device_matches(["cuda"]):
|
|
self.skipTest("Unsupported platform")
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
size = 4
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-2, f64=1e-10, c64=1e-2, c128=1e-10)[dtype_name]
|
|
operand = CompatTest.eigh_input((size, size), dtype)
|
|
data = self.load_testdata(cuda_eigh_cusolver_syev.data_2024_09_30[dtype_name])
|
|
func = lambda: CompatTest.eigh_harness((size, size), dtype)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=partial(self.check_eigh_results, operand))
|
|
|
|
def test_tpu_Eigh(self):
|
|
self.skipTest(
|
|
"TODO(b/280668311): Change input matrix to not be ill-conditioned."
|
|
)
|
|
# For lax.linalg.eigh
|
|
shape = (8, 8)
|
|
dtype = np.float32
|
|
operand = CompatTest.eigh_input(shape, dtype)
|
|
func = lambda: CompatTest.eigh_harness(shape, dtype)
|
|
data = self.load_testdata(tpu_Eigh.data)
|
|
self.run_one_test(func, data, rtol=1e-3,
|
|
check_results=partial(self.check_eigh_results, operand))
|
|
|
|
@staticmethod
|
|
def lu_pivots_to_permutation_harness(shape):
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=np.int32), shape)
|
|
return lax.linalg.lu_pivots_to_permutation(operand, permutation_size=8)
|
|
|
|
def test_cuda_lu_pivots_to_permutation(self):
|
|
shape = (2, 3, 4)
|
|
func = lambda: CompatTest.lu_pivots_to_permutation_harness(shape)
|
|
data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2024_08_08)
|
|
self.run_one_test(func, data)
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}",
|
|
dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
def test_cuda_lu_lapack_getrf(self, dtype_name:str):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
shape = (3, 4)
|
|
func = lambda: CompatTest.lu_harness(shape, dtype)
|
|
data = self.load_testdata(cuda_lu_cusolver_getrf.data_2024_08_19[dtype_name])
|
|
self.run_one_test(func, data)
|
|
|
|
@staticmethod
|
|
def qr_harness(shape, dtype):
|
|
# In order to keep inputs small, we construct the input programmatically
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
|
return lax.linalg.qr(operand, full_matrices=True)
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
func = lambda: CompatTest.qr_harness((3, 3), dtype)
|
|
|
|
info = cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
|
|
data = self.load_testdata(info)
|
|
self.run_one_test(func, data, rtol=rtol)
|
|
|
|
# TODO(b/369826500): Remove legacy custom call test after mid March 2025.
|
|
data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name])
|
|
self.run_one_test(func, data, rtol=rtol,
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
# TODO(b/369826500): Remove legacy custom call test after mid March 2025.
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}_{batched}",
|
|
dtype_name=dtype_name, batched=batched)
|
|
for dtype_name in ("f32",)
|
|
# For batched qr we use cublas_geqrf_batched/hipblas_geqrf_batched.
|
|
for batched in ("batched", "unbatched"))
|
|
def test_gpu_qr_solver_geqrf_legacy(self, dtype_name, batched):
|
|
if jtu.test_device_matches(["rocm"]):
|
|
data = self.load_testdata(rocm_qr_hipsolver_geqrf.data_2024_08_05[batched])
|
|
prefix = "hip"
|
|
elif jtu.test_device_matches(["cuda"]):
|
|
data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched])
|
|
prefix = "cu"
|
|
else:
|
|
self.skipTest("Unsupported platform")
|
|
dtype = dict(f32=np.float32)[dtype_name]
|
|
rtol = dict(f32=1e-3)[dtype_name]
|
|
shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched]
|
|
func = lambda: CompatTest.qr_harness(shape, dtype)
|
|
self.run_one_test(func, data, rtol=rtol, expect_current_custom_calls=[
|
|
f"{prefix}solver_geqrf_ffi", f"{prefix}solver_orgqr_ffi"])
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
def test_gpu_qr_solver_geqrf(self, dtype_name="f32"):
|
|
if not jtu.test_device_matches(["cuda"]):
|
|
self.skipTest("Unsupported platform")
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
shape = (2, 3, 3)
|
|
func = lambda: CompatTest.qr_harness(shape, dtype)
|
|
data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2024_09_26[dtype_name])
|
|
self.run_one_test(func, data, rtol=rtol)
|
|
|
|
def test_tpu_Qr(self):
|
|
# For lax.linalg.qr
|
|
func = lambda: CompatTest.qr_harness((3, 3), np.float32)
|
|
data = self.load_testdata(tpu_Qr.data_2023_03_17)
|
|
self.run_one_test(func, data, rtol=1e-3, atol=1e-3)
|
|
|
|
@staticmethod
|
|
def lu_harness(shape, dtype):
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
|
return lax.linalg.lu(operand)
|
|
|
|
def check_lu_results(self, operand, res_now, res_expected, *,
|
|
dtype, rtol=None, atol=None):
|
|
# Same checker as in linalg_test.py
|
|
del res_expected # we do not check against expected
|
|
lu_now, pivots_now, _ = res_now
|
|
|
|
n, m = operand.shape
|
|
self.assertEqual(n, m)
|
|
l = np.tril(lu_now, -1) + np.eye(n, dtype=dtype)
|
|
u = np.triu(lu_now)
|
|
operand_copy = operand.copy()
|
|
for i in range(n):
|
|
operand_copy[[i, pivots_now[i]],] = operand_copy[[pivots_now[i], i],]
|
|
self.assertAllClose(operand_copy, np.matmul(l, u), rtol=rtol, atol=atol)
|
|
|
|
def test_tpu_Lu(self):
|
|
# For lax.linalg.lu on TPU.
|
|
shape = (3, 3)
|
|
dtype = np.float32
|
|
func = lambda: CompatTest.lu_harness(shape, dtype)
|
|
data = self.load_testdata(tpu_Lu.data_2023_03_21)
|
|
operand = np.reshape(np.arange(math.prod(shape), dtype=dtype), shape)
|
|
self.run_one_test(func, data, rtol=1e-3,
|
|
check_results=partial(self.check_lu_results, operand,
|
|
dtype=dtype))
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}",
|
|
dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
def test_cpu_lu_lapack_getrf(self, dtype_name:str):
|
|
# For lax.linalg.lu on CPU.
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
shape = (3, 3)
|
|
func = lambda: CompatTest.lu_harness(shape, dtype)
|
|
operand = np.reshape(np.arange(math.prod(shape), dtype=dtype), shape)
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
info = cpu_lu_lapack_getrf.data_2024_05_31[dtype_name]
|
|
data = self.load_testdata(info)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=partial(self.check_lu_results, operand,
|
|
dtype=dtype))
|
|
|
|
# TODO(b/357034884): Remove legacy custom call test after mid March 2025.
|
|
legacy_data = self.load_testdata(
|
|
cpu_lu_lapack_getrf.data_2023_06_14[dtype_name])
|
|
self.run_one_test(func, legacy_data, rtol=rtol, atol=atol,
|
|
check_results=partial(self.check_lu_results, operand,
|
|
dtype=dtype),
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
def check_svd_results(self, input, res_run, res_exp,
|
|
rtol=None, atol=None):
|
|
# Following linalg_test.testSVD
|
|
def compute_max_backward_error(operand, reconstructed_operand):
|
|
error_norm = np.linalg.norm(operand - reconstructed_operand,
|
|
axis=(-2, -1))
|
|
backward_error = (error_norm /
|
|
np.linalg.norm(operand, axis=(-2, -1)))
|
|
max_backward_error = np.amax(backward_error)
|
|
return max_backward_error
|
|
|
|
tol = 80 * jnp.finfo(input.dtype).eps
|
|
reconstruction_tol = 2 * tol
|
|
unitariness_tol = tol
|
|
|
|
out = res_run
|
|
a = input
|
|
compute_uv = True
|
|
full_matrices = True
|
|
b, m, n = input.shape
|
|
T = lambda x: np.swapaxes(x, -1, -2)
|
|
|
|
if compute_uv:
|
|
# Check the reconstructed matrices
|
|
out = list(out)
|
|
out[1] = out[1].astype(out[0].dtype) # for strict dtype promotion.
|
|
if m and n:
|
|
if full_matrices:
|
|
k = min(m, n)
|
|
if m < n:
|
|
max_backward_error = compute_max_backward_error(
|
|
a, np.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :]))
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
|
else:
|
|
max_backward_error = compute_max_backward_error(
|
|
a, np.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2]))
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
|
else:
|
|
max_backward_error = compute_max_backward_error(
|
|
a, np.matmul(out[1][..., None, :] * out[0], out[2]))
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
|
|
|
# Check the unitary properties of the singular vector matrices.
|
|
unitary_mat = np.real(np.matmul(np.conj(T(out[0])), out[0]))
|
|
eye_slice = np.eye(out[0].shape[-1], dtype=unitary_mat.dtype)
|
|
self.assertAllClose(np.broadcast_to(eye_slice, (b,) + eye_slice.shape),
|
|
unitary_mat, rtol=unitariness_tol,
|
|
atol=unitariness_tol)
|
|
if m >= n:
|
|
unitary_mat = np.real(np.matmul(np.conj(T(out[2])), out[2]))
|
|
eye_slice = np.eye(out[2].shape[-1], dtype=unitary_mat.dtype)
|
|
self.assertAllClose(np.broadcast_to(eye_slice, (b,) + eye_slice.shape),
|
|
unitary_mat, rtol=unitariness_tol,
|
|
atol=unitariness_tol)
|
|
else:
|
|
unitary_mat = np.real(np.matmul(out[2], np.conj(np.T(out[2]))))
|
|
eye_slice = np.eye(out[2].shape[-2], dtype=unitary_mat.dtype)
|
|
self.assertAllClose(np.broadcast_to(eye_slice, (b,) + eye_slice.shape),
|
|
unitary_mat, rtol=unitariness_tol,
|
|
atol=unitariness_tol)
|
|
else:
|
|
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False),
|
|
np.asarray(out), atol=1e-4, rtol=1e-4))
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}",
|
|
dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
@jax.default_matmul_precision("float32")
|
|
def test_cpu_schur_lapack_gees(self, dtype_name="f32"):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
shape = (4, 4)
|
|
input = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
|
|
|
|
def func(input):
|
|
return lax.linalg.schur(input, compute_schur_vectors=True)
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
def check_schur_results(res_run, res_expected, *, rtol, atol):
|
|
t_run, s_run = res_run
|
|
self.assertAllClose(input, s_run @ t_run @ np.conj(s_run.T),
|
|
rtol=rtol, atol=atol)
|
|
|
|
info = cpu_schur_lapack_gees.data_2024_11_29[dtype_name]
|
|
data = self.load_testdata(info)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=check_schur_results)
|
|
data = self.load_testdata(cpu_schur_lapack_gees.data_2023_07_16[dtype_name])
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=check_schur_results,
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
@jax.default_matmul_precision("float32")
|
|
def test_cpu_svd_lapack_gesdd(self, dtype_name="f32"):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
def func(operand):
|
|
return lax.linalg.svd(operand, full_matrices=True, compute_uv=True)
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
info = cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
|
|
data = self.load_testdata(info)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=partial(self.check_svd_results,
|
|
*data.inputs))
|
|
|
|
data = self.load_testdata(cpu_svd_lapack_gesdd.data_2023_06_19[dtype_name])
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=partial(self.check_svd_results,
|
|
*data.inputs),
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}_algorithm={algorithm_name}",
|
|
dtype_name=dtype_name, algorithm_name=algorithm_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128")
|
|
for algorithm_name in ("qr", "jacobi"))
|
|
@jax.default_matmul_precision("float32")
|
|
def test_gpu_svd_solver_gesvd(self, dtype_name, algorithm_name):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
def func(operand):
|
|
return lax.linalg.svd(operand, full_matrices=True, compute_uv=True,
|
|
algorithm=algorithm)
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
algorithm = dict(qr=lax.linalg.SvdAlgorithm.QR,
|
|
jacobi=lax.linalg.SvdAlgorithm.JACOBI)[algorithm_name]
|
|
|
|
info = cuda_svd_cusolver_gesvd.data_2024_10_08[algorithm_name][dtype_name]
|
|
data = self.load_testdata(info)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=partial(self.check_svd_results,
|
|
*data.inputs))
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128")])
|
|
@jax.default_matmul_precision("float32")
|
|
def test_cpu_triangular_solve_blas_trsm(self, dtype_name="f32"):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
a_shape = (4, 4)
|
|
a = np.arange(math.prod(a_shape), dtype=dtype).reshape(a_shape)
|
|
a = np.tril(a + 5 * np.eye(a.shape[-1], dtype=a.dtype))
|
|
b_shape = (4, 5)
|
|
b = np.arange(math.prod(b_shape), dtype=dtype).reshape(b_shape)
|
|
left_side = True
|
|
def func(a, b):
|
|
return lax.linalg.triangular_solve(a, b, lower=True,
|
|
transpose_a=False,
|
|
conjugate_a=False, unit_diagonal=False,
|
|
left_side=left_side)
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
def check_triangular_solve_results(res_run, res_expected, *, rtol, atol):
|
|
x, = res_run
|
|
matmul = partial(jnp.matmul, precision=lax.Precision.HIGHEST)
|
|
y = matmul(a, x) if left_side else matmul(x, a)
|
|
self.assertArraysAllClose(y, jnp.broadcast_to(b, y.shape), rtol=rtol, atol=atol)
|
|
|
|
info = cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name]
|
|
data = self.load_testdata(info)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=check_triangular_solve_results)
|
|
|
|
data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2023_07_16[dtype_name])
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
check_results=check_triangular_solve_results,
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
@jax.default_matmul_precision("float32")
|
|
def test_cpu_hessenberg_lapack_gehrd(self, dtype_name="f32"):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
shape = (2, 4, 4)
|
|
input_data = jtu.rand_default(self.rng())(shape, dtype)
|
|
# del input_data # Input is in the testdata, here for readability
|
|
def func():
|
|
return lax.linalg.hessenberg(input_data)
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
info = cpu_hessenberg_lapack_gehrd.data_2024_08_31[dtype_name]
|
|
data = self.load_testdata(info)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
|
|
|
data = self.load_testdata(
|
|
cpu_hessenberg_lapack_gehrd.data_2024_08_30[dtype_name]
|
|
)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
@jax.default_matmul_precision("float32")
|
|
def test_cpu_tridiagonal_lapack_sytrd_hetrd(self, dtype_name="f32"):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
shape = (2, 4, 4)
|
|
input_data = jtu.rand_default(self.rng())(shape, dtype)
|
|
# del input_data # Input is in the testdata, here for readability
|
|
def func():
|
|
return lax.linalg.tridiagonal(input_data, lower=True)
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
info = cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01[dtype_name]
|
|
data = self.load_testdata(info)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
|
|
|
data = self.load_testdata(
|
|
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03[dtype_name]
|
|
)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
@jax.default_matmul_precision("float32")
|
|
def test_cpu_tridiagonal_solve_lapack_gtsv(self, dtype_name):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
data = self.load_testdata(
|
|
cpu_tridiagonal_solve_lapack_gtsv.data_2025_01_09[dtype_name]
|
|
)
|
|
self.run_one_test(lax.linalg.tridiagonal_solve, data, rtol=rtol, atol=atol)
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
@jax.default_matmul_precision("float32")
|
|
def test_gpu_tridiagonal_solver_sytrd(self, dtype_name):
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
def func(x):
|
|
return lax.linalg.tridiagonal(x, lower=True)
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
data = self.load_testdata(
|
|
cuda_tridiagonal_cusolver_sytrd.data_2025_01_09[dtype_name]
|
|
)
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
|
|
|
def test_approx_top_k(self):
|
|
def func():
|
|
x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0])
|
|
y = lax.approx_max_k(x, 3)
|
|
z = lax.approx_max_k(x, 3)
|
|
return y + z
|
|
data = self.load_testdata(tpu_ApproxTopK.data_2023_05_16)
|
|
self.run_one_test(func, data)
|
|
|
|
def test_cuda_threefry2x32(self):
|
|
with config.threefry_partitionable(False):
|
|
def func(x):
|
|
return jax.random.uniform(x, (2, 4), dtype=np.float32)
|
|
|
|
data = self.load_testdata(cuda_threefry2x32.data_2024_07_30)
|
|
self.run_one_test(func, data)
|
|
|
|
def test_sharding(self):
|
|
# Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU
|
|
if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2:
|
|
self.skipTest("Test runs only on TPU with at least 2 devices")
|
|
|
|
# Must use exactly 2 devices for expected outputs from ppermute
|
|
devices = jax.devices()[:2]
|
|
mesh = Mesh(devices, axis_names=('a'))
|
|
|
|
@partial(pjit.pjit,
|
|
in_shardings=(P('a', None),), out_shardings=P('a', None))
|
|
@partial(shard_map, mesh=mesh,
|
|
in_specs=(P('a', None),), out_specs=P('a', None))
|
|
def func(x): # b: f32[2, 4]
|
|
axis_size = lax.psum(1, 'a')
|
|
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
|
return lax.ppermute(x, 'a', perm=perm)
|
|
|
|
data = self.load_testdata(tpu_Sharding.data_2023_03_16)
|
|
with mesh:
|
|
self.run_one_test(func, data)
|
|
|
|
def test_tpu_stablehlo_dynamic_reduce_window_unary(self):
|
|
# stablehlo.dynamic_reduce_window is used temporarily on TPU for a
|
|
# reduce window with dynamic shapes.
|
|
# See https://github.com/openxla/stablehlo/issues/1258 for the long term.
|
|
# The inputs are already in the test data, here only for readability.
|
|
shape = (3, 4)
|
|
_ = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
def func(x):
|
|
return jnp.cumsum(x, axis=0)
|
|
|
|
data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17)
|
|
self.run_one_test(
|
|
func, data,
|
|
polymorphic_shapes=("b, ...",),
|
|
# Recent serializations also include shape_assertion, tested with dynamic_top_k
|
|
expect_current_custom_calls=["stablehlo.dynamic_reduce_window", "shape_assertion"])
|
|
|
|
def test_tpu_stablehlo_dynamic_reduce_window_variadic(self):
|
|
# stablehlo.dynamic_reduce_window is used temporarily on TPU for a
|
|
# reduce window with dynamic shapes.
|
|
# See https://github.com/openxla/stablehlo/issues/1258 for the long term.
|
|
# The inputs are already in the test data, here only for readability.
|
|
shape = (3, 4)
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
y = 100 + np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
|
|
_ = (x, y)
|
|
def func(x, y): # x: f32[b, 2] y: i32[b, 2]
|
|
return lax.reduce_window(
|
|
(x, y), (np.array(1., np.float32), np.array(2, np.int32)),
|
|
lambda xy0, xy1: (lax.add(xy0[0], xy1[0]),
|
|
lax.sub(xy0[1], xy1[1])),
|
|
(2, x.shape[0]), (1, 1), "VALID")
|
|
|
|
data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17)
|
|
self.run_one_test(
|
|
func, data,
|
|
polymorphic_shapes=("b, ...", "b, ..."),
|
|
# Recent serializations also include shape_assertion, tested with dynamic_top_k
|
|
expect_current_custom_calls=["stablehlo.dynamic_reduce_window", "shape_assertion"])
|
|
|
|
@jtu.ignore_warning(
|
|
category=FutureWarning,
|
|
message="Raw arrays as random keys to jax.random functions are deprecated")
|
|
def test_stablehlo_dynamic_rbg_bit_generator(self):
|
|
# stablehlo.dynamic_rbg_bit_generator is used temporarily for a
|
|
# rbg_bit_generator with dynamic shapes.
|
|
# See https://github.com/openxla/stablehlo/issues/1344 for the long term.
|
|
key = np.arange(42, 42+4, dtype=np.uint32)
|
|
a_shape = (2, 3)
|
|
a = np.arange(math.prod(a_shape), dtype=np.float32).reshape(a_shape)
|
|
inputs = (key, a)
|
|
del inputs # already in the test data, here only for readability.
|
|
|
|
def func(key, a): # a is only used for its shape
|
|
return jax.random.key_data(jax.random.split(key, a.shape[0] * a.shape[1]))
|
|
|
|
# Note that the test currently checks that the generated sequence is the
|
|
# same. According to the StableHLO spec: "The output is guaranteed to be
|
|
# deterministic function of initial_state, but it is not guaranteed to be
|
|
# deterministic between implementations"
|
|
# See https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rng_bit_generator
|
|
# This test will fail when the implementation changes. We expect this to
|
|
# be rare, and most users may expect the RNG sequence to be the same
|
|
# upon reloading of a saved model.
|
|
# In case of an intended change in behavior we will have the option to
|
|
# replace this strict check with something else.
|
|
data = self.load_testdata(stablehlo_dynamic_rng_bit_generator.data_2023_06_17)
|
|
|
|
with config.default_prng_impl("unsafe_rbg"):
|
|
self.run_one_test(
|
|
func, data, polymorphic_shapes=(None, "b0, b1"),
|
|
# Recent serializations also include shape_assertion, tested with dynamic_top_k
|
|
expect_current_custom_calls=["stablehlo.dynamic_rng_bit_generator", "shape_assertion"])
|
|
|
|
def test_stablehlo_dynamic_top_k(self):
|
|
# stablehlo.dynamic_top_k is used temporarily for a top_k with dynamism
|
|
a = np.arange(12, dtype=np.float32).reshape((4, 3))
|
|
|
|
def func(a):
|
|
return lax.top_k(a, k=a.shape[-1] - 1)
|
|
|
|
data = self.load_testdata(stablehlo_dynamic_top_k.data_2023_07_16)
|
|
def check_top_k_results(res_run, res_expected, *, rtol, atol):
|
|
# The order of the results may be different, but should be the same ones
|
|
values_expected, _ = res_expected
|
|
values_run, indices_run = res_run
|
|
# Check that indices are correct
|
|
self.assertAllClose(values_run,
|
|
a[np.arange(a.shape[0]).reshape(a.shape[0], 1),
|
|
indices_run], atol=atol, rtol=rtol)
|
|
self.assertAllClose(np.sort(values_run), np.sort(values_expected),
|
|
atol=atol, rtol=rtol)
|
|
|
|
self.run_one_test(func, data,
|
|
polymorphic_shapes=("_, b",),
|
|
check_results=check_top_k_results,
|
|
# Recent serializations also include shape_assertion
|
|
expect_current_custom_calls=["stablehlo.dynamic_top_k", "shape_assertion"])
|
|
|
|
# Now a test with serialization version 7, including shape_assertion
|
|
data_2 = self.load_testdata(stablehlo_dynamic_top_k.data_2023_08_11)
|
|
self.run_one_test(func, data_2,
|
|
polymorphic_shapes=("_, b",),
|
|
check_results=check_top_k_results)
|
|
|
|
def test_dynamic_approx_top_k(self):
|
|
# stablehlo.dynamic_approx_top_k is used temporarily for a approx_top_k
|
|
# with dynamism
|
|
# This is the input that was used to generate the test_data
|
|
_ = np.arange(24, dtype=np.float32)
|
|
|
|
def func(a): # a: f32[b + 4]
|
|
return lax.approx_max_k(a, k=a.shape[0] - 4)
|
|
|
|
data = self.load_testdata(stablehlo_dynamic_approx_top_k.data_2024_05_30)
|
|
|
|
def check_top_k_results(res_run, res_expected, *, rtol, atol):
|
|
a = data.inputs[0]
|
|
# The order of the results may be different, but should be the same ones
|
|
values_expected, _ = res_expected
|
|
values_run, indices_run = res_run
|
|
# Check that indices are correct
|
|
self.assertAllClose(
|
|
values_run,
|
|
a[indices_run],
|
|
atol=atol,
|
|
rtol=rtol,
|
|
)
|
|
self.assertAllClose(
|
|
np.sort(values_run), np.sort(values_expected), atol=atol, rtol=rtol
|
|
)
|
|
|
|
self.run_one_test(
|
|
func,
|
|
data,
|
|
polymorphic_shapes=("b + 4,",),
|
|
check_results=check_top_k_results,
|
|
expect_current_custom_calls=[
|
|
"stablehlo.dynamic_approx_top_k",
|
|
"shape_assertion",
|
|
],
|
|
)
|
|
|
|
|
|
@jtu.with_config(jax_use_shardy_partitioner=True)
|
|
class ShardyCompatTest(bctu.CompatTestBase):
|
|
def test_shardy_sharding_ops_with_different_meshes(self):
|
|
# Tests whether we can save and load a module with meshes that have the
|
|
# same axis sizes (and same order) but different axis names.
|
|
# Also tests "Sharding", "xla.sdy.GlobalToLocalShape",
|
|
# "xla.sdy.LocalToGlobalShape".
|
|
if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2:
|
|
self.skipTest("Test runs only on TPU with at least 2 devices")
|
|
|
|
# Must use exactly 2 devices for expected outputs from ppermute.
|
|
devices = jax.devices()[:2]
|
|
old_mesh = Mesh(devices, axis_names=('a'))
|
|
|
|
def func(x): # x: f32[4, 4]
|
|
@partial(shard_map, mesh=old_mesh,
|
|
in_specs=(P('a', None),), out_specs=P('a', None))
|
|
def shard_map_func(x): # b: f32[2, 4]
|
|
axis_size = lax.psum(1, 'a')
|
|
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
|
return lax.ppermute(x, 'a', perm=perm)
|
|
x = jax.lax.with_sharding_constraint(x, NS(old_mesh, P('a', None)))
|
|
return shard_map_func(x)
|
|
|
|
data = self.load_testdata(shardy_sharding_ops_with_different_meshes.data_2025_02_12)
|
|
with Mesh(devices, axis_names=('x')):
|
|
self.run_one_test(func, data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|