Don't mask out zero elements on the diagonal of the matrix when inverting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes https://github.com/google/jax/issues/3589
Fixes https://github.com/google/jax/issues/15429

PiperOrigin-RevId: 653562611
This commit is contained in:
Peter Hawkins 2024-07-18 04:08:55 -07:00 committed by jax authors
parent 174429d7cf
commit 47e6da3332
2 changed files with 13 additions and 0 deletions

View File

@ -42,6 +42,8 @@ Remember to align the itemized text with the first line of an item within a list
* Bug fixes
* Fixed a bug that meant that negative static_argnums to a jit were mishandled
by the jit dispatch fast path.
* Fixed a bug that meant triangular solves of batches of singular matrices
produce nonsensical finite values, instead of inf or nan (#3589, #15429).
## jax 0.4.30 (June 18, 2024)

View File

@ -16,6 +16,7 @@
from functools import partial
import itertools
import unittest
import numpy as np
import scipy
@ -33,6 +34,7 @@ from jax._src import config
from jax._src.lax import linalg as lax_linalg
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_extension_version
from jax._src.numpy.util import promote_dtypes_inexact
config.parse_flags_with_absl()
@ -1623,6 +1625,15 @@ class ScipyLinalgTest(jtu.JaxTestCase):
(a, b),
(a, b))
@unittest.skipIf(xla_extension_version < 277, "Requires jaxlib > 0.4.30")
def testTriangularSolveSingularBatched(self):
x = jnp.array([[1, 1], [0, 0]], dtype=np.float32)
y = jnp.array([[1], [1.]], dtype=np.float32)
out = jax.lax.linalg.triangular_solve(x[None], y[None], left_side=True)
# x is singular. The triangular solve may contain either nans or infs, but
# it should not consist of only finite values.
self.assertFalse(np.all(np.isfinite(out)))
@jtu.sample_product(
n=[1, 4, 5, 20, 50, 100],
batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()],