1) Padding with the identity caused the convergence criterion to trigger prematurely in case the norm of the original matrix is small compared to the identity.
2) Fix incorrectly scaled error norm in the subspace iteration.
3) Avoid scale-dependent norm check.
This change also tightens a number of test tolerances for eigh and svd in linalg_test.py, and adds a unit test for a matrix with tiny norm.
PiperOrigin-RevId: 526800258
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
Fixes a bug where eigh returned NaNs for diagonal matrices, e.g., the identity matrix.
Nakatsukasa and Higham mention this stopping criterion in section 5.2 of Stable and Efficient Spectral Divide and Conquer Algorithms for the Symmetric Eigenvalue Decomposition and the SVD.
PiperOrigin-RevId: 490505832
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.
This PR appears to have caused numerical problems in downstream tests.
PiperOrigin-RevId: 487942754
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.
PiperOrigin-RevId: 487903227
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.
Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.
PiperOrigin-RevId: 487621469
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.
It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.
Fix many test cases that were shown to be broken with a larger number of test cases enabled.
PiperOrigin-RevId: 487406670
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
Increase precision of matmuls in LU decompositions, pseudo-inverse solves, and their gradients. It is unlikely users want to use low precision for these operations and high precision is probably the right default.
PiperOrigin-RevId: 482071629
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
Adds a new non-differentiable primitive `eigh_jacobi` that calls the XLA Jacobi eigh implementation for use inside the TPU QDWH-eigh lowering rule.
PiperOrigin-RevId: 451471088
Adds a non-standard `method` argument to `jnp.linalg.slogdet` to select between the current LU decomposition based implementation (like NumPy) and the QR decomposition implementation.
QR decomposition is more amenable to a high performance batched implementation particularly on TPU hardware because it does not need row pivoting. The same may be true on other hardware also, and having the option is nice either way!
PiperOrigin-RevId: 449271317
In essence, this lifts the implementation of QR decomposition out of the lowering rules and into the JAX level instead.
This is useful because it allows direct access to the raw form of the decomposition returned by geqrf; sometimes we actually want access to the Householder reflectors instead of their product. Currently neither geqrf nor orgqr are differentiable in isolation.
Change in preparation for adding an implementation of jnp.linalg.slogdet that uses QR decomposition instead of LU decomposition.
Fixes https://github.com/google/jax/issues/2322
PiperOrigin-RevId: 449033350
An upcoming change to add a more scalable QDWH-based TPU symmetric eigendecomposition requires that we can obtain the TPU eigenvalues unsorted. The option already exists in XLA, so we simply need to plumb it through to the lax primitive.
PiperOrigin-RevId: 448047584