mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #14169 from tttc3:polar_unitary
PiperOrigin-RevId: 505143284
This commit is contained in:
commit
e4c18b89fe
@ -7,6 +7,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
-->
|
||||
|
||||
## jax 0.4.3
|
||||
* Breaking changes
|
||||
* Deleted {func}`jax.scipy.linalg.polar_unitary`, which was a deprecated JAX
|
||||
extension to the scipy API. Use {func}`jax.scipy.linalg.polar` instead.
|
||||
|
||||
## jaxlib 0.4.3
|
||||
|
||||
|
@ -36,7 +36,6 @@ jax.scipy.linalg
|
||||
lu_factor
|
||||
lu_solve
|
||||
polar
|
||||
polar_unitary
|
||||
qr
|
||||
rsf2csf
|
||||
schur
|
||||
|
@ -875,22 +875,6 @@ def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: Optio
|
||||
return unitary, posdef
|
||||
|
||||
|
||||
def polar_unitary(a: ArrayLike, *, method: str = "qdwh", eps: Optional[float] = None,
|
||||
max_iterations: Optional[int] = None) -> Tuple[Array, Array]:
|
||||
""" Computes the unitary factor u in the polar decomposition ``a = u p``
|
||||
(or ``a = p u``).
|
||||
|
||||
.. warning::
|
||||
This function is deprecated. Use :func:`jax.scipy.linalg.polar` instead.
|
||||
"""
|
||||
# TODO(phawkins): delete this function after 2022/8/11.
|
||||
warnings.warn("jax.scipy.linalg.polar_unitary is deprecated. Call "
|
||||
"jax.scipy.linalg.polar instead.",
|
||||
DeprecationWarning)
|
||||
unitary, _ = polar(a, method, eps, max_iterations)
|
||||
return unitary
|
||||
|
||||
|
||||
@jit
|
||||
def _sqrtm_triu(T: Array) -> Array:
|
||||
"""
|
||||
|
@ -31,7 +31,6 @@ from jax._src.scipy.linalg import (
|
||||
lu_factor as lu_factor,
|
||||
lu_solve as lu_solve,
|
||||
polar as polar,
|
||||
polar_unitary as polar_unitary,
|
||||
qr as qr,
|
||||
rsf2csf as rsf2csf,
|
||||
schur as schur,
|
||||
|
Loading…
x
Reference in New Issue
Block a user