mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove additional info return value from jax.scipy.linalg.polar().
This commit is contained in:
parent
c95ef8799d
commit
0dfd76af97
@ -19,6 +19,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
|
* The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
|
||||||
removed.
|
removed.
|
||||||
|
|
||||||
|
* New features:
|
||||||
|
* Added a polar decomposition ({py:func}`jax.scipy.linalg.polar`).
|
||||||
|
|
||||||
* Bug fixes:
|
* Bug fixes:
|
||||||
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
|
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
|
||||||
not used with invalid `axis` value, or with an empty reduction dimension.
|
not used with invalid `axis` value, or with an empty reduction dimension.
|
||||||
|
@ -231,7 +231,7 @@ def svd(A, precision=lax.Precision.HIGHEST):
|
|||||||
V_dag: An `n` by `n` unitary matrix of `A`'s conjugate transposed
|
V_dag: An `n` by `n` unitary matrix of `A`'s conjugate transposed
|
||||||
right singular vectors.
|
right singular vectors.
|
||||||
"""
|
"""
|
||||||
Up, H, _ = jsp.linalg.polar(A)
|
Up, H = jsp.linalg.polar(A)
|
||||||
S, V = eigh(H, precision=precision)
|
S, V = eigh(H, precision=precision)
|
||||||
U = jnp.dot(Up, V, precision=precision)
|
U = jnp.dot(Up, V, precision=precision)
|
||||||
return U, S, V.conj().T
|
return U, S, V.conj().T
|
||||||
|
@ -22,6 +22,7 @@ import textwrap
|
|||||||
from jax import jit, vmap, jvp
|
from jax import jit, vmap, jvp
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax._src.lax import linalg as lax_linalg
|
from jax._src.lax import linalg as lax_linalg
|
||||||
|
from jax._src.lax import polar as lax_polar
|
||||||
from jax._src.numpy.util import _wraps
|
from jax._src.numpy.util import _wraps
|
||||||
from jax._src.numpy import lax_numpy as jnp
|
from jax._src.numpy import lax_numpy as jnp
|
||||||
from jax._src.numpy import linalg as np_linalg
|
from jax._src.numpy import linalg as np_linalg
|
||||||
@ -589,3 +590,9 @@ def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a',
|
|||||||
|
|
||||||
_, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
|
_, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
|
||||||
return mid
|
return mid
|
||||||
|
|
||||||
|
@_wraps(scipy.linalg.polar)
|
||||||
|
def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
|
||||||
|
unitary, posdef, _ = lax_polar.polar(a, side=side, method=method, eps=eps,
|
||||||
|
maxiter=maxiter)
|
||||||
|
return unitary, posdef
|
||||||
|
@ -28,6 +28,7 @@ from jax._src.scipy.linalg import (
|
|||||||
lu,
|
lu,
|
||||||
lu_factor,
|
lu_factor,
|
||||||
lu_solve,
|
lu_solve,
|
||||||
|
polar,
|
||||||
qr,
|
qr,
|
||||||
solve,
|
solve,
|
||||||
solve_triangular,
|
solve_triangular,
|
||||||
@ -37,6 +38,5 @@ from jax._src.scipy.linalg import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from jax._src.lax.polar import (
|
from jax._src.lax.polar import (
|
||||||
polar,
|
|
||||||
polar_unitary
|
polar_unitary
|
||||||
)
|
)
|
||||||
|
@ -499,7 +499,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
|||||||
side=side)
|
side=side)
|
||||||
return
|
return
|
||||||
|
|
||||||
unitary, posdef, info = jsp.linalg.polar(matrix, method=method, side=side)
|
unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side)
|
||||||
if shape[0] >= shape[1]:
|
if shape[0] >= shape[1]:
|
||||||
should_be_eye = np.matmul(unitary.conj().T, unitary)
|
should_be_eye = np.matmul(unitary.conj().T, unitary)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user