Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.

Improve the documentation of lax.eig().

Fixes https://github.com/google/jax/issues/18226

PiperOrigin-RevId: 584170564
This commit is contained in:
Peter Hawkins 2023-11-20 17:27:42 -08:00 committed by jax authors
parent 29eec05c92
commit 49c80e68d1
4 changed files with 91 additions and 19 deletions

View File

@ -43,6 +43,10 @@ Remember to align the itemized text with the first line of an item within a list
* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
* Bug fixes
* Fixed error/hang when an array with non-finite values is passed to a
non-symmetric eigendecomposition (#18226). Arrays with non-finite values now
produce arrays full of NaNs as outputs.
## jax 0.4.20 (Nov 2, 2023)

View File

@ -140,6 +140,21 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
"""Eigendecomposition of a general matrix.
Nonsymmetric eigendecomposition is at present only implemented on CPU.
Args:
x: A batch of square matrices with shape ``[..., n, n]``.
compute_left_eigenvectors: If true, the left eigenvectors will be computed.
compute_right_eigenvectors: If true, the right eigenvectors will be
computed.
Returns:
The eigendecomposition of ``x``, which is a tuple of the form
``(w, vl, vr)`` where ``w`` are the eigenvalues, ``vl`` are the left
eigenvectors, and ``vr`` are the right eigenvectors. ``vl`` and ``vr`` are
optional and will only be included if ``compute_left_eigenvectors`` or
``compute_right_eigenvectors`` respectively are ``True``.
If the eigendecomposition fails, then arrays full of NaNs will be returned
for that batch element.
"""
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "jaxlib/cpu/lapack_kernels.h"
#include <cmath>
#include <cstdint>
#include <cstring>
#include <limits>
@ -562,22 +563,35 @@ void RealGeev<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
lwork = static_cast<int>(work_query);
T* work = new T[lwork];
auto is_finite = [](T* a_work, int64_t n) {
for (int64_t j = 0; j < n; ++j) {
for (int64_t k = 0; k < n; ++k) {
if (!std::isfinite(a_work[j * n + k])) {
return false;
}
}
}
return true;
};
for (int i = 0; i < b; ++i) {
size_t a_size = n * n * sizeof(T);
std::memcpy(a_work, a_in, a_size);
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int,
vr_work, &n_int, work, &lwork, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
if (info_out[0] == 0) {
UnpackEigenvectors(n, wi_out, vl_work, vl_out);
UnpackEigenvectors(n, wi_out, vr_work, vr_out);
if (is_finite(a_work, n)) {
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work,
&n_int, vr_work, &n_int, work, &lwork, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
if (info_out[0] == 0) {
UnpackEigenvectors(n, wi_out, vl_work, vl_out);
UnpackEigenvectors(n, wi_out, vr_work, vr_out);
}
} else {
*info_out = -4;
}
a_in += n * n;
wr_out += n;
wi_out += n;
@ -621,16 +635,32 @@ void ComplexGeev<T>::Kernel(void* out_tuple, void** data,
lwork = static_cast<int>(work_query.real());
T* work = new T[lwork];
auto is_finite = [](T* a_work, int64_t n) {
for (int64_t j = 0; j < n; ++j) {
for (int64_t k = 0; k < n; ++k) {
T v = a_work[j * n + k];
if (!std::isfinite(v.real()) || !std::isfinite(v.imag())) {
return false;
}
}
}
return true;
};
for (int i = 0; i < b; ++i) {
size_t a_size = n * n * sizeof(T);
std::memcpy(a_work, a_in, a_size);
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
&n_int, work, &lwork, r_work, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
if (is_finite(a_work, n)) {
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
&n_int, work, &lwork, r_work, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
} else {
*info_out = -4;
}
a_in += n * n;
w_out += n;
vl_out += n * n;

View File

@ -16,6 +16,7 @@
from functools import partial
import itertools
import unittest
import numpy as np
import scipy
@ -29,6 +30,7 @@ from jax import jit, grad, jvp, vmap
from jax import lax
from jax import numpy as jnp
from jax import scipy as jsp
from jax._src.lib import version as jaxlib_version
from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge
@ -245,6 +247,27 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
rtol=1e-3)
@jtu.sample_product(
shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)],
dtype=float_types + complex_types,
compute_left_eigenvectors=[False, True],
compute_right_eigenvectors=[False, True],
)
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.run_on_devices("cpu")
@unittest.skipIf(jaxlib_version < (0, 4, 21), "Test requires jaxlib 0.4.21")
def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
"""Verifies that `eig` fails gracefully if given non-finite inputs."""
a = jnp.full(shape, jnp.nan, dtype)
results = lax.linalg.eig(
a, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
for result in results:
self.assertTrue(np.all(np.isnan(result)))
@jtu.sample_product(
shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)],
dtype=float_types + complex_types,