Change implementation of jax.scipy.linalg.polar() and jax._src.scipy.eigh to use the QDWH decomposition from jax._src.lax.qdwh.

Remove jax._src.lax.polar.

PiperOrigin-RevId: 448241206
This commit is contained in:
Peter Hawkins 2022-05-12 07:15:55 -07:00 committed by jax authors
parent 880cfc9c79
commit 7ba36fc178
8 changed files with 155 additions and 398 deletions

View File

@ -13,11 +13,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Changes
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
that allows users to opt out of eigenvalue sorting on TPU.
* Deprecations
* Non-array arguments to functions in {mod}`jax.lax.linalg` are now marked
keyword-only. As a backward-compatibility step passing keyword-only
arguments positionally yields a warning, but in a future JAX release passing
keyword-only arguments positionally will fail.
However, most users should prefer to use {mod}`jax.numpy.linalg` instead.
* {func}`jax.scipy.linalg.polar_unitary`, which was a JAX extension to the
scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead.
## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -1,343 +0,0 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""
Functions to compute the polar decomposition of the m x n matrix A, A = U @ H
where U is unitary (an m x n isometry in the m > n case) and H is n x n and
positive semidefinite (or positive definite if A is nonsingular). The method
is described in the docstring to `polarU`. This file covers the serial
case.
"""
import functools
import jax
from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
# TODO: Allow singular value estimates to be manually specified
@jax.jit
def _add_to_diagonal(X, val):
new_diagonal = X.diagonal() + val
diag_indices = jnp.diag_indices(X.shape[0])
return X.at[diag_indices].set(new_diagonal)
@jax.jit
def _dot(a, b):
return jnp.dot(a, b, precision=lax.Precision.HIGHEST)
def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
""" Computes the polar decomposition.
Given the (m x n) matrix `a`, returns the factors of the polar decomposition
`u` (m x n) and `p` such that `a = up` (if side is "right"; p is (n x n)) or
`a = pu` (if side is "left"; p is (m x m)), where `p` is positive
semidefinite. If `a` is nonsingular, `p` is positive definite and the
decomposition is unique. `u` has orthonormal columns unless n > m, in which
case it has orthonormal rows.
Writing an SVD of `a` as `a = u_svd @ s_svd @ v^h_svd`, we have
`u = u_svd @ v^h_svd`. Thus the unitary factor `u` can be construed as
the application of the signum function to the singular values of `a`;
or, if `a` is Hermitian, the eigenvalues.
Several methods exist to compute the polar decomposition. Currently two
are supported:
`method`="svd": Computes the SVD of `a` and then forms
`u = u_svd @ v^h_svd`. This fails on the TPU, since
no SVD algorithm independent of the polar decomposition
exists there.
`method`="qdwh": Applies a certain iterative expansion of the matrix
signum function to `a` based on QR and Cholesky
decompositions.
Args:
a: The m x n input matrix.
side: Determines whether a right or left polar decomposition is computed.
If side is "right" then `a = up`. If side is "left" then `a = pu`. The
default is "right".
method: Determines the algorithm used, as described above.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
The remaining arguments are only meaningful if method is "qdwh".
eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
maxiter: Iterations will terminate after this many steps even if the
above is unsatisfied.
Returns:
unitary: The unitary factor (m x n).
posdef: The positive-semidefinite factor. Either (n, n) or (m, m)
depending on whether side is "right" or "left", respectively.
info: Stores convergence information.
if method is "svd": None
if method is "qdwh": j_qr: Number of QR iterations.
j_chol: Number of Cholesky iterations.
errs: Convergence history.
"""
return _polar(a, side, method, eps, maxiter)
@functools.partial(jax.jit, static_argnums=(1, 2, 4))
def _polar(a, side, method, eps, maxiter):
if side not in ("left", "right"):
raise ValueError(f"side={side} was invalid.")
unitary, info = _polar_unitary(a, method, eps, maxiter)
if side == "right":
posdef = _dot(unitary.conj().T, a)
else:
posdef = _dot(a, unitary.conj().T)
posdef = 0.5 * (posdef + posdef.conj().T)
return unitary, posdef, info
def polar_unitary(a, method="qdwh", eps=None, maxiter=50):
""" Computes the unitary factor u in the polar decomposition `a = u p`
(or `a = p u`).
"""
return _polar_unitary(a, method, eps, maxiter)
@functools.partial(jax.jit, static_argnums=(1, 3))
def _polar_unitary(a, method, eps, maxiter):
if method not in ("svd", "qdwh"):
raise ValueError(f"method={method} is unsupported.")
if method == "svd":
u_svd, _, vh_svd = jnp.linalg.svd(a, full_matrices=False)
unitary = _dot(u_svd, vh_svd)
info = None
elif method == "qdwh":
unitary, j_qr, j_chol, errs = _qdwh(a, eps, maxiter)
info = (j_qr, j_chol, errs)
else:
raise ValueError("How did we get here?")
return unitary, info
@functools.partial(jax.jit, static_argnums=(2,))
def _qdwh(matrix, eps, maxiter):
""" Computes the unitary factor in the polar decomposition of A using
the QDWH method. QDWH implements a 3rd order Pade approximation to the
matrix sign function,
X' = X * (aI + b X^H X)(I + c X^H X)^-1, X0 = A / ||A||_2. (1)
The coefficients a, b, and c are chosen dynamically based on an evolving
estimate of the matrix condition number. Specifically,
a = h(l), b = g(a), c = a + b - 1, h(x) = x g(x^2), g(x) = a + bx / (1 + cx)
where l is initially a lower bound on the smallest singular value of X0,
and subsequently evolves according to l' = l (a + bl^2) / (1 + c l^2).
For poorly conditioned matrices
(c > 100) the iteration (1) is rewritten in QR form,
X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X] (2)
[Q2] [I ].
For well-conditioned matrices it is instead formulated using cheaper
Cholesky iterations,
X' = (b / c) X + (a - b/c) (X W^-1) W^-H, W = chol(I + c X^H X). (3)
The QR iterations rapidly improve the condition number, and typically
only 1 or 2 are required. A maximum of 6 iterations total are required
for backwards stability to double precision.
Args:
matrix: The m x n input matrix.
eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
maxiter: Iterations will terminate after this many steps even if the
above is unsatisfied.
Returns:
matrix: The unitary factor (m x n).
jq: The number of QR iterations (1).
jc: The number of Cholesky iterations (2).
errs: Convergence history.
"""
n_rows, n_cols = matrix.shape
fat = n_cols > n_rows
if fat:
matrix = matrix.T
matrix, q_factor, l0 = _initialize_qdwh(matrix)
if eps is None:
eps = jnp.finfo(matrix.dtype).eps
tol_lk = 5 * eps # stop when lk differs from 1 by less
tol_delta = jnp.cbrt(tol_lk) # stop when the iterates change by less
coefs = _qdwh_coefs(l0)
errs = jnp.zeros(maxiter, dtype=matrix.real.dtype)
matrix, j_qr, coefs, errs = _qdwh_qr(
matrix, coefs, errs, tol_lk, tol_delta, maxiter)
matrix, j_chol, errs = _qdwh_cholesky(
matrix, coefs, errs, tol_lk, tol_delta, j_qr, maxiter)
matrix = _dot(q_factor, matrix)
if fat:
matrix = matrix.T
return matrix, j_qr, j_chol, errs
@jax.jit
def _initialize_qdwh(matrix):
""" Does preparatory computations for QDWH:
1. Computes an initial QR factorization of the input A. The iterations
will be on the triangular factor R, whose condition is more easily
estimated, and which is square even when A is rectangular.
2. Computes R -> R / ||R||_F. Now 1 is used to upper-bound ||R||_2.
3. Computes R^-1 by solving R R^-1 = I.
4. Uses sqrt(N) * ||R^-1||_1 as a lower bound for ||R^-2||.
1 / sqrt(N) * ||R^-1||_1 is then used as the initial l_0. It should be clear
there is room for improvement here.
Returns:
X = R / ||R||_F;
Q from A -> Q @ R;
l0, the initial estimate for the QDWH coefficients.
"""
q_factor, r_factor = jnp.linalg.qr(matrix, mode="reduced")
alpha = jnp.linalg.norm(r_factor)
r_factor /= alpha
eye = jnp.eye(*r_factor.shape, dtype=r_factor.dtype)
r_inv = jsp.linalg.solve_triangular(r_factor, eye, overwrite_b=True)
one_norm_inv = jnp.linalg.norm(r_inv, ord=1)
l0 = 1 / (jnp.sqrt(matrix.shape[1]) * one_norm_inv)
eps = jnp.finfo(r_factor.dtype).eps
l0 = jnp.array(l0, dtype=r_factor.real.dtype)
l0 = jnp.where(l0 < eps, x=eps, y=l0)
l0 = jnp.where(l0 > 1.0, x=1.0, y=l0)
return r_factor, q_factor, l0
@jax.jit
def _qdwh_coefs(lk):
""" Computes a, b, c, l for the QDWH iterations.
The input lk must be in (0, 1]; lk=1 is a fixed point.
Some facts about the coefficients:
-for lk = 1 we have a=3, b=1, c=3, lk_new = 1.
-The float64 vs float32 computation of each coef appears to differ
only by noise on the order of 1E-9 to 1E-7 for all values of lk.
There is no apparent secular change in the (relative) error.
-All coefs change roughly as power laws; over e.g. [1E-14, 1]:
- a decreases from 5.43E9 to 3.
- b decreases from 7.37E18 to 1.
- c decreases from 7.37E18 to 3, only diverging from b near lk=1.
- lk increases from 5.45E-5 to 1.
lk is an estimate of the scaled matrix's smallest singular value
"""
lk = jnp.where(lk > 1.0, x=1.0, y=lk)
d = (4. * (1. - lk**2) / (lk**4))**(1 / 3)
f = 8. * (2. - lk**2) / (lk**2 * (1. + d)**(1 / 2))
a = (1. + d)**(1 / 2) + 0.5 * (8. - 4. * d + f)**0.5
b = (a - 1.)**2 / 4
c = a + b - 1.
lk = lk * (a + b * lk**2) / (1 + c * lk**2)
return a, b, c, lk
@jax.jit
def _unconverged(lk, j, maxiter, err, tol_delta, tol_lk):
changing = err > tol_delta
far_from_end = jnp.abs(1 - lk) > tol_lk
unconverged = jnp.logical_or(changing, far_from_end)
iterating = j < maxiter
return jnp.logical_and(iterating, unconverged)[0]
@jax.jit
def _qdwh_qr(matrix, coefs, errs, tol_lk, tol_delta, maxiter):
""" Applies the QDWH iteration formulated as
X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X]
[Q2] [I ]
to X until either c < 100, ||X' - X|| < eps||X'||,
or the iteration count exceeds maxiter.
"""
n_rows, n_cols = matrix.shape
eye = jnp.eye(n_cols, dtype=matrix.dtype)
def _do_qr(args):
_, j, coefs, _, err = args
c = coefs[2]
lk = coefs[-1]
unconverged = _unconverged(lk, j, maxiter, err, tol_delta, tol_lk)
ill_conditioned = c > 100.
return jnp.logical_and(ill_conditioned, unconverged)
def _qr_work(args):
matrix, j, coefs, errs, _ = args
a, b, c, lk = coefs
csqrt = jnp.sqrt(c)
matrixI = jnp.vstack((csqrt * matrix, eye))
# Note: it should be possible to compute the QR of csqrt * matrix
# and build the concatenation with I at O(N).
Q, _ = jnp.linalg.qr(matrixI, mode="reduced")
Q1 = Q[:n_rows, :]
Q2 = Q[n_rows:, :]
coef = (1 / csqrt) * (a - (b / c))
new_matrix = (b / c) * matrix + coef * _dot(Q1, Q2.T.conj())
err = jnp.linalg.norm(matrix - new_matrix)
err = jnp.full(1, err).astype(errs[0].dtype)
errs = errs.at[j].set(err)
coefs = _qdwh_coefs(lk)
return new_matrix, j + 1, coefs, errs, err
j = jnp.zeros(1, dtype=jnp.int32)
err = jnp.full(1, 2 * tol_delta).astype(matrix.real.dtype)
matrix, j, coefs, errs, _ = jax.lax.while_loop(
_do_qr, _qr_work, (matrix, j, coefs, errs, err))
return matrix, j, coefs, errs
@jax.jit
def _qdwh_cholesky(matrix, coefs, errs, tol_delta, tol_lk, j0, maxiter):
""" Applies the QDWH iteration formulated as
matrix' = (b / c) matrix + (a - b/c) B,
B = (matrix W^-1) W^-H, W = chol(I + c matrix^H matrix).
to matrix until either ||matrix' - matrix|| < eps * ||matrix'||,
or the iteration count exceeds maxiter.
"""
def _do_cholesky(args):
_, j, coefs, errs = args
lk = coefs[-1]
return _unconverged(lk, j, maxiter, errs[j - 1], tol_delta, tol_lk)
def _cholesky_work(args):
matrix, j, coefs, errs = args
a, b, c, lk = coefs
Z = c * _dot(matrix.T.conj(), matrix)
Z = _add_to_diagonal(Z, 1.)
W = jsp.linalg.cholesky(Z)
B = jsp.linalg.solve_triangular(W.T, matrix.T, lower=True).conj()
B = jsp.linalg.solve_triangular(W, B).conj().T
new_matrix = (b / c) * matrix + (a - b / c) * B
# possible instability if a ~ b / c
err = jnp.linalg.norm(new_matrix - matrix).astype(errs[0].dtype)
errs = errs.at[j].set(err)
coefs = _qdwh_coefs(lk)
return new_matrix, j + 1, coefs, errs
carry = (matrix, j0, coefs, errs)
matrix, j_total, coefs, errs = jax.lax.while_loop(
_do_cholesky, _cholesky_work, carry)
return matrix, j_total - j0, errs

View File

@ -65,14 +65,14 @@ def _use_cholesky(u, params):
return u
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
def _qdwh(x, is_symmetric, max_iterations):
def _qdwh(x, is_hermitian, max_iterations, eps):
"""QR-based dynamically weighted Halley iteration for polar decomposition."""
# Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of
# norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for
# the smallest singular value of x.
eps = jnp.finfo(x.dtype).eps
if eps is None:
eps = jnp.finfo(x.dtype).eps
alpha = (jnp.sqrt(jnp.linalg.norm(x, ord=1)) *
jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf)))
l = eps
@ -113,7 +113,7 @@ def _qdwh(x, is_symmetric, max_iterations):
u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u))
if is_symmetric:
if is_hermitian:
u = (u + u.T.conj()) / 2.0
# Checks convergence.
@ -145,13 +145,17 @@ def _qdwh(x, is_symmetric, max_iterations):
# TODO: Add pivoting.
def qdwh(x, is_symmetric, max_iterations=10):
@functools.partial(jax.jit, static_argnames=('is_hermitian',))
def qdwh(x, is_hermitian=False, max_iterations=None, eps=None):
"""QR-based dynamically weighted Halley iteration for polar decomposition.
Args:
x: A full-rank matrix of shape `m x n` with `m >= n`.
is_symmetric: True if `x` is symmetric.
max_iterations: The predefined maximum number of iterations.
x: A full-rank matrix of shape `m x n`.
is_hermitian: True if `x` is Hermitian. Default to `False`.
eps: The final result will satisfy
``|x_k - x_k-1| < |x_k| * (4*eps)**(1/3)`` where `x_k` is the iterate.
max_iterations: Iterations will terminate after this many steps even if the
above is unsatisfied.
Returns:
A four-tuple of (u, h, num_iters, is_converged) containing the
@ -159,29 +163,19 @@ def qdwh(x, is_symmetric, max_iterations=10):
and `is_converged`, whose value is `True` when the convergence is achieved
within the maximum number of iterations.
"""
m, n = x.shape
is_hermitian = core.concrete_or_error(
bool, is_hermitian, 'The `is_hermitian` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
if max_iterations is None:
max_iterations = 10
m, n = x.shape
if m < n:
raise ValueError('The input matrix of shape m x n must have m >= n.')
max_iterations = core.concrete_or_error(
int, max_iterations, 'The `max_iterations` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
is_symmetric = core.concrete_or_error(
bool, is_symmetric, 'The `is_symmetric` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
if is_symmetric:
eps = jnp.finfo(x.dtype).eps
tol = 50.0 * eps
relative_diff = jnp.linalg.norm(x - x.T.conj()) / jnp.linalg.norm(x)
if relative_diff > tol:
raise ValueError('The input `x` is NOT symmetric because '
'`norm(x-x.H) / norm(x)` is {}, which is greater than '
'the tolerance {}.'.format(relative_diff, tol))
with jax.default_matmul_precision('float32'):
u, h, num_iters, is_converged = _qdwh(x, is_symmetric, max_iterations)
u, h, num_iters, is_converged = _qdwh(x, is_hermitian, max_iterations, eps)
return u, h, num_iters, is_converged

View File

@ -20,6 +20,7 @@ import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import lax
from jax._src.lax import qdwh
def _similarity_transform(
@ -141,7 +142,7 @@ def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST):
return X.at[jnp.diag_indices(X.shape[0])].set(vals)
H_shift = _fill_diagonal(H, H.diagonal() - split_point)
U, _ = jsp.linalg.polar_unitary(H_shift)
U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True)
P = -0.5 * _fill_diagonal(U, U.diagonal() - 1.)
rank = jnp.round(jnp.trace(P)).astype(jnp.int32)
rank = int(rank)

View File

@ -18,11 +18,13 @@ from functools import partial
import numpy as np
import scipy.linalg
import textwrap
import warnings
import jax
from jax import jit, vmap, jvp
from jax import lax
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import polar as lax_polar
from jax._src.lax import qdwh
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg as np_linalg
@ -619,11 +621,111 @@ def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a',
_, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
return mid
@_wraps(scipy.linalg.polar)
def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
unitary, posdef, _ = lax_polar.polar(a, side=side, method=method, eps=eps,
maxiter=maxiter)
return unitary, posdef
@partial(jit, static_argnames=('side', 'method'))
@jax.default_matmul_precision("float32")
def polar(a, side='right', *, method='qdwh', eps=None, max_iterations=None):
r"""Computes the polar decomposition.
Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar
decomposition :math:`u` (also :math:`m \times n`) and :math:`p` such that
:math:`a = up` (if side is ``"right"``; :math:`p` is :math:`n \times n`) or
:math:`a = pu` (if side is ``"left"``; :math:`p` is :math:`m \times m`),
where :math:`p` is positive semidefinite. If :math:`a` is nonsingular,
:math:`p` is positive definite and the
decomposition is unique. :math:`u` has orthonormal columns unless
:math:`n > m`, in which case it has orthonormal rows.
Writing the SVD of :math:`a` as
:math:`a = u_\mathit{svd} \cdot s_\mathit{svd} \cdot v^h_\mathit{svd}`, we
have :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. Thus the unitary
factor :math:`u` can be constructed as the application of the sign function to
the singular values of :math:`a`; or, if :math:`a` is Hermitian, the
eigenvalues.
Several methods exist to compute the polar decomposition. Currently two
are supported:
* ``method="svd"``:
Computes the SVD of :math:`a` and then forms
:math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`.
* ``method="qdwh"``:
Applies the `QDWH`_ (QR-based Dynamically Weighted Halley) algorithm.
Args:
a: The :math:`m \times n` input matrix.
side: Determines whether a right or left polar decomposition is computed.
If ``side`` is ``"right"`` then :math:`a = up`. If ``side`` is ``"left"``
then :math:`a = pu`. The default is ``"right"``.
method: Determines the algorithm used, as described above.
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
eps: The final result will satisfy
:math:`\left|x_k - x_{k-1}\right| < \left|x_k\right| (4\epsilon)^{\frac{1}{3}}`,
where :math:`x_k` are the QDWH iterates. Ignored if ``method`` is not
``"qdwh"``.
max_iterations: Iterations will terminate after this many steps even if the
above is unsatisfied. Ignored if ``method`` is not ``"qdwh"``.
Returns:
A ``(unitary, posdef)`` tuple, where ``unitary`` is the unitary factor
(:math:`m \times n`), and ``posdef`` is the positive-semidefinite factor.
``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on
whether ``side`` is ``"right"`` or ``"left"``, respectively.
.. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999
"""
a = jnp.asarray(a)
if a.ndim != 2:
raise ValueError("The input `a` must be a 2-D array.")
if side not in ["right", "left"]:
raise ValueError("The argument `side` must be either 'right' or 'left'.")
m, n = a.shape
if method == "qdwh":
# TODO(phawkins): return info also if the user opts in?
if m >= n and side == "right":
unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps)
elif m < n and side == "left":
a = a.T.conj()
unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps)
posdef = posdef.T.conj()
unitary = unitary.T.conj()
else:
raise NotImplementedError("method='qdwh' only supports mxn matrices "
"where m < n where side='right' and m >= n "
f"side='left', got {a.shape} with side={side}")
elif method == "svd":
u_svd, s_svd, vh_svd = lax_linalg.svd(a, full_matrices=False)
unitary = u_svd @ vh_svd
if side == "right":
# a = u * p
posdef = (vh_svd.T.conj() * s_svd[None, :]) @ vh_svd
else:
# a = p * u
posdef = (u_svd * s_svd[None, :]) @ (u_svd.T.conj())
else:
raise ValueError(f"Unknown polar decomposition method {method}.")
return unitary, posdef
def polar_unitary(a, *, method="qdwh", eps=None, max_iterations=None):
""" Computes the unitary factor u in the polar decomposition ``a = u p``
(or ``a = p u``).
.. warning::
This function is deprecated. Use :func:`jax.scipy.linalg.polar` instead.
"""
# TODO(phawkins): delete this function after 2022/8/11.
warnings.warn("jax.scipy.linalg.polar_unitary is deprecated. Call "
"jax.scipy.linalg.polar instead.",
DeprecationWarning)
unitary, _ = polar(a, method, eps, max_iterations)
return unitary
@jit
def _sqrtm_triu(T):

View File

@ -27,6 +27,7 @@ from jax._src.scipy.linalg import (
lu_factor as lu_factor,
lu_solve as lu_solve,
polar as polar,
polar_unitary as polar_unitary,
qr as qr,
rsf2csf as rsf2csf,
schur as schur,
@ -38,10 +39,6 @@ from jax._src.scipy.linalg import (
triu as triu,
)
from jax._src.lax.polar import (
polar_unitary as polar_unitary,
)
from jax._src.third_party.scipy.linalg import (
funm as funm,
)

View File

@ -490,7 +490,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
for method in methods
for side in sides
for nonzero_condition_number in nonzero_condition_numbers
for dtype in jtu.dtypes.floating
for dtype in jtu.dtypes.inexact
for seed in seeds))
@jtu.skip_on_devices("gpu") # Fails on A100.
def testPolar(
@ -500,8 +500,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
if jtu.device_under_test() != "cpu":
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
raise unittest.SkipTest("Skip half precision off CPU.")
if method == "svd":
raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.")
m, n = shape
if (method == "qdwh" and ((side == "left" and m >= n) or
(side == "right" and m < n))):
raise unittest.SkipTest("method=qdwh does not support these sizes")
matrix, _ = _initialize_polar_test(self.rng(),
shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
@ -517,7 +520,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
should_be_eye = np.matmul(unitary.conj().T, unitary)
else:
should_be_eye = np.matmul(unitary, unitary.conj().T)
tol = 10 * jnp.finfo(matrix.dtype).eps
tol = 500 * jnp.finfo(matrix.dtype).eps
eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
with self.subTest('Test unitarity.'):
self.assertAllClose(

View File

@ -45,12 +45,12 @@ def _check_symmetry(x: jnp.ndarray) -> bool:
m, n = x.shape
eps = jnp.finfo(x.dtype).eps
tol = 50.0 * eps
is_symmetric = False
is_hermitian = False
if m == n:
if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol:
is_symmetric = True
is_hermitian = True
return is_symmetric
return is_hermitian
def _compute_relative_diff(actual, expected):
"""Computes relative difference between two matrices."""
@ -75,11 +75,11 @@ class QdwhTest(jtu.JaxTestCase):
cond = 10**log_cond
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
a = (u * s) @ v
is_symmetric = _check_symmetry(a)
is_hermitian = _check_symmetry(a)
max_iterations = 2
_, _, actual_num_iterations, is_converged = qdwh.qdwh(
a, is_symmetric, max_iterations)
a, is_hermitian, max_iterations)
with self.subTest('Number of iterations.'):
self.assertEqual(max_iterations, actual_num_iterations)
@ -102,10 +102,10 @@ class QdwhTest(jtu.JaxTestCase):
cond = 10**log_cond
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
a = (u * s) @ v
is_symmetric = _check_symmetry(a)
is_hermitian = _check_symmetry(a)
max_iterations = 10
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations)
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian, max_iterations)
expected_u, expected_h = osp_linalg.polar(a)
# Sets the test tolerance.
@ -145,12 +145,12 @@ class QdwhTest(jtu.JaxTestCase):
cond = 10**log_cond
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
a = (u * s) @ v
is_symmetric = _check_symmetry(a)
is_hermitian = _check_symmetry(a)
max_iterations = 10
def lsp_linalg_fn(a):
u, h, _, _ = qdwh.qdwh(
a, is_symmetric=is_symmetric, max_iterations=max_iterations)
a, is_hermitian=is_hermitian, max_iterations=max_iterations)
return u, h
args_maker = lambda: [a]
@ -185,9 +185,9 @@ class QdwhTest(jtu.JaxTestCase):
s = jnp.expand_dims(s.at[-1].set(0), range(u.ndim - 1))
a = (u * s) @ v
is_symmetric = _check_symmetry(a)
is_hermitian = _check_symmetry(a)
max_iterations = 15
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations)
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian, max_iterations)
_, expected_h = osp_linalg.polar(a)
# Sets the test tolerance.
@ -221,13 +221,13 @@ class QdwhTest(jtu.JaxTestCase):
tiny_elem = jnp.finfo(a).tiny
a = a.at[r, c].set(tiny_elem)
is_symmetric = _check_symmetry(a)
is_hermitian = _check_symmetry(a)
max_iterations = 10
@jax.jit
def lsp_linalg_fn(a):
u, h, _, _ = qdwh.qdwh(
a, is_symmetric=is_symmetric, max_iterations=max_iterations)
a, is_hermitian=is_hermitian, max_iterations=max_iterations)
return u, h
actual_u, actual_h = lsp_linalg_fn(a)