diff --git a/CHANGELOG.md b/CHANGELOG.md index 54ae487f7..0e5ca5383 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been removed. +* New features: + * Added a polar decomposition ({py:func}`jax.scipy.linalg.polar`). + * Bug fixes: * Tightened the checks for lax.argmin and lax.argmax to ensure they are not used with invalid `axis` value, or with an empty reduction dimension. diff --git a/jax/_src/scipy/eigh.py b/jax/_src/scipy/eigh.py index 09205f05a..f2f47d32e 100644 --- a/jax/_src/scipy/eigh.py +++ b/jax/_src/scipy/eigh.py @@ -231,7 +231,7 @@ def svd(A, precision=lax.Precision.HIGHEST): V_dag: An `n` by `n` unitary matrix of `A`'s conjugate transposed right singular vectors. """ - Up, H, _ = jsp.linalg.polar(A) + Up, H = jsp.linalg.polar(A) S, V = eigh(H, precision=precision) U = jnp.dot(Up, V, precision=precision) return U, S, V.conj().T diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index acd42f468..65518564d 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -22,6 +22,7 @@ import textwrap from jax import jit, vmap, jvp from jax import lax from jax._src.lax import linalg as lax_linalg +from jax._src.lax import polar as lax_polar from jax._src.numpy.util import _wraps from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import linalg as np_linalg @@ -589,3 +590,9 @@ def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper)) return mid + +@_wraps(scipy.linalg.polar) +def polar(a, side='right', method='qdwh', eps=None, maxiter=50): + unitary, posdef, _ = lax_polar.polar(a, side=side, method=method, eps=eps, + maxiter=maxiter) + return unitary, posdef diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 74e792c09..0ab55ee67 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -28,6 +28,7 @@ from jax._src.scipy.linalg import ( lu, lu_factor, lu_solve, + polar, qr, solve, solve_triangular, @@ -37,6 +38,5 @@ from jax._src.scipy.linalg import ( ) from jax._src.lax.polar import ( - polar, polar_unitary ) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index f350680ea..25c5215a3 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -499,7 +499,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): side=side) return - unitary, posdef, info = jsp.linalg.polar(matrix, method=method, side=side) + unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side) if shape[0] >= shape[1]: should_be_eye = np.matmul(unitary.conj().T, unitary) else: