From 49c80e68d105dc93e5f26ef15b434b279bf00a03 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 20 Nov 2023 17:27:42 -0800 Subject: [PATCH] Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition. Improve the documentation of lax.eig(). Fixes https://github.com/google/jax/issues/18226 PiperOrigin-RevId: 584170564 --- CHANGELOG.md | 4 +++ jax/_src/lax/linalg.py | 15 ++++++++ jaxlib/cpu/lapack_kernels.cc | 68 ++++++++++++++++++++++++++---------- tests/linalg_test.py | 23 ++++++++++++ 4 files changed, 91 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fe5420e2..0b2587c3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,10 @@ Remember to align the itemized text with the first line of an item within a list * On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to 1024x1024. The Jacobi solver appears faster than the non-Jacobi version. +* Bug fixes + * Fixed error/hang when an array with non-finite values is passed to a + non-symmetric eigendecomposition (#18226). Arrays with non-finite values now + produce arrays full of NaNs as outputs. ## jax 0.4.20 (Nov 2, 2023) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index cfeacd3e8..3cece6a47 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -140,6 +140,21 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, """Eigendecomposition of a general matrix. Nonsymmetric eigendecomposition is at present only implemented on CPU. + + Args: + x: A batch of square matrices with shape ``[..., n, n]``. + compute_left_eigenvectors: If true, the left eigenvectors will be computed. + compute_right_eigenvectors: If true, the right eigenvectors will be + computed. + Returns: + The eigendecomposition of ``x``, which is a tuple of the form + ``(w, vl, vr)`` where ``w`` are the eigenvalues, ``vl`` are the left + eigenvectors, and ``vr`` are the right eigenvectors. ``vl`` and ``vr`` are + optional and will only be included if ``compute_left_eigenvectors`` or + ``compute_right_eigenvectors`` respectively are ``True``. + + If the eigendecomposition fails, then arrays full of NaNs will be returned + for that batch element. """ return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors) diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index e52004053..211d11020 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/cpu/lapack_kernels.h" #include +#include #include #include @@ -562,22 +563,35 @@ void RealGeev::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { lwork = static_cast(work_query); T* work = new T[lwork]; + auto is_finite = [](T* a_work, int64_t n) { + for (int64_t j = 0; j < n; ++j) { + for (int64_t k = 0; k < n; ++k) { + if (!std::isfinite(a_work[j * n + k])) { + return false; + } + } + } + return true; + }; for (int i = 0; i < b; ++i) { size_t a_size = n * n * sizeof(T); std::memcpy(a_work, a_in, a_size); - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int, - vr_work, &n_int, work, &lwork, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - if (info_out[0] == 0) { - UnpackEigenvectors(n, wi_out, vl_work, vl_out); - UnpackEigenvectors(n, wi_out, vr_work, vr_out); + if (is_finite(a_work, n)) { + fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, + &n_int, vr_work, &n_int, work, &lwork, info_out); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); + if (info_out[0] == 0) { + UnpackEigenvectors(n, wi_out, vl_work, vl_out); + UnpackEigenvectors(n, wi_out, vr_work, vr_out); + } + } else { + *info_out = -4; } - a_in += n * n; wr_out += n; wi_out += n; @@ -621,16 +635,32 @@ void ComplexGeev::Kernel(void* out_tuple, void** data, lwork = static_cast(work_query.real()); T* work = new T[lwork]; + auto is_finite = [](T* a_work, int64_t n) { + for (int64_t j = 0; j < n; ++j) { + for (int64_t k = 0; k < n; ++k) { + T v = a_work[j * n + k]; + if (!std::isfinite(v.real()) || !std::isfinite(v.imag())) { + return false; + } + } + } + return true; + }; + for (int i = 0; i < b; ++i) { size_t a_size = n * n * sizeof(T); std::memcpy(a_work, a_in, a_size); - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out, - &n_int, work, &lwork, r_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); + if (is_finite(a_work, n)) { + fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out, + &n_int, work, &lwork, r_work, info_out); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); + } else { + *info_out = -4; + } a_in += n * n; w_out += n; vl_out += n * n; diff --git a/tests/linalg_test.py b/tests/linalg_test.py index ece7a0ae9..d239fffc9 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -16,6 +16,7 @@ from functools import partial import itertools +import unittest import numpy as np import scipy @@ -29,6 +30,7 @@ from jax import jit, grad, jvp, vmap from jax import lax from jax import numpy as jnp from jax import scipy as jsp +from jax._src.lib import version as jaxlib_version from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge @@ -245,6 +247,27 @@ class NumpyLinalgTest(jtu.JaxTestCase): self._CompileAndCheck(partial(jnp.linalg.eig), args_maker, rtol=1e-3) + @jtu.sample_product( + shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)], + dtype=float_types + complex_types, + compute_left_eigenvectors=[False, True], + compute_right_eigenvectors=[False, True], + ) + # TODO(phawkins): enable when there is an eigendecomposition implementation + # for GPU/TPU. + @jtu.run_on_devices("cpu") + @unittest.skipIf(jaxlib_version < (0, 4, 21), "Test requires jaxlib 0.4.21") + def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, + compute_right_eigenvectors): + """Verifies that `eig` fails gracefully if given non-finite inputs.""" + a = jnp.full(shape, jnp.nan, dtype) + results = lax.linalg.eig( + a, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors) + for result in results: + self.assertTrue(np.all(np.isnan(result))) + + @jtu.sample_product( shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)], dtype=float_types + complex_types,