mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Change implementation of jax.scipy.linalg.polar() and jax._src.scipy.eigh to use the QDWH decomposition from jax._src.lax.qdwh.
Remove jax._src.lax.polar. PiperOrigin-RevId: 448241206
This commit is contained in:
parent
880cfc9c79
commit
7ba36fc178
@ -13,11 +13,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Changes
|
||||
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
|
||||
that allows users to opt out of eigenvalue sorting on TPU.
|
||||
* Deprecations
|
||||
* Non-array arguments to functions in {mod}`jax.lax.linalg` are now marked
|
||||
keyword-only. As a backward-compatibility step passing keyword-only
|
||||
arguments positionally yields a warning, but in a future JAX release passing
|
||||
keyword-only arguments positionally will fail.
|
||||
However, most users should prefer to use {mod}`jax.numpy.linalg` instead.
|
||||
* {func}`jax.scipy.linalg.polar_unitary`, which was a JAX extension to the
|
||||
scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead.
|
||||
|
||||
## jaxlib 0.3.11 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||
|
@ -1,343 +0,0 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
|
||||
"""
|
||||
Functions to compute the polar decomposition of the m x n matrix A, A = U @ H
|
||||
where U is unitary (an m x n isometry in the m > n case) and H is n x n and
|
||||
positive semidefinite (or positive definite if A is nonsingular). The method
|
||||
is described in the docstring to `polarU`. This file covers the serial
|
||||
case.
|
||||
"""
|
||||
import functools
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy as jsp
|
||||
|
||||
|
||||
# TODO: Allow singular value estimates to be manually specified
|
||||
@jax.jit
|
||||
def _add_to_diagonal(X, val):
|
||||
new_diagonal = X.diagonal() + val
|
||||
diag_indices = jnp.diag_indices(X.shape[0])
|
||||
return X.at[diag_indices].set(new_diagonal)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _dot(a, b):
|
||||
return jnp.dot(a, b, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
|
||||
""" Computes the polar decomposition.
|
||||
|
||||
Given the (m x n) matrix `a`, returns the factors of the polar decomposition
|
||||
`u` (m x n) and `p` such that `a = up` (if side is "right"; p is (n x n)) or
|
||||
`a = pu` (if side is "left"; p is (m x m)), where `p` is positive
|
||||
semidefinite. If `a` is nonsingular, `p` is positive definite and the
|
||||
decomposition is unique. `u` has orthonormal columns unless n > m, in which
|
||||
case it has orthonormal rows.
|
||||
|
||||
Writing an SVD of `a` as `a = u_svd @ s_svd @ v^h_svd`, we have
|
||||
`u = u_svd @ v^h_svd`. Thus the unitary factor `u` can be construed as
|
||||
the application of the signum function to the singular values of `a`;
|
||||
or, if `a` is Hermitian, the eigenvalues.
|
||||
|
||||
Several methods exist to compute the polar decomposition. Currently two
|
||||
are supported:
|
||||
`method`="svd": Computes the SVD of `a` and then forms
|
||||
`u = u_svd @ v^h_svd`. This fails on the TPU, since
|
||||
no SVD algorithm independent of the polar decomposition
|
||||
exists there.
|
||||
`method`="qdwh": Applies a certain iterative expansion of the matrix
|
||||
signum function to `a` based on QR and Cholesky
|
||||
decompositions.
|
||||
|
||||
Args:
|
||||
a: The m x n input matrix.
|
||||
side: Determines whether a right or left polar decomposition is computed.
|
||||
If side is "right" then `a = up`. If side is "left" then `a = pu`. The
|
||||
default is "right".
|
||||
method: Determines the algorithm used, as described above.
|
||||
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
|
||||
|
||||
The remaining arguments are only meaningful if method is "qdwh".
|
||||
eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
|
||||
maxiter: Iterations will terminate after this many steps even if the
|
||||
above is unsatisfied.
|
||||
Returns:
|
||||
unitary: The unitary factor (m x n).
|
||||
posdef: The positive-semidefinite factor. Either (n, n) or (m, m)
|
||||
depending on whether side is "right" or "left", respectively.
|
||||
info: Stores convergence information.
|
||||
if method is "svd": None
|
||||
if method is "qdwh": j_qr: Number of QR iterations.
|
||||
j_chol: Number of Cholesky iterations.
|
||||
errs: Convergence history.
|
||||
"""
|
||||
return _polar(a, side, method, eps, maxiter)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2, 4))
|
||||
def _polar(a, side, method, eps, maxiter):
|
||||
if side not in ("left", "right"):
|
||||
raise ValueError(f"side={side} was invalid.")
|
||||
|
||||
unitary, info = _polar_unitary(a, method, eps, maxiter)
|
||||
if side == "right":
|
||||
posdef = _dot(unitary.conj().T, a)
|
||||
else:
|
||||
posdef = _dot(a, unitary.conj().T)
|
||||
posdef = 0.5 * (posdef + posdef.conj().T)
|
||||
return unitary, posdef, info
|
||||
|
||||
|
||||
def polar_unitary(a, method="qdwh", eps=None, maxiter=50):
|
||||
""" Computes the unitary factor u in the polar decomposition `a = u p`
|
||||
(or `a = p u`).
|
||||
"""
|
||||
return _polar_unitary(a, method, eps, maxiter)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 3))
|
||||
def _polar_unitary(a, method, eps, maxiter):
|
||||
if method not in ("svd", "qdwh"):
|
||||
raise ValueError(f"method={method} is unsupported.")
|
||||
|
||||
if method == "svd":
|
||||
u_svd, _, vh_svd = jnp.linalg.svd(a, full_matrices=False)
|
||||
unitary = _dot(u_svd, vh_svd)
|
||||
info = None
|
||||
elif method == "qdwh":
|
||||
unitary, j_qr, j_chol, errs = _qdwh(a, eps, maxiter)
|
||||
info = (j_qr, j_chol, errs)
|
||||
else:
|
||||
raise ValueError("How did we get here?")
|
||||
return unitary, info
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(2,))
|
||||
def _qdwh(matrix, eps, maxiter):
|
||||
""" Computes the unitary factor in the polar decomposition of A using
|
||||
the QDWH method. QDWH implements a 3rd order Pade approximation to the
|
||||
matrix sign function,
|
||||
|
||||
X' = X * (aI + b X^H X)(I + c X^H X)^-1, X0 = A / ||A||_2. (1)
|
||||
|
||||
The coefficients a, b, and c are chosen dynamically based on an evolving
|
||||
estimate of the matrix condition number. Specifically,
|
||||
|
||||
a = h(l), b = g(a), c = a + b - 1, h(x) = x g(x^2), g(x) = a + bx / (1 + cx)
|
||||
|
||||
where l is initially a lower bound on the smallest singular value of X0,
|
||||
and subsequently evolves according to l' = l (a + bl^2) / (1 + c l^2).
|
||||
|
||||
For poorly conditioned matrices
|
||||
(c > 100) the iteration (1) is rewritten in QR form,
|
||||
|
||||
X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X] (2)
|
||||
[Q2] [I ].
|
||||
|
||||
For well-conditioned matrices it is instead formulated using cheaper
|
||||
Cholesky iterations,
|
||||
|
||||
X' = (b / c) X + (a - b/c) (X W^-1) W^-H, W = chol(I + c X^H X). (3)
|
||||
|
||||
The QR iterations rapidly improve the condition number, and typically
|
||||
only 1 or 2 are required. A maximum of 6 iterations total are required
|
||||
for backwards stability to double precision.
|
||||
|
||||
Args:
|
||||
matrix: The m x n input matrix.
|
||||
eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
|
||||
maxiter: Iterations will terminate after this many steps even if the
|
||||
above is unsatisfied.
|
||||
Returns:
|
||||
matrix: The unitary factor (m x n).
|
||||
jq: The number of QR iterations (1).
|
||||
jc: The number of Cholesky iterations (2).
|
||||
errs: Convergence history.
|
||||
"""
|
||||
n_rows, n_cols = matrix.shape
|
||||
fat = n_cols > n_rows
|
||||
if fat:
|
||||
matrix = matrix.T
|
||||
matrix, q_factor, l0 = _initialize_qdwh(matrix)
|
||||
|
||||
if eps is None:
|
||||
eps = jnp.finfo(matrix.dtype).eps
|
||||
tol_lk = 5 * eps # stop when lk differs from 1 by less
|
||||
tol_delta = jnp.cbrt(tol_lk) # stop when the iterates change by less
|
||||
coefs = _qdwh_coefs(l0)
|
||||
errs = jnp.zeros(maxiter, dtype=matrix.real.dtype)
|
||||
matrix, j_qr, coefs, errs = _qdwh_qr(
|
||||
matrix, coefs, errs, tol_lk, tol_delta, maxiter)
|
||||
matrix, j_chol, errs = _qdwh_cholesky(
|
||||
matrix, coefs, errs, tol_lk, tol_delta, j_qr, maxiter)
|
||||
matrix = _dot(q_factor, matrix)
|
||||
|
||||
if fat:
|
||||
matrix = matrix.T
|
||||
return matrix, j_qr, j_chol, errs
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _initialize_qdwh(matrix):
|
||||
""" Does preparatory computations for QDWH:
|
||||
1. Computes an initial QR factorization of the input A. The iterations
|
||||
will be on the triangular factor R, whose condition is more easily
|
||||
estimated, and which is square even when A is rectangular.
|
||||
2. Computes R -> R / ||R||_F. Now 1 is used to upper-bound ||R||_2.
|
||||
3. Computes R^-1 by solving R R^-1 = I.
|
||||
4. Uses sqrt(N) * ||R^-1||_1 as a lower bound for ||R^-2||.
|
||||
1 / sqrt(N) * ||R^-1||_1 is then used as the initial l_0. It should be clear
|
||||
there is room for improvement here.
|
||||
|
||||
Returns:
|
||||
X = R / ||R||_F;
|
||||
Q from A -> Q @ R;
|
||||
l0, the initial estimate for the QDWH coefficients.
|
||||
"""
|
||||
q_factor, r_factor = jnp.linalg.qr(matrix, mode="reduced")
|
||||
alpha = jnp.linalg.norm(r_factor)
|
||||
r_factor /= alpha
|
||||
eye = jnp.eye(*r_factor.shape, dtype=r_factor.dtype)
|
||||
r_inv = jsp.linalg.solve_triangular(r_factor, eye, overwrite_b=True)
|
||||
one_norm_inv = jnp.linalg.norm(r_inv, ord=1)
|
||||
l0 = 1 / (jnp.sqrt(matrix.shape[1]) * one_norm_inv)
|
||||
eps = jnp.finfo(r_factor.dtype).eps
|
||||
l0 = jnp.array(l0, dtype=r_factor.real.dtype)
|
||||
l0 = jnp.where(l0 < eps, x=eps, y=l0)
|
||||
l0 = jnp.where(l0 > 1.0, x=1.0, y=l0)
|
||||
return r_factor, q_factor, l0
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _qdwh_coefs(lk):
|
||||
""" Computes a, b, c, l for the QDWH iterations.
|
||||
The input lk must be in (0, 1]; lk=1 is a fixed point.
|
||||
Some facts about the coefficients:
|
||||
-for lk = 1 we have a=3, b=1, c=3, lk_new = 1.
|
||||
-The float64 vs float32 computation of each coef appears to differ
|
||||
only by noise on the order of 1E-9 to 1E-7 for all values of lk.
|
||||
There is no apparent secular change in the (relative) error.
|
||||
-All coefs change roughly as power laws; over e.g. [1E-14, 1]:
|
||||
- a decreases from 5.43E9 to 3.
|
||||
- b decreases from 7.37E18 to 1.
|
||||
- c decreases from 7.37E18 to 3, only diverging from b near lk=1.
|
||||
- lk increases from 5.45E-5 to 1.
|
||||
|
||||
lk is an estimate of the scaled matrix's smallest singular value
|
||||
"""
|
||||
lk = jnp.where(lk > 1.0, x=1.0, y=lk)
|
||||
d = (4. * (1. - lk**2) / (lk**4))**(1 / 3)
|
||||
f = 8. * (2. - lk**2) / (lk**2 * (1. + d)**(1 / 2))
|
||||
a = (1. + d)**(1 / 2) + 0.5 * (8. - 4. * d + f)**0.5
|
||||
b = (a - 1.)**2 / 4
|
||||
c = a + b - 1.
|
||||
lk = lk * (a + b * lk**2) / (1 + c * lk**2)
|
||||
return a, b, c, lk
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _unconverged(lk, j, maxiter, err, tol_delta, tol_lk):
|
||||
changing = err > tol_delta
|
||||
far_from_end = jnp.abs(1 - lk) > tol_lk
|
||||
unconverged = jnp.logical_or(changing, far_from_end)
|
||||
iterating = j < maxiter
|
||||
return jnp.logical_and(iterating, unconverged)[0]
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _qdwh_qr(matrix, coefs, errs, tol_lk, tol_delta, maxiter):
|
||||
""" Applies the QDWH iteration formulated as
|
||||
|
||||
X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X]
|
||||
[Q2] [I ]
|
||||
|
||||
to X until either c < 100, ||X' - X|| < eps||X'||,
|
||||
or the iteration count exceeds maxiter.
|
||||
"""
|
||||
n_rows, n_cols = matrix.shape
|
||||
eye = jnp.eye(n_cols, dtype=matrix.dtype)
|
||||
|
||||
def _do_qr(args):
|
||||
_, j, coefs, _, err = args
|
||||
c = coefs[2]
|
||||
lk = coefs[-1]
|
||||
unconverged = _unconverged(lk, j, maxiter, err, tol_delta, tol_lk)
|
||||
ill_conditioned = c > 100.
|
||||
return jnp.logical_and(ill_conditioned, unconverged)
|
||||
|
||||
def _qr_work(args):
|
||||
matrix, j, coefs, errs, _ = args
|
||||
a, b, c, lk = coefs
|
||||
csqrt = jnp.sqrt(c)
|
||||
matrixI = jnp.vstack((csqrt * matrix, eye))
|
||||
# Note: it should be possible to compute the QR of csqrt * matrix
|
||||
# and build the concatenation with I at O(N).
|
||||
Q, _ = jnp.linalg.qr(matrixI, mode="reduced")
|
||||
Q1 = Q[:n_rows, :]
|
||||
Q2 = Q[n_rows:, :]
|
||||
coef = (1 / csqrt) * (a - (b / c))
|
||||
new_matrix = (b / c) * matrix + coef * _dot(Q1, Q2.T.conj())
|
||||
err = jnp.linalg.norm(matrix - new_matrix)
|
||||
err = jnp.full(1, err).astype(errs[0].dtype)
|
||||
errs = errs.at[j].set(err)
|
||||
coefs = _qdwh_coefs(lk)
|
||||
return new_matrix, j + 1, coefs, errs, err
|
||||
|
||||
j = jnp.zeros(1, dtype=jnp.int32)
|
||||
err = jnp.full(1, 2 * tol_delta).astype(matrix.real.dtype)
|
||||
matrix, j, coefs, errs, _ = jax.lax.while_loop(
|
||||
_do_qr, _qr_work, (matrix, j, coefs, errs, err))
|
||||
return matrix, j, coefs, errs
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _qdwh_cholesky(matrix, coefs, errs, tol_delta, tol_lk, j0, maxiter):
|
||||
""" Applies the QDWH iteration formulated as
|
||||
|
||||
matrix' = (b / c) matrix + (a - b/c) B,
|
||||
B = (matrix W^-1) W^-H, W = chol(I + c matrix^H matrix).
|
||||
|
||||
to matrix until either ||matrix' - matrix|| < eps * ||matrix'||,
|
||||
or the iteration count exceeds maxiter.
|
||||
"""
|
||||
|
||||
def _do_cholesky(args):
|
||||
_, j, coefs, errs = args
|
||||
lk = coefs[-1]
|
||||
return _unconverged(lk, j, maxiter, errs[j - 1], tol_delta, tol_lk)
|
||||
|
||||
def _cholesky_work(args):
|
||||
matrix, j, coefs, errs = args
|
||||
a, b, c, lk = coefs
|
||||
Z = c * _dot(matrix.T.conj(), matrix)
|
||||
Z = _add_to_diagonal(Z, 1.)
|
||||
W = jsp.linalg.cholesky(Z)
|
||||
B = jsp.linalg.solve_triangular(W.T, matrix.T, lower=True).conj()
|
||||
B = jsp.linalg.solve_triangular(W, B).conj().T
|
||||
new_matrix = (b / c) * matrix + (a - b / c) * B
|
||||
# possible instability if a ~ b / c
|
||||
err = jnp.linalg.norm(new_matrix - matrix).astype(errs[0].dtype)
|
||||
errs = errs.at[j].set(err)
|
||||
coefs = _qdwh_coefs(lk)
|
||||
return new_matrix, j + 1, coefs, errs
|
||||
|
||||
carry = (matrix, j0, coefs, errs)
|
||||
matrix, j_total, coefs, errs = jax.lax.while_loop(
|
||||
_do_cholesky, _cholesky_work, carry)
|
||||
return matrix, j_total - j0, errs
|
@ -65,14 +65,14 @@ def _use_cholesky(u, params):
|
||||
return u
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
|
||||
def _qdwh(x, is_symmetric, max_iterations):
|
||||
def _qdwh(x, is_hermitian, max_iterations, eps):
|
||||
"""QR-based dynamically weighted Halley iteration for polar decomposition."""
|
||||
|
||||
# Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of
|
||||
# norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for
|
||||
# the smallest singular value of x.
|
||||
eps = jnp.finfo(x.dtype).eps
|
||||
if eps is None:
|
||||
eps = jnp.finfo(x.dtype).eps
|
||||
alpha = (jnp.sqrt(jnp.linalg.norm(x, ord=1)) *
|
||||
jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf)))
|
||||
l = eps
|
||||
@ -113,7 +113,7 @@ def _qdwh(x, is_symmetric, max_iterations):
|
||||
|
||||
u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u))
|
||||
|
||||
if is_symmetric:
|
||||
if is_hermitian:
|
||||
u = (u + u.T.conj()) / 2.0
|
||||
|
||||
# Checks convergence.
|
||||
@ -145,13 +145,17 @@ def _qdwh(x, is_symmetric, max_iterations):
|
||||
|
||||
|
||||
# TODO: Add pivoting.
|
||||
def qdwh(x, is_symmetric, max_iterations=10):
|
||||
@functools.partial(jax.jit, static_argnames=('is_hermitian',))
|
||||
def qdwh(x, is_hermitian=False, max_iterations=None, eps=None):
|
||||
"""QR-based dynamically weighted Halley iteration for polar decomposition.
|
||||
|
||||
Args:
|
||||
x: A full-rank matrix of shape `m x n` with `m >= n`.
|
||||
is_symmetric: True if `x` is symmetric.
|
||||
max_iterations: The predefined maximum number of iterations.
|
||||
x: A full-rank matrix of shape `m x n`.
|
||||
is_hermitian: True if `x` is Hermitian. Default to `False`.
|
||||
eps: The final result will satisfy
|
||||
``|x_k - x_k-1| < |x_k| * (4*eps)**(1/3)`` where `x_k` is the iterate.
|
||||
max_iterations: Iterations will terminate after this many steps even if the
|
||||
above is unsatisfied.
|
||||
|
||||
Returns:
|
||||
A four-tuple of (u, h, num_iters, is_converged) containing the
|
||||
@ -159,29 +163,19 @@ def qdwh(x, is_symmetric, max_iterations=10):
|
||||
and `is_converged`, whose value is `True` when the convergence is achieved
|
||||
within the maximum number of iterations.
|
||||
"""
|
||||
m, n = x.shape
|
||||
is_hermitian = core.concrete_or_error(
|
||||
bool, is_hermitian, 'The `is_hermitian` argument must be statically '
|
||||
'specified to use `qdwh` within JAX transformations.')
|
||||
|
||||
if max_iterations is None:
|
||||
max_iterations = 10
|
||||
|
||||
m, n = x.shape
|
||||
if m < n:
|
||||
raise ValueError('The input matrix of shape m x n must have m >= n.')
|
||||
|
||||
max_iterations = core.concrete_or_error(
|
||||
int, max_iterations, 'The `max_iterations` argument must be statically '
|
||||
'specified to use `qdwh` within JAX transformations.')
|
||||
|
||||
is_symmetric = core.concrete_or_error(
|
||||
bool, is_symmetric, 'The `is_symmetric` argument must be statically '
|
||||
'specified to use `qdwh` within JAX transformations.')
|
||||
|
||||
if is_symmetric:
|
||||
eps = jnp.finfo(x.dtype).eps
|
||||
tol = 50.0 * eps
|
||||
relative_diff = jnp.linalg.norm(x - x.T.conj()) / jnp.linalg.norm(x)
|
||||
if relative_diff > tol:
|
||||
raise ValueError('The input `x` is NOT symmetric because '
|
||||
'`norm(x-x.H) / norm(x)` is {}, which is greater than '
|
||||
'the tolerance {}.'.format(relative_diff, tol))
|
||||
|
||||
with jax.default_matmul_precision('float32'):
|
||||
u, h, num_iters, is_converged = _qdwh(x, is_symmetric, max_iterations)
|
||||
u, h, num_iters, is_converged = _qdwh(x, is_hermitian, max_iterations, eps)
|
||||
|
||||
|
||||
return u, h, num_iters, is_converged
|
||||
|
@ -20,6 +20,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy as jsp
|
||||
from jax import lax
|
||||
from jax._src.lax import qdwh
|
||||
|
||||
|
||||
def _similarity_transform(
|
||||
@ -141,7 +142,7 @@ def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST):
|
||||
return X.at[jnp.diag_indices(X.shape[0])].set(vals)
|
||||
|
||||
H_shift = _fill_diagonal(H, H.diagonal() - split_point)
|
||||
U, _ = jsp.linalg.polar_unitary(H_shift)
|
||||
U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True)
|
||||
P = -0.5 * _fill_diagonal(U, U.diagonal() - 1.)
|
||||
rank = jnp.round(jnp.trace(P)).astype(jnp.int32)
|
||||
rank = int(rank)
|
||||
|
@ -18,11 +18,13 @@ from functools import partial
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
import textwrap
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax import jit, vmap, jvp
|
||||
from jax import lax
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.lax import polar as lax_polar
|
||||
from jax._src.lax import qdwh
|
||||
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import linalg as np_linalg
|
||||
@ -619,11 +621,111 @@ def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a',
|
||||
_, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
|
||||
return mid
|
||||
|
||||
@_wraps(scipy.linalg.polar)
|
||||
def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
|
||||
unitary, posdef, _ = lax_polar.polar(a, side=side, method=method, eps=eps,
|
||||
maxiter=maxiter)
|
||||
return unitary, posdef
|
||||
@partial(jit, static_argnames=('side', 'method'))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def polar(a, side='right', *, method='qdwh', eps=None, max_iterations=None):
|
||||
r"""Computes the polar decomposition.
|
||||
|
||||
Given the :math:`m \times n` matrix :math:`a`, returns the factors of the polar
|
||||
decomposition :math:`u` (also :math:`m \times n`) and :math:`p` such that
|
||||
:math:`a = up` (if side is ``"right"``; :math:`p` is :math:`n \times n`) or
|
||||
:math:`a = pu` (if side is ``"left"``; :math:`p` is :math:`m \times m`),
|
||||
where :math:`p` is positive semidefinite. If :math:`a` is nonsingular,
|
||||
:math:`p` is positive definite and the
|
||||
decomposition is unique. :math:`u` has orthonormal columns unless
|
||||
:math:`n > m`, in which case it has orthonormal rows.
|
||||
|
||||
Writing the SVD of :math:`a` as
|
||||
:math:`a = u_\mathit{svd} \cdot s_\mathit{svd} \cdot v^h_\mathit{svd}`, we
|
||||
have :math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`. Thus the unitary
|
||||
factor :math:`u` can be constructed as the application of the sign function to
|
||||
the singular values of :math:`a`; or, if :math:`a` is Hermitian, the
|
||||
eigenvalues.
|
||||
|
||||
Several methods exist to compute the polar decomposition. Currently two
|
||||
are supported:
|
||||
|
||||
* ``method="svd"``:
|
||||
|
||||
Computes the SVD of :math:`a` and then forms
|
||||
:math:`u = u_\mathit{svd} \cdot v^h_\mathit{svd}`.
|
||||
|
||||
* ``method="qdwh"``:
|
||||
|
||||
Applies the `QDWH`_ (QR-based Dynamically Weighted Halley) algorithm.
|
||||
|
||||
Args:
|
||||
a: The :math:`m \times n` input matrix.
|
||||
side: Determines whether a right or left polar decomposition is computed.
|
||||
If ``side`` is ``"right"`` then :math:`a = up`. If ``side`` is ``"left"``
|
||||
then :math:`a = pu`. The default is ``"right"``.
|
||||
method: Determines the algorithm used, as described above.
|
||||
precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
|
||||
eps: The final result will satisfy
|
||||
:math:`\left|x_k - x_{k-1}\right| < \left|x_k\right| (4\epsilon)^{\frac{1}{3}}`,
|
||||
where :math:`x_k` are the QDWH iterates. Ignored if ``method`` is not
|
||||
``"qdwh"``.
|
||||
max_iterations: Iterations will terminate after this many steps even if the
|
||||
above is unsatisfied. Ignored if ``method`` is not ``"qdwh"``.
|
||||
|
||||
Returns:
|
||||
A ``(unitary, posdef)`` tuple, where ``unitary`` is the unitary factor
|
||||
(:math:`m \times n`), and ``posdef`` is the positive-semidefinite factor.
|
||||
``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on
|
||||
whether ``side`` is ``"right"`` or ``"left"``, respectively.
|
||||
|
||||
.. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999
|
||||
"""
|
||||
a = jnp.asarray(a)
|
||||
if a.ndim != 2:
|
||||
raise ValueError("The input `a` must be a 2-D array.")
|
||||
|
||||
if side not in ["right", "left"]:
|
||||
raise ValueError("The argument `side` must be either 'right' or 'left'.")
|
||||
|
||||
m, n = a.shape
|
||||
if method == "qdwh":
|
||||
# TODO(phawkins): return info also if the user opts in?
|
||||
if m >= n and side == "right":
|
||||
unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps)
|
||||
elif m < n and side == "left":
|
||||
a = a.T.conj()
|
||||
unitary, posdef, _, _ = qdwh.qdwh(a, is_hermitian=False, eps=eps)
|
||||
posdef = posdef.T.conj()
|
||||
unitary = unitary.T.conj()
|
||||
else:
|
||||
raise NotImplementedError("method='qdwh' only supports mxn matrices "
|
||||
"where m < n where side='right' and m >= n "
|
||||
f"side='left', got {a.shape} with side={side}")
|
||||
elif method == "svd":
|
||||
u_svd, s_svd, vh_svd = lax_linalg.svd(a, full_matrices=False)
|
||||
unitary = u_svd @ vh_svd
|
||||
if side == "right":
|
||||
# a = u * p
|
||||
posdef = (vh_svd.T.conj() * s_svd[None, :]) @ vh_svd
|
||||
else:
|
||||
# a = p * u
|
||||
posdef = (u_svd * s_svd[None, :]) @ (u_svd.T.conj())
|
||||
else:
|
||||
raise ValueError(f"Unknown polar decomposition method {method}.")
|
||||
|
||||
return unitary, posdef
|
||||
|
||||
|
||||
def polar_unitary(a, *, method="qdwh", eps=None, max_iterations=None):
|
||||
""" Computes the unitary factor u in the polar decomposition ``a = u p``
|
||||
(or ``a = p u``).
|
||||
|
||||
.. warning::
|
||||
This function is deprecated. Use :func:`jax.scipy.linalg.polar` instead.
|
||||
"""
|
||||
# TODO(phawkins): delete this function after 2022/8/11.
|
||||
warnings.warn("jax.scipy.linalg.polar_unitary is deprecated. Call "
|
||||
"jax.scipy.linalg.polar instead.",
|
||||
DeprecationWarning)
|
||||
unitary, _ = polar(a, method, eps, max_iterations)
|
||||
return unitary
|
||||
|
||||
|
||||
@jit
|
||||
def _sqrtm_triu(T):
|
||||
|
@ -27,6 +27,7 @@ from jax._src.scipy.linalg import (
|
||||
lu_factor as lu_factor,
|
||||
lu_solve as lu_solve,
|
||||
polar as polar,
|
||||
polar_unitary as polar_unitary,
|
||||
qr as qr,
|
||||
rsf2csf as rsf2csf,
|
||||
schur as schur,
|
||||
@ -38,10 +39,6 @@ from jax._src.scipy.linalg import (
|
||||
triu as triu,
|
||||
)
|
||||
|
||||
from jax._src.lax.polar import (
|
||||
polar_unitary as polar_unitary,
|
||||
)
|
||||
|
||||
from jax._src.third_party.scipy.linalg import (
|
||||
funm as funm,
|
||||
)
|
||||
|
@ -490,7 +490,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
for method in methods
|
||||
for side in sides
|
||||
for nonzero_condition_number in nonzero_condition_numbers
|
||||
for dtype in jtu.dtypes.floating
|
||||
for dtype in jtu.dtypes.inexact
|
||||
for seed in seeds))
|
||||
@jtu.skip_on_devices("gpu") # Fails on A100.
|
||||
def testPolar(
|
||||
@ -500,8 +500,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
if jtu.device_under_test() != "cpu":
|
||||
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
|
||||
raise unittest.SkipTest("Skip half precision off CPU.")
|
||||
if method == "svd":
|
||||
raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.")
|
||||
|
||||
m, n = shape
|
||||
if (method == "qdwh" and ((side == "left" and m >= n) or
|
||||
(side == "right" and m < n))):
|
||||
raise unittest.SkipTest("method=qdwh does not support these sizes")
|
||||
|
||||
matrix, _ = _initialize_polar_test(self.rng(),
|
||||
shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
|
||||
@ -517,7 +520,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
should_be_eye = np.matmul(unitary.conj().T, unitary)
|
||||
else:
|
||||
should_be_eye = np.matmul(unitary, unitary.conj().T)
|
||||
tol = 10 * jnp.finfo(matrix.dtype).eps
|
||||
tol = 500 * jnp.finfo(matrix.dtype).eps
|
||||
eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
|
||||
with self.subTest('Test unitarity.'):
|
||||
self.assertAllClose(
|
||||
|
@ -45,12 +45,12 @@ def _check_symmetry(x: jnp.ndarray) -> bool:
|
||||
m, n = x.shape
|
||||
eps = jnp.finfo(x.dtype).eps
|
||||
tol = 50.0 * eps
|
||||
is_symmetric = False
|
||||
is_hermitian = False
|
||||
if m == n:
|
||||
if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol:
|
||||
is_symmetric = True
|
||||
is_hermitian = True
|
||||
|
||||
return is_symmetric
|
||||
return is_hermitian
|
||||
|
||||
def _compute_relative_diff(actual, expected):
|
||||
"""Computes relative difference between two matrices."""
|
||||
@ -75,11 +75,11 @@ class QdwhTest(jtu.JaxTestCase):
|
||||
cond = 10**log_cond
|
||||
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
|
||||
a = (u * s) @ v
|
||||
is_symmetric = _check_symmetry(a)
|
||||
is_hermitian = _check_symmetry(a)
|
||||
max_iterations = 2
|
||||
|
||||
_, _, actual_num_iterations, is_converged = qdwh.qdwh(
|
||||
a, is_symmetric, max_iterations)
|
||||
a, is_hermitian, max_iterations)
|
||||
|
||||
with self.subTest('Number of iterations.'):
|
||||
self.assertEqual(max_iterations, actual_num_iterations)
|
||||
@ -102,10 +102,10 @@ class QdwhTest(jtu.JaxTestCase):
|
||||
cond = 10**log_cond
|
||||
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
|
||||
a = (u * s) @ v
|
||||
is_symmetric = _check_symmetry(a)
|
||||
is_hermitian = _check_symmetry(a)
|
||||
max_iterations = 10
|
||||
|
||||
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations)
|
||||
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian, max_iterations)
|
||||
expected_u, expected_h = osp_linalg.polar(a)
|
||||
|
||||
# Sets the test tolerance.
|
||||
@ -145,12 +145,12 @@ class QdwhTest(jtu.JaxTestCase):
|
||||
cond = 10**log_cond
|
||||
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
|
||||
a = (u * s) @ v
|
||||
is_symmetric = _check_symmetry(a)
|
||||
is_hermitian = _check_symmetry(a)
|
||||
max_iterations = 10
|
||||
|
||||
def lsp_linalg_fn(a):
|
||||
u, h, _, _ = qdwh.qdwh(
|
||||
a, is_symmetric=is_symmetric, max_iterations=max_iterations)
|
||||
a, is_hermitian=is_hermitian, max_iterations=max_iterations)
|
||||
return u, h
|
||||
|
||||
args_maker = lambda: [a]
|
||||
@ -185,9 +185,9 @@ class QdwhTest(jtu.JaxTestCase):
|
||||
s = jnp.expand_dims(s.at[-1].set(0), range(u.ndim - 1))
|
||||
a = (u * s) @ v
|
||||
|
||||
is_symmetric = _check_symmetry(a)
|
||||
is_hermitian = _check_symmetry(a)
|
||||
max_iterations = 15
|
||||
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations)
|
||||
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian, max_iterations)
|
||||
_, expected_h = osp_linalg.polar(a)
|
||||
|
||||
# Sets the test tolerance.
|
||||
@ -221,13 +221,13 @@ class QdwhTest(jtu.JaxTestCase):
|
||||
tiny_elem = jnp.finfo(a).tiny
|
||||
a = a.at[r, c].set(tiny_elem)
|
||||
|
||||
is_symmetric = _check_symmetry(a)
|
||||
is_hermitian = _check_symmetry(a)
|
||||
max_iterations = 10
|
||||
|
||||
@jax.jit
|
||||
def lsp_linalg_fn(a):
|
||||
u, h, _, _ = qdwh.qdwh(
|
||||
a, is_symmetric=is_symmetric, max_iterations=max_iterations)
|
||||
a, is_hermitian=is_hermitian, max_iterations=max_iterations)
|
||||
return u, h
|
||||
|
||||
actual_u, actual_h = lsp_linalg_fn(a)
|
||||
|
Loading…
x
Reference in New Issue
Block a user