Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.

Improve the documentation of lax.eig().

Fixes https://github.com/google/jax/issues/18226

PiperOrigin-RevId: 584170564
This commit is contained in:
Peter Hawkins 2023-11-20 17:27:42 -08:00 committed by jax authors
parent 29eec05c92
commit 49c80e68d1
4 changed files with 91 additions and 19 deletions

View File

@ -43,6 +43,10 @@ Remember to align the itemized text with the first line of an item within a list
* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to * On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
1024x1024. The Jacobi solver appears faster than the non-Jacobi version. 1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
* Bug fixes
* Fixed error/hang when an array with non-finite values is passed to a
non-symmetric eigendecomposition (#18226). Arrays with non-finite values now
produce arrays full of NaNs as outputs.
## jax 0.4.20 (Nov 2, 2023) ## jax 0.4.20 (Nov 2, 2023)

View File

@ -140,6 +140,21 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
"""Eigendecomposition of a general matrix. """Eigendecomposition of a general matrix.
Nonsymmetric eigendecomposition is at present only implemented on CPU. Nonsymmetric eigendecomposition is at present only implemented on CPU.
Args:
x: A batch of square matrices with shape ``[..., n, n]``.
compute_left_eigenvectors: If true, the left eigenvectors will be computed.
compute_right_eigenvectors: If true, the right eigenvectors will be
computed.
Returns:
The eigendecomposition of ``x``, which is a tuple of the form
``(w, vl, vr)`` where ``w`` are the eigenvalues, ``vl`` are the left
eigenvectors, and ``vr`` are the right eigenvectors. ``vl`` and ``vr`` are
optional and will only be included if ``compute_left_eigenvectors`` or
``compute_right_eigenvectors`` respectively are ``True``.
If the eigendecomposition fails, then arrays full of NaNs will be returned
for that batch element.
""" """
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors) compute_right_eigenvectors=compute_right_eigenvectors)

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/cpu/lapack_kernels.h"
#include <cmath> #include <cmath>
#include <cstdint>
#include <cstring> #include <cstring>
#include <limits> #include <limits>
@ -562,22 +563,35 @@ void RealGeev<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
lwork = static_cast<int>(work_query); lwork = static_cast<int>(work_query);
T* work = new T[lwork]; T* work = new T[lwork];
auto is_finite = [](T* a_work, int64_t n) {
for (int64_t j = 0; j < n; ++j) {
for (int64_t k = 0; k < n; ++k) {
if (!std::isfinite(a_work[j * n + k])) {
return false;
}
}
}
return true;
};
for (int i = 0; i < b; ++i) { for (int i = 0; i < b; ++i) {
size_t a_size = n * n * sizeof(T); size_t a_size = n * n * sizeof(T);
std::memcpy(a_work, a_in, a_size); std::memcpy(a_work, a_in, a_size);
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int, if (is_finite(a_work, n)) {
vr_work, &n_int, work, &lwork, info_out); fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work,
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); &n_int, vr_work, &n_int, work, &lwork, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n);
if (info_out[0] == 0) { ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
UnpackEigenvectors(n, wi_out, vl_work, vl_out); if (info_out[0] == 0) {
UnpackEigenvectors(n, wi_out, vr_work, vr_out); UnpackEigenvectors(n, wi_out, vl_work, vl_out);
UnpackEigenvectors(n, wi_out, vr_work, vr_out);
}
} else {
*info_out = -4;
} }
a_in += n * n; a_in += n * n;
wr_out += n; wr_out += n;
wi_out += n; wi_out += n;
@ -621,16 +635,32 @@ void ComplexGeev<T>::Kernel(void* out_tuple, void** data,
lwork = static_cast<int>(work_query.real()); lwork = static_cast<int>(work_query.real());
T* work = new T[lwork]; T* work = new T[lwork];
auto is_finite = [](T* a_work, int64_t n) {
for (int64_t j = 0; j < n; ++j) {
for (int64_t k = 0; k < n; ++k) {
T v = a_work[j * n + k];
if (!std::isfinite(v.real()) || !std::isfinite(v.imag())) {
return false;
}
}
}
return true;
};
for (int i = 0; i < b; ++i) { for (int i = 0; i < b; ++i) {
size_t a_size = n * n * sizeof(T); size_t a_size = n * n * sizeof(T);
std::memcpy(a_work, a_in, a_size); std::memcpy(a_work, a_in, a_size);
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out, if (is_finite(a_work, n)) {
&n_int, work, &lwork, r_work, info_out); fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); &n_int, work, &lwork, r_work, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
} else {
*info_out = -4;
}
a_in += n * n; a_in += n * n;
w_out += n; w_out += n;
vl_out += n * n; vl_out += n * n;

View File

@ -16,6 +16,7 @@
from functools import partial from functools import partial
import itertools import itertools
import unittest
import numpy as np import numpy as np
import scipy import scipy
@ -29,6 +30,7 @@ from jax import jit, grad, jvp, vmap
from jax import lax from jax import lax
from jax import numpy as jnp from jax import numpy as jnp
from jax import scipy as jsp from jax import scipy as jsp
from jax._src.lib import version as jaxlib_version
from jax._src import config from jax._src import config
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src import xla_bridge from jax._src import xla_bridge
@ -245,6 +247,27 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker, self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
rtol=1e-3) rtol=1e-3)
@jtu.sample_product(
shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)],
dtype=float_types + complex_types,
compute_left_eigenvectors=[False, True],
compute_right_eigenvectors=[False, True],
)
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.run_on_devices("cpu")
@unittest.skipIf(jaxlib_version < (0, 4, 21), "Test requires jaxlib 0.4.21")
def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
"""Verifies that `eig` fails gracefully if given non-finite inputs."""
a = jnp.full(shape, jnp.nan, dtype)
results = lax.linalg.eig(
a, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
for result in results:
self.assertTrue(np.all(np.isnan(result)))
@jtu.sample_product( @jtu.sample_product(
shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)], shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)],
dtype=float_types + complex_types, dtype=float_types + complex_types,