mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove additional info return value from jax.scipy.linalg.polar().
This commit is contained in:
parent
c95ef8799d
commit
0dfd76af97
@ -19,6 +19,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
|
||||
removed.
|
||||
|
||||
* New features:
|
||||
* Added a polar decomposition ({py:func}`jax.scipy.linalg.polar`).
|
||||
|
||||
* Bug fixes:
|
||||
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
|
||||
not used with invalid `axis` value, or with an empty reduction dimension.
|
||||
|
@ -231,7 +231,7 @@ def svd(A, precision=lax.Precision.HIGHEST):
|
||||
V_dag: An `n` by `n` unitary matrix of `A`'s conjugate transposed
|
||||
right singular vectors.
|
||||
"""
|
||||
Up, H, _ = jsp.linalg.polar(A)
|
||||
Up, H = jsp.linalg.polar(A)
|
||||
S, V = eigh(H, precision=precision)
|
||||
U = jnp.dot(Up, V, precision=precision)
|
||||
return U, S, V.conj().T
|
||||
|
@ -22,6 +22,7 @@ import textwrap
|
||||
from jax import jit, vmap, jvp
|
||||
from jax import lax
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.lax import polar as lax_polar
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import linalg as np_linalg
|
||||
@ -589,3 +590,9 @@ def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a',
|
||||
|
||||
_, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
|
||||
return mid
|
||||
|
||||
@_wraps(scipy.linalg.polar)
|
||||
def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
|
||||
unitary, posdef, _ = lax_polar.polar(a, side=side, method=method, eps=eps,
|
||||
maxiter=maxiter)
|
||||
return unitary, posdef
|
||||
|
@ -28,6 +28,7 @@ from jax._src.scipy.linalg import (
|
||||
lu,
|
||||
lu_factor,
|
||||
lu_solve,
|
||||
polar,
|
||||
qr,
|
||||
solve,
|
||||
solve_triangular,
|
||||
@ -37,6 +38,5 @@ from jax._src.scipy.linalg import (
|
||||
)
|
||||
|
||||
from jax._src.lax.polar import (
|
||||
polar,
|
||||
polar_unitary
|
||||
)
|
||||
|
@ -499,7 +499,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
side=side)
|
||||
return
|
||||
|
||||
unitary, posdef, info = jsp.linalg.polar(matrix, method=method, side=side)
|
||||
unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side)
|
||||
if shape[0] >= shape[1]:
|
||||
should_be_eye = np.matmul(unitary.conj().T, unitary)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user