Merge pull request #7335 from hawkinsp:qdwh

PiperOrigin-RevId: 385796794
This commit is contained in:
jax authors 2021-07-20 08:54:16 -07:00
commit c95ef8799d
4 changed files with 777 additions and 0 deletions

342
jax/_src/lax/polar.py Normal file
View File

@ -0,0 +1,342 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""
Functions to compute the polar decomposition of the m x n matrix A, A = U @ H
where U is unitary (an m x n isometry in the m > n case) and H is n x n and
positive semidefinite (or positive definite if A is nonsingular). The method
is described in the docstring to `polarU`. This file covers the serial
case.
"""
import jax
from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
# TODO: Allow singular value estimates to be manually specified
@jax.jit
def _add_to_diagonal(X, val):
new_diagonal = X.diagonal() + val
diag_indices = jnp.diag_indices(X.shape[0])
return jax.ops.index_update(X, diag_indices, new_diagonal)
@jax.jit
def _dot(a, b):
return jnp.dot(a, b, precision=lax.Precision.HIGHEST)
def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
""" Computes the polar decomposition.
Given the (m x n) matrix `a`, returns the factors of the polar decomposition
`u` (m x n) and `p` such that `a = up` (if side is "right"; p is (n x n)) or
`a = pu` (if side is "left"; p is (m x m)), where `p` is positive
semidefinite. If `a` is nonsingular, `p` is positive definite and the
decomposition is unique. `u` has orthonormal columns unless n > m, in which
case it has orthonormal rows.
Writing an SVD of `a` as `a = u_svd @ s_svd @ v^h_svd`, we have
`u = u_svd @ v^h_svd`. Thus the unitary factor `u` can be construed as
the application of the signum function to the singular values of `a`;
or, if `a` is Hermitian, the eigenvalues.
Several methods exist to compute the polar decomposition. Currently two
are supported:
`method`="svd": Computes the SVD of `a` and then forms
`u = u_svd @ v^h_svd`. This fails on the TPU, since
no SVD algorithm independent of the polar decomposition
exists there.
`method`="qdwh": Applies a certain iterative expansion of the matrix
signum function to `a` based on QR and Cholesky
decompositions.
Args:
a: The m x n input matrix.
side: Determines whether a right or left polar decomposition is computed.
If side is "right" then `a = up`. If side is "left" then `a = pu`. The
default is "right".
method: Determines the algorithm used, as described above.
precision: Controls the TPU matrix multiplication precision.
The remaining arguments are only meaningful if method is "qdwh".
eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
maxiter: Iterations will terminate after this many steps even if the
above is unsatisfied.
Returns:
unitary: The unitary factor (m x n).
posdef: The positive-semidefinite factor. Either (n, n) or (m, m)
depending on whether side is "right" or "left", respectively.
info: Stores convergence information.
if method is "svd": None
if method is "qdwh": j_qr: Number of QR iterations.
j_chol: Number of Cholesky iterations.
errs: Convergence history.
"""
return _polar(a, side, method, eps, maxiter)
@jax.partial(jax.jit, static_argnums=(1, 2, 4))
def _polar(a, side, method, eps, maxiter):
if side not in ("left", "right"):
raise ValueError(f"side={side} was invalid.")
unitary, info = _polar_unitary(a, method, eps, maxiter)
if side == "right":
posdef = _dot(unitary.conj().T, a)
else:
posdef = _dot(a, unitary.conj().T)
posdef = 0.5 * (posdef + posdef.conj().T)
return unitary, posdef, info
def polar_unitary(a, method="qdwh", eps=None, maxiter=50):
""" Computes the unitary factor u in the polar decomposition `a = u p`
(or `a = p u`).
"""
return _polar_unitary(a, method, eps, maxiter)
@jax.partial(jax.jit, static_argnums=(1, 3))
def _polar_unitary(a, method, eps, maxiter):
if method not in ("svd", "qdwh"):
raise ValueError(f"method={method} is unsupported.")
if method == "svd":
u_svd, _, vh_svd = jnp.linalg.svd(a, full_matrices=False)
unitary = _dot(u_svd, vh_svd)
info = None
elif method == "qdwh":
unitary, j_qr, j_chol, errs = _qdwh(a, eps, maxiter)
info = (j_qr, j_chol, errs)
else:
raise ValueError("How did we get here?")
return unitary, info
@jax.partial(jax.jit, static_argnums=(2,))
def _qdwh(matrix, eps, maxiter):
""" Computes the unitary factor in the polar decomposition of A using
the QDWH method. QDWH implements a 3rd order Pade approximation to the
matrix sign function,
X' = X * (aI + b X^H X)(I + c X^H X)^-1, X0 = A / ||A||_2. (1)
The coefficients a, b, and c are chosen dynamically based on an evolving
estimate of the matrix condition number. Specifically,
a = h(l), b = g(a), c = a + b - 1, h(x) = x g(x^2), g(x) = a + bx / (1 + cx)
where l is initially a lower bound on the smallest singular value of X0,
and subsequently evolves according to l' = l (a + bl^2) / (1 + c l^2).
For poorly conditioned matrices
(c > 100) the iteration (1) is rewritten in QR form,
X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X] (2)
[Q2] [I ].
For well-conditioned matrices it is instead formulated using cheaper
Cholesky iterations,
X' = (b / c) X + (a - b/c) (X W^-1) W^-H, W = chol(I + c X^H X). (3)
The QR iterations rapidly improve the condition number, and typically
only 1 or 2 are required. A maximum of 6 iterations total are required
for backwards stability to double precision.
Args:
matrix: The m x n input matrix.
eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
maxiter: Iterations will terminate after this many steps even if the
above is unsatisfied.
Returns:
matrix: The unitary factor (m x n).
jq: The number of QR iterations (1).
jc: The number of Cholesky iterations (2).
errs: Convergence history.
"""
n_rows, n_cols = matrix.shape
fat = n_cols > n_rows
if fat:
matrix = matrix.T
matrix, q_factor, l0 = _initialize_qdwh(matrix)
if eps is None:
eps = jnp.finfo(matrix.dtype).eps
tol_lk = 5 * eps # stop when lk differs from 1 by less
tol_delta = jnp.cbrt(tol_lk) # stop when the iterates change by less
coefs = _qdwh_coefs(l0)
errs = jnp.zeros(maxiter, dtype=matrix.real.dtype)
matrix, j_qr, coefs, errs = _qdwh_qr(
matrix, coefs, errs, tol_lk, tol_delta, maxiter)
matrix, j_chol, errs = _qdwh_cholesky(
matrix, coefs, errs, tol_lk, tol_delta, j_qr, maxiter)
matrix = _dot(q_factor, matrix)
if fat:
matrix = matrix.T
return matrix, j_qr, j_chol, errs
@jax.jit
def _initialize_qdwh(matrix):
""" Does preparatory computations for QDWH:
1. Computes an initial QR factorization of the input A. The iterations
will be on the triangular factor R, whose condition is more easily
estimated, and which is square even when A is rectangular.
2. Computes R -> R / ||R||_F. Now 1 is used to upper-bound ||R||_2.
3. Computes R^-1 by solving R R^-1 = I.
4. Uses sqrt(N) * ||R^-1||_1 as a lower bound for ||R^-2||.
1 / sqrt(N) * ||R^-1||_1 is then used as the initial l_0. It should be clear
there is room for improvement here.
Returns:
X = R / ||R||_F;
Q from A -> Q @ R;
l0, the initial estimate for the QDWH coefficients.
"""
q_factor, r_factor = jnp.linalg.qr(matrix, mode="reduced")
alpha = jnp.linalg.norm(r_factor)
r_factor /= alpha
eye = jnp.eye(*r_factor.shape, dtype=r_factor.dtype)
r_inv = jsp.linalg.solve_triangular(r_factor, eye, overwrite_b=True)
one_norm_inv = jnp.linalg.norm(r_inv, ord=1)
l0 = 1 / (jnp.sqrt(matrix.shape[1]) * one_norm_inv)
eps = jnp.finfo(r_factor.dtype).eps
l0 = jnp.array(l0, dtype=r_factor.real.dtype)
l0 = jnp.where(l0 < eps, x=eps, y=l0)
l0 = jnp.where(l0 > 1.0, x=1.0, y=l0)
return r_factor, q_factor, l0
@jax.jit
def _qdwh_coefs(lk):
""" Computes a, b, c, l for the QDWH iterations.
The input lk must be in (0, 1]; lk=1 is a fixed point.
Some facts about the coefficients:
-for lk = 1 we have a=3, b=1, c=3, lk_new = 1.
-The float64 vs float32 computation of each coef appears to differ
only by noise on the order of 1E-9 to 1E-7 for all values of lk.
There is no apparent secular change in the (relative) error.
-All coefs change roughly as power laws; over e.g. [1E-14, 1]:
- a decreases from 5.43E9 to 3.
- b decreases from 7.37E18 to 1.
- c decreases from 7.37E18 to 3, only diverging from b near lk=1.
- lk increases from 5.45E-5 to 1.
lk is an estimate of the scaled matrix's smallest singular value
"""
lk = jnp.where(lk > 1.0, x=1.0, y=lk)
d = (4. * (1. - lk**2) / (lk**4))**(1 / 3)
f = 8. * (2. - lk**2) / (lk**2 * (1. + d)**(1 / 2))
a = (1. + d)**(1 / 2) + 0.5 * (8. - 4. * d + f)**0.5
b = (a - 1.)**2 / 4
c = a + b - 1.
lk = lk * (a + b * lk**2) / (1 + c * lk**2)
return a, b, c, lk
@jax.jit
def _unconverged(lk, j, maxiter, err, tol_delta, tol_lk):
changing = err > tol_delta
far_from_end = jnp.abs(1 - lk) > tol_lk
unconverged = jnp.logical_or(changing, far_from_end)
iterating = j < maxiter
return jnp.logical_and(iterating, unconverged)[0]
@jax.jit
def _qdwh_qr(matrix, coefs, errs, tol_lk, tol_delta, maxiter):
""" Applies the QDWH iteration formulated as
X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X]
[Q2] [I ]
to X until either c < 100, ||X' - X|| < eps||X'||,
or the iteration count exceeds maxiter.
"""
n_rows, n_cols = matrix.shape
eye = jnp.eye(n_cols, dtype=matrix.dtype)
def _do_qr(args):
_, j, coefs, _, err = args
c = coefs[2]
lk = coefs[-1]
unconverged = _unconverged(lk, j, maxiter, err, tol_delta, tol_lk)
ill_conditioned = c > 100.
return jnp.logical_and(ill_conditioned, unconverged)
def _qr_work(args):
matrix, j, coefs, errs, _ = args
a, b, c, lk = coefs
csqrt = jnp.sqrt(c)
matrixI = jnp.vstack((csqrt * matrix, eye))
# Note: it should be possible to compute the QR of csqrt * matrix
# and build the concatenation with I at O(N).
Q, _ = jnp.linalg.qr(matrixI, mode="reduced")
Q1 = Q[:n_rows, :]
Q2 = Q[n_rows:, :]
coef = (1 / csqrt) * (a - (b / c))
new_matrix = (b / c) * matrix + coef * _dot(Q1, Q2.T.conj())
err = jnp.linalg.norm(matrix - new_matrix)
err = jnp.full(1, err).astype(errs[0].dtype)
errs = errs.at[j].set(err)
coefs = _qdwh_coefs(lk)
return new_matrix, j + 1, coefs, errs, err
j = jnp.zeros(1, dtype=jnp.int32)
err = jnp.full(1, 2 * tol_delta).astype(matrix.real.dtype)
matrix, j, coefs, errs, _ = jax.lax.while_loop(
_do_qr, _qr_work, (matrix, j, coefs, errs, err))
return matrix, j, coefs, errs
@jax.jit
def _qdwh_cholesky(matrix, coefs, errs, tol_delta, tol_lk, j0, maxiter):
""" Applies the QDWH iteration formulated as
matrix' = (b / c) matrix + (a - b/c) B,
B = (matrix W^-1) W^-H, W = chol(I + c matrix^H matrix).
to matrix until either ||matrix' - matrix|| < eps * ||matrix'||,
or the iteration count exceeds maxiter.
"""
def _do_cholesky(args):
_, j, coefs, errs = args
lk = coefs[-1]
return _unconverged(lk, j, maxiter, errs[j - 1], tol_delta, tol_lk)
def _cholesky_work(args):
matrix, j, coefs, errs = args
a, b, c, lk = coefs
Z = c * _dot(matrix.T.conj(), matrix)
Z = _add_to_diagonal(Z, 1.)
W = jsp.linalg.cholesky(Z)
B = jsp.linalg.solve_triangular(W.T, matrix.T, lower=True).conj()
B = jsp.linalg.solve_triangular(W, B).conj().T
new_matrix = (b / c) * matrix + (a - b / c) * B
# possible instability if a ~ b / c
err = jnp.linalg.norm(new_matrix - matrix).astype(errs[0].dtype)
errs = errs.at[j].set(err)
coefs = _qdwh_coefs(lk)
return new_matrix, j + 1, coefs, errs
carry = (matrix, j0, coefs, errs)
matrix, j_total, coefs, errs = jax.lax.while_loop(
_do_cholesky, _cholesky_work, carry)
return matrix, j_total - j0, errs

237
jax/_src/scipy/eigh.py Normal file
View File

@ -0,0 +1,237 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""Serial algorithm for eigh."""
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import lax
def _similarity_transform(
matrix_in, matrix_out, precision=lax.Precision.HIGHEST):
""" Returns matrix_out.conj().T @ matrix_in @ matrix_out, done in
order from left to right.
"""
out = jnp.dot(matrix_out.conj().T, matrix_in, precision=precision)
return jnp.dot(out, matrix_out, precision=precision)
def _projector_subspace(P, H, rank, maxiter=2):
""" Decomposes the `n x n` rank `rank` Hermitian projector `P` into
an `n x rank` isometry `Vm` such that `P = Vm @ Vm.conj().T` and
an `n x (n - rank)` isometry `Vm` such that -(I - P) = Vp @ Vp.conj().T`.
The subspaces are computed using the naiive QR eigendecomposition
algorithm, which converges very quickly due to the sharp separation
between the relevant eigenvalues of the projector.
Args:
P: A rank-`rank` Hermitian projector into the space of `H`'s
first `rank` eigenpairs.
H: The aforementioned Hermitian matrix, which is used to track
convergence.
rank: Rank of `P`.
maxiter: Maximum number of iterations.
Returns:
Vm, Vp: Isometries into the eigenspaces described in the docstring.
"""
# Choose an initial guess: the `rank` largest-norm columns of P.
column_norms = jnp.linalg.norm(P, axis=1)
sort_idxs = jnp.argsort(column_norms)
X = P[:, sort_idxs]
X = X[:, :rank]
H_norm = jnp.linalg.norm(H)
thresh = 10 * jnp.finfo(X.dtype).eps * H_norm
# First iteration skips the matmul.
def body_f_after_matmul(X):
Q, _ = jnp.linalg.qr(X, mode="complete")
V1 = Q[:, :rank]
V2 = Q[:, rank:]
# TODO: might be able to get away with lower precision here
error_matrix = jnp.dot(V2.conj().T, H, precision=lax.Precision.HIGHEST)
error_matrix = jnp.dot(error_matrix, V1, precision=lax.Precision.HIGHEST)
error = jnp.linalg.norm(error_matrix) / H_norm
return V1, V2, error
def cond_f(args):
_, _, j, error = args
still_counting = j < maxiter
unconverged = error > thresh
return jnp.logical_and(still_counting, unconverged)[0]
def body_f(args):
V1, _, j, _ = args
X = jnp.dot(P, V1, precision=lax.Precision.HIGHEST)
V1, V2, error = body_f_after_matmul(X)
return V1, V2, j + 1, error
V1, V2, error = body_f_after_matmul(X)
one = jnp.ones(1, dtype=jnp.int32)
V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error))
return V1, V2
@jax.partial(jax.jit, static_argnums=(3, 4))
def _split_spectrum_jittable(P, H, V0, rank, precision):
""" The jittable portion of `split_spectrum`. At this point the sizes of the
relavant matrix blocks have been concretized.
Args:
P: Projection matrix.
H: Matrix to be projected.
V0: Accumulates the isometries into the projected subspaces.
rank: Rank of P.
precision: The matmul precision.
Returns:
H1, V1: Projection of H into the column space of P, and the accumulated
isometry performing that projection.
H2, V2: Projection of H into the null space of P, and the accumulated
isometry performing that projection.
"""
Vm, Vp = _projector_subspace(P, H, rank)
Hm = _similarity_transform(H, Vm, precision)
Hp = _similarity_transform(H, Vp, precision)
if V0 is not None:
Vm = jnp.dot(V0, Vm, precision=precision)
Vp = jnp.dot(V0, Vp, precision=precision)
return Hm, Vm, Hp, Vp
def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST):
""" The Hermitian matrix `H` is split into two matrices `Hm`
`Hp`, respectively sharing its eigenspaces beneath and above
its `split_point`th eigenvalue.
Returns, in addition, `Vm` and `Vp`, isometries such that
`Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
returned instead; this allows the overall isometries mapping from
an initial input matrix to progressively smaller blocks to be formed.
Args:
H: The Hermitian matrix to split.
split_point: The eigenvalue to split along.
V0: Matrix of isometries to be updated.
precision: TPU matmul precision.
Returns:
Hm: A Hermitian matrix sharing the eigenvalues of `H` beneath
`split_point`.
Vm: An isometry from the input space of `V0` to `Hm`.
Hp: A Hermitian matrix sharing the eigenvalues of `H` above
`split_point`.
Vp: An isometry from the input space of `V0` to `Hp`.
"""
def _fill_diagonal(X, vals):
return jax.ops.index_update(X, jnp.diag_indices(X.shape[0]), vals)
H_shift = _fill_diagonal(H, H.diagonal() - split_point)
U, _ = jsp.linalg.polar_unitary(H_shift)
P = -0.5 * _fill_diagonal(U, U.diagonal() - 1.)
rank = jnp.round(jnp.trace(P)).astype(jnp.int32)
rank = int(rank)
return _split_spectrum_jittable(P, H, V0, rank, precision)
def _eigh_work(
H, V=None, precision=lax.Precision.HIGHEST, termination_size=128):
""" The main work loop performing the symmetric eigendecomposition of H.
Each step recursively computes a projector into the space of eigenvalues
above jnp.mean(jnp.diag(H)). The result of the projections into and out of
that space, along with the isometries accomplishing these, are then computed.
This is performed recursively until the projections have size 1, and thus
store an eigenvalue of the original input; the corresponding isometry is
the related eigenvector. The results are then composed.
This function cannot be Jitted because the internal split_spectrum cannot
be.
Args:
H: The Hermitian input.
V: Stores the isometries projecting H into its subspaces.
precision: The matmul precision.
Returns:
H, V: The result of the projection.
"""
if H.shape[0] <= termination_size:
evals, evecs = jnp.linalg.eigh(H)
if V is not None:
evecs = jnp.dot(V, evecs, precision=precision)
return evals, evecs
split_point = jnp.median(jnp.diag(H)) # TODO: Improve this?
Hm, Vm, Hp, Vp = split_spectrum(H, split_point, V0=V, precision=precision)
Hm, Vm = _eigh_work(
Hm, V=Vm, precision=precision, termination_size=termination_size)
Hp, Vp = _eigh_work(
Hp, V=Vp, precision=precision, termination_size=termination_size)
if Hm.ndim != 1 or Hp.ndim != 1:
raise ValueError(f"One of Hm.ndim={Hm.ndim} or Hp.ndim={Hp.ndim} != 1 ",
"indicating recursion terminated unexpectedly.")
evals = jnp.hstack((Hm, Hp))
evecs = jnp.hstack((Vm, Vp))
return evals, evecs
def eigh(
H, precision=lax.Precision.HIGHEST, symmetrize=True, termination_size=128):
""" Computes the eigendecomposition of the symmetric/Hermitian matrix H.
Args:
H: The `n x n` Hermitian input.
precision: The matmul precision.
symmetrize: If True, `0.5 * (H + H.conj().T)` rather than `H` is used.
termination_size: Recursion ends once the blocks reach this linear size.
Returns:
vals: The `n` eigenvalues of `H`, sorted from lowest to higest.
vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
to numerical error.
"""
nrows, ncols = H.shape
if nrows != ncols:
raise TypeError(f"Input H of shape {H.shape} must be square.")
if ncols <= termination_size:
return jnp.linalg.eigh(H)
evals, evecs = _eigh_work(H, precision=precision)
sort_idxs = jnp.argsort(evals)
evals = evals[sort_idxs]
evecs = evecs[:, sort_idxs]
return evals, evecs
def svd(A, precision=lax.Precision.HIGHEST):
""" Computes an SVD of `A`.
Args:
A: The `m` by `n` input matrix.
precision: TPU matmul precision.
Returns:
U: An `m` by `m` unitary matrix of `A`'s left singular vectors.
S: A length-`min(m, n)` vector of `A`'s singular values.
V_dag: An `n` by `n` unitary matrix of `A`'s conjugate transposed
right singular vectors.
"""
Up, H, _ = jsp.linalg.polar(A)
S, V = eigh(H, precision=precision)
U = jnp.dot(Up, V, precision=precision)
return U, S, V.conj().T

