mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks to the `geqp3` routine of LAPACK. To provide the same functionality in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK routine via the FFI on CPU devices. Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the use of column-pivoting on CPU devices. To provide a GPU implementation of `geqp3` may require using MAGMA, due to the lack of a `geqp3` implementation in `cuSolver` - see ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for an example of using MAGMA in GPU lowerings. Such a GPU implementation can be considered in the future.