mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.
Improve the documentation of lax.eig(). Fixes https://github.com/google/jax/issues/18226 PiperOrigin-RevId: 584170564
This commit is contained in:
parent
29eec05c92
commit
49c80e68d1
@ -43,6 +43,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
|
||||
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
|
||||
|
||||
* Bug fixes
|
||||
* Fixed error/hang when an array with non-finite values is passed to a
|
||||
non-symmetric eigendecomposition (#18226). Arrays with non-finite values now
|
||||
produce arrays full of NaNs as outputs.
|
||||
|
||||
## jax 0.4.20 (Nov 2, 2023)
|
||||
|
||||
|
@ -140,6 +140,21 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
|
||||
"""Eigendecomposition of a general matrix.
|
||||
|
||||
Nonsymmetric eigendecomposition is at present only implemented on CPU.
|
||||
|
||||
Args:
|
||||
x: A batch of square matrices with shape ``[..., n, n]``.
|
||||
compute_left_eigenvectors: If true, the left eigenvectors will be computed.
|
||||
compute_right_eigenvectors: If true, the right eigenvectors will be
|
||||
computed.
|
||||
Returns:
|
||||
The eigendecomposition of ``x``, which is a tuple of the form
|
||||
``(w, vl, vr)`` where ``w`` are the eigenvalues, ``vl`` are the left
|
||||
eigenvectors, and ``vr`` are the right eigenvectors. ``vl`` and ``vr`` are
|
||||
optional and will only be included if ``compute_left_eigenvectors`` or
|
||||
``compute_right_eigenvectors`` respectively are ``True``.
|
||||
|
||||
If the eigendecomposition fails, then arrays full of NaNs will be returned
|
||||
for that batch element.
|
||||
"""
|
||||
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
|
||||
compute_right_eigenvectors=compute_right_eigenvectors)
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "jaxlib/cpu/lapack_kernels.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
@ -562,22 +563,35 @@ void RealGeev<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
lwork = static_cast<int>(work_query);
|
||||
T* work = new T[lwork];
|
||||
|
||||
auto is_finite = [](T* a_work, int64_t n) {
|
||||
for (int64_t j = 0; j < n; ++j) {
|
||||
for (int64_t k = 0; k < n; ++k) {
|
||||
if (!std::isfinite(a_work[j * n + k])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
for (int i = 0; i < b; ++i) {
|
||||
size_t a_size = n * n * sizeof(T);
|
||||
std::memcpy(a_work, a_in, a_size);
|
||||
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int,
|
||||
vr_work, &n_int, work, &lwork, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
||||
if (info_out[0] == 0) {
|
||||
UnpackEigenvectors(n, wi_out, vl_work, vl_out);
|
||||
UnpackEigenvectors(n, wi_out, vr_work, vr_out);
|
||||
if (is_finite(a_work, n)) {
|
||||
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work,
|
||||
&n_int, vr_work, &n_int, work, &lwork, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
||||
if (info_out[0] == 0) {
|
||||
UnpackEigenvectors(n, wi_out, vl_work, vl_out);
|
||||
UnpackEigenvectors(n, wi_out, vr_work, vr_out);
|
||||
}
|
||||
} else {
|
||||
*info_out = -4;
|
||||
}
|
||||
|
||||
a_in += n * n;
|
||||
wr_out += n;
|
||||
wi_out += n;
|
||||
@ -621,16 +635,32 @@ void ComplexGeev<T>::Kernel(void* out_tuple, void** data,
|
||||
lwork = static_cast<int>(work_query.real());
|
||||
T* work = new T[lwork];
|
||||
|
||||
auto is_finite = [](T* a_work, int64_t n) {
|
||||
for (int64_t j = 0; j < n; ++j) {
|
||||
for (int64_t k = 0; k < n; ++k) {
|
||||
T v = a_work[j * n + k];
|
||||
if (!std::isfinite(v.real()) || !std::isfinite(v.imag())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
for (int i = 0; i < b; ++i) {
|
||||
size_t a_size = n * n * sizeof(T);
|
||||
std::memcpy(a_work, a_in, a_size);
|
||||
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
|
||||
&n_int, work, &lwork, r_work, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
||||
if (is_finite(a_work, n)) {
|
||||
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
|
||||
&n_int, work, &lwork, r_work, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
||||
} else {
|
||||
*info_out = -4;
|
||||
}
|
||||
a_in += n * n;
|
||||
w_out += n;
|
||||
vl_out += n * n;
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
@ -29,6 +30,7 @@ from jax import jit, grad, jvp, vmap
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import scipy as jsp
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
@ -245,6 +247,27 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
|
||||
rtol=1e-3)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)],
|
||||
dtype=float_types + complex_types,
|
||||
compute_left_eigenvectors=[False, True],
|
||||
compute_right_eigenvectors=[False, True],
|
||||
)
|
||||
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
||||
# for GPU/TPU.
|
||||
@jtu.run_on_devices("cpu")
|
||||
@unittest.skipIf(jaxlib_version < (0, 4, 21), "Test requires jaxlib 0.4.21")
|
||||
def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
|
||||
compute_right_eigenvectors):
|
||||
"""Verifies that `eig` fails gracefully if given non-finite inputs."""
|
||||
a = jnp.full(shape, jnp.nan, dtype)
|
||||
results = lax.linalg.eig(
|
||||
a, compute_left_eigenvectors=compute_left_eigenvectors,
|
||||
compute_right_eigenvectors=compute_right_eigenvectors)
|
||||
for result in results:
|
||||
self.assertTrue(np.all(np.isnan(result)))
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)],
|
||||
dtype=float_types + complex_types,
|
||||
|
Loading…
x
Reference in New Issue
Block a user