2023-03-15 23:09:59 -07:00
|
|
|
# Copyright 2023 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2023-11-17 02:04:49 -08:00
|
|
|
"""Tests for backwards compatibility of exporting code with custom calls.
|
2023-03-15 23:09:59 -07:00
|
|
|
|
2023-11-17 02:04:49 -08:00
|
|
|
See the export_back_compat_test_util module docstring for how to setup and
|
|
|
|
update these tests.
|
2023-03-15 23:09:59 -07:00
|
|
|
"""
|
|
|
|
import dataclasses
|
2023-03-17 11:19:15 -07:00
|
|
|
from functools import partial
|
2023-03-20 07:07:43 -07:00
|
|
|
import itertools
|
2023-04-13 11:48:11 -07:00
|
|
|
import math
|
2023-03-15 23:09:59 -07:00
|
|
|
|
2023-03-17 10:34:51 -07:00
|
|
|
from absl.testing import absltest, parameterized
|
2023-03-15 23:09:59 -07:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import jax
|
|
|
|
from jax import lax
|
2024-06-09 08:58:54 -07:00
|
|
|
from jax._src.export import _export
|
2024-05-02 05:37:41 -07:00
|
|
|
|
2023-11-17 02:04:49 -08:00
|
|
|
from jax._src.internal_test_util import export_back_compat_test_util as bctu
|
2023-06-22 02:37:39 -07:00
|
|
|
|
2025-04-08 00:09:27 -07:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import annotate_data_placement
|
2023-12-18 10:12:52 -08:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_cholesky_lapack_potrf
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_eig_lapack_geev
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_eigh_cusolver_syev
|
2024-07-24 23:52:09 -05:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import rocm_eigh_hipsolver_syev
|
2023-12-18 10:12:52 -08:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_eigh_lapack_syev
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_lu_lapack_getrf
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_qr_cusolver_geqrf
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_qr_lapack_geqrf
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_lapack_gees
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm
|
2024-10-01 09:19:13 -07:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_hessenberg_lapack_gehrd
|
2024-12-09 04:36:21 -08:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_tridiagonal_lapack_sytrd_hetrd
|
2025-01-10 08:56:14 -08:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cpu_tridiagonal_solve_lapack_gtsv
|
2023-12-18 10:12:52 -08:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32
|
2024-08-12 03:38:55 -07:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation
|
2024-08-21 05:07:58 -07:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf
|
Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.
This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:
1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.
2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.
Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.
PiperOrigin-RevId: 687106965
2024-10-17 17:56:33 -07:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_svd_cusolver_gesvd
|
2025-01-10 09:27:31 -08:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import cuda_tridiagonal_cusolver_sytrd
|
2023-12-18 10:12:52 -08:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Qr
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Sharding
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import tpu_stablehlo_dynamic_reduce_window
|
2025-03-03 06:00:51 -08:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import shardy_sharding_ops_with_different_meshes
|
2023-12-18 10:12:52 -08:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_rng_bit_generator
|
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_top_k
|
2024-06-05 09:51:02 -07:00
|
|
|
from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_approx_top_k
|
2023-03-20 07:07:43 -07:00
|
|
|
|
2023-03-17 11:19:15 -07:00
|
|
|
from jax.experimental import pjit
|
|
|
|
from jax.experimental.shard_map import shard_map
|
2023-03-15 23:09:59 -07:00
|
|
|
import jax.numpy as jnp
|
|
|
|
|
2023-03-17 11:19:15 -07:00
|
|
|
from jax.sharding import Mesh
|
|
|
|
from jax.sharding import PartitionSpec as P
|
2025-03-03 06:00:51 -08:00
|
|
|
from jax.sharding import NamedSharding as NS
|
2023-03-17 11:19:15 -07:00
|
|
|
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src import config
|
2023-03-15 23:09:59 -07:00
|
|
|
from jax._src import test_util as jtu
|
2024-05-16 02:01:35 -07:00
|
|
|
from jax._src.lib import cuda_versions
|
2023-06-01 10:29:12 -07:00
|
|
|
|
2023-03-15 23:09:59 -07:00
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
2024-07-31 08:09:28 -07:00
|
|
|
|
2024-05-16 02:01:35 -07:00
|
|
|
def _is_required_cusolver_version_satisfied(required_version):
|
|
|
|
if cuda_versions is None:
|
|
|
|
return False
|
|
|
|
return cuda_versions.cusolver_get_version() >= required_version
|
2023-03-17 10:15:47 -07:00
|
|
|
|
2024-07-31 08:09:28 -07:00
|
|
|
|
2024-05-02 05:37:41 -07:00
|
|
|
@jtu.with_config(jax_legacy_prng_key="allow",
|
2024-05-01 10:32:36 -07:00
|
|
|
jax_debug_key_reuse=False,
|
2024-05-02 05:37:41 -07:00
|
|
|
jax_include_full_tracebacks_in_locations=False,
|
2024-05-01 10:32:36 -07:00
|
|
|
jax_threefry_gpu_kernel_lowering=True)
|
2023-06-22 02:37:39 -07:00
|
|
|
class CompatTest(bctu.CompatTestBase):
|
2023-03-15 23:09:59 -07:00
|
|
|
def test_dummy(self):
|
2023-06-19 16:15:10 +03:00
|
|
|
# Tests the testing mechanism. Let this test run on all platforms
|
2023-06-22 02:37:39 -07:00
|
|
|
dummy_data = self.load_testdata(bctu.dummy_data_dict)
|
2023-03-17 10:15:47 -07:00
|
|
|
platform_dummy_data = dataclasses.replace(
|
2023-06-19 16:15:10 +03:00
|
|
|
dummy_data, platform=self.default_jax_backend())
|
2023-03-15 23:09:59 -07:00
|
|
|
self.run_one_test(jnp.sin, platform_dummy_data)
|
|
|
|
|
|
|
|
def test_detect_different_output(self):
|
|
|
|
# Test the detection mechanism. Let this test run on all platforms
|
2023-06-22 02:37:39 -07:00
|
|
|
dummy_data = self.load_testdata(bctu.dummy_data_dict)
|
2023-03-15 23:09:59 -07:00
|
|
|
platform_dummy_data = dataclasses.replace(
|
|
|
|
dummy_data,
|
2023-06-19 16:15:10 +03:00
|
|
|
platform=self.default_jax_backend(),
|
2023-03-17 10:15:47 -07:00
|
|
|
expected_outputs=(np.array(2.0, dtype=np.float32),))
|
2023-03-15 23:09:59 -07:00
|
|
|
with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"):
|
|
|
|
self.run_one_test(jnp.sin, platform_dummy_data)
|
|
|
|
|
|
|
|
def test_detect_different_custom_calls(self):
|
|
|
|
# Test the detection mechanism. Let this test run on all platforms
|
2023-06-22 02:37:39 -07:00
|
|
|
dummy_data = self.load_testdata(bctu.dummy_data_dict)
|
2023-03-15 23:09:59 -07:00
|
|
|
platform_dummy_data = dataclasses.replace(
|
|
|
|
dummy_data,
|
2023-06-19 16:15:10 +03:00
|
|
|
platform=self.default_jax_backend(),
|
2023-03-15 23:09:59 -07:00
|
|
|
custom_call_targets=["missing"])
|
2023-08-11 17:58:52 +03:00
|
|
|
with self.assertRaisesRegex(AssertionError, "Element counts were not equal"):
|
2023-03-15 23:09:59 -07:00
|
|
|
self.run_one_test(jnp.sin, platform_dummy_data)
|
|
|
|
|
2023-03-20 07:07:43 -07:00
|
|
|
def test_custom_call_coverage(self):
|
2023-06-17 10:33:29 -07:00
|
|
|
"""Tests that the back compat tests cover all the targets declared stable."""
|
2024-01-08 05:29:11 -08:00
|
|
|
targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
|
2024-07-25 09:58:41 -07:00
|
|
|
cpu_ffi_testdatas = [
|
|
|
|
cpu_cholesky_lapack_potrf.data_2024_05_31,
|
2024-08-23 03:20:55 -07:00
|
|
|
cpu_qr_lapack_geqrf.data_2024_08_22,
|
2024-08-28 03:53:07 -07:00
|
|
|
cpu_eig_lapack_geev.data_2024_08_19,
|
|
|
|
cpu_eigh_lapack_syev.data_2024_08_19,
|
2024-08-02 02:22:17 -07:00
|
|
|
cpu_lu_lapack_getrf.data_2024_05_31,
|
2024-12-06 06:49:15 -08:00
|
|
|
cpu_schur_lapack_gees.data_2024_11_29,
|
2024-12-11 02:21:56 -08:00
|
|
|
cpu_triangular_solve_blas_trsm.data_2024_12_02,
|
2024-08-13 02:40:52 -07:00
|
|
|
cpu_svd_lapack_gesdd.data_2024_08_13,
|
2024-10-01 09:19:13 -07:00
|
|
|
cpu_hessenberg_lapack_gehrd.data_2024_08_31,
|
2024-12-09 04:36:21 -08:00
|
|
|
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01,
|
2024-07-25 09:58:41 -07:00
|
|
|
]
|
2023-03-20 07:07:43 -07:00
|
|
|
# Add here all the testdatas that should cover the targets guaranteed
|
|
|
|
# stable
|
|
|
|
covering_testdatas = [
|
2024-07-25 09:58:41 -07:00
|
|
|
*cpu_ffi_testdatas,
|
2023-06-19 17:49:45 +03:00
|
|
|
cpu_cholesky_lapack_potrf.data_2023_06_19,
|
2023-06-19 17:17:55 +03:00
|
|
|
cpu_eig_lapack_geev.data_2023_06_19,
|
2023-06-19 16:22:26 +03:00
|
|
|
cpu_eigh_lapack_syev.data_2023_03_17,
|
2024-07-31 08:09:28 -07:00
|
|
|
cpu_qr_lapack_geqrf.data_2023_03_17,
|
2025-02-04 07:34:06 -08:00
|
|
|
cuda_threefry2x32.data_2024_07_30,
|
2023-06-23 07:24:41 -07:00
|
|
|
cpu_lu_lapack_getrf.data_2023_06_14,
|
2025-04-02 07:40:02 -07:00
|
|
|
cuda_lu_pivots_to_permutation.data_2025_04_01,
|
2024-08-21 05:07:58 -07:00
|
|
|
cuda_lu_cusolver_getrf.data_2024_08_19,
|
2024-10-04 07:25:59 -07:00
|
|
|
cuda_qr_cusolver_geqrf.data_2024_09_26,
|
2024-10-04 12:37:37 -07:00
|
|
|
cuda_eigh_cusolver_syev.data_2024_09_30,
|
Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.
This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:
1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.
2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.
Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.
PiperOrigin-RevId: 687106965
2024-10-17 17:56:33 -07:00
|
|
|
cuda_svd_cusolver_gesvd.data_2024_10_08,
|
2025-01-10 08:56:14 -08:00
|
|
|
cpu_tridiagonal_solve_lapack_gtsv.data_2025_01_09,
|
2025-01-10 09:27:31 -08:00
|
|
|
cuda_tridiagonal_cusolver_sytrd.data_2025_01_09,
|
2024-07-24 23:52:09 -05:00
|
|
|
rocm_eigh_hipsolver_syev.data_2024_08_05,
|
2023-07-16 14:31:45 +03:00
|
|
|
cpu_schur_lapack_gees.data_2023_07_16,
|
2023-06-20 00:01:40 +03:00
|
|
|
cpu_svd_lapack_gesdd.data_2023_06_19,
|
2023-07-16 12:09:39 +03:00
|
|
|
cpu_triangular_solve_blas_trsm.data_2023_07_16,
|
2024-10-01 09:19:13 -07:00
|
|
|
cpu_hessenberg_lapack_gehrd.data_2024_08_30,
|
2024-12-09 04:36:21 -08:00
|
|
|
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03,
|
2023-05-09 22:30:50 -07:00
|
|
|
tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17,
|
2023-05-16 13:07:23 -07:00
|
|
|
tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17,
|
2023-06-17 10:33:29 -07:00
|
|
|
tpu_ApproxTopK.data_2023_05_16,
|
|
|
|
tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17,
|
2023-06-19 00:38:59 -07:00
|
|
|
tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17,
|
2023-07-28 06:19:04 -07:00
|
|
|
stablehlo_dynamic_rng_bit_generator.data_2023_06_17,
|
|
|
|
stablehlo_dynamic_top_k.data_2023_07_16,
|
2023-08-11 17:58:52 +03:00
|
|
|
stablehlo_dynamic_top_k.data_2023_08_11, # with shape_assertion
|
2024-06-05 09:51:02 -07:00
|
|
|
stablehlo_dynamic_approx_top_k.data_2024_05_30,
|
2025-04-08 00:09:27 -07:00
|
|
|
annotate_data_placement.data_2025_04_07_tpu,
|
|
|
|
annotate_data_placement.data_2025_04_07_cuda,
|
2023-07-28 06:19:04 -07:00
|
|
|
]
|
2023-06-23 07:24:41 -07:00
|
|
|
# Some of the above are nested structures.
|
2023-03-20 07:07:43 -07:00
|
|
|
covering_testdatas = itertools.chain(
|
2023-06-19 16:15:10 +03:00
|
|
|
*[self.load_testdata_nested(d) for d in covering_testdatas])
|
2023-03-20 07:07:43 -07:00
|
|
|
covered_targets = set()
|
|
|
|
for data in covering_testdatas:
|
2023-06-22 02:37:39 -07:00
|
|
|
self.assertIsInstance(data, bctu.CompatTestData)
|
2023-03-20 07:07:43 -07:00
|
|
|
covered_targets = covered_targets.union(data.custom_call_targets)
|
|
|
|
|
2023-06-16 23:58:37 -07:00
|
|
|
covered_targets = covered_targets.union({
|
2024-01-08 04:47:36 -08:00
|
|
|
"tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py
|
2023-06-30 02:48:51 -07:00
|
|
|
"tpu_custom_call", # tested separately
|
2024-05-02 05:37:41 -07:00
|
|
|
"__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py
|
2024-10-04 07:25:59 -07:00
|
|
|
# The following require ROCm to test
|
|
|
|
"hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi",
|
2024-10-04 12:37:37 -07:00
|
|
|
"hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", "hipsolver_syevd_ffi",
|
Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.
This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:
1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.
2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.
Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.
PiperOrigin-RevId: 687106965
2024-10-17 17:56:33 -07:00
|
|
|
"hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi",
|
2023-06-22 01:05:23 -07:00
|
|
|
})
|
2023-03-20 07:07:43 -07:00
|
|
|
not_covered = targets_to_cover.difference(covered_targets)
|
2023-07-28 06:19:04 -07:00
|
|
|
self.assertEmpty(not_covered,
|
|
|
|
msg=("The following custom call targets are declared "
|
|
|
|
"stable but are not covered by any tests: "
|
|
|
|
f"{not_covered}"))
|
2023-03-20 07:07:43 -07:00
|
|
|
|
2023-06-19 17:49:45 +03:00
|
|
|
def cholesky_input(self, shape, dtype):
|
|
|
|
a = jtu.rand_default(self.rng())(shape, dtype)
|
|
|
|
return np.matmul(a, np.conj(np.swapaxes(a, -1, -2)))
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"):
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
2023-06-19 17:49:45 +03:00
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
shape = (4, 4)
|
|
|
|
input = self.cholesky_input(shape, dtype)
|
|
|
|
del input # Input is in the testdata, here for readability
|
|
|
|
func = lax.linalg.cholesky
|
|
|
|
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
|
2025-02-10 08:26:52 -08:00
|
|
|
info = cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name]
|
|
|
|
data = self.load_testdata(info)
|
2023-06-19 17:49:45 +03:00
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
2025-02-10 08:26:52 -08:00
|
|
|
|
|
|
|
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name])
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
2023-06-19 17:49:45 +03:00
|
|
|
|
2023-06-19 17:17:55 +03:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
def test_cpu_eig_lapack_geev(self, dtype_name="f32"):
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
2023-06-19 17:17:55 +03:00
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
shape = (4, 4)
|
|
|
|
def func():
|
|
|
|
# Compute the inputs to simplify the harness
|
|
|
|
input = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape)
|
|
|
|
return lax.linalg.eig(input,
|
|
|
|
compute_left_eigenvectors=True,
|
|
|
|
compute_right_eigenvectors=True)
|
|
|
|
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
|
|
|
|
def check_eig_results(res_run, res_expected, *, rtol, atol):
|
|
|
|
# Test ported from tests.linlag_test.testEig
|
|
|
|
# Norm, adjusted for dimension and type.
|
|
|
|
inner_dimension = shape[-1]
|
|
|
|
operand = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
|
|
|
|
def norm(x):
|
|
|
|
norm = np.linalg.norm(x, axis=(-2, -1))
|
|
|
|
return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps)
|
|
|
|
|
|
|
|
def check_right_eigenvectors(a, w, vr):
|
|
|
|
self.assertTrue(
|
|
|
|
np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))
|
|
|
|
|
|
|
|
def check_left_eigenvectors(a, w, vl):
|
|
|
|
rank = len(a.shape)
|
|
|
|
aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
|
|
|
|
wC = jnp.conj(w)
|
|
|
|
check_right_eigenvectors(aH, wC, vl)
|
|
|
|
|
|
|
|
def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
|
|
|
|
closest_diff = min(abs(eigenvalues_array - eigenvalue))
|
|
|
|
self.assertAllClose(
|
|
|
|
closest_diff,
|
|
|
|
np.array(0., closest_diff.dtype),
|
|
|
|
atol=atol, rtol=rtol)
|
|
|
|
|
|
|
|
all_w_run, all_w_exp = res_run[0], res_expected[0]
|
|
|
|
for idx in itertools.product(*map(range, operand.shape[:-2])):
|
|
|
|
w_run, w_exp = all_w_run[idx], all_w_exp[idx]
|
|
|
|
for i in range(inner_dimension):
|
|
|
|
check_eigenvalue_is_in_array(w_run[i], w_exp)
|
|
|
|
check_eigenvalue_is_in_array(w_exp[i], w_run)
|
|
|
|
|
|
|
|
check_left_eigenvectors(operand, all_w_run, res_run[1])
|
|
|
|
check_right_eigenvectors(operand, all_w_run, res_run[2])
|
|
|
|
|
2025-02-10 08:26:52 -08:00
|
|
|
info = cpu_eig_lapack_geev.data_2024_08_19[dtype_name]
|
|
|
|
data = self.load_testdata(info)
|
2023-06-19 17:17:55 +03:00
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=check_eig_results)
|
2025-02-10 08:26:52 -08:00
|
|
|
data = self.load_testdata(cpu_eig_lapack_geev.data_2023_06_19[dtype_name])
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=check_eig_results,
|
|
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
2023-06-19 17:17:55 +03:00
|
|
|
|
2023-03-17 10:34:51 -07:00
|
|
|
@staticmethod
|
2023-04-12 11:44:48 -07:00
|
|
|
def eigh_input(shape, dtype):
|
2023-03-17 10:34:51 -07:00
|
|
|
# In order to keep inputs small, we construct the input programmatically
|
2023-04-13 11:48:11 -07:00
|
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
2023-03-17 10:34:51 -07:00
|
|
|
# Make operand self-adjoint
|
|
|
|
operand = (operand + jnp.conj(jnp.swapaxes(operand, -1, -2))) / 2.
|
2023-04-12 11:44:48 -07:00
|
|
|
return operand
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def eigh_harness(shape, dtype):
|
|
|
|
operand = CompatTest.eigh_input(shape, dtype)
|
2023-03-17 10:34:51 -07:00
|
|
|
return lax.linalg.eigh(jnp.tril(operand), lower=True, symmetrize_input=False)
|
|
|
|
|
2023-04-12 11:44:48 -07:00
|
|
|
def check_eigh_results(self, operand, res_now, res_expected, *,
|
2023-06-12 16:15:15 -07:00
|
|
|
rtol, atol=None):
|
2023-04-12 11:44:48 -07:00
|
|
|
v_now, w_now = res_now
|
|
|
|
_, w_expected = res_expected
|
|
|
|
n, m = operand.shape
|
|
|
|
assert n == m
|
2023-04-28 21:22:07 -07:00
|
|
|
assert v_now.shape == operand.shape
|
|
|
|
assert w_now.shape == (n,)
|
2023-04-12 11:44:48 -07:00
|
|
|
self.assertLessEqual(
|
|
|
|
np.linalg.norm(np.eye(n) - np.matmul(np.conj(np.swapaxes(v_now, -1, -2)), v_now)),
|
|
|
|
rtol)
|
2023-04-28 21:22:07 -07:00
|
|
|
# w_now : f64[n] while v_now: c128[n, n]
|
|
|
|
w_now_like_v = w_now[np.newaxis, :].astype(v_now.dtype)
|
|
|
|
self.assertLessEqual(
|
|
|
|
np.linalg.norm(np.matmul(operand, v_now) - w_now_like_v * v_now),
|
|
|
|
rtol * np.linalg.norm(operand))
|
2023-06-12 16:15:15 -07:00
|
|
|
self.assertAllClose(w_expected, w_now, rtol=rtol, atol=atol)
|
2023-04-12 11:44:48 -07:00
|
|
|
|
2023-03-17 10:34:51 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
2023-06-19 16:22:26 +03:00
|
|
|
def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"):
|
2023-03-20 07:07:43 -07:00
|
|
|
# For lax.linalg.eigh
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
2023-03-17 10:34:51 -07:00
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
2023-04-12 11:44:48 -07:00
|
|
|
size = 8
|
|
|
|
operand = CompatTest.eigh_input((size, size), dtype)
|
2023-03-17 10:34:51 -07:00
|
|
|
func = lambda: CompatTest.eigh_harness((8, 8), dtype)
|
2023-04-12 11:44:48 -07:00
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
2023-06-12 16:15:15 -07:00
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
2024-10-04 12:37:37 -07:00
|
|
|
|
|
|
|
info = cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]
|
|
|
|
data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
|
2023-06-12 16:15:15 -07:00
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
2023-04-12 11:44:48 -07:00
|
|
|
check_results=partial(self.check_eigh_results, operand))
|
2024-10-04 12:37:37 -07:00
|
|
|
|
|
|
|
# Legacy custom call test
|
|
|
|
data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name])
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=partial(self.check_eigh_results, operand),
|
|
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
2023-03-17 10:34:51 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}_{variant}",
|
|
|
|
dtype_name=dtype_name, variant=variant)
|
|
|
|
for dtype_name in ("f32", "f64")
|
|
|
|
# We use different custom calls for sizes <= 32
|
|
|
|
for variant in ["syevj", "syevd"])
|
2024-10-04 12:37:37 -07:00
|
|
|
def test_gpu_eigh_solver_syev_legacy(self, dtype_name="f32", variant="syevj"):
|
2024-01-08 04:47:36 -08:00
|
|
|
if not config.enable_x64.value and dtype_name == "f64":
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
2024-10-04 12:37:37 -07:00
|
|
|
if jtu.test_device_matches(["rocm"]):
|
|
|
|
data = self.load_testdata(rocm_eigh_hipsolver_syev.data_2024_08_05[f"{dtype_name}_{variant}"])
|
|
|
|
prefix = "hip"
|
|
|
|
elif jtu.test_device_matches(["cuda"]):
|
2024-07-24 23:52:09 -05:00
|
|
|
if _is_required_cusolver_version_satisfied(11600):
|
|
|
|
# The underlying problem is that this test assumes the workspace size can be
|
|
|
|
# queried from an older version of cuSOLVER and then be used in a newer one.
|
|
|
|
self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized")
|
|
|
|
data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"])
|
2024-10-04 12:37:37 -07:00
|
|
|
prefix = "cu"
|
2024-07-24 23:52:09 -05:00
|
|
|
else:
|
|
|
|
self.skipTest("Unsupported platform")
|
2023-03-20 07:07:43 -07:00
|
|
|
# For lax.linalg.eigh
|
2023-03-17 10:34:51 -07:00
|
|
|
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
|
|
|
|
size = dict(syevj=8, syevd=36)[variant]
|
2023-04-12 11:44:48 -07:00
|
|
|
rtol = dict(f32=1e-3, f64=1e-5)[dtype_name]
|
2023-06-14 13:00:00 -07:00
|
|
|
atol = dict(f32=1e-2, f64=1e-10)[dtype_name]
|
2023-04-12 11:44:48 -07:00
|
|
|
operand = CompatTest.eigh_input((size, size), dtype)
|
2023-03-17 10:34:51 -07:00
|
|
|
func = lambda: CompatTest.eigh_harness((size, size), dtype)
|
2024-10-04 12:37:37 -07:00
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=partial(self.check_eigh_results, operand),
|
|
|
|
expect_current_custom_calls=[f"{prefix}solver_syevd_ffi"])
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
def test_gpu_eigh_solver_syev(self, dtype_name="f32"):
|
|
|
|
if not jtu.test_device_matches(["cuda"]):
|
|
|
|
self.skipTest("Unsupported platform")
|
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
size = 4
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-2, f64=1e-10, c64=1e-2, c128=1e-10)[dtype_name]
|
|
|
|
operand = CompatTest.eigh_input((size, size), dtype)
|
|
|
|
data = self.load_testdata(cuda_eigh_cusolver_syev.data_2024_09_30[dtype_name])
|
|
|
|
func = lambda: CompatTest.eigh_harness((size, size), dtype)
|
2023-06-14 13:00:00 -07:00
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
2023-04-12 11:44:48 -07:00
|
|
|
check_results=partial(self.check_eigh_results, operand))
|
2023-03-17 10:34:51 -07:00
|
|
|
|
|
|
|
def test_tpu_Eigh(self):
|
2023-05-04 07:45:38 -07:00
|
|
|
self.skipTest(
|
|
|
|
"TODO(b/280668311): Change input matrix to not be ill-conditioned."
|
|
|
|
)
|
2023-03-20 07:07:43 -07:00
|
|
|
# For lax.linalg.eigh
|
2023-04-12 11:44:48 -07:00
|
|
|
shape = (8, 8)
|
|
|
|
dtype = np.float32
|
|
|
|
operand = CompatTest.eigh_input(shape, dtype)
|
|
|
|
func = lambda: CompatTest.eigh_harness(shape, dtype)
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_Eigh.data)
|
2023-04-12 11:44:48 -07:00
|
|
|
self.run_one_test(func, data, rtol=1e-3,
|
|
|
|
check_results=partial(self.check_eigh_results, operand))
|
2023-03-17 10:34:51 -07:00
|
|
|
|
2024-08-12 03:38:55 -07:00
|
|
|
@staticmethod
|
|
|
|
def lu_pivots_to_permutation_harness(shape):
|
|
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=np.int32), shape)
|
|
|
|
return lax.linalg.lu_pivots_to_permutation(operand, permutation_size=8)
|
|
|
|
|
|
|
|
def test_cuda_lu_pivots_to_permutation(self):
|
|
|
|
shape = (2, 3, 4)
|
|
|
|
func = lambda: CompatTest.lu_pivots_to_permutation_harness(shape)
|
2025-04-02 07:40:02 -07:00
|
|
|
data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2025_04_01)
|
2024-08-12 03:38:55 -07:00
|
|
|
self.run_one_test(func, data)
|
|
|
|
|
2024-08-21 05:07:58 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}",
|
|
|
|
dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
def test_cuda_lu_lapack_getrf(self, dtype_name:str):
|
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
shape = (3, 4)
|
|
|
|
func = lambda: CompatTest.lu_harness(shape, dtype)
|
2024-09-30 07:22:37 -07:00
|
|
|
data = self.load_testdata(cuda_lu_cusolver_getrf.data_2024_08_19[dtype_name])
|
|
|
|
self.run_one_test(func, data)
|
2024-08-21 05:07:58 -07:00
|
|
|
|
2023-03-20 07:07:43 -07:00
|
|
|
@staticmethod
|
|
|
|
def qr_harness(shape, dtype):
|
|
|
|
# In order to keep inputs small, we construct the input programmatically
|
2023-04-13 11:48:11 -07:00
|
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
2023-03-20 07:07:43 -07:00
|
|
|
return lax.linalg.qr(operand, full_matrices=True)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
2023-06-19 16:22:26 +03:00
|
|
|
def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"):
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
2023-03-20 07:07:43 -07:00
|
|
|
self.skipTest("Test disabled for x32 mode")
|
2024-10-04 07:25:59 -07:00
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
2023-03-20 07:07:43 -07:00
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
func = lambda: CompatTest.qr_harness((3, 3), dtype)
|
2024-10-04 07:25:59 -07:00
|
|
|
|
|
|
|
info = cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
|
|
|
|
data = self.load_testdata(info)
|
2023-03-20 07:07:43 -07:00
|
|
|
self.run_one_test(func, data, rtol=rtol)
|
|
|
|
|
2024-10-04 07:25:59 -07:00
|
|
|
# TODO(b/369826500): Remove legacy custom call test after mid March 2025.
|
|
|
|
data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name])
|
|
|
|
self.run_one_test(func, data, rtol=rtol,
|
|
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
def test_gpu_qr_solver_geqrf(self, dtype_name="f32"):
|
|
|
|
if not jtu.test_device_matches(["cuda"]):
|
|
|
|
self.skipTest("Unsupported platform")
|
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
shape = (2, 3, 3)
|
|
|
|
func = lambda: CompatTest.qr_harness(shape, dtype)
|
|
|
|
data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2024_09_26[dtype_name])
|
2023-03-20 07:07:43 -07:00
|
|
|
self.run_one_test(func, data, rtol=rtol)
|
|
|
|
|
|
|
|
def test_tpu_Qr(self):
|
|
|
|
# For lax.linalg.qr
|
|
|
|
func = lambda: CompatTest.qr_harness((3, 3), np.float32)
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_Qr.data_2023_03_17)
|
2023-10-23 02:22:23 -07:00
|
|
|
self.run_one_test(func, data, rtol=1e-3, atol=1e-3)
|
2023-03-20 07:07:43 -07:00
|
|
|
|
2023-03-21 16:25:43 -07:00
|
|
|
@staticmethod
|
|
|
|
def lu_harness(shape, dtype):
|
2023-04-13 11:48:11 -07:00
|
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
2023-03-21 16:25:43 -07:00
|
|
|
return lax.linalg.lu(operand)
|
|
|
|
|
2023-06-23 07:24:41 -07:00
|
|
|
def check_lu_results(self, operand, res_now, res_expected, *,
|
|
|
|
dtype, rtol=None, atol=None):
|
|
|
|
# Same checker as in linalg_test.py
|
|
|
|
del res_expected # we do not check against expected
|
|
|
|
lu_now, pivots_now, _ = res_now
|
|
|
|
|
|
|
|
n, m = operand.shape
|
|
|
|
self.assertEqual(n, m)
|
|
|
|
l = np.tril(lu_now, -1) + np.eye(n, dtype=dtype)
|
|
|
|
u = np.triu(lu_now)
|
|
|
|
operand_copy = operand.copy()
|
|
|
|
for i in range(n):
|
|
|
|
operand_copy[[i, pivots_now[i]],] = operand_copy[[pivots_now[i], i],]
|
|
|
|
self.assertAllClose(operand_copy, np.matmul(l, u), rtol=rtol, atol=atol)
|
|
|
|
|
2023-03-21 16:25:43 -07:00
|
|
|
def test_tpu_Lu(self):
|
2023-06-23 07:24:41 -07:00
|
|
|
# For lax.linalg.lu on TPU.
|
|
|
|
shape = (3, 3)
|
|
|
|
dtype = np.float32
|
|
|
|
func = lambda: CompatTest.lu_harness(shape, dtype)
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_Lu.data_2023_03_21)
|
2023-06-23 07:24:41 -07:00
|
|
|
operand = np.reshape(np.arange(math.prod(shape), dtype=dtype), shape)
|
|
|
|
self.run_one_test(func, data, rtol=1e-3,
|
|
|
|
check_results=partial(self.check_lu_results, operand,
|
|
|
|
dtype=dtype))
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}",
|
|
|
|
dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
def test_cpu_lu_lapack_getrf(self, dtype_name:str):
|
|
|
|
# For lax.linalg.lu on CPU.
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
2023-06-23 07:24:41 -07:00
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
shape = (3, 3)
|
|
|
|
func = lambda: CompatTest.lu_harness(shape, dtype)
|
|
|
|
operand = np.reshape(np.arange(math.prod(shape), dtype=dtype), shape)
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
2024-09-30 07:22:37 -07:00
|
|
|
info = cpu_lu_lapack_getrf.data_2024_05_31[dtype_name]
|
|
|
|
data = self.load_testdata(info)
|
2023-06-23 07:24:41 -07:00
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=partial(self.check_lu_results, operand,
|
|
|
|
dtype=dtype))
|
2024-09-30 07:22:37 -07:00
|
|
|
|
|
|
|
# TODO(b/357034884): Remove legacy custom call test after mid March 2025.
|
|
|
|
legacy_data = self.load_testdata(
|
|
|
|
cpu_lu_lapack_getrf.data_2023_06_14[dtype_name])
|
|
|
|
self.run_one_test(func, legacy_data, rtol=rtol, atol=atol,
|
|
|
|
check_results=partial(self.check_lu_results, operand,
|
|
|
|
dtype=dtype),
|
|
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
2023-03-21 16:25:43 -07:00
|
|
|
|
2023-06-20 00:01:40 +03:00
|
|
|
def check_svd_results(self, input, res_run, res_exp,
|
|
|
|
rtol=None, atol=None):
|
|
|
|
# Following linalg_test.testSVD
|
|
|
|
def compute_max_backward_error(operand, reconstructed_operand):
|
|
|
|
error_norm = np.linalg.norm(operand - reconstructed_operand,
|
|
|
|
axis=(-2, -1))
|
|
|
|
backward_error = (error_norm /
|
|
|
|
np.linalg.norm(operand, axis=(-2, -1)))
|
|
|
|
max_backward_error = np.amax(backward_error)
|
|
|
|
return max_backward_error
|
|
|
|
|
|
|
|
tol = 80 * jnp.finfo(input.dtype).eps
|
|
|
|
reconstruction_tol = 2 * tol
|
|
|
|
unitariness_tol = tol
|
|
|
|
|
|
|
|
out = res_run
|
|
|
|
a = input
|
|
|
|
compute_uv = True
|
|
|
|
full_matrices = True
|
|
|
|
b, m, n = input.shape
|
|
|
|
T = lambda x: np.swapaxes(x, -1, -2)
|
|
|
|
|
|
|
|
if compute_uv:
|
|
|
|
# Check the reconstructed matrices
|
|
|
|
out = list(out)
|
|
|
|
out[1] = out[1].astype(out[0].dtype) # for strict dtype promotion.
|
|
|
|
if m and n:
|
|
|
|
if full_matrices:
|
|
|
|
k = min(m, n)
|
|
|
|
if m < n:
|
|
|
|
max_backward_error = compute_max_backward_error(
|
|
|
|
a, np.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :]))
|
|
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
|
|
|
else:
|
|
|
|
max_backward_error = compute_max_backward_error(
|
|
|
|
a, np.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2]))
|
|
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
|
|
|
else:
|
|
|
|
max_backward_error = compute_max_backward_error(
|
|
|
|
a, np.matmul(out[1][..., None, :] * out[0], out[2]))
|
|
|
|
self.assertLess(max_backward_error, reconstruction_tol)
|
|
|
|
|
|
|
|
# Check the unitary properties of the singular vector matrices.
|
|
|
|
unitary_mat = np.real(np.matmul(np.conj(T(out[0])), out[0]))
|
|
|
|
eye_slice = np.eye(out[0].shape[-1], dtype=unitary_mat.dtype)
|
|
|
|
self.assertAllClose(np.broadcast_to(eye_slice, (b,) + eye_slice.shape),
|
|
|
|
unitary_mat, rtol=unitariness_tol,
|
|
|
|
atol=unitariness_tol)
|
|
|
|
if m >= n:
|
|
|
|
unitary_mat = np.real(np.matmul(np.conj(T(out[2])), out[2]))
|
|
|
|
eye_slice = np.eye(out[2].shape[-1], dtype=unitary_mat.dtype)
|
|
|
|
self.assertAllClose(np.broadcast_to(eye_slice, (b,) + eye_slice.shape),
|
|
|
|
unitary_mat, rtol=unitariness_tol,
|
|
|
|
atol=unitariness_tol)
|
|
|
|
else:
|
|
|
|
unitary_mat = np.real(np.matmul(out[2], np.conj(np.T(out[2]))))
|
|
|
|
eye_slice = np.eye(out[2].shape[-2], dtype=unitary_mat.dtype)
|
|
|
|
self.assertAllClose(np.broadcast_to(eye_slice, (b,) + eye_slice.shape),
|
|
|
|
unitary_mat, rtol=unitariness_tol,
|
|
|
|
atol=unitariness_tol)
|
|
|
|
else:
|
|
|
|
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False),
|
|
|
|
np.asarray(out), atol=1e-4, rtol=1e-4))
|
|
|
|
|
2024-12-06 06:49:15 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}",
|
|
|
|
dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
2023-07-16 14:31:45 +03:00
|
|
|
@jax.default_matmul_precision("float32")
|
|
|
|
def test_cpu_schur_lapack_gees(self, dtype_name="f32"):
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
2023-07-16 14:31:45 +03:00
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
shape = (4, 4)
|
|
|
|
input = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
|
|
|
|
|
|
|
|
def func(input):
|
|
|
|
return lax.linalg.schur(input, compute_schur_vectors=True)
|
|
|
|
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
|
|
|
|
def check_schur_results(res_run, res_expected, *, rtol, atol):
|
|
|
|
t_run, s_run = res_run
|
|
|
|
self.assertAllClose(input, s_run @ t_run @ np.conj(s_run.T),
|
|
|
|
rtol=rtol, atol=atol)
|
|
|
|
|
2025-02-10 08:26:52 -08:00
|
|
|
info = cpu_schur_lapack_gees.data_2024_11_29[dtype_name]
|
|
|
|
data = self.load_testdata(info)
|
2023-07-16 14:31:45 +03:00
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=check_schur_results)
|
2025-02-10 08:26:52 -08:00
|
|
|
data = self.load_testdata(cpu_schur_lapack_gees.data_2023_07_16[dtype_name])
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=check_schur_results,
|
|
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
2023-06-20 00:01:40 +03:00
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
|
|
def test_cpu_svd_lapack_gesdd(self, dtype_name="f32"):
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
2023-06-20 00:01:40 +03:00
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.
This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:
1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.
2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.
Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.
PiperOrigin-RevId: 687106965
2024-10-17 17:56:33 -07:00
|
|
|
def func(operand):
|
|
|
|
return lax.linalg.svd(operand, full_matrices=True, compute_uv=True)
|
2023-06-20 00:01:40 +03:00
|
|
|
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
|
Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.
This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:
1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.
2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.
Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.
PiperOrigin-RevId: 687106965
2024-10-17 17:56:33 -07:00
|
|
|
info = cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
|
|
|
|
data = self.load_testdata(info)
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=partial(self.check_svd_results,
|
|
|
|
*data.inputs))
|
|
|
|
|
2023-06-20 00:01:40 +03:00
|
|
|
data = self.load_testdata(cpu_svd_lapack_gesdd.data_2023_06_19[dtype_name])
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=partial(self.check_svd_results,
|
Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.
This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:
1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.
2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.
Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.
PiperOrigin-RevId: 687106965
2024-10-17 17:56:33 -07:00
|
|
|
*data.inputs),
|
|
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}_algorithm={algorithm_name}",
|
|
|
|
dtype_name=dtype_name, algorithm_name=algorithm_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128")
|
|
|
|
for algorithm_name in ("qr", "jacobi"))
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
|
|
def test_gpu_svd_solver_gesvd(self, dtype_name, algorithm_name):
|
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
def func(operand):
|
|
|
|
return lax.linalg.svd(operand, full_matrices=True, compute_uv=True,
|
|
|
|
algorithm=algorithm)
|
|
|
|
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
algorithm = dict(qr=lax.linalg.SvdAlgorithm.QR,
|
|
|
|
jacobi=lax.linalg.SvdAlgorithm.JACOBI)[algorithm_name]
|
|
|
|
|
|
|
|
info = cuda_svd_cusolver_gesvd.data_2024_10_08[algorithm_name][dtype_name]
|
|
|
|
data = self.load_testdata(info)
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=partial(self.check_svd_results,
|
|
|
|
*data.inputs))
|
2023-06-20 00:01:40 +03:00
|
|
|
|
2023-07-16 12:09:39 +03:00
|
|
|
@jtu.parameterized_filterable(
|
|
|
|
kwargs=[
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128")])
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
|
|
def test_cpu_triangular_solve_blas_trsm(self, dtype_name="f32"):
|
2023-10-12 13:15:22 +01:00
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
2023-07-16 12:09:39 +03:00
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
a_shape = (4, 4)
|
|
|
|
a = np.arange(math.prod(a_shape), dtype=dtype).reshape(a_shape)
|
|
|
|
a = np.tril(a + 5 * np.eye(a.shape[-1], dtype=a.dtype))
|
|
|
|
b_shape = (4, 5)
|
|
|
|
b = np.arange(math.prod(b_shape), dtype=dtype).reshape(b_shape)
|
|
|
|
left_side = True
|
|
|
|
def func(a, b):
|
|
|
|
return lax.linalg.triangular_solve(a, b, lower=True,
|
|
|
|
transpose_a=False,
|
|
|
|
conjugate_a=False, unit_diagonal=False,
|
|
|
|
left_side=left_side)
|
|
|
|
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
|
|
|
|
def check_triangular_solve_results(res_run, res_expected, *, rtol, atol):
|
|
|
|
x, = res_run
|
|
|
|
matmul = partial(jnp.matmul, precision=lax.Precision.HIGHEST)
|
|
|
|
y = matmul(a, x) if left_side else matmul(x, a)
|
|
|
|
self.assertArraysAllClose(y, jnp.broadcast_to(b, y.shape), rtol=rtol, atol=atol)
|
|
|
|
|
2025-02-10 08:26:52 -08:00
|
|
|
info = cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name]
|
|
|
|
data = self.load_testdata(info)
|
2023-07-16 12:09:39 +03:00
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=check_triangular_solve_results)
|
2025-02-10 08:26:52 -08:00
|
|
|
|
|
|
|
data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2023_07_16[dtype_name])
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
check_results=check_triangular_solve_results,
|
|
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
2023-07-16 12:09:39 +03:00
|
|
|
|
2024-10-01 09:19:13 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
|
|
def test_cpu_hessenberg_lapack_gehrd(self, dtype_name="f32"):
|
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
shape = (2, 4, 4)
|
|
|
|
input_data = jtu.rand_default(self.rng())(shape, dtype)
|
|
|
|
# del input_data # Input is in the testdata, here for readability
|
|
|
|
def func():
|
|
|
|
return lax.linalg.hessenberg(input_data)
|
|
|
|
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
|
2025-02-10 08:26:52 -08:00
|
|
|
info = cpu_hessenberg_lapack_gehrd.data_2024_08_31[dtype_name]
|
|
|
|
data = self.load_testdata(info)
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
|
|
|
|
2024-10-01 09:19:13 -07:00
|
|
|
data = self.load_testdata(
|
|
|
|
cpu_hessenberg_lapack_gehrd.data_2024_08_30[dtype_name]
|
|
|
|
)
|
2025-02-10 08:26:52 -08:00
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
2024-12-09 04:36:21 -08:00
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
|
|
def test_cpu_tridiagonal_lapack_sytrd_hetrd(self, dtype_name="f32"):
|
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
dtype = dict(f32=np.float32, f64=np.float64,
|
|
|
|
c64=np.complex64, c128=np.complex128)[dtype_name]
|
|
|
|
shape = (2, 4, 4)
|
|
|
|
input_data = jtu.rand_default(self.rng())(shape, dtype)
|
|
|
|
# del input_data # Input is in the testdata, here for readability
|
|
|
|
def func():
|
|
|
|
return lax.linalg.tridiagonal(input_data, lower=True)
|
|
|
|
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
|
2025-02-10 08:26:52 -08:00
|
|
|
info = cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01[dtype_name]
|
|
|
|
data = self.load_testdata(info)
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
|
|
|
|
2024-12-09 04:36:21 -08:00
|
|
|
data = self.load_testdata(
|
|
|
|
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03[dtype_name]
|
|
|
|
)
|
2025-02-10 08:26:52 -08:00
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
|
|
|
expect_current_custom_calls=info["custom_call_targets"])
|
2025-01-10 08:56:14 -08:00
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
|
|
def test_cpu_tridiagonal_solve_lapack_gtsv(self, dtype_name):
|
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
data = self.load_testdata(
|
|
|
|
cpu_tridiagonal_solve_lapack_gtsv.data_2025_01_09[dtype_name]
|
|
|
|
)
|
|
|
|
self.run_one_test(lax.linalg.tridiagonal_solve, data, rtol=rtol, atol=atol)
|
2025-01-10 09:27:31 -08:00
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
|
|
|
for dtype_name in ("f32", "f64", "c64", "c128"))
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
|
|
def test_gpu_tridiagonal_solver_sytrd(self, dtype_name):
|
|
|
|
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
|
|
|
self.skipTest("Test disabled for x32 mode")
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
return lax.linalg.tridiagonal(x, lower=True)
|
|
|
|
|
|
|
|
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
|
|
|
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
|
|
|
|
|
|
|
data = self.load_testdata(
|
|
|
|
cuda_tridiagonal_cusolver_sytrd.data_2025_01_09[dtype_name]
|
|
|
|
)
|
|
|
|
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
2024-10-01 09:19:13 -07:00
|
|
|
|
2025-04-08 00:09:27 -07:00
|
|
|
def test_tpu_approx_top_k(self):
|
2023-05-09 22:30:50 -07:00
|
|
|
def func():
|
|
|
|
x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0])
|
2023-05-16 13:07:23 -07:00
|
|
|
y = lax.approx_max_k(x, 3)
|
|
|
|
z = lax.approx_max_k(x, 3)
|
|
|
|
return y + z
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_ApproxTopK.data_2023_05_16)
|
2023-05-09 22:30:50 -07:00
|
|
|
self.run_one_test(func, data)
|
|
|
|
|
2023-06-19 17:17:55 +03:00
|
|
|
def test_cuda_threefry2x32(self):
|
2025-01-13 22:45:41 -08:00
|
|
|
with config.threefry_partitionable(False):
|
|
|
|
def func(x):
|
|
|
|
return jax.random.uniform(x, (2, 4), dtype=np.float32)
|
2023-03-17 10:15:47 -07:00
|
|
|
|
2025-01-13 22:45:41 -08:00
|
|
|
data = self.load_testdata(cuda_threefry2x32.data_2024_07_30)
|
|
|
|
self.run_one_test(func, data)
|
2023-03-15 23:09:59 -07:00
|
|
|
|
2025-04-08 00:09:27 -07:00
|
|
|
def test_tpu_sharding(self):
|
2023-03-17 11:19:15 -07:00
|
|
|
# Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU
|
2023-09-27 12:10:06 -07:00
|
|
|
if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2:
|
2024-06-05 09:51:02 -07:00
|
|
|
self.skipTest("Test runs only on TPU with at least 2 devices")
|
2023-03-17 11:19:15 -07:00
|
|
|
|
|
|
|
# Must use exactly 2 devices for expected outputs from ppermute
|
|
|
|
devices = jax.devices()[:2]
|
|
|
|
mesh = Mesh(devices, axis_names=('a'))
|
|
|
|
|
|
|
|
@partial(pjit.pjit,
|
|
|
|
in_shardings=(P('a', None),), out_shardings=P('a', None))
|
|
|
|
@partial(shard_map, mesh=mesh,
|
|
|
|
in_specs=(P('a', None),), out_specs=P('a', None))
|
|
|
|
def func(x): # b: f32[2, 4]
|
|
|
|
axis_size = lax.psum(1, 'a')
|
|
|
|
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
|
|
|
return lax.ppermute(x, 'a', perm=perm)
|
|
|
|
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_Sharding.data_2023_03_16)
|
2023-03-17 11:19:15 -07:00
|
|
|
with mesh:
|
2023-04-28 21:22:07 -07:00
|
|
|
self.run_one_test(func, data)
|
2023-03-17 11:19:15 -07:00
|
|
|
|
2025-04-08 00:09:27 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"_platform={platform}", platform=platform)
|
|
|
|
for platform in ("tpu", "gpu"))
|
|
|
|
def test_annotate_device_placement(self, platform):
|
|
|
|
if not jtu.test_device_matches([platform]):
|
|
|
|
self.skipTest(f"Test enabled only for {platform}")
|
|
|
|
|
|
|
|
mesh = Mesh(jax.local_devices()[0:1], axis_names=("a"))
|
|
|
|
|
|
|
|
dev_sharding = NS(mesh, P("a"))
|
|
|
|
host_sharding = NS(mesh, P("a"), memory_kind="pinned_host")
|
|
|
|
|
|
|
|
@partial(jax.jit,
|
|
|
|
in_shardings=(dev_sharding, host_sharding),
|
|
|
|
out_shardings=host_sharding)
|
|
|
|
def func(x, y):
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
if platform == "tpu":
|
|
|
|
data = self.load_testdata(annotate_data_placement.data_2025_04_07_tpu)
|
|
|
|
else:
|
|
|
|
data = self.load_testdata(annotate_data_placement.data_2025_04_07_cuda)
|
|
|
|
|
|
|
|
self.run_one_test(func, data)
|
|
|
|
|
2023-06-17 10:33:29 -07:00
|
|
|
def test_tpu_stablehlo_dynamic_reduce_window_unary(self):
|
|
|
|
# stablehlo.dynamic_reduce_window is used temporarily on TPU for a
|
|
|
|
# reduce window with dynamic shapes.
|
|
|
|
# See https://github.com/openxla/stablehlo/issues/1258 for the long term.
|
|
|
|
# The inputs are already in the test data, here only for readability.
|
|
|
|
shape = (3, 4)
|
|
|
|
_ = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
return jnp.cumsum(x, axis=0)
|
|
|
|
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17)
|
2023-06-17 10:33:29 -07:00
|
|
|
self.run_one_test(
|
|
|
|
func, data,
|
2023-07-03 17:31:31 +03:00
|
|
|
polymorphic_shapes=("b, ...",),
|
2023-08-11 17:58:52 +03:00
|
|
|
# Recent serializations also include shape_assertion, tested with dynamic_top_k
|
|
|
|
expect_current_custom_calls=["stablehlo.dynamic_reduce_window", "shape_assertion"])
|
2023-06-17 10:33:29 -07:00
|
|
|
|
|
|
|
def test_tpu_stablehlo_dynamic_reduce_window_variadic(self):
|
|
|
|
# stablehlo.dynamic_reduce_window is used temporarily on TPU for a
|
|
|
|
# reduce window with dynamic shapes.
|
|
|
|
# See https://github.com/openxla/stablehlo/issues/1258 for the long term.
|
|
|
|
# The inputs are already in the test data, here only for readability.
|
|
|
|
shape = (3, 4)
|
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
y = 100 + np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
|
|
|
|
_ = (x, y)
|
|
|
|
def func(x, y): # x: f32[b, 2] y: i32[b, 2]
|
|
|
|
return lax.reduce_window(
|
|
|
|
(x, y), (np.array(1., np.float32), np.array(2, np.int32)),
|
|
|
|
lambda xy0, xy1: (lax.add(xy0[0], xy1[0]),
|
|
|
|
lax.sub(xy0[1], xy1[1])),
|
|
|
|
(2, x.shape[0]), (1, 1), "VALID")
|
|
|
|
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17)
|
2023-06-17 10:33:29 -07:00
|
|
|
self.run_one_test(
|
|
|
|
func, data,
|
2023-07-03 17:31:31 +03:00
|
|
|
polymorphic_shapes=("b, ...", "b, ..."),
|
2023-08-11 17:58:52 +03:00
|
|
|
# Recent serializations also include shape_assertion, tested with dynamic_top_k
|
|
|
|
expect_current_custom_calls=["stablehlo.dynamic_reduce_window", "shape_assertion"])
|
2023-06-17 10:33:29 -07:00
|
|
|
|
2024-01-08 04:47:36 -08:00
|
|
|
@jtu.ignore_warning(
|
|
|
|
category=FutureWarning,
|
|
|
|
message="Raw arrays as random keys to jax.random functions are deprecated")
|
2023-06-19 00:38:59 -07:00
|
|
|
def test_stablehlo_dynamic_rbg_bit_generator(self):
|
|
|
|
# stablehlo.dynamic_rbg_bit_generator is used temporarily for a
|
|
|
|
# rbg_bit_generator with dynamic shapes.
|
|
|
|
# See https://github.com/openxla/stablehlo/issues/1344 for the long term.
|
|
|
|
key = np.arange(42, 42+4, dtype=np.uint32)
|
|
|
|
a_shape = (2, 3)
|
|
|
|
a = np.arange(math.prod(a_shape), dtype=np.float32).reshape(a_shape)
|
|
|
|
inputs = (key, a)
|
|
|
|
del inputs # already in the test data, here only for readability.
|
|
|
|
|
|
|
|
def func(key, a): # a is only used for its shape
|
|
|
|
return jax.random.key_data(jax.random.split(key, a.shape[0] * a.shape[1]))
|
|
|
|
|
|
|
|
# Note that the test currently checks that the generated sequence is the
|
|
|
|
# same. According to the StableHLO spec: "The output is guaranteed to be
|
|
|
|
# deterministic function of initial_state, but it is not guaranteed to be
|
|
|
|
# deterministic between implementations"
|
|
|
|
# See https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rng_bit_generator
|
|
|
|
# This test will fail when the implementation changes. We expect this to
|
|
|
|
# be rare, and most users may expect the RNG sequence to be the same
|
|
|
|
# upon reloading of a saved model.
|
|
|
|
# In case of an intended change in behavior we will have the option to
|
|
|
|
# replace this strict check with something else.
|
2023-06-19 16:15:10 +03:00
|
|
|
data = self.load_testdata(stablehlo_dynamic_rng_bit_generator.data_2023_06_17)
|
2023-06-19 00:38:59 -07:00
|
|
|
|
2023-10-12 13:15:22 +01:00
|
|
|
with config.default_prng_impl("unsafe_rbg"):
|
2023-08-11 17:58:52 +03:00
|
|
|
self.run_one_test(
|
|
|
|
func, data, polymorphic_shapes=(None, "b0, b1"),
|
|
|
|
# Recent serializations also include shape_assertion, tested with dynamic_top_k
|
|
|
|
expect_current_custom_calls=["stablehlo.dynamic_rng_bit_generator", "shape_assertion"])
|
2023-06-19 00:38:59 -07:00
|
|
|
|
2023-07-28 06:19:04 -07:00
|
|
|
def test_stablehlo_dynamic_top_k(self):
|
|
|
|
# stablehlo.dynamic_top_k is used temporarily for a top_k with dynamism
|
|
|
|
a = np.arange(12, dtype=np.float32).reshape((4, 3))
|
|
|
|
|
|
|
|
def func(a):
|
|
|
|
return lax.top_k(a, k=a.shape[-1] - 1)
|
|
|
|
|
|
|
|
data = self.load_testdata(stablehlo_dynamic_top_k.data_2023_07_16)
|
|
|
|
def check_top_k_results(res_run, res_expected, *, rtol, atol):
|
|
|
|
# The order of the results may be different, but should be the same ones
|
|
|
|
values_expected, _ = res_expected
|
|
|
|
values_run, indices_run = res_run
|
|
|
|
# Check that indices are correct
|
|
|
|
self.assertAllClose(values_run,
|
|
|
|
a[np.arange(a.shape[0]).reshape(a.shape[0], 1),
|
|
|
|
indices_run], atol=atol, rtol=rtol)
|
|
|
|
self.assertAllClose(np.sort(values_run), np.sort(values_expected),
|
|
|
|
atol=atol, rtol=rtol)
|
|
|
|
|
|
|
|
self.run_one_test(func, data,
|
|
|
|
polymorphic_shapes=("_, b",),
|
2023-08-11 22:49:04 -07:00
|
|
|
check_results=check_top_k_results,
|
2023-08-11 17:58:52 +03:00
|
|
|
# Recent serializations also include shape_assertion
|
|
|
|
expect_current_custom_calls=["stablehlo.dynamic_top_k", "shape_assertion"])
|
|
|
|
|
|
|
|
# Now a test with serialization version 7, including shape_assertion
|
|
|
|
data_2 = self.load_testdata(stablehlo_dynamic_top_k.data_2023_08_11)
|
|
|
|
self.run_one_test(func, data_2,
|
|
|
|
polymorphic_shapes=("_, b",),
|
|
|
|
check_results=check_top_k_results)
|
2023-07-28 06:19:04 -07:00
|
|
|
|
2024-06-05 09:51:02 -07:00
|
|
|
def test_dynamic_approx_top_k(self):
|
|
|
|
# stablehlo.dynamic_approx_top_k is used temporarily for a approx_top_k
|
|
|
|
# with dynamism
|
|
|
|
# This is the input that was used to generate the test_data
|
|
|
|
_ = np.arange(24, dtype=np.float32)
|
|
|
|
|
|
|
|
def func(a): # a: f32[b + 4]
|
|
|
|
return lax.approx_max_k(a, k=a.shape[0] - 4)
|
|
|
|
|
|
|
|
data = self.load_testdata(stablehlo_dynamic_approx_top_k.data_2024_05_30)
|
|
|
|
|
|
|
|
def check_top_k_results(res_run, res_expected, *, rtol, atol):
|
|
|
|
a = data.inputs[0]
|
|
|
|
# The order of the results may be different, but should be the same ones
|
|
|
|
values_expected, _ = res_expected
|
|
|
|
values_run, indices_run = res_run
|
|
|
|
# Check that indices are correct
|
|
|
|
self.assertAllClose(
|
|
|
|
values_run,
|
|
|
|
a[indices_run],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol,
|
|
|
|
)
|
|
|
|
self.assertAllClose(
|
|
|
|
np.sort(values_run), np.sort(values_expected), atol=atol, rtol=rtol
|
|
|
|
)
|
|
|
|
|
|
|
|
self.run_one_test(
|
|
|
|
func,
|
|
|
|
data,
|
|
|
|
polymorphic_shapes=("b + 4,",),
|
|
|
|
check_results=check_top_k_results,
|
|
|
|
expect_current_custom_calls=[
|
|
|
|
"stablehlo.dynamic_approx_top_k",
|
|
|
|
"shape_assertion",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-03-15 23:09:59 -07:00
|
|
|
|
2025-03-03 06:00:51 -08:00
|
|
|
@jtu.with_config(jax_use_shardy_partitioner=True)
|
|
|
|
class ShardyCompatTest(bctu.CompatTestBase):
|
|
|
|
def test_shardy_sharding_ops_with_different_meshes(self):
|
|
|
|
# Tests whether we can save and load a module with meshes that have the
|
|
|
|
# same axis sizes (and same order) but different axis names.
|
|
|
|
# Also tests "Sharding", "xla.sdy.GlobalToLocalShape",
|
|
|
|
# "xla.sdy.LocalToGlobalShape".
|
|
|
|
if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2:
|
|
|
|
self.skipTest("Test runs only on TPU with at least 2 devices")
|
|
|
|
|
|
|
|
# Must use exactly 2 devices for expected outputs from ppermute.
|
|
|
|
devices = jax.devices()[:2]
|
|
|
|
old_mesh = Mesh(devices, axis_names=('a'))
|
|
|
|
|
|
|
|
def func(x): # x: f32[4, 4]
|
|
|
|
@partial(shard_map, mesh=old_mesh,
|
|
|
|
in_specs=(P('a', None),), out_specs=P('a', None))
|
|
|
|
def shard_map_func(x): # b: f32[2, 4]
|
|
|
|
axis_size = lax.psum(1, 'a')
|
|
|
|
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
|
|
|
return lax.ppermute(x, 'a', perm=perm)
|
|
|
|
x = jax.lax.with_sharding_constraint(x, NS(old_mesh, P('a', None)))
|
|
|
|
return shard_map_func(x)
|
|
|
|
|
|
|
|
data = self.load_testdata(shardy_sharding_ops_with_different_meshes.data_2025_02_12)
|
|
|
|
with Mesh(devices, axis_names=('x')):
|
|
|
|
self.run_one_test(func, data)
|
|
|
|
|
|
|
|
|
2023-03-15 23:09:59 -07:00
|
|
|
if __name__ == "__main__":
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|