View File

@ -35,3 +35,8 @@ from jax._src.scipy.linalg import (
tril,
triu,
)
from jax._src.lax.polar import (
polar,
polar_unitary
)

View File

@ -27,8 +27,11 @@ import scipy.special as osp_special
from jax._src import api
from jax import numpy as jnp
from jax import lax
from jax import scipy as jsp
from jax import test_util as jtu
from jax.scipy import special as lsp_special
import jax._src.scipy.eigh
from jax.config import config
config.parse_flags_with_absl()
@ -44,6 +47,50 @@ float_dtypes = jtu.dtypes.floating
complex_dtypes = jtu.dtypes.complex
int_dtypes = jtu.dtypes.integer
# Params for the polar tests.
polar_shapes = [(16, 12), (12, 16), (128, 128)]
n_zero_svs = [0, 4]
degeneracies = [0, 4]
geometric_spectra = [False, True]
max_svs = [0.1, 10.]
nonzero_condition_numbers = [0.1, 100000]
sides = ["right", "left"]
methods = ["qdwh", "svd"]
seeds = [1, 10]
linear_sizes = [16, 128, 256]
def _initialize_polar_test(shape, n_zero_svs, degeneracy, geometric_spectrum,
max_sv, nonzero_condition_number, dtype):
n_rows, n_cols = shape
min_dim = min(shape)
left_vecs = np.random.randn(n_rows, min_dim).astype(np.float64)
left_vecs, _ = np.linalg.qr(left_vecs)
right_vecs = np.random.randn(n_cols, min_dim).astype(np.float64)
right_vecs, _ = np.linalg.qr(right_vecs)
min_nonzero_sv = max_sv / nonzero_condition_number
num_nonzero_svs = min_dim - n_zero_svs
if geometric_spectrum:
nonzero_svs = np.geomspace(min_nonzero_sv, max_sv, num=num_nonzero_svs,
dtype=np.float64)
else:
nonzero_svs = np.linspace(min_nonzero_sv, max_sv, num=num_nonzero_svs,
dtype=np.float64)
half_point = n_zero_svs // 2
for i in range(half_point, half_point + degeneracy):
nonzero_svs[i] = nonzero_svs[half_point]
svs = np.zeros(min(shape), dtype=np.float64)
svs[n_zero_svs:] = nonzero_svs
svs = svs[::-1]
result = np.dot(left_vecs * svs, right_vecs.conj().T)
result = jnp.array(result).astype(dtype)
spectrum = jnp.array(svs).astype(dtype)
return result, spectrum
OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "rng_factory", "test_autodiff", "nondiff_argnums", "test_name"])
@ -99,6 +146,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
def _GetArgsMaker(self, rng, shapes, dtypes):
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
@ -405,6 +453,151 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name':
'_shape={}'
'_n_zero_sv={}_degeneracy={}_geometric_spectrum={}'
'_max_sv={}_method={}_side={}'
'_nonzero_condition_number={}_seed={}'.format(
jtu.format_shape_dtype_string(
shape, jnp.dtype(dtype).name).replace(" ", ""),
n_zero_sv, degeneracy, geometric_spectrum, max_sv,
method, side, nonzero_condition_number, seed
),
'n_zero_sv': n_zero_sv, 'degeneracy': degeneracy,
'geometric_spectrum': geometric_spectrum,
'max_sv': max_sv, 'shape': shape, 'method': method,
'side': side, 'nonzero_condition_number': nonzero_condition_number,
'dtype': dtype, 'seed': seed}
for n_zero_sv in n_zero_svs
for degeneracy in degeneracies
for geometric_spectrum in geometric_spectra
for max_sv in max_svs
for shape in polar_shapes
for method in methods
for side in sides
for nonzero_condition_number in nonzero_condition_numbers
for dtype in jtu.dtypes.floating
for seed in seeds))
def testPolar(
self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method,
side, nonzero_condition_number, dtype, seed):
""" Tests jax.scipy.linalg.polar."""
if jtu.device_under_test() != "cpu":
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
raise unittest.SkipTest("Skip half precision off CPU.")
if method == "svd":
raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.")
np.random.seed(seed)
matrix, _ = _initialize_polar_test(
shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
nonzero_condition_number, dtype)
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
self.assertRaises(
NotImplementedError, jsp.linalg.polar, matrix, method=method,
side=side)
return
unitary, posdef, info = jsp.linalg.polar(matrix, method=method, side=side)
if shape[0] >= shape[1]:
should_be_eye = np.matmul(unitary.conj().T, unitary)
else:
should_be_eye = np.matmul(unitary, unitary.conj().T)
tol = 10 * jnp.finfo(matrix.dtype).eps
eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
with self.subTest('Test unitarity.'):
self.assertAllClose(
eye_mat, should_be_eye, atol=tol * min(shape))
with self.subTest('Test Hermiticity.'):
self.assertAllClose(
posdef, posdef.conj().T, atol=tol * jnp.linalg.norm(posdef))
ev, _ = np.linalg.eigh(posdef)
ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)]
negative_ev = jnp.sum(ev < 0.)
with self.subTest('Test positive definiteness.'):
assert negative_ev == 0.
if side == "right":
recon = jnp.matmul(unitary, posdef, precision=lax.Precision.HIGHEST)
elif side == "left":
recon = jnp.matmul(posdef, unitary, precision=lax.Precision.HIGHEST)
with self.subTest('Test reconstruction.'):
self.assertAllClose(
matrix, recon, atol=tol * jnp.linalg.norm(matrix))
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name':
'_linear_size_={}_seed={}_dtype={}'.format(
linear_size, seed, jnp.dtype(dtype).name
),
'linear_size': linear_size, 'seed': seed, 'dtype': dtype}
for linear_size in linear_sizes
for seed in seeds
for dtype in jtu.dtypes.floating))
def test_spectral_dac_eigh(self, linear_size, seed, dtype):
if jtu.device_under_test != "cpu":
raise unittest.SkipTest("Skip eigh off CPU for now.")
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
if jtu.device_under_test() != "cpu":
raise unittest.SkipTest("Skip half precision off CPU.")
np.random.seed(seed)
H = np.random.randn(linear_size, linear_size)
H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
self.assertRaises(
NotImplementedError, jax._src.scipy.eigh.eigh, H)
return
evs, V = jax._src.scipy.eigh.eigh(H)
ev_exp, eV_exp = jnp.linalg.eigh(H)
HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
vV = evs * V
eps = jnp.finfo(H.dtype).eps
atol = jnp.linalg.norm(H) * eps
self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
self.assertAllClose(HV, vV, atol=30 * atol)
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name':
'_linear_size_={}_seed={}_dtype={}'.format(
linear_size, seed, jnp.dtype(dtype).name
),
'linear_size': linear_size, 'seed': seed, 'dtype': dtype}
for linear_size in linear_sizes
for seed in seeds
for dtype in jtu.dtypes.floating))
def test_spectral_dac_svd(self, linear_size, seed, dtype):
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
if jtu.device_under_test() != "cpu":
raise unittest.SkipTest("Skip half precision off CPU.")
np.random.seed(seed)
A = np.random.randn(linear_size, linear_size).astype(dtype)
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
self.assertRaises(
NotImplementedError, jax._src.scipy.eigh.svd, A)
return
S_expected = np.linalg.svd(A, compute_uv=False)
U, S, V = jax._src.scipy.eigh.svd(A)
recon = jnp.dot((U * S), V, precision=lax.Precision.HIGHEST)
eps = jnp.finfo(dtype).eps
eps = eps * jnp.linalg.norm(A) * 10
self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps)
self.assertAllClose(A, recon, atol=eps)
# U is unitary.
u_unitary_delta = jnp.dot(U.conj().T, U, precision=lax.Precision.HIGHEST)
u_eye = jnp.eye(u_unitary_delta.shape[0], dtype=dtype)
self.assertAllClose(u_unitary_delta, u_eye, atol=eps)
# V is unitary.
v_unitary_delta = jnp.dot(V.conj().T, V, precision=lax.Precision.HIGHEST)
v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype)
self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())