Remove additional info return value from jax.scipy.linalg.polar().

This commit is contained in:
Peter Hawkins 2021-07-20 13:12:13 -04:00
parent c95ef8799d
commit 0dfd76af97
5 changed files with 13 additions and 3 deletions

View File

@ -19,6 +19,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
removed.
* New features:
* Added a polar decomposition ({py:func}`jax.scipy.linalg.polar`).
* Bug fixes:
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with invalid `axis` value, or with an empty reduction dimension.

View File

@ -231,7 +231,7 @@ def svd(A, precision=lax.Precision.HIGHEST):
V_dag: An `n` by `n` unitary matrix of `A`'s conjugate transposed
right singular vectors.
"""
Up, H, _ = jsp.linalg.polar(A)
Up, H = jsp.linalg.polar(A)
S, V = eigh(H, precision=precision)
U = jnp.dot(Up, V, precision=precision)
return U, S, V.conj().T

View File

@ -22,6 +22,7 @@ import textwrap
from jax import jit, vmap, jvp
from jax import lax
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import polar as lax_polar
from jax._src.numpy.util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg as np_linalg
@ -589,3 +590,9 @@ def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a',
_, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
return mid
@_wraps(scipy.linalg.polar)
def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
unitary, posdef, _ = lax_polar.polar(a, side=side, method=method, eps=eps,
maxiter=maxiter)
return unitary, posdef

View File

@ -28,6 +28,7 @@ from jax._src.scipy.linalg import (
lu,
lu_factor,
lu_solve,
polar,
qr,
solve,
solve_triangular,
@ -37,6 +38,5 @@ from jax._src.scipy.linalg import (
)
from jax._src.lax.polar import (
polar,
polar_unitary
)

View File

@ -499,7 +499,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
side=side)
return
unitary, posdef, info = jsp.linalg.polar(matrix, method=method, side=side)
unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side)
if shape[0] >= shape[1]:
should_be_eye = np.matmul(unitary.conj().T, unitary)
else: