287 Commits

Author SHA1 Message Date
jax authors
162f09fc8d Stop recursion in spectral bisection eigensolver when the remaining sub-matrix has norm less than epsilon times the input matrix norm, which means that it is pure numerical noise.
PiperOrigin-RevId: 528891206
2023-05-02 14:35:07 -07:00
jax authors
c226b308de Fix bugs in the JAX spectral bisection eigensolver implementing the QDWH-eigh algorthm.
1) Padding with the identity caused the convergence criterion to trigger prematurely in case the norm of the original matrix is small compared to the identity.
2) Fix incorrectly scaled error norm in the subspace iteration.
3) Avoid scale-dependent norm check.

This change also tightens a number of test tolerances for eigh and svd in linalg_test.py, and adds a unit test for a matrix with tiny norm.

PiperOrigin-RevId: 526800258
2023-04-24 17:31:45 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
jax authors
f25b701b26 [XLA] Change criterion for annihilating off-diagonal elements in the 2x2 symmetric Schur decomposition used by eigh. This significantly improves the accuracy, and makes eigh exact for the identity matrix.
Modify the QDWH test so it doesn't have a dependence on eigh.

PiperOrigin-RevId: 523171958
2023-04-10 11:43:56 -07:00
Jake VanderPlas
ad0fc8979b jax.scipy.linalg.expm: support batched inputs 2023-03-27 16:39:48 -07:00
Rahul Batra
01a10a1d06 [ROCm] Re-enable some linalg and sparse tests 2023-02-07 22:05:14 +00:00
Peter Hawkins
b730ed4645 Remove placeholder functions for unimplemented NumPy functions.
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
2023-02-02 13:00:18 -05:00
Rahul Batra
3391a5e385 [ROCm]: Disable some tests on ROCm platform 2022-12-19 21:33:13 +00:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
Yotaro Kubo
1ade5f8592 Add jax.scipy.linalg.toeplitz. 2022-12-09 01:03:21 +09:00
Jake VanderPlas
924894fdd6 [x64] make tests more type-safe 2022-12-02 13:21:35 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Peter Hawkins
c1e1d64e66 [TPU] Add cutoff for nearly diagonal matrices in QDWH-eigh algorithm.
Fixes a bug where eigh returned NaNs for diagonal matrices, e.g., the identity matrix.

Nakatsukasa and Higham mention this stopping criterion in section 5.2 of Stable and Efficient Spectral Divide and Conquer Algorithms for the Symmetric Eigenvalue Decomposition and the SVD.

PiperOrigin-RevId: 490505832
2022-11-23 08:17:50 -08:00
Peter Hawkins
40e81c3a86 Revert: Use pinv to compute lstsq.
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.

This PR appears to have caused numerical problems in downstream tests.

PiperOrigin-RevId: 487942754
2022-11-11 16:24:05 -08:00
Peter Hawkins
7c3fb81310 Use pinv to compute lstsq.
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.

PiperOrigin-RevId: 487903227
2022-11-11 13:28:48 -08:00
Peter Hawkins
047974dd0c Be more economical when computing the JVP of the SVD of non-square matrices.
(Note this isn't a regression from #13147: the previous change did not alter the order of operations.)

PiperOrigin-RevId: 487896154
2022-11-11 12:55:53 -08:00
Peter Hawkins
c9ebf60f4e Compute the JVP of jnp.linalg.pinv more economically for non-square matrices.
The order of the matrix products matters.

PiperOrigin-RevId: 487879202
2022-11-11 11:45:02 -08:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Peter Hawkins
ab8cde9ed4 Add support for the hermitian option on jnp.linalg.pinv.
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
2022-11-08 08:53:00 -05:00
Peter Hawkins
845f8df837 Avoid forming identity matrix in SVD JVP.
Set the default matmul precision in the SVD JVP, and use @ to express matmuls.
Also fix a flaky test failure in QR test on Mac ARM.
2022-11-07 13:55:45 -05:00
Jake VanderPlas
9ade89ea62 jnp.linalg.lstsq: handle zero-size inputs 2022-10-24 14:10:31 -07:00
Peter Hawkins
807269990e Enable more GPU and TPU tests that pass at head.
Increase precision of matmuls in LU decompositions, pseudo-inverse solves, and their gradients. It is unlikely users want to use low precision for these operations and high precision is probably the right default.

PiperOrigin-RevId: 482071629
2022-10-18 18:09:44 -07:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
0d3277b5c3 Port more tests from jtu.cases_from_list to jtu.sample_product. 2022-10-11 21:06:08 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
1e242be03b Bump test tolerance for eigenvalue test.
This test fails by a small amount on Linux aarch64.
2022-08-16 14:44:16 -04:00
Jake VanderPlas
1a327f1ddd jnp.linalg.matrix_rank: support stacks of matrices 2022-08-08 11:16:34 -07:00
Anthony
736bed0344 Fix jax numpy norm bug where passing ord=inf always returns one 2022-08-05 11:37:31 -07:00
Jake VanderPlas
9090dd179d jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0 2022-07-19 13:57:49 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Sharad Vikram
fcf65ac64e Bump minimum jaxlib version to 0.3.10 2022-06-28 15:39:21 -07:00
Jake VanderPlas
b5ba210097 [x64] make linalg functions & tests compatible with strict dtype promotion 2022-06-16 10:32:20 -07:00
Peter Hawkins
5ccdcc5cc6 [TPU] Switch the default eigendecomposition implementation on TPU to use QDWH-eig.
Adds a new non-differentiable primitive `eigh_jacobi` that calls the XLA Jacobi eigh implementation for use inside the TPU QDWH-eigh lowering rule.

PiperOrigin-RevId: 451471088
2022-05-27 13:50:37 -07:00
Peter Hawkins
2eebd5b26d Add more tests for batched QR decomposition.
PiperOrigin-RevId: 449585655
2022-05-18 15:03:04 -07:00
jax authors
6110be40dc Merge pull request #10678 from JeppeKlitgaard:precommit-pyupgrade
PiperOrigin-RevId: 449561541
2022-05-18 13:25:52 -07:00
Yash Katariya
85646dcdd1 Disable slogdet test for qr decomposition
PiperOrigin-RevId: 449539738
2022-05-18 11:57:52 -07:00
Peter Hawkins
720d09c7df Add more tests for batched QR decomposition.
PiperOrigin-RevId: 449364433
2022-05-17 18:27:47 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
1bcb5e073c Add an implementation of jnp.linalg.slogdet based on QR decomposition.
Adds a non-standard `method` argument to `jnp.linalg.slogdet` to select between the current LU decomposition based implementation (like NumPy) and the QR decomposition implementation.

QR decomposition is more amenable to a high performance batched implementation particularly on TPU hardware because it does not need row pivoting. The same may be true on other hardware also, and having the option is nice either way!

PiperOrigin-RevId: 449271317
2022-05-17 11:24:11 -07:00
Peter Hawkins
909c0328b0 Decompose lax.linalg.qr into two subprimitives geqrf and orgqr.
In essence, this lifts the implementation of QR decomposition out of the lowering rules and into the JAX level instead.

This is useful because it allows direct access to the raw form of the decomposition returned by geqrf; sometimes we actually want access to the Householder reflectors instead of their product. Currently neither geqrf nor orgqr are differentiable in isolation.

Change in preparation for adding an implementation of jnp.linalg.slogdet that uses QR decomposition instead of LU decomposition.

Fixes https://github.com/google/jax/issues/2322

PiperOrigin-RevId: 449033350
2022-05-16 12:59:57 -07:00
Peter Hawkins
705e241409 Change non-array arguments to jax.lax.linalg functions to be keyword-only arguments.
PiperOrigin-RevId: 448066207
2022-05-11 13:06:54 -07:00
Peter Hawkins
590b9161fe Add a sort_eigenvalues option to lax.linalg.eigh().
An upcoming change to add a more scalable QDWH-based TPU symmetric eigendecomposition requires that we can obtain the TPU eigenvalues unsorted. The option already exists in XLA, so we simply need to plumb it through to the lax primitive.

PiperOrigin-RevId: 448047584
2022-05-11 11:46:03 -07:00
jax authors
1b0be5095a Merge pull request #10658 from ROCmSoftwarePlatform:rocm_unit_test_enablement
PiperOrigin-RevId: 448010649
2022-05-11 09:18:47 -07:00
Yash Katariya
dfb2caf31e Add nightly __version__ string if building jaxlib nightly
PiperOrigin-RevId: 447822974
2022-05-10 14:05:35 -07:00
Rohit Santhanam
8d9f17df19 Disabled one and enabled several unit tests for ROCm. 2022-05-10 19:47:26 +00:00
jax authors
8ee8a75566 Merge pull request #10490 from ajcr:add_jax_scipy_linalg_funm
PiperOrigin-RevId: 446570023
2022-05-04 15:42:39 -07:00
Jake VanderPlas
c6343ddf8e jax.scipy.linalg.schur: error on 16-bit floats
Fixes https://github.com/google/jax/issues/10530

PiperOrigin-RevId: 446279906
2022-05-03 13:47:44 -07:00
Alex Riley
372371cec6 Add jax.scipy.linalg.funm 2022-05-02 21:46:41 +01:00