20 Commits

Author SHA1 Message Date
tttc3
c89be05b5b Enable pivoted QR on CPU devices.
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
2025-01-09 20:44:45 +00:00
Jake VanderPlas
fd897745d3 Partial rollback of https://github.com/google/jax/pull/23353 as discussed in https://github.com/google/jax/pull/23353#issuecomment-2326604708
Reverts eed273c106af699efefc726eea1ff2b0f548f669

PiperOrigin-RevId: 670596159
2024-09-03 09:49:22 -07:00
Jake VanderPlas
7b41583414 refactor jax.lax to not depend on jax.numpy 2024-09-01 07:49:49 -07:00
rajasekharporeddy
b93da3873b Fix Typos 2024-06-17 13:55:46 +05:30
jax authors
b2654c08f9 Improve performance of SVD when batched by avoiding several cond() constructs.
This also simplifies the code by not special casing the code for all-zero inputs.

PiperOrigin-RevId: 628518807
2024-04-26 15:00:50 -07:00
jax authors
51763d8b5d Fix bug in rank-deficient fix-up code: Do not zero out the corresponding column of u_out if a diagonal entry of r is exactly zero.
PiperOrigin-RevId: 626056825
2024-04-18 09:20:48 -07:00
jax authors
7e7094c82d [JAX] Add an option subset_by_index that allows computing a contiguous subset of singular components from svd.
PiperOrigin-RevId: 607493941
2024-02-15 16:33:09 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
jax authors
21051fffc9 Fix corner cases in JAX SVD: a) Clamp negative singular values to zero. b) Return all NaN for matrices with non-finite values.
PiperOrigin-RevId: 540015938
2023-06-13 11:06:49 -07:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
4f62cef1f5 [x64] Make TPU svd compatible with strict type promotion 2022-06-08 15:38:27 -07:00
Peter Hawkins
db73670ec3 Add support for padded arrays in QDWH algorithm.
This change is in preparation for adding a jit-table QDWH-eig implementation.

PiperOrigin-RevId: 448571523
2022-05-13 13:57:36 -07:00
Tianjian Lu
4bc1c1c004 [linalg] Add svd on zero matrix.
PiperOrigin-RevId: 447521398
2022-05-09 11:29:22 -07:00
Tianjian Lu
1093559856 [linalg] Add matmul precision scope for svd.
PiperOrigin-RevId: 447095391
2022-05-06 16:33:50 -07:00
Tianjian Lu
020849076c [linalg] Add tpu svd lowering rule.
PiperOrigin-RevId: 445533767
2022-04-29 16:43:53 -07:00
Tianjian Lu
455c9f823e [linalg] Adds full_matrices option to TPU SVD.
PiperOrigin-RevId: 443163571
2022-04-20 12:32:00 -07:00
Tianjian Lu
5a1c5ba114 [linalg] Adds compute_uv to TPU SVD.
PiperOrigin-RevId: 442864883
2022-04-19 11:28:43 -07:00
Tianjian Lu
5a012d5e7b [JAX] Added jit-able singular value decomposition.
PiperOrigin-RevId: 426193395
2022-02-03 11:16:55 -08:00