A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.
Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.
To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` - see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782