rocm_jax/jax/scipy/linalg.py
Peter Hawkins 1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00

46 lines
1.2 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.scipy.linalg import (
block_diag as block_diag,
cholesky as cholesky,
cho_factor as cho_factor,
cho_solve as cho_solve,
det as det,
eigh as eigh,
eigh_tridiagonal as eigh_tridiagonal,
expm as expm,
expm_frechet as expm_frechet,
hessenberg as hessenberg,
inv as inv,
lu as lu,
lu_factor as lu_factor,
lu_solve as lu_solve,
polar as polar,
polar_unitary as polar_unitary,
qr as qr,
rsf2csf as rsf2csf,
schur as schur,
sqrtm as sqrtm,
solve as solve,
solve_triangular as solve_triangular,
svd as svd,
tril as tril,
triu as triu,
)
from jax._src.third_party.scipy.linalg import (
funm as funm,
)