The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
1) Padding with the identity caused the convergence criterion to trigger prematurely in case the norm of the original matrix is small compared to the identity.
2) Fix incorrectly scaled error norm in the subspace iteration.
3) Avoid scale-dependent norm check.
This change also tightens a number of test tolerances for eigh and svd in linalg_test.py, and adds a unit test for a matrix with tiny norm.
PiperOrigin-RevId: 526800258
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
Fixes a bug where eigh returned NaNs for diagonal matrices, e.g., the identity matrix.
Nakatsukasa and Higham mention this stopping criterion in section 5.2 of Stable and Efficient Spectral Divide and Conquer Algorithms for the Symmetric Eigenvalue Decomposition and the SVD.
PiperOrigin-RevId: 490505832
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.
This PR appears to have caused numerical problems in downstream tests.
PiperOrigin-RevId: 487942754
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.
PiperOrigin-RevId: 487903227
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.
Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.
PiperOrigin-RevId: 487621469
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.
It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.
Fix many test cases that were shown to be broken with a larger number of test cases enabled.
PiperOrigin-RevId: 487406670
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
Increase precision of matmuls in LU decompositions, pseudo-inverse solves, and their gradients. It is unlikely users want to use low precision for these operations and high precision is probably the right default.
PiperOrigin-RevId: 482071629
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
Adds a new non-differentiable primitive `eigh_jacobi` that calls the XLA Jacobi eigh implementation for use inside the TPU QDWH-eigh lowering rule.
PiperOrigin-RevId: 451471088