Change implementation of jax.scipy.linalg.polar() and jax._src.scipy.eigh to use the QDWH decomposition from jax._src.lax.qdwh.

Remove jax._src.lax.polar.

PiperOrigin-RevId: 448241206
This commit is contained in:
Peter Hawkins 2022-05-12 07:15:55 -07:00 committed by jax authors
parent 880cfc9c79
commit 7ba36fc178
8 changed files with 155 additions and 398 deletions

View File

@ -13,11 +13,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Changes * Changes
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument * {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
that allows users to opt out of eigenvalue sorting on TPU. that allows users to opt out of eigenvalue sorting on TPU.
* Deprecations
* Non-array arguments to functions in {mod}`jax.lax.linalg` are now marked * Non-array arguments to functions in {mod}`jax.lax.linalg` are now marked
keyword-only. As a backward-compatibility step passing keyword-only keyword-only. As a backward-compatibility step passing keyword-only
arguments positionally yields a warning, but in a future JAX release passing arguments positionally yields a warning, but in a future JAX release passing
keyword-only arguments positionally will fail. keyword-only arguments positionally will fail.
However, most users should prefer to use {mod}`jax.numpy.linalg` instead. However, most users should prefer to use {mod}`jax.numpy.linalg` instead.
* {func}`jax.scipy.linalg.polar_unitary`, which was a JAX extension to the
scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead.
## jaxlib 0.3.11 (Unreleased) ## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -1,343 +0,0 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""
Functions to compute the polar decomposition of the m x n matrix A, A = U @ H
where U is unitary (an m x n isometry in the m > n case) and H is n x n and
positive semidefinite (or positive definite if A is nonsingular). The method
is described in the docstring to `polarU`. This file covers the serial
case.
"""
import functools
import jax
from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
# TODO: Allow singular value estimates to be manually specified
@jax.jit
def _add_to_diagonal(X, val):
new_diagonal = X.diagonal() + val
diag_indices = jnp.diag_indices(X.shape[0])
return X.at[diag_indices].set(new_diagonal)
@jax.jit
def _dot(a, b):
return jnp.dot(a, b, precision=lax.Precision.HIGHEST)
def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
""" Computes the polar decomposition.
Given the (m x n) matrix `a`, returns the factors of the polar decomposition
`u` (m x n) and `p` such that `a = up` (if side is "right"; p is (n x n)) or
`a = pu` (if side is "left"; p is (m x m)), where `p` is positive
semidefinite. If `a` is nonsingular, `p` is positive definite and the
decomposition is unique. `u` has orthonormal columns unless n > m, in which
case it has orthonormal rows.
Writing an SVD of `a` as `a = u_svd @ s_svd @ v^h_svd`, we have
`u = u_svd @ v^h_svd`. Thus the unitary factor `u` can be construed as
the application of the signum function to the singular values of `a`;
or, if `a` is Hermitian, the eigenvalues.
Several methods exist to compute the polar decomposition. Currently two
are supported:
`method`="svd": Computes the SVD of `a` and then forms
`u = u_svd @ v^h_svd`. This fails on the TPU, since
no SVD algorithm independent of the polar decomposition
exists there.
`method`="qdwh": Applies a certain iterative expansion of the matrix
signum function to `a` based on QR and Cholesky
decompositions.
Args:
a: The m x n input matrix.
side: Determines whether a right or left polar decomposition is computed.
If side is "right" then `a = up`. If side is "left" then `a = pu`. The
default is "right".
method: Determines the algorithm used, as described above.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
The remaining arguments are only meaningful if method is "qdwh".
eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
maxiter: Iterations will terminate after this many steps even if the
above is unsatisfied.
Returns:
unitary: The unitary factor (m x n).
posdef: The positive-semidefinite factor. Either (n, n) or (m, m)
depending on whether side is "right" or "left", respectively.
info: Stores convergence information.
if method is "svd": None
if method is "qdwh": j_qr: Number of QR iterations.
j_chol: Number of Cholesky iterations.
errs: Convergence history.
"""
return _polar(a, side, method, eps, maxiter)
@functools.partial(jax.jit, static_argnums=(1, 2, 4))
def _polar(a, side, method, eps, maxiter):
if side not in ("left", "right"):
raise ValueError(f"side={side} was invalid.")
unitary, info = _polar_unitary(a, method, eps, maxiter)
if side == "right":
posdef = _dot(unitary.conj().T, a)
else:
posdef = _dot(a, unitary.conj().T)
posdef = 0.5 * (posdef + posdef.conj().T)
return unitary, posdef, info
def polar_unitary(a, method="qdwh", eps=None, maxiter=50):
""" Computes the unitary factor u in the polar decomposition `a = u p`
(or `a = p u`).
"""
return _polar_unitary(a, method, eps, maxiter)
@functools.partial(jax.jit, static_argnums=(1, 3))
def _polar_unitary(a, method, eps, maxiter):
if method not in ("svd", "qdwh"):
raise ValueError(f"method={method} is unsupported.")
if method == "svd":
u_svd, _, vh_svd = jnp.linalg.svd(a, full_matrices=False)
unitary = _dot(u_svd, vh_svd)
info = None
elif method == "qdwh":
unitary, j_qr, j_chol, errs = _qdwh(a, eps, maxiter)
info = (j_qr, j_chol, errs)
else:
raise ValueError("How did we get here?")
return unitary, info
@functools.partial(jax.jit, static_argnums=(2,))
def _qdwh(matrix, eps, maxiter):
""" Computes the unitary factor in the polar decomposition of A using
the QDWH method. QDWH implements a 3rd order Pade approximation to the
matrix sign function,
X' = X * (aI + b X^H X)(I + c X^H X)^-1, X0 = A / ||A||_2. (1)
The coefficients a, b, and c are chosen dynamically based on an evolving
estimate of the matrix condition number. Specifically,
a = h(l), b = g(a), c = a + b - 1, h(x) = x g(x^2), g(x) = a + bx / (1 + cx)
where l is initially a lower bound on the smallest singular value of X0,
and subsequently evolves according to l' = l (a + bl^2) / (1 + c l^2).
For poorly conditioned matrices
(c > 100) the iteration (1) is rewritten in QR form,
X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X] (2)
[Q2] [I ].
For well-conditioned matrices it is instead formulated using cheaper
Cholesky iterations,
X' = (b / c) X + (a - b/c) (X W^-1) W^-H, W = chol(I + c X^H X). (3)
The QR iterations rapidly improve the condition number, and typically
only 1 or 2 are required. A maximum of 6 iterations total are required
for backwards stability to double precision.
Args:
matrix: The m x n input matrix.
eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
maxiter: Iterations will terminate after this many steps even if the
above is unsatisfied.
Returns:
matrix: The unitary factor (m x n).
jq: The number of QR iterations (1).
jc: The number of Cholesky iterations (2).
errs: Convergence history.
"""
n_rows, n_cols = matrix.shape
fat = n_cols > n_rows
if fat:
matrix = matrix.T
matrix, q_factor, l0 = _initialize_qdwh(matrix)
if eps is None:
eps = jnp.finfo(matrix.dtype).eps
tol_lk = 5 * eps # stop when lk differs from 1 by less
tol_delta = jnp.cbrt(tol_lk) # stop when the iterates change by less
coefs = _qdwh_coefs(l0)
errs = jnp.zeros(maxiter, dtype=matrix.real.dtype)
matrix, j_qr, coefs, errs = _qdwh_qr(
matrix, coefs, errs, tol_lk, tol_delta, maxiter)
matrix, j_chol, errs = _qdwh_cholesky(
matrix, coefs, errs, tol_lk, tol_delta, j_qr, maxiter)
matrix = _dot(q_factor, matrix)
if fat:
matrix = matrix.T
return matrix, j_qr, j_chol, errs
@jax.jit
def _initialize_qdwh(matrix):
""" Does preparatory computations for QDWH:
1. Computes an initial QR factorization of the input A. The iterations
will be on the triangular factor R, whose condition is more easily
estimated, and which is square even when A is rectangular.
2. Computes R -> R / ||R||_F. Now 1 is used to upper-bound ||R||_2.
3. Computes R^-1 by solving R R^-1 = I.
4. Uses sqrt(N) * ||R^-1||_1 as a lower bound for ||R^-2||.
1 / sqrt(N) * ||R^-1||_1 is then used as the initial l_0. It should be clear
there is room for improvement here.
Returns:
X = R / ||R||_F;
Q from A -> Q @ R;
l0, the initial estimate for the QDWH coefficients.
"""
q_factor, r_factor = jnp.linalg.qr(matrix, mode="reduced")
alpha = jnp.linalg.norm(r_factor)
r_factor /= alpha
eye = jnp.eye(*r_factor.shape, dtype=r_factor.dtype)
r_inv = jsp.linalg.solve_triangular(r_factor, eye, overwrite_b=True)
one_norm_inv = jnp.linalg.norm(r_inv, ord=1)
l0 = 1 / (jnp.sqrt(matrix.shape[1]) * one_norm_inv)
eps = jnp.finfo(r_factor.dtype).eps
l0 = jnp.array(l0, dtype=r_factor.real.dtype)
l0 = jnp.where(l0 < eps, x=eps, y=l0)
l0 = jnp.where(l0 > 1.0, x=1.0, y=l0)
return r_factor, q_factor, l0
@jax.jit
def _qdwh_coefs(lk):
""" Computes a, b, c, l for the QDWH iterations.
The input lk must be in (0, 1]; lk=1 is a fixed point.
Some facts about the coefficients:
-for lk = 1 we have a=3, b=1, c=3, lk_new = 1.
-The float64 vs float32 computation of each coef appears to differ
only by noise on the order of 1E-9 to 1E-7 for all values of lk.
There is no apparent secular change in the (relative) error.
-All coefs change roughly as power laws; over e.g. [1E-14, 1]:
- a decreases from 5.43E9 to 3.
- b decreases from 7.37E18 to 1.
- c decreases from 7.37E18 to 3, only diverging from b near lk=1.
- lk increases from 5.45E-5 to 1.
lk is an estimate of the scaled matrix's smallest singular value
"""
lk = jnp.where(lk > 1.0, x=1.0, y=lk)
d = (4. * (1. - lk**2) / (lk**4))**(1 / 3)
f = 8. * (2. - lk**2) / (lk**2 * (1. + d)**(1 / 2))
a = (1. + d)**(1 / 2) + 0.5 * (8. - 4. * d + f)**0.5
b = (a - 1.)**2 / 4
c = a + b - 1.
lk = lk * (a + b * lk**2) / (1 + c * lk**2)
return a, b, c, lk
@jax.jit
def _unconverged(lk, j, maxiter, err, tol_delta, tol_lk):
changing = err > tol_delta
far_from_end = jnp.abs(1 - lk) > tol_lk
unconverged = jnp.logical_or(changing, far_from_end)
iterating = j < maxiter
return jnp.logical_and(iterating, unconverged)[0]
@jax.jit
def _qdwh_qr(matrix, coefs, errs, tol_lk, tol_delta, maxiter):
""" Applies the QDWH iteration formulated as
X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X]
[Q2] [I ]
to X until either c < 100, ||X' - X|| < eps||X'||,
or the iteration count exceeds maxiter.
"""
n_rows, n_cols = matrix.shape
eye = jnp.eye(n_cols, dtype=matrix.dtype)
def _do_qr(args):
_, j, coefs, _, err = args
c = coefs[2]
lk = coefs[-1]
unconverged = _unconverged(lk, j, maxiter, err, tol_delta, tol_lk)
ill_conditioned = c > 100.
return jnp.logical_and(ill_conditioned, unconverged)
def _qr_work(args):
matrix, j, coefs, errs, _ = args
a, b, c, lk = coefs
csqrt = jnp.sqrt(c)
matrixI = jnp.vstack((csqrt * matrix, eye))
# Note: it should be possible to compute the QR of csqrt * matrix
# and build the concatenation with I at O(N).
Q, _ = jnp.linalg.qr(matrixI, mode="reduced")
Q1 = Q[:n_rows, :]
Q2 = Q[n_rows:, :]
coef = (1 / csqrt) * (a - (b / c))
new_matrix = (b / c) * matrix + coef * _dot(Q1, Q2.T.conj())
err = jnp.linalg.norm(matrix - new_matrix)
err = jnp.full(1, err).astype(errs[0].dtype)
errs = errs.at[j].set(err)
coefs = _qdwh_coefs(lk)
return new_matrix, j + 1, coefs, errs, err
j = jnp.zeros(1, dtype=jnp.int32)
err = jnp.full(1, 2 * tol_delta).astype(matrix.real.dtype)
matrix, j, coefs, errs, _ = jax.lax.while_loop(
_do_qr, _qr_work, (matrix, j, coefs, errs, err))
return matrix, j, coefs, errs
@jax.jit
def _qdwh_cholesky(matrix, coefs, errs, tol_delta, tol_lk, j0, maxiter):
""" Applies the QDWH iteration formulated as
matrix' = (b / c) matrix + (a - b/c) B,
B = (matrix W^-1) W^-H, W = chol(I + c matrix^H matrix).
to matrix until either ||matrix' - matrix|| < eps * ||matrix'||,
or the iteration count exceeds maxiter.
"""
def _do_cholesky(args):
_, j, coefs, errs = args
lk = coefs[-1]
return _unconverged(lk, j, maxiter, errs[j - 1], tol_delta, tol_lk)
def _cholesky_work(args):
matrix, j, coefs, errs = args
a, b, c, lk = coefs
Z = c * _dot(matrix.T.conj(), matrix)
Z = _add_to_diagonal(Z, 1.)
W = jsp.linalg.cholesky(Z)
B = jsp.linalg.solve_triangular(W.T, matrix.T, lower=True).conj()
B = jsp.linalg.solve_triangular(W, B).conj().T
new_matrix = (b / c) * matrix + (a - b / c) * B
# possible instability if a ~ b / c
err = jnp.linalg.norm(new_matrix - matrix).astype(errs[0].dtype)
errs = errs.at[j].set(err)
coefs = _qdwh_coefs(lk)
return new_matrix, j + 1, coefs, errs
carry = (matrix, j0, coefs, errs)
matrix, j_total, coefs, errs = jax.lax.while_loop(
_do_cholesky, _cholesky_work, carry)
return matrix, j_total - j0, errs

View File

@ -65,14 +65,14 @@ def _use_cholesky(u, params):
return u return u
@functools.partial(jax.jit, static_argnums=(1, 2, 3)) def _qdwh(x, is_hermitian, max_iterations, eps):
def _qdwh(x, is_symmetric, max_iterations):
"""QR-based dynamically weighted Halley iteration for polar decomposition.""" """QR-based dynamically weighted Halley iteration for polar decomposition."""
# Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of # Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of
# norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for # norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for
# the smallest singular value of x. # the smallest singular value of x.
eps = jnp.finfo(x.dtype).eps if eps is None:
eps = jnp.finfo(x.dtype).eps
alpha = (jnp.sqrt(jnp.linalg.norm(x, ord=1)) * alpha = (jnp.sqrt(jnp.linalg.norm(x, ord=1)) *
jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf))) jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf)))
l = eps l = eps
@ -113,7 +113,7 @@ def _qdwh(x, is_symmetric, max_iterations):
u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u)) u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u))
if is_symmetric: if is_hermitian:
u = (u + u.T.conj()) / 2.0 u = (u + u.T.conj()) / 2.0
# Checks convergence. # Checks convergence.
@ -145,13 +145,17 @@ def _qdwh(x, is_symmetric, max_iterations):
# TODO: Add pivoting. # TODO: Add pivoting.
def qdwh(x, is_symmetric, max_iterations=10): @functools.partial(jax.jit, static_argnames=('is_hermitian',))
def qdwh(x, is_hermitian=False, max_iterations=None, eps=None):
"""QR-based dynamically weighted Halley iteration for polar decomposition. """QR-based dynamically weighted Halley iteration for polar decomposition.
Args: Args:
x: A full-rank matrix of shape `m x n` with `m >= n`. x: A full-rank matrix of shape `m x n`.
is_symmetric: True if `x` is symmetric. is_hermitian: True if `x` is Hermitian. Default to `False`.
max_iterations: The predefined maximum number of iterations. eps: The final result will satisfy
``|x_k - x_k-1| < |x_k| * (4*eps)**(1/3)`` where `x_k` is the iterate.
max_iterations: Iterations will terminate after this many steps even if the
above is unsatisfied.
Returns: Returns:
A four-tuple of (u, h, num_iters, is_converged) containing the A four-tuple of (u, h, num_iters, is_converged) containing the
@ -159,29 +163,19 @@ def qdwh(x, is_symmetric, max_iterations=10):
and `is_converged`, whose value is `True` when the convergence is achieved and `is_converged`, whose value is `True` when the convergence is achieved
within the maximum number of iterations. within the maximum number of iterations.
""" """
m, n = x.shape is_hermitian = core.concrete_or_error(
bool, is_hermitian, 'The `is_hermitian` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
if max_iterations is None:
max_iterations = 10
m, n = x.shape
if m < n: if m < n:
raise ValueError('The input matrix of shape m x n must have m >= n.') raise ValueError('The input matrix of shape m x n must have m >= n.')
max_iterations = core.concrete_or_error(
int, max_iterations, 'The `max_iterations` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
is_symmetric = core.concrete_or_error(
bool, is_symmetric, 'The `is_symmetric` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
if is_symmetric:
eps = jnp.finfo(x.dtype).eps
tol = 50.0 * eps
relative_diff = jnp.linalg.norm(x - x.T.conj()) / jnp.linalg.norm(x)
if relative_diff > tol:
raise ValueError('The input `x` is NOT symmetric because '
'`norm(x-x.H) / norm(x)` is {}, which is greater than '
'the tolerance {}.'.format(relative_diff, tol))
with jax.default_matmul_precision('float32'): with jax.default_matmul_precision('float32'):
u, h, num_iters, is_converged = _qdwh(x, is_symmetric, max_iterations) u, h, num_iters, is_converged = _qdwh(x, is_hermitian, max_iterations, eps)
return u, h, num_iters, is_converged return u, h, num_iters, is_converged

View File

@ -20,6 +20,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax.scipy as jsp import jax.scipy as jsp
from jax import lax from jax import lax
from jax._src.lax import qdwh
def _similarity_transform( def _similarity_transform(
@ -141,7 +142,7 @@ def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST):
return X.at[jnp.diag_indices(X.shape[0])].set(vals) return X.at[jnp.diag_indices(X.shape[0])].set(vals)
H_shift = _fill_diagonal(H, H.diagonal() - split_point) H_shift = _fill_diagonal(H, H.diagonal() - split_point)
U, _ = jsp.linalg.polar_unitary(H_shift) U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True)
P = -0.5 * _fill_diagonal(U, U.diagonal() - 1.) P = -0.5 * _fill_diagonal(U, U.diagonal() - 1.)
rank = jnp.round(jnp.trace(P)).astype(jnp.int32) rank = jnp.round(jnp.trace(P)).astype(jnp.int32)
rank = int(rank) rank = int(rank)

View File

@ -18,11 +18,13 @@ from functools import partial
import numpy as np import numpy as np
import scipy.linalg import scipy.linalg
import textwrap import textwrap
import warnings
import jax
from jax import jit, vmap, jvp from jax import jit, vmap, jvp
from jax import lax from jax import lax
from jax._src.lax import linalg as lax_linalg from jax._src.lax import linalg as lax_linalg
from jax._src.lax import polar as lax_polar from jax._src.lax import qdwh
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact from jax._src.numpy.util import _wraps, _promote_dtypes_inexact
from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg as np_linalg from jax._src.numpy import linalg as np_linalg
@ -619,11 +621,111 @@ def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a',
_, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper)) _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
return mid return mid
@_wraps(scipy.linalg.polar) @partial(jit, static_argnames=('side', 'method'))
def polar(a, side='right', method='qdwh', eps=None, maxiter=50): @jax.default_matmul_precision("float32")
unitary, posdef, _ = lax_polar.polar(a, side=side, method=method, eps=eps, def polar(a, side='right', *, method='qdwh', eps=None, max_iterations=None):
maxiter=maxiter) r"""Computes the polar decomposition.
return unitary, posdef
Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar
decomposition :math:`u` (also :math:`m \times n`) and :math:`p` such that
:math:`a = up` (if side is ``"right"``; :math:`p` is :math:`n \times n`) or
:math:`a = pu` (if side is ``"left"``; :math:`p` is :math:`m \times m`),
where :math:`p` is positive semidefinite. If :math:`a` is nonsingular,
:math:`p` is positive definite and the
decomposition is unique. :math:`u` has orthonormal columns unless
:math:`n > m`, in which case it has orthonormal rows.
Writing the SVD of :math:`a` as
:math:`a = u_\mathit{svd} \cdot s_\mathit{svd} \cdot v^h_\mathit{svd}`, we
have :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. Thus the unitary
factor :math:`u` can be constructed as the application of the sign function to
the singular values of :math:`a`; or, if :math:`a` is Hermitian, the
eigenvalues.
Several methods exist to compute the polar decomposition. Currently two
are supported:
* ``method="svd"``:
Computes the SVD of :math:`a` and then forms
:math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`.
* ``method="qdwh"``:
Applies the `QDWH`_ (QR-based Dynamically Weighted Halley) algorithm.
Args:
a: The :math:`m \times n` input matrix.
side: Determines whether a right or left polar decomposition is computed.
If ``side`` is ``"right"`` then :math:`a = up`. If ``side`` is ``"left"``
then :math:`a = pu`. The default is ``"right"``.
method: Determines the algorithm used, as described above.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
eps: The final result will satisfy
:math:`\left|x_k - x_{k-1}\right| < \left|x_k\right| (4\epsilon)^{\frac{1}{3}}`,
where :math:`x_k` are the QDWH iterates. Ignored if ``method`` is not
``"qdwh"``.
max_iterations: Iterations will terminate after this many steps even if the
above is unsatisfied. Ignored if ``method`` is not ``"qdwh"``.
Returns:
A ``(unitary, posdef)`` tuple, where ``unitary`` is the unitary factor
(:math:`m \times n`), and ``posdef`` is the positive-semidefinite factor.
``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on
whether ``side`` is ``"right"`` or ``"left"``, respectively.
.. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999
"""
a = jnp.asarray(a)
if a.ndim != 2:
raise ValueError("The input `a` must be a 2-D array.")
if side not in ["right", "left"]:
raise ValueError("The argument `side` must be either 'right' or 'left'.")
m, n = a.shape
if method == "qdwh":
# TODO(phawkins): return info also if the user opts in?
if m >= n and side == "right":
unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps)
elif m < n and side == "left":
a = a.T.conj()
unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps)
posdef = posdef.T.conj()
unitary = unitary.T.conj()
else:
raise NotImplementedError("method='qdwh' only supports mxn matrices "
"where m < n where side='right' and m >= n "
f"side='left', got {a.shape} with side={side}")
elif method == "svd":
u_svd, s_svd, vh_svd = lax_linalg.svd(a, full_matrices=False)
unitary = u_svd @ vh_svd
if side == "right":
# a = u * p
posdef = (vh_svd.T.conj() * s_svd[None, :]) @ vh_svd
else:
# a = p * u
posdef = (u_svd * s_svd[None, :]) @ (u_svd.T.conj())
else:
raise ValueError(f"Unknown polar decomposition method {method}.")
return unitary, posdef
def polar_unitary(a, *, method="qdwh", eps=None, max_iterations=None):
""" Computes the unitary factor u in the polar decomposition ``a = u p``
(or ``a = p u``).
.. warning::
This function is deprecated. Use :func:`jax.scipy.linalg.polar` instead.
"""
# TODO(phawkins): delete this function after 2022/8/11.
warnings.warn("jax.scipy.linalg.polar_unitary is deprecated. Call "
"jax.scipy.linalg.polar instead.",
DeprecationWarning)
unitary, _ = polar(a, method, eps, max_iterations)
return unitary
@jit @jit
def _sqrtm_triu(T): def _sqrtm_triu(T):

View File

@ -27,6 +27,7 @@ from jax._src.scipy.linalg import (
lu_factor as lu_factor, lu_factor as lu_factor,
lu_solve as lu_solve, lu_solve as lu_solve,
polar as polar, polar as polar,
polar_unitary as polar_unitary,
qr as qr, qr as qr,
rsf2csf as rsf2csf, rsf2csf as rsf2csf,
schur as schur, schur as schur,
@ -38,10 +39,6 @@ from jax._src.scipy.linalg import (
triu as triu, triu as triu,
) )
from jax._src.lax.polar import (
polar_unitary as polar_unitary,
)
from jax._src.third_party.scipy.linalg import ( from jax._src.third_party.scipy.linalg import (
funm as funm, funm as funm,
) )

View File

@ -490,7 +490,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
for method in methods for method in methods
for side in sides for side in sides
for nonzero_condition_number in nonzero_condition_numbers for nonzero_condition_number in nonzero_condition_numbers
for dtype in jtu.dtypes.floating for dtype in jtu.dtypes.inexact
for seed in seeds)) for seed in seeds))
@jtu.skip_on_devices("gpu") # Fails on A100. @jtu.skip_on_devices("gpu") # Fails on A100.
def testPolar( def testPolar(
@ -500,8 +500,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
if jtu.device_under_test() != "cpu": if jtu.device_under_test() != "cpu":
if jnp.dtype(dtype).name in ("bfloat16", "float16"): if jnp.dtype(dtype).name in ("bfloat16", "float16"):
raise unittest.SkipTest("Skip half precision off CPU.") raise unittest.SkipTest("Skip half precision off CPU.")
if method == "svd":
raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.") m, n = shape
if (method == "qdwh" and ((side == "left" and m >= n) or
(side == "right" and m < n))):
raise unittest.SkipTest("method=qdwh does not support these sizes")
matrix, _ = _initialize_polar_test(self.rng(), matrix, _ = _initialize_polar_test(self.rng(),
shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
@ -517,7 +520,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
should_be_eye = np.matmul(unitary.conj().T, unitary) should_be_eye = np.matmul(unitary.conj().T, unitary)
else: else:
should_be_eye = np.matmul(unitary, unitary.conj().T) should_be_eye = np.matmul(unitary, unitary.conj().T)
tol = 10 * jnp.finfo(matrix.dtype).eps tol = 500 * jnp.finfo(matrix.dtype).eps
eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype) eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
with self.subTest('Test unitarity.'): with self.subTest('Test unitarity.'):
self.assertAllClose( self.assertAllClose(

View File

@ -45,12 +45,12 @@ def _check_symmetry(x: jnp.ndarray) -> bool:
m, n = x.shape m, n = x.shape
eps = jnp.finfo(x.dtype).eps eps = jnp.finfo(x.dtype).eps
tol = 50.0 * eps tol = 50.0 * eps
is_symmetric = False is_hermitian = False
if m == n: if m == n:
if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol: if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol:
is_symmetric = True is_hermitian = True
return is_symmetric return is_hermitian
def _compute_relative_diff(actual, expected): def _compute_relative_diff(actual, expected):
"""Computes relative difference between two matrices.""" """Computes relative difference between two matrices."""
@ -75,11 +75,11 @@ class QdwhTest(jtu.JaxTestCase):
cond = 10**log_cond cond = 10**log_cond
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
a = (u * s) @ v a = (u * s) @ v
is_symmetric = _check_symmetry(a) is_hermitian = _check_symmetry(a)
max_iterations = 2 max_iterations = 2
_, _, actual_num_iterations, is_converged = qdwh.qdwh( _, _, actual_num_iterations, is_converged = qdwh.qdwh(
a, is_symmetric, max_iterations) a, is_hermitian, max_iterations)
with self.subTest('Number of iterations.'): with self.subTest('Number of iterations.'):
self.assertEqual(max_iterations, actual_num_iterations) self.assertEqual(max_iterations, actual_num_iterations)
@ -102,10 +102,10 @@ class QdwhTest(jtu.JaxTestCase):
cond = 10**log_cond cond = 10**log_cond
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
a = (u * s) @ v a = (u * s) @ v
is_symmetric = _check_symmetry(a) is_hermitian = _check_symmetry(a)
max_iterations = 10 max_iterations = 10
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations) actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian, max_iterations)
expected_u, expected_h = osp_linalg.polar(a) expected_u, expected_h = osp_linalg.polar(a)
# Sets the test tolerance. # Sets the test tolerance.
@ -145,12 +145,12 @@ class QdwhTest(jtu.JaxTestCase):
cond = 10**log_cond cond = 10**log_cond
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
a = (u * s) @ v a = (u * s) @ v
is_symmetric = _check_symmetry(a) is_hermitian = _check_symmetry(a)
max_iterations = 10 max_iterations = 10
def lsp_linalg_fn(a): def lsp_linalg_fn(a):
u, h, _, _ = qdwh.qdwh( u, h, _, _ = qdwh.qdwh(
a, is_symmetric=is_symmetric, max_iterations=max_iterations) a, is_hermitian=is_hermitian, max_iterations=max_iterations)
return u, h return u, h
args_maker = lambda: [a] args_maker = lambda: [a]
@ -185,9 +185,9 @@ class QdwhTest(jtu.JaxTestCase):
s = jnp.expand_dims(s.at[-1].set(0), range(u.ndim - 1)) s = jnp.expand_dims(s.at[-1].set(0), range(u.ndim - 1))
a = (u * s) @ v a = (u * s) @ v
is_symmetric = _check_symmetry(a) is_hermitian = _check_symmetry(a)
max_iterations = 15 max_iterations = 15
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations) actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian, max_iterations)
_, expected_h = osp_linalg.polar(a) _, expected_h = osp_linalg.polar(a)
# Sets the test tolerance. # Sets the test tolerance.
@ -221,13 +221,13 @@ class QdwhTest(jtu.JaxTestCase):
tiny_elem = jnp.finfo(a).tiny tiny_elem = jnp.finfo(a).tiny
a = a.at[r, c].set(tiny_elem) a = a.at[r, c].set(tiny_elem)
is_symmetric = _check_symmetry(a) is_hermitian = _check_symmetry(a)
max_iterations = 10 max_iterations = 10
@jax.jit @jax.jit
def lsp_linalg_fn(a): def lsp_linalg_fn(a):
u, h, _, _ = qdwh.qdwh( u, h, _, _ = qdwh.qdwh(
a, is_symmetric=is_symmetric, max_iterations=max_iterations) a, is_hermitian=is_hermitian, max_iterations=max_iterations)
return u, h return u, h
actual_u, actual_h = lsp_linalg_fn(a) actual_u, actual_h = lsp_linalg_fn(a)