diff --git a/CHANGELOG.md b/CHANGELOG.md index e9e62eb6e..0ca89a2c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,11 +13,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * Changes * {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument that allows users to opt out of eigenvalue sorting on TPU. +* Deprecations * Non-array arguments to functions in {mod}`jax.lax.linalg` are now marked keyword-only. As a backward-compatibility step passing keyword-only arguments positionally yields a warning, but in a future JAX release passing keyword-only arguments positionally will fail. However, most users should prefer to use {mod}`jax.numpy.linalg` instead. + * {func}`jax.scipy.linalg.polar_unitary`, which was a JAX extension to the + scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead. ## jaxlib 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). diff --git a/jax/_src/lax/polar.py b/jax/_src/lax/polar.py deleted file mode 100644 index 5c6c3cce6..000000000 --- a/jax/_src/lax/polar.py +++ /dev/null @@ -1,343 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License - - -""" -Functions to compute the polar decomposition of the m x n matrix A, A = U @ H -where U is unitary (an m x n isometry in the m > n case) and H is n x n and -positive semidefinite (or positive definite if A is nonsingular). The method -is described in the docstring to `polarU`. This file covers the serial -case. -""" -import functools -import jax -from jax import lax -import jax.numpy as jnp -import jax.scipy as jsp - - -# TODO: Allow singular value estimates to be manually specified -@jax.jit -def _add_to_diagonal(X, val): - new_diagonal = X.diagonal() + val - diag_indices = jnp.diag_indices(X.shape[0]) - return X.at[diag_indices].set(new_diagonal) - - -@jax.jit -def _dot(a, b): - return jnp.dot(a, b, precision=lax.Precision.HIGHEST) - - -def polar(a, side='right', method='qdwh', eps=None, maxiter=50): - """ Computes the polar decomposition. - - Given the (m x n) matrix `a`, returns the factors of the polar decomposition - `u` (m x n) and `p` such that `a = up` (if side is "right"; p is (n x n)) or - `a = pu` (if side is "left"; p is (m x m)), where `p` is positive - semidefinite. If `a` is nonsingular, `p` is positive definite and the - decomposition is unique. `u` has orthonormal columns unless n > m, in which - case it has orthonormal rows. - - Writing an SVD of `a` as `a = u_svd @ s_svd @ v^h_svd`, we have - `u = u_svd @ v^h_svd`. Thus the unitary factor `u` can be construed as - the application of the signum function to the singular values of `a`; - or, if `a` is Hermitian, the eigenvalues. - - Several methods exist to compute the polar decomposition. Currently two - are supported: - `method`="svd": Computes the SVD of `a` and then forms - `u = u_svd @ v^h_svd`. This fails on the TPU, since - no SVD algorithm independent of the polar decomposition - exists there. - `method`="qdwh": Applies a certain iterative expansion of the matrix - signum function to `a` based on QR and Cholesky - decompositions. - - Args: - a: The m x n input matrix. - side: Determines whether a right or left polar decomposition is computed. - If side is "right" then `a = up`. If side is "left" then `a = pu`. The - default is "right". - method: Determines the algorithm used, as described above. - precision: :class:`~jax.lax.Precision` object specifying the matmul precision. - - The remaining arguments are only meaningful if method is "qdwh". - eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) . - maxiter: Iterations will terminate after this many steps even if the - above is unsatisfied. - Returns: - unitary: The unitary factor (m x n). - posdef: The positive-semidefinite factor. Either (n, n) or (m, m) - depending on whether side is "right" or "left", respectively. - info: Stores convergence information. - if method is "svd": None - if method is "qdwh": j_qr: Number of QR iterations. - j_chol: Number of Cholesky iterations. - errs: Convergence history. - """ - return _polar(a, side, method, eps, maxiter) - - -@functools.partial(jax.jit, static_argnums=(1, 2, 4)) -def _polar(a, side, method, eps, maxiter): - if side not in ("left", "right"): - raise ValueError(f"side={side} was invalid.") - - unitary, info = _polar_unitary(a, method, eps, maxiter) - if side == "right": - posdef = _dot(unitary.conj().T, a) - else: - posdef = _dot(a, unitary.conj().T) - posdef = 0.5 * (posdef + posdef.conj().T) - return unitary, posdef, info - - -def polar_unitary(a, method="qdwh", eps=None, maxiter=50): - """ Computes the unitary factor u in the polar decomposition `a = u p` - (or `a = p u`). - """ - return _polar_unitary(a, method, eps, maxiter) - - -@functools.partial(jax.jit, static_argnums=(1, 3)) -def _polar_unitary(a, method, eps, maxiter): - if method not in ("svd", "qdwh"): - raise ValueError(f"method={method} is unsupported.") - - if method == "svd": - u_svd, _, vh_svd = jnp.linalg.svd(a, full_matrices=False) - unitary = _dot(u_svd, vh_svd) - info = None - elif method == "qdwh": - unitary, j_qr, j_chol, errs = _qdwh(a, eps, maxiter) - info = (j_qr, j_chol, errs) - else: - raise ValueError("How did we get here?") - return unitary, info - - -@functools.partial(jax.jit, static_argnums=(2,)) -def _qdwh(matrix, eps, maxiter): - """ Computes the unitary factor in the polar decomposition of A using - the QDWH method. QDWH implements a 3rd order Pade approximation to the - matrix sign function, - - X' = X * (aI + b X^H X)(I + c X^H X)^-1, X0 = A / ||A||_2. (1) - - The coefficients a, b, and c are chosen dynamically based on an evolving - estimate of the matrix condition number. Specifically, - - a = h(l), b = g(a), c = a + b - 1, h(x) = x g(x^2), g(x) = a + bx / (1 + cx) - - where l is initially a lower bound on the smallest singular value of X0, - and subsequently evolves according to l' = l (a + bl^2) / (1 + c l^2). - - For poorly conditioned matrices - (c > 100) the iteration (1) is rewritten in QR form, - - X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X] (2) - [Q2] [I ]. - - For well-conditioned matrices it is instead formulated using cheaper - Cholesky iterations, - - X' = (b / c) X + (a - b/c) (X W^-1) W^-H, W = chol(I + c X^H X). (3) - - The QR iterations rapidly improve the condition number, and typically - only 1 or 2 are required. A maximum of 6 iterations total are required - for backwards stability to double precision. - - Args: - matrix: The m x n input matrix. - eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) . - maxiter: Iterations will terminate after this many steps even if the - above is unsatisfied. - Returns: - matrix: The unitary factor (m x n). - jq: The number of QR iterations (1). - jc: The number of Cholesky iterations (2). - errs: Convergence history. - """ - n_rows, n_cols = matrix.shape - fat = n_cols > n_rows - if fat: - matrix = matrix.T - matrix, q_factor, l0 = _initialize_qdwh(matrix) - - if eps is None: - eps = jnp.finfo(matrix.dtype).eps - tol_lk = 5 * eps # stop when lk differs from 1 by less - tol_delta = jnp.cbrt(tol_lk) # stop when the iterates change by less - coefs = _qdwh_coefs(l0) - errs = jnp.zeros(maxiter, dtype=matrix.real.dtype) - matrix, j_qr, coefs, errs = _qdwh_qr( - matrix, coefs, errs, tol_lk, tol_delta, maxiter) - matrix, j_chol, errs = _qdwh_cholesky( - matrix, coefs, errs, tol_lk, tol_delta, j_qr, maxiter) - matrix = _dot(q_factor, matrix) - - if fat: - matrix = matrix.T - return matrix, j_qr, j_chol, errs - - -@jax.jit -def _initialize_qdwh(matrix): - """ Does preparatory computations for QDWH: - 1. Computes an initial QR factorization of the input A. The iterations - will be on the triangular factor R, whose condition is more easily - estimated, and which is square even when A is rectangular. - 2. Computes R -> R / ||R||_F. Now 1 is used to upper-bound ||R||_2. - 3. Computes R^-1 by solving R R^-1 = I. - 4. Uses sqrt(N) * ||R^-1||_1 as a lower bound for ||R^-2||. - 1 / sqrt(N) * ||R^-1||_1 is then used as the initial l_0. It should be clear - there is room for improvement here. - - Returns: - X = R / ||R||_F; - Q from A -> Q @ R; - l0, the initial estimate for the QDWH coefficients. - """ - q_factor, r_factor = jnp.linalg.qr(matrix, mode="reduced") - alpha = jnp.linalg.norm(r_factor) - r_factor /= alpha - eye = jnp.eye(*r_factor.shape, dtype=r_factor.dtype) - r_inv = jsp.linalg.solve_triangular(r_factor, eye, overwrite_b=True) - one_norm_inv = jnp.linalg.norm(r_inv, ord=1) - l0 = 1 / (jnp.sqrt(matrix.shape[1]) * one_norm_inv) - eps = jnp.finfo(r_factor.dtype).eps - l0 = jnp.array(l0, dtype=r_factor.real.dtype) - l0 = jnp.where(l0 < eps, x=eps, y=l0) - l0 = jnp.where(l0 > 1.0, x=1.0, y=l0) - return r_factor, q_factor, l0 - - -@jax.jit -def _qdwh_coefs(lk): - """ Computes a, b, c, l for the QDWH iterations. - The input lk must be in (0, 1]; lk=1 is a fixed point. - Some facts about the coefficients: - -for lk = 1 we have a=3, b=1, c=3, lk_new = 1. - -The float64 vs float32 computation of each coef appears to differ - only by noise on the order of 1E-9 to 1E-7 for all values of lk. - There is no apparent secular change in the (relative) error. - -All coefs change roughly as power laws; over e.g. [1E-14, 1]: - - a decreases from 5.43E9 to 3. - - b decreases from 7.37E18 to 1. - - c decreases from 7.37E18 to 3, only diverging from b near lk=1. - - lk increases from 5.45E-5 to 1. - - lk is an estimate of the scaled matrix's smallest singular value - """ - lk = jnp.where(lk > 1.0, x=1.0, y=lk) - d = (4. * (1. - lk**2) / (lk**4))**(1 / 3) - f = 8. * (2. - lk**2) / (lk**2 * (1. + d)**(1 / 2)) - a = (1. + d)**(1 / 2) + 0.5 * (8. - 4. * d + f)**0.5 - b = (a - 1.)**2 / 4 - c = a + b - 1. - lk = lk * (a + b * lk**2) / (1 + c * lk**2) - return a, b, c, lk - - -@jax.jit -def _unconverged(lk, j, maxiter, err, tol_delta, tol_lk): - changing = err > tol_delta - far_from_end = jnp.abs(1 - lk) > tol_lk - unconverged = jnp.logical_or(changing, far_from_end) - iterating = j < maxiter - return jnp.logical_and(iterating, unconverged)[0] - - -@jax.jit -def _qdwh_qr(matrix, coefs, errs, tol_lk, tol_delta, maxiter): - """ Applies the QDWH iteration formulated as - - X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X] - [Q2] [I ] - - to X until either c < 100, ||X' - X|| < eps||X'||, - or the iteration count exceeds maxiter. - """ - n_rows, n_cols = matrix.shape - eye = jnp.eye(n_cols, dtype=matrix.dtype) - - def _do_qr(args): - _, j, coefs, _, err = args - c = coefs[2] - lk = coefs[-1] - unconverged = _unconverged(lk, j, maxiter, err, tol_delta, tol_lk) - ill_conditioned = c > 100. - return jnp.logical_and(ill_conditioned, unconverged) - - def _qr_work(args): - matrix, j, coefs, errs, _ = args - a, b, c, lk = coefs - csqrt = jnp.sqrt(c) - matrixI = jnp.vstack((csqrt * matrix, eye)) - # Note: it should be possible to compute the QR of csqrt * matrix - # and build the concatenation with I at O(N). - Q, _ = jnp.linalg.qr(matrixI, mode="reduced") - Q1 = Q[:n_rows, :] - Q2 = Q[n_rows:, :] - coef = (1 / csqrt) * (a - (b / c)) - new_matrix = (b / c) * matrix + coef * _dot(Q1, Q2.T.conj()) - err = jnp.linalg.norm(matrix - new_matrix) - err = jnp.full(1, err).astype(errs[0].dtype) - errs = errs.at[j].set(err) - coefs = _qdwh_coefs(lk) - return new_matrix, j + 1, coefs, errs, err - - j = jnp.zeros(1, dtype=jnp.int32) - err = jnp.full(1, 2 * tol_delta).astype(matrix.real.dtype) - matrix, j, coefs, errs, _ = jax.lax.while_loop( - _do_qr, _qr_work, (matrix, j, coefs, errs, err)) - return matrix, j, coefs, errs - - -@jax.jit -def _qdwh_cholesky(matrix, coefs, errs, tol_delta, tol_lk, j0, maxiter): - """ Applies the QDWH iteration formulated as - - matrix' = (b / c) matrix + (a - b/c) B, - B = (matrix W^-1) W^-H, W = chol(I + c matrix^H matrix). - - to matrix until either ||matrix' - matrix|| < eps * ||matrix'||, - or the iteration count exceeds maxiter. - """ - - def _do_cholesky(args): - _, j, coefs, errs = args - lk = coefs[-1] - return _unconverged(lk, j, maxiter, errs[j - 1], tol_delta, tol_lk) - - def _cholesky_work(args): - matrix, j, coefs, errs = args - a, b, c, lk = coefs - Z = c * _dot(matrix.T.conj(), matrix) - Z = _add_to_diagonal(Z, 1.) - W = jsp.linalg.cholesky(Z) - B = jsp.linalg.solve_triangular(W.T, matrix.T, lower=True).conj() - B = jsp.linalg.solve_triangular(W, B).conj().T - new_matrix = (b / c) * matrix + (a - b / c) * B - # possible instability if a ~ b / c - err = jnp.linalg.norm(new_matrix - matrix).astype(errs[0].dtype) - errs = errs.at[j].set(err) - coefs = _qdwh_coefs(lk) - return new_matrix, j + 1, coefs, errs - - carry = (matrix, j0, coefs, errs) - matrix, j_total, coefs, errs = jax.lax.while_loop( - _do_cholesky, _cholesky_work, carry) - return matrix, j_total - j0, errs diff --git a/jax/_src/lax/qdwh.py b/jax/_src/lax/qdwh.py index 003cd4fa5..73a69735e 100644 --- a/jax/_src/lax/qdwh.py +++ b/jax/_src/lax/qdwh.py @@ -65,14 +65,14 @@ def _use_cholesky(u, params): return u -@functools.partial(jax.jit, static_argnums=(1, 2, 3)) -def _qdwh(x, is_symmetric, max_iterations): +def _qdwh(x, is_hermitian, max_iterations, eps): """QR-based dynamically weighted Halley iteration for polar decomposition.""" # Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of # norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for # the smallest singular value of x. - eps = jnp.finfo(x.dtype).eps + if eps is None: + eps = jnp.finfo(x.dtype).eps alpha = (jnp.sqrt(jnp.linalg.norm(x, ord=1)) * jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf))) l = eps @@ -113,7 +113,7 @@ def _qdwh(x, is_symmetric, max_iterations): u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u)) - if is_symmetric: + if is_hermitian: u = (u + u.T.conj()) / 2.0 # Checks convergence. @@ -145,13 +145,17 @@ def _qdwh(x, is_symmetric, max_iterations): # TODO: Add pivoting. -def qdwh(x, is_symmetric, max_iterations=10): +@functools.partial(jax.jit, static_argnames=('is_hermitian',)) +def qdwh(x, is_hermitian=False, max_iterations=None, eps=None): """QR-based dynamically weighted Halley iteration for polar decomposition. Args: - x: A full-rank matrix of shape `m x n` with `m >= n`. - is_symmetric: True if `x` is symmetric. - max_iterations: The predefined maximum number of iterations. + x: A full-rank matrix of shape `m x n`. + is_hermitian: True if `x` is Hermitian. Default to `False`. + eps: The final result will satisfy + ``|x_k - x_k-1| < |x_k| * (4*eps)**(1/3)`` where `x_k` is the iterate. + max_iterations: Iterations will terminate after this many steps even if the + above is unsatisfied. Returns: A four-tuple of (u, h, num_iters, is_converged) containing the @@ -159,29 +163,19 @@ def qdwh(x, is_symmetric, max_iterations=10): and `is_converged`, whose value is `True` when the convergence is achieved within the maximum number of iterations. """ - m, n = x.shape + is_hermitian = core.concrete_or_error( + bool, is_hermitian, 'The `is_hermitian` argument must be statically ' + 'specified to use `qdwh` within JAX transformations.') + if max_iterations is None: + max_iterations = 10 + + m, n = x.shape if m < n: raise ValueError('The input matrix of shape m x n must have m >= n.') - max_iterations = core.concrete_or_error( - int, max_iterations, 'The `max_iterations` argument must be statically ' - 'specified to use `qdwh` within JAX transformations.') - - is_symmetric = core.concrete_or_error( - bool, is_symmetric, 'The `is_symmetric` argument must be statically ' - 'specified to use `qdwh` within JAX transformations.') - - if is_symmetric: - eps = jnp.finfo(x.dtype).eps - tol = 50.0 * eps - relative_diff = jnp.linalg.norm(x - x.T.conj()) / jnp.linalg.norm(x) - if relative_diff > tol: - raise ValueError('The input `x` is NOT symmetric because ' - '`norm(x-x.H) / norm(x)` is {}, which is greater than ' - 'the tolerance {}.'.format(relative_diff, tol)) - with jax.default_matmul_precision('float32'): - u, h, num_iters, is_converged = _qdwh(x, is_symmetric, max_iterations) + u, h, num_iters, is_converged = _qdwh(x, is_hermitian, max_iterations, eps) + return u, h, num_iters, is_converged diff --git a/jax/_src/scipy/eigh.py b/jax/_src/scipy/eigh.py index f172c3396..8f54c3da0 100644 --- a/jax/_src/scipy/eigh.py +++ b/jax/_src/scipy/eigh.py @@ -20,6 +20,7 @@ import jax import jax.numpy as jnp import jax.scipy as jsp from jax import lax +from jax._src.lax import qdwh def _similarity_transform( @@ -141,7 +142,7 @@ def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST): return X.at[jnp.diag_indices(X.shape[0])].set(vals) H_shift = _fill_diagonal(H, H.diagonal() - split_point) - U, _ = jsp.linalg.polar_unitary(H_shift) + U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True) P = -0.5 * _fill_diagonal(U, U.diagonal() - 1.) rank = jnp.round(jnp.trace(P)).astype(jnp.int32) rank = int(rank) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 3d25a0816..8b90d5e43 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -18,11 +18,13 @@ from functools import partial import numpy as np import scipy.linalg import textwrap +import warnings +import jax from jax import jit, vmap, jvp from jax import lax from jax._src.lax import linalg as lax_linalg -from jax._src.lax import polar as lax_polar +from jax._src.lax import qdwh from jax._src.numpy.util import _wraps, _promote_dtypes_inexact from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import linalg as np_linalg @@ -619,11 +621,111 @@ def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper)) return mid -@_wraps(scipy.linalg.polar) -def polar(a, side='right', method='qdwh', eps=None, maxiter=50): - unitary, posdef, _ = lax_polar.polar(a, side=side, method=method, eps=eps, - maxiter=maxiter) - return unitary, posdef +@partial(jit, static_argnames=('side', 'method')) +@jax.default_matmul_precision("float32") +def polar(a, side='right', *, method='qdwh', eps=None, max_iterations=None): + r"""Computes the polar decomposition. + + Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar + decomposition :math:`u` (also :math:`m \times n`) and :math:`p` such that + :math:`a = up` (if side is ``"right"``; :math:`p` is :math:`n \times n`) or + :math:`a = pu` (if side is ``"left"``; :math:`p` is :math:`m \times m`), + where :math:`p` is positive semidefinite. If :math:`a` is nonsingular, + :math:`p` is positive definite and the + decomposition is unique. :math:`u` has orthonormal columns unless + :math:`n > m`, in which case it has orthonormal rows. + + Writing the SVD of :math:`a` as + :math:`a = u_\mathit{svd} \cdot s_\mathit{svd} \cdot v^h_\mathit{svd}`, we + have :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. Thus the unitary + factor :math:`u` can be constructed as the application of the sign function to + the singular values of :math:`a`; or, if :math:`a` is Hermitian, the + eigenvalues. + + Several methods exist to compute the polar decomposition. Currently two + are supported: + + * ``method="svd"``: + + Computes the SVD of :math:`a` and then forms + :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. + + * ``method="qdwh"``: + + Applies the `QDWH`_ (QR-based Dynamically Weighted Halley) algorithm. + + Args: + a: The :math:`m \times n` input matrix. + side: Determines whether a right or left polar decomposition is computed. + If ``side`` is ``"right"`` then :math:`a = up`. If ``side`` is ``"left"`` + then :math:`a = pu`. The default is ``"right"``. + method: Determines the algorithm used, as described above. + precision: :class:`~jax.lax.Precision` object specifying the matmul precision. + eps: The final result will satisfy + :math:`\left|x_k - x_{k-1}\right| < \left|x_k\right| (4\epsilon)^{\frac{1}{3}}`, + where :math:`x_k` are the QDWH iterates. Ignored if ``method`` is not + ``"qdwh"``. + max_iterations: Iterations will terminate after this many steps even if the + above is unsatisfied. Ignored if ``method`` is not ``"qdwh"``. + + Returns: + A ``(unitary, posdef)`` tuple, where ``unitary`` is the unitary factor + (:math:`m \times n`), and ``posdef`` is the positive-semidefinite factor. + ``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on + whether ``side`` is ``"right"`` or ``"left"``, respectively. + + .. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999 + """ + a = jnp.asarray(a) + if a.ndim != 2: + raise ValueError("The input `a` must be a 2-D array.") + + if side not in ["right", "left"]: + raise ValueError("The argument `side` must be either 'right' or 'left'.") + + m, n = a.shape + if method == "qdwh": + # TODO(phawkins): return info also if the user opts in? + if m >= n and side == "right": + unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps) + elif m < n and side == "left": + a = a.T.conj() + unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps) + posdef = posdef.T.conj() + unitary = unitary.T.conj() + else: + raise NotImplementedError("method='qdwh' only supports mxn matrices " + "where m < n where side='right' and m >= n " + f"side='left', got {a.shape} with side={side}") + elif method == "svd": + u_svd, s_svd, vh_svd = lax_linalg.svd(a, full_matrices=False) + unitary = u_svd @ vh_svd + if side == "right": + # a = u * p + posdef = (vh_svd.T.conj() * s_svd[None, :]) @ vh_svd + else: + # a = p * u + posdef = (u_svd * s_svd[None, :]) @ (u_svd.T.conj()) + else: + raise ValueError(f"Unknown polar decomposition method {method}.") + + return unitary, posdef + + +def polar_unitary(a, *, method="qdwh", eps=None, max_iterations=None): + """ Computes the unitary factor u in the polar decomposition ``a = u p`` + (or ``a = p u``). + + .. warning:: + This function is deprecated. Use :func:`jax.scipy.linalg.polar` instead. + """ + # TODO(phawkins): delete this function after 2022/8/11. + warnings.warn("jax.scipy.linalg.polar_unitary is deprecated. Call " + "jax.scipy.linalg.polar instead.", + DeprecationWarning) + unitary, _ = polar(a, method, eps, max_iterations) + return unitary + @jit def _sqrtm_triu(T): diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 4dadcf30a..1a7b9cee5 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -27,6 +27,7 @@ from jax._src.scipy.linalg import ( lu_factor as lu_factor, lu_solve as lu_solve, polar as polar, + polar_unitary as polar_unitary, qr as qr, rsf2csf as rsf2csf, schur as schur, @@ -38,10 +39,6 @@ from jax._src.scipy.linalg import ( triu as triu, ) -from jax._src.lax.polar import ( - polar_unitary as polar_unitary, -) - from jax._src.third_party.scipy.linalg import ( funm as funm, ) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 5dcc15249..7c363cffb 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -490,7 +490,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): for method in methods for side in sides for nonzero_condition_number in nonzero_condition_numbers - for dtype in jtu.dtypes.floating + for dtype in jtu.dtypes.inexact for seed in seeds)) @jtu.skip_on_devices("gpu") # Fails on A100. def testPolar( @@ -500,8 +500,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase): if jtu.device_under_test() != "cpu": if jnp.dtype(dtype).name in ("bfloat16", "float16"): raise unittest.SkipTest("Skip half precision off CPU.") - if method == "svd": - raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.") + + m, n = shape + if (method == "qdwh" and ((side == "left" and m >= n) or + (side == "right" and m < n))): + raise unittest.SkipTest("method=qdwh does not support these sizes") matrix, _ = _initialize_polar_test(self.rng(), shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv, @@ -517,7 +520,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): should_be_eye = np.matmul(unitary.conj().T, unitary) else: should_be_eye = np.matmul(unitary, unitary.conj().T) - tol = 10 * jnp.finfo(matrix.dtype).eps + tol = 500 * jnp.finfo(matrix.dtype).eps eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype) with self.subTest('Test unitarity.'): self.assertAllClose( diff --git a/tests/qdwh_test.py b/tests/qdwh_test.py index a1fcf611c..02ca41814 100644 --- a/tests/qdwh_test.py +++ b/tests/qdwh_test.py @@ -45,12 +45,12 @@ def _check_symmetry(x: jnp.ndarray) -> bool: m, n = x.shape eps = jnp.finfo(x.dtype).eps tol = 50.0 * eps - is_symmetric = False + is_hermitian = False if m == n: if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol: - is_symmetric = True + is_hermitian = True - return is_symmetric + return is_hermitian def _compute_relative_diff(actual, expected): """Computes relative difference between two matrices.""" @@ -75,11 +75,11 @@ class QdwhTest(jtu.JaxTestCase): cond = 10**log_cond s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) a = (u * s) @ v - is_symmetric = _check_symmetry(a) + is_hermitian = _check_symmetry(a) max_iterations = 2 _, _, actual_num_iterations, is_converged = qdwh.qdwh( - a, is_symmetric, max_iterations) + a, is_hermitian, max_iterations) with self.subTest('Number of iterations.'): self.assertEqual(max_iterations, actual_num_iterations) @@ -102,10 +102,10 @@ class QdwhTest(jtu.JaxTestCase): cond = 10**log_cond s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) a = (u * s) @ v - is_symmetric = _check_symmetry(a) + is_hermitian = _check_symmetry(a) max_iterations = 10 - actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations) + actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian, max_iterations) expected_u, expected_h = osp_linalg.polar(a) # Sets the test tolerance. @@ -145,12 +145,12 @@ class QdwhTest(jtu.JaxTestCase): cond = 10**log_cond s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) a = (u * s) @ v - is_symmetric = _check_symmetry(a) + is_hermitian = _check_symmetry(a) max_iterations = 10 def lsp_linalg_fn(a): u, h, _, _ = qdwh.qdwh( - a, is_symmetric=is_symmetric, max_iterations=max_iterations) + a, is_hermitian=is_hermitian, max_iterations=max_iterations) return u, h args_maker = lambda: [a] @@ -185,9 +185,9 @@ class QdwhTest(jtu.JaxTestCase): s = jnp.expand_dims(s.at[-1].set(0), range(u.ndim - 1)) a = (u * s) @ v - is_symmetric = _check_symmetry(a) + is_hermitian = _check_symmetry(a) max_iterations = 15 - actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations) + actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian, max_iterations) _, expected_h = osp_linalg.polar(a) # Sets the test tolerance. @@ -221,13 +221,13 @@ class QdwhTest(jtu.JaxTestCase): tiny_elem = jnp.finfo(a).tiny a = a.at[r, c].set(tiny_elem) - is_symmetric = _check_symmetry(a) + is_hermitian = _check_symmetry(a) max_iterations = 10 @jax.jit def lsp_linalg_fn(a): u, h, _, _ = qdwh.qdwh( - a, is_symmetric=is_symmetric, max_iterations=max_iterations) + a, is_hermitian=is_hermitian, max_iterations=max_iterations) return u, h actual_u, actual_h = lsp_linalg_fn(a)