300 Commits

Author SHA1 Message Date
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Skye Wanderman-Milne
72b1eb3205 Bump NumpyLinalgTest.testEighRankDeficient tolerance
Otherwise it sometimes fails on Cloud TPU v5e.
2023-09-29 18:43:33 +00:00
Peter Hawkins
ef6fd2ebb6 Bump test tolerance for sqrtm test.
This test fails on ARM with a LAPACK built with gfortran 11.

PiperOrigin-RevId: 569540626
2023-09-29 11:15:48 -07:00
Peter Hawkins
6be860bda8 Clean up some device opt-in/opt-outs in test suite.
Use allowlists rather than denylists in a few places.

PiperOrigin-RevId: 568968749
2023-09-27 14:56:00 -07:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Peter Hawkins
65de2cf907 Relax tolerance of a test that fails on Linux aarch64.
PiperOrigin-RevId: 565749687
2023-09-15 12:31:43 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Peter Hawkins
619377ebc1 Second attempt at fixing funm tolerance for LLVM change.
An LLVM change seems to have made this test fail. The impact seems small, so we can just relax the test tolerance.

PiperOrigin-RevId: 556886248
2023-08-14 13:01:19 -07:00
Peter Hawkins
3a40cc3ca9 Relax test tolerance for funm test.
PiperOrigin-RevId: 556838400
2023-08-14 10:33:01 -07:00
Artem Belevich
d49b67a73a Disable tests that trigger a known bug in cublasDtrsmBatched in cuda-12 on sm_60.
PiperOrigin-RevId: 548727690
2023-07-17 10:17:21 -07:00
Rahul Batra
2650c14cf5 [ROCm]: Re-enable EighIdentity test 2023-06-23 17:51:43 +00:00
Peter Hawkins
0adfafe293 Relax test tolerances.
This makes the tests pass on CPU with a slightly different seed (+ 1).

PiperOrigin-RevId: 542877795
2023-06-23 09:22:11 -07:00
jax authors
68614b4dcc [XLA:TPU] Fix a bug in eigh that caused a slight loss of accuracy.
PiperOrigin-RevId: 529406623
2023-05-04 07:49:04 -07:00
jax authors
162f09fc8d Stop recursion in spectral bisection eigensolver when the remaining sub-matrix has norm less than epsilon times the input matrix norm, which means that it is pure numerical noise.
PiperOrigin-RevId: 528891206
2023-05-02 14:35:07 -07:00
jax authors
c226b308de Fix bugs in the JAX spectral bisection eigensolver implementing the QDWH-eigh algorthm.
1) Padding with the identity caused the convergence criterion to trigger prematurely in case the norm of the original matrix is small compared to the identity.
2) Fix incorrectly scaled error norm in the subspace iteration.
3) Avoid scale-dependent norm check.

This change also tightens a number of test tolerances for eigh and svd in linalg_test.py, and adds a unit test for a matrix with tiny norm.

PiperOrigin-RevId: 526800258
2023-04-24 17:31:45 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
jax authors
f25b701b26 [XLA] Change criterion for annihilating off-diagonal elements in the 2x2 symmetric Schur decomposition used by eigh. This significantly improves the accuracy, and makes eigh exact for the identity matrix.
Modify the QDWH test so it doesn't have a dependence on eigh.

PiperOrigin-RevId: 523171958
2023-04-10 11:43:56 -07:00
Jake VanderPlas
ad0fc8979b jax.scipy.linalg.expm: support batched inputs 2023-03-27 16:39:48 -07:00
Rahul Batra
01a10a1d06 [ROCm] Re-enable some linalg and sparse tests 2023-02-07 22:05:14 +00:00
Peter Hawkins
b730ed4645 Remove placeholder functions for unimplemented NumPy functions.
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
2023-02-02 13:00:18 -05:00
Rahul Batra
3391a5e385 [ROCm]: Disable some tests on ROCm platform 2022-12-19 21:33:13 +00:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
Yotaro Kubo
1ade5f8592 Add jax.scipy.linalg.toeplitz. 2022-12-09 01:03:21 +09:00
Jake VanderPlas
924894fdd6 [x64] make tests more type-safe 2022-12-02 13:21:35 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Peter Hawkins
c1e1d64e66 [TPU] Add cutoff for nearly diagonal matrices in QDWH-eigh algorithm.
Fixes a bug where eigh returned NaNs for diagonal matrices, e.g., the identity matrix.

Nakatsukasa and Higham mention this stopping criterion in section 5.2 of Stable and Efficient Spectral Divide and Conquer Algorithms for the Symmetric Eigenvalue Decomposition and the SVD.

PiperOrigin-RevId: 490505832
2022-11-23 08:17:50 -08:00
Peter Hawkins
40e81c3a86 Revert: Use pinv to compute lstsq.
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.

This PR appears to have caused numerical problems in downstream tests.

PiperOrigin-RevId: 487942754
2022-11-11 16:24:05 -08:00
Peter Hawkins
7c3fb81310 Use pinv to compute lstsq.
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.

PiperOrigin-RevId: 487903227
2022-11-11 13:28:48 -08:00
Peter Hawkins
047974dd0c Be more economical when computing the JVP of the SVD of non-square matrices.
(Note this isn't a regression from #13147: the previous change did not alter the order of operations.)

PiperOrigin-RevId: 487896154
2022-11-11 12:55:53 -08:00
Peter Hawkins
c9ebf60f4e Compute the JVP of jnp.linalg.pinv more economically for non-square matrices.
The order of the matrix products matters.

PiperOrigin-RevId: 487879202
2022-11-11 11:45:02 -08:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Peter Hawkins
ab8cde9ed4 Add support for the hermitian option on jnp.linalg.pinv.
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
2022-11-08 08:53:00 -05:00
Peter Hawkins
845f8df837 Avoid forming identity matrix in SVD JVP.
Set the default matmul precision in the SVD JVP, and use @ to express matmuls.
Also fix a flaky test failure in QR test on Mac ARM.
2022-11-07 13:55:45 -05:00
Jake VanderPlas
9ade89ea62 jnp.linalg.lstsq: handle zero-size inputs 2022-10-24 14:10:31 -07:00
Peter Hawkins
807269990e Enable more GPU and TPU tests that pass at head.
Increase precision of matmuls in LU decompositions, pseudo-inverse solves, and their gradients. It is unlikely users want to use low precision for these operations and high precision is probably the right default.

PiperOrigin-RevId: 482071629
2022-10-18 18:09:44 -07:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
0d3277b5c3 Port more tests from jtu.cases_from_list to jtu.sample_product. 2022-10-11 21:06:08 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
1e242be03b Bump test tolerance for eigenvalue test.
This test fails by a small amount on Linux aarch64.
2022-08-16 14:44:16 -04:00
Jake VanderPlas
1a327f1ddd jnp.linalg.matrix_rank: support stacks of matrices 2022-08-08 11:16:34 -07:00
Anthony
736bed0344 Fix jax numpy norm bug where passing ord=inf always returns one 2022-08-05 11:37:31 -07:00
Jake VanderPlas
9090dd179d jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0 2022-07-19 13:57:49 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Sharad Vikram
fcf65ac64e Bump minimum jaxlib version to 0.3.10 2022-06-28 15:39:21 -07:00
Jake VanderPlas
b5ba210097 [x64] make linalg functions & tests compatible with strict dtype promotion 2022-06-16 10:32:20 -07:00
Peter Hawkins
5ccdcc5cc6 [TPU] Switch the default eigendecomposition implementation on TPU to use QDWH-eig.
Adds a new non-differentiable primitive `eigh_jacobi` that calls the XLA Jacobi eigh implementation for use inside the TPU QDWH-eigh lowering rule.

PiperOrigin-RevId: 451471088
2022-05-27 13:50:37 -07:00
Peter Hawkins
2eebd5b26d Add more tests for batched QR decomposition.
PiperOrigin-RevId: 449585655
2022-05-18 15:03:04 -07:00
jax authors
6110be40dc Merge pull request #10678 from JeppeKlitgaard:precommit-pyupgrade
PiperOrigin-RevId: 449561541
2022-05-18 13:25:52 -07:00