mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.
Improve the documentation of lax.eig(). Fixes https://github.com/google/jax/issues/18226 PiperOrigin-RevId: 584170564
This commit is contained in:
parent
29eec05c92
commit
49c80e68d1
@ -43,6 +43,10 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
|
* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
|
||||||
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
|
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
|
||||||
|
|
||||||
|
* Bug fixes
|
||||||
|
* Fixed error/hang when an array with non-finite values is passed to a
|
||||||
|
non-symmetric eigendecomposition (#18226). Arrays with non-finite values now
|
||||||
|
produce arrays full of NaNs as outputs.
|
||||||
|
|
||||||
## jax 0.4.20 (Nov 2, 2023)
|
## jax 0.4.20 (Nov 2, 2023)
|
||||||
|
|
||||||
|
@ -140,6 +140,21 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
|
|||||||
"""Eigendecomposition of a general matrix.
|
"""Eigendecomposition of a general matrix.
|
||||||
|
|
||||||
Nonsymmetric eigendecomposition is at present only implemented on CPU.
|
Nonsymmetric eigendecomposition is at present only implemented on CPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: A batch of square matrices with shape ``[..., n, n]``.
|
||||||
|
compute_left_eigenvectors: If true, the left eigenvectors will be computed.
|
||||||
|
compute_right_eigenvectors: If true, the right eigenvectors will be
|
||||||
|
computed.
|
||||||
|
Returns:
|
||||||
|
The eigendecomposition of ``x``, which is a tuple of the form
|
||||||
|
``(w, vl, vr)`` where ``w`` are the eigenvalues, ``vl`` are the left
|
||||||
|
eigenvectors, and ``vr`` are the right eigenvectors. ``vl`` and ``vr`` are
|
||||||
|
optional and will only be included if ``compute_left_eigenvectors`` or
|
||||||
|
``compute_right_eigenvectors`` respectively are ``True``.
|
||||||
|
|
||||||
|
If the eigendecomposition fails, then arrays full of NaNs will be returned
|
||||||
|
for that batch element.
|
||||||
"""
|
"""
|
||||||
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
|
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
|
||||||
compute_right_eigenvectors=compute_right_eigenvectors)
|
compute_right_eigenvectors=compute_right_eigenvectors)
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "jaxlib/cpu/lapack_kernels.h"
|
#include "jaxlib/cpu/lapack_kernels.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
@ -562,22 +563,35 @@ void RealGeev<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
|||||||
lwork = static_cast<int>(work_query);
|
lwork = static_cast<int>(work_query);
|
||||||
T* work = new T[lwork];
|
T* work = new T[lwork];
|
||||||
|
|
||||||
|
auto is_finite = [](T* a_work, int64_t n) {
|
||||||
|
for (int64_t j = 0; j < n; ++j) {
|
||||||
|
for (int64_t k = 0; k < n; ++k) {
|
||||||
|
if (!std::isfinite(a_work[j * n + k])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
for (int i = 0; i < b; ++i) {
|
for (int i = 0; i < b; ++i) {
|
||||||
size_t a_size = n * n * sizeof(T);
|
size_t a_size = n * n * sizeof(T);
|
||||||
std::memcpy(a_work, a_in, a_size);
|
std::memcpy(a_work, a_in, a_size);
|
||||||
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int,
|
if (is_finite(a_work, n)) {
|
||||||
vr_work, &n_int, work, &lwork, info_out);
|
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work,
|
||||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
&n_int, vr_work, &n_int, work, &lwork, info_out);
|
||||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
||||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
|
||||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n);
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
|
||||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n);
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n);
|
||||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n);
|
||||||
if (info_out[0] == 0) {
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
||||||
UnpackEigenvectors(n, wi_out, vl_work, vl_out);
|
if (info_out[0] == 0) {
|
||||||
UnpackEigenvectors(n, wi_out, vr_work, vr_out);
|
UnpackEigenvectors(n, wi_out, vl_work, vl_out);
|
||||||
|
UnpackEigenvectors(n, wi_out, vr_work, vr_out);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*info_out = -4;
|
||||||
}
|
}
|
||||||
|
|
||||||
a_in += n * n;
|
a_in += n * n;
|
||||||
wr_out += n;
|
wr_out += n;
|
||||||
wi_out += n;
|
wi_out += n;
|
||||||
@ -621,16 +635,32 @@ void ComplexGeev<T>::Kernel(void* out_tuple, void** data,
|
|||||||
lwork = static_cast<int>(work_query.real());
|
lwork = static_cast<int>(work_query.real());
|
||||||
T* work = new T[lwork];
|
T* work = new T[lwork];
|
||||||
|
|
||||||
|
auto is_finite = [](T* a_work, int64_t n) {
|
||||||
|
for (int64_t j = 0; j < n; ++j) {
|
||||||
|
for (int64_t k = 0; k < n; ++k) {
|
||||||
|
T v = a_work[j * n + k];
|
||||||
|
if (!std::isfinite(v.real()) || !std::isfinite(v.imag())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
for (int i = 0; i < b; ++i) {
|
for (int i = 0; i < b; ++i) {
|
||||||
size_t a_size = n * n * sizeof(T);
|
size_t a_size = n * n * sizeof(T);
|
||||||
std::memcpy(a_work, a_in, a_size);
|
std::memcpy(a_work, a_in, a_size);
|
||||||
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
|
if (is_finite(a_work, n)) {
|
||||||
&n_int, work, &lwork, r_work, info_out);
|
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
|
||||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
&n_int, work, &lwork, r_work, info_out);
|
||||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
||||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n);
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
|
||||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n);
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n);
|
||||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n);
|
||||||
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
||||||
|
} else {
|
||||||
|
*info_out = -4;
|
||||||
|
}
|
||||||
a_in += n * n;
|
a_in += n * n;
|
||||||
w_out += n;
|
w_out += n;
|
||||||
vl_out += n * n;
|
vl_out += n * n;
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import itertools
|
import itertools
|
||||||
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
@ -29,6 +30,7 @@ from jax import jit, grad, jvp, vmap
|
|||||||
from jax import lax
|
from jax import lax
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jax import scipy as jsp
|
from jax import scipy as jsp
|
||||||
|
from jax._src.lib import version as jaxlib_version
|
||||||
from jax._src import config
|
from jax._src import config
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src import xla_bridge
|
from jax._src import xla_bridge
|
||||||
@ -245,6 +247,27 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
|||||||
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
|
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
|
||||||
rtol=1e-3)
|
rtol=1e-3)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)],
|
||||||
|
dtype=float_types + complex_types,
|
||||||
|
compute_left_eigenvectors=[False, True],
|
||||||
|
compute_right_eigenvectors=[False, True],
|
||||||
|
)
|
||||||
|
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
||||||
|
# for GPU/TPU.
|
||||||
|
@jtu.run_on_devices("cpu")
|
||||||
|
@unittest.skipIf(jaxlib_version < (0, 4, 21), "Test requires jaxlib 0.4.21")
|
||||||
|
def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
|
||||||
|
compute_right_eigenvectors):
|
||||||
|
"""Verifies that `eig` fails gracefully if given non-finite inputs."""
|
||||||
|
a = jnp.full(shape, jnp.nan, dtype)
|
||||||
|
results = lax.linalg.eig(
|
||||||
|
a, compute_left_eigenvectors=compute_left_eigenvectors,
|
||||||
|
compute_right_eigenvectors=compute_right_eigenvectors)
|
||||||
|
for result in results:
|
||||||
|
self.assertTrue(np.all(np.isnan(result)))
|
||||||
|
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)],
|
shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)],
|
||||||
dtype=float_types + complex_types,
|
dtype=float_types + complex_types,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user