mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Don't mask out zero elements on the diagonal of the matrix when inverting triangular matrices.
The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes https://github.com/google/jax/issues/3589 Fixes https://github.com/google/jax/issues/15429 PiperOrigin-RevId: 653562611
This commit is contained in:
parent
174429d7cf
commit
47e6da3332
@ -42,6 +42,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Bug fixes
|
||||
* Fixed a bug that meant that negative static_argnums to a jit were mishandled
|
||||
by the jit dispatch fast path.
|
||||
* Fixed a bug that meant triangular solves of batches of singular matrices
|
||||
produce nonsensical finite values, instead of inf or nan (#3589, #15429).
|
||||
|
||||
## jax 0.4.30 (June 18, 2024)
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
@ -33,6 +34,7 @@ from jax._src import config
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.numpy.util import promote_dtypes_inexact
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -1623,6 +1625,15 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
(a, b),
|
||||
(a, b))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 277, "Requires jaxlib > 0.4.30")
|
||||
def testTriangularSolveSingularBatched(self):
|
||||
x = jnp.array([[1, 1], [0, 0]], dtype=np.float32)
|
||||
y = jnp.array([[1], [1.]], dtype=np.float32)
|
||||
out = jax.lax.linalg.triangular_solve(x[None], y[None], left_side=True)
|
||||
# x is singular. The triangular solve may contain either nans or infs, but
|
||||
# it should not consist of only finite values.
|
||||
self.assertFalse(np.all(np.isfinite(out)))
|
||||
|
||||
@jtu.sample_product(
|
||||
n=[1, 4, 5, 20, 50, 100],
|
||||
batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()],
|
||||
|
Loading…
x
Reference in New Issue
Block a user