mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #7335 from hawkinsp:qdwh
PiperOrigin-RevId: 385796794
This commit is contained in:
commit
c95ef8799d
342
jax/_src/lax/polar.py
Normal file
342
jax/_src/lax/polar.py
Normal file
@ -0,0 +1,342 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
|
||||
"""
|
||||
Functions to compute the polar decomposition of the m x n matrix A, A = U @ H
|
||||
where U is unitary (an m x n isometry in the m > n case) and H is n x n and
|
||||
positive semidefinite (or positive definite if A is nonsingular). The method
|
||||
is described in the docstring to `polarU`. This file covers the serial
|
||||
case.
|
||||
"""
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy as jsp
|
||||
|
||||
|
||||
# TODO: Allow singular value estimates to be manually specified
|
||||
@jax.jit
|
||||
def _add_to_diagonal(X, val):
|
||||
new_diagonal = X.diagonal() + val
|
||||
diag_indices = jnp.diag_indices(X.shape[0])
|
||||
return jax.ops.index_update(X, diag_indices, new_diagonal)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _dot(a, b):
|
||||
return jnp.dot(a, b, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
|
||||
""" Computes the polar decomposition.
|
||||
|
||||
Given the (m x n) matrix `a`, returns the factors of the polar decomposition
|
||||
`u` (m x n) and `p` such that `a = up` (if side is "right"; p is (n x n)) or
|
||||
`a = pu` (if side is "left"; p is (m x m)), where `p` is positive
|
||||
semidefinite. If `a` is nonsingular, `p` is positive definite and the
|
||||
decomposition is unique. `u` has orthonormal columns unless n > m, in which
|
||||
case it has orthonormal rows.
|
||||
|
||||
Writing an SVD of `a` as `a = u_svd @ s_svd @ v^h_svd`, we have
|
||||
`u = u_svd @ v^h_svd`. Thus the unitary factor `u` can be construed as
|
||||
the application of the signum function to the singular values of `a`;
|
||||
or, if `a` is Hermitian, the eigenvalues.
|
||||
|
||||
Several methods exist to compute the polar decomposition. Currently two
|
||||
are supported:
|
||||
`method`="svd": Computes the SVD of `a` and then forms
|
||||
`u = u_svd @ v^h_svd`. This fails on the TPU, since
|
||||
no SVD algorithm independent of the polar decomposition
|
||||
exists there.
|
||||
`method`="qdwh": Applies a certain iterative expansion of the matrix
|
||||
signum function to `a` based on QR and Cholesky
|
||||
decompositions.
|
||||
|
||||
Args:
|
||||
a: The m x n input matrix.
|
||||
side: Determines whether a right or left polar decomposition is computed.
|
||||
If side is "right" then `a = up`. If side is "left" then `a = pu`. The
|
||||
default is "right".
|
||||
method: Determines the algorithm used, as described above.
|
||||
precision: Controls the TPU matrix multiplication precision.
|
||||
|
||||
The remaining arguments are only meaningful if method is "qdwh".
|
||||
eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
|
||||
maxiter: Iterations will terminate after this many steps even if the
|
||||
above is unsatisfied.
|
||||
Returns:
|
||||
unitary: The unitary factor (m x n).
|
||||
posdef: The positive-semidefinite factor. Either (n, n) or (m, m)
|
||||
depending on whether side is "right" or "left", respectively.
|
||||
info: Stores convergence information.
|
||||
if method is "svd": None
|
||||
if method is "qdwh": j_qr: Number of QR iterations.
|
||||
j_chol: Number of Cholesky iterations.
|
||||
errs: Convergence history.
|
||||
"""
|
||||
return _polar(a, side, method, eps, maxiter)
|
||||
|
||||
|
||||
@jax.partial(jax.jit, static_argnums=(1, 2, 4))
|
||||
def _polar(a, side, method, eps, maxiter):
|
||||
if side not in ("left", "right"):
|
||||
raise ValueError(f"side={side} was invalid.")
|
||||
|
||||
unitary, info = _polar_unitary(a, method, eps, maxiter)
|
||||
if side == "right":
|
||||
posdef = _dot(unitary.conj().T, a)
|
||||
else:
|
||||
posdef = _dot(a, unitary.conj().T)
|
||||
posdef = 0.5 * (posdef + posdef.conj().T)
|
||||
return unitary, posdef, info
|
||||
|
||||
|
||||
def polar_unitary(a, method="qdwh", eps=None, maxiter=50):
|
||||
""" Computes the unitary factor u in the polar decomposition `a = u p`
|
||||
(or `a = p u`).
|
||||
"""
|
||||
return _polar_unitary(a, method, eps, maxiter)
|
||||
|
||||
|
||||
@jax.partial(jax.jit, static_argnums=(1, 3))
|
||||
def _polar_unitary(a, method, eps, maxiter):
|
||||
if method not in ("svd", "qdwh"):
|
||||
raise ValueError(f"method={method} is unsupported.")
|
||||
|
||||
if method == "svd":
|
||||
u_svd, _, vh_svd = jnp.linalg.svd(a, full_matrices=False)
|
||||
unitary = _dot(u_svd, vh_svd)
|
||||
info = None
|
||||
elif method == "qdwh":
|
||||
unitary, j_qr, j_chol, errs = _qdwh(a, eps, maxiter)
|
||||
info = (j_qr, j_chol, errs)
|
||||
else:
|
||||
raise ValueError("How did we get here?")
|
||||
return unitary, info
|
||||
|
||||
|
||||
@jax.partial(jax.jit, static_argnums=(2,))
|
||||
def _qdwh(matrix, eps, maxiter):
|
||||
""" Computes the unitary factor in the polar decomposition of A using
|
||||
the QDWH method. QDWH implements a 3rd order Pade approximation to the
|
||||
matrix sign function,
|
||||
|
||||
X' = X * (aI + b X^H X)(I + c X^H X)^-1, X0 = A / ||A||_2. (1)
|
||||
|
||||
The coefficients a, b, and c are chosen dynamically based on an evolving
|
||||
estimate of the matrix condition number. Specifically,
|
||||
|
||||
a = h(l), b = g(a), c = a + b - 1, h(x) = x g(x^2), g(x) = a + bx / (1 + cx)
|
||||
|
||||
where l is initially a lower bound on the smallest singular value of X0,
|
||||
and subsequently evolves according to l' = l (a + bl^2) / (1 + c l^2).
|
||||
|
||||
For poorly conditioned matrices
|
||||
(c > 100) the iteration (1) is rewritten in QR form,
|
||||
|
||||
X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X] (2)
|
||||
[Q2] [I ].
|
||||
|
||||
For well-conditioned matrices it is instead formulated using cheaper
|
||||
Cholesky iterations,
|
||||
|
||||
X' = (b / c) X + (a - b/c) (X W^-1) W^-H, W = chol(I + c X^H X). (3)
|
||||
|
||||
The QR iterations rapidly improve the condition number, and typically
|
||||
only 1 or 2 are required. A maximum of 6 iterations total are required
|
||||
for backwards stability to double precision.
|
||||
|
||||
Args:
|
||||
matrix: The m x n input matrix.
|
||||
eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
|
||||
maxiter: Iterations will terminate after this many steps even if the
|
||||
above is unsatisfied.
|
||||
Returns:
|
||||
matrix: The unitary factor (m x n).
|
||||
jq: The number of QR iterations (1).
|
||||
jc: The number of Cholesky iterations (2).
|
||||
errs: Convergence history.
|
||||
"""
|
||||
n_rows, n_cols = matrix.shape
|
||||
fat = n_cols > n_rows
|
||||
if fat:
|
||||
matrix = matrix.T
|
||||
matrix, q_factor, l0 = _initialize_qdwh(matrix)
|
||||
|
||||
if eps is None:
|
||||
eps = jnp.finfo(matrix.dtype).eps
|
||||
tol_lk = 5 * eps # stop when lk differs from 1 by less
|
||||
tol_delta = jnp.cbrt(tol_lk) # stop when the iterates change by less
|
||||
coefs = _qdwh_coefs(l0)
|
||||
errs = jnp.zeros(maxiter, dtype=matrix.real.dtype)
|
||||
matrix, j_qr, coefs, errs = _qdwh_qr(
|
||||
matrix, coefs, errs, tol_lk, tol_delta, maxiter)
|
||||
matrix, j_chol, errs = _qdwh_cholesky(
|
||||
matrix, coefs, errs, tol_lk, tol_delta, j_qr, maxiter)
|
||||
matrix = _dot(q_factor, matrix)
|
||||
|
||||
if fat:
|
||||
matrix = matrix.T
|
||||
return matrix, j_qr, j_chol, errs
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _initialize_qdwh(matrix):
|
||||
""" Does preparatory computations for QDWH:
|
||||
1. Computes an initial QR factorization of the input A. The iterations
|
||||
will be on the triangular factor R, whose condition is more easily
|
||||
estimated, and which is square even when A is rectangular.
|
||||
2. Computes R -> R / ||R||_F. Now 1 is used to upper-bound ||R||_2.
|
||||
3. Computes R^-1 by solving R R^-1 = I.
|
||||
4. Uses sqrt(N) * ||R^-1||_1 as a lower bound for ||R^-2||.
|
||||
1 / sqrt(N) * ||R^-1||_1 is then used as the initial l_0. It should be clear
|
||||
there is room for improvement here.
|
||||
|
||||
Returns:
|
||||
X = R / ||R||_F;
|
||||
Q from A -> Q @ R;
|
||||
l0, the initial estimate for the QDWH coefficients.
|
||||
"""
|
||||
q_factor, r_factor = jnp.linalg.qr(matrix, mode="reduced")
|
||||
alpha = jnp.linalg.norm(r_factor)
|
||||
r_factor /= alpha
|
||||
eye = jnp.eye(*r_factor.shape, dtype=r_factor.dtype)
|
||||
r_inv = jsp.linalg.solve_triangular(r_factor, eye, overwrite_b=True)
|
||||
one_norm_inv = jnp.linalg.norm(r_inv, ord=1)
|
||||
l0 = 1 / (jnp.sqrt(matrix.shape[1]) * one_norm_inv)
|
||||
eps = jnp.finfo(r_factor.dtype).eps
|
||||
l0 = jnp.array(l0, dtype=r_factor.real.dtype)
|
||||
l0 = jnp.where(l0 < eps, x=eps, y=l0)
|
||||
l0 = jnp.where(l0 > 1.0, x=1.0, y=l0)
|
||||
return r_factor, q_factor, l0
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _qdwh_coefs(lk):
|
||||
""" Computes a, b, c, l for the QDWH iterations.
|
||||
The input lk must be in (0, 1]; lk=1 is a fixed point.
|
||||
Some facts about the coefficients:
|
||||
-for lk = 1 we have a=3, b=1, c=3, lk_new = 1.
|
||||
-The float64 vs float32 computation of each coef appears to differ
|
||||
only by noise on the order of 1E-9 to 1E-7 for all values of lk.
|
||||
There is no apparent secular change in the (relative) error.
|
||||
-All coefs change roughly as power laws; over e.g. [1E-14, 1]:
|
||||
- a decreases from 5.43E9 to 3.
|
||||
- b decreases from 7.37E18 to 1.
|
||||
- c decreases from 7.37E18 to 3, only diverging from b near lk=1.
|
||||
- lk increases from 5.45E-5 to 1.
|
||||
|
||||
lk is an estimate of the scaled matrix's smallest singular value
|
||||
"""
|
||||
lk = jnp.where(lk > 1.0, x=1.0, y=lk)
|
||||
d = (4. * (1. - lk**2) / (lk**4))**(1 / 3)
|
||||
f = 8. * (2. - lk**2) / (lk**2 * (1. + d)**(1 / 2))
|
||||
a = (1. + d)**(1 / 2) + 0.5 * (8. - 4. * d + f)**0.5
|
||||
b = (a - 1.)**2 / 4
|
||||
c = a + b - 1.
|
||||
lk = lk * (a + b * lk**2) / (1 + c * lk**2)
|
||||
return a, b, c, lk
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _unconverged(lk, j, maxiter, err, tol_delta, tol_lk):
|
||||
changing = err > tol_delta
|
||||
far_from_end = jnp.abs(1 - lk) > tol_lk
|
||||
unconverged = jnp.logical_or(changing, far_from_end)
|
||||
iterating = j < maxiter
|
||||
return jnp.logical_and(iterating, unconverged)[0]
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _qdwh_qr(matrix, coefs, errs, tol_lk, tol_delta, maxiter):
|
||||
""" Applies the QDWH iteration formulated as
|
||||
|
||||
X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X]
|
||||
[Q2] [I ]
|
||||
|
||||
to X until either c < 100, ||X' - X|| < eps||X'||,
|
||||
or the iteration count exceeds maxiter.
|
||||
"""
|
||||
n_rows, n_cols = matrix.shape
|
||||
eye = jnp.eye(n_cols, dtype=matrix.dtype)
|
||||
|
||||
def _do_qr(args):
|
||||
_, j, coefs, _, err = args
|
||||
c = coefs[2]
|
||||
lk = coefs[-1]
|
||||
unconverged = _unconverged(lk, j, maxiter, err, tol_delta, tol_lk)
|
||||
ill_conditioned = c > 100.
|
||||
return jnp.logical_and(ill_conditioned, unconverged)
|
||||
|
||||
def _qr_work(args):
|
||||
matrix, j, coefs, errs, _ = args
|
||||
a, b, c, lk = coefs
|
||||
csqrt = jnp.sqrt(c)
|
||||
matrixI = jnp.vstack((csqrt * matrix, eye))
|
||||
# Note: it should be possible to compute the QR of csqrt * matrix
|
||||
# and build the concatenation with I at O(N).
|
||||
Q, _ = jnp.linalg.qr(matrixI, mode="reduced")
|
||||
Q1 = Q[:n_rows, :]
|
||||
Q2 = Q[n_rows:, :]
|
||||
coef = (1 / csqrt) * (a - (b / c))
|
||||
new_matrix = (b / c) * matrix + coef * _dot(Q1, Q2.T.conj())
|
||||
err = jnp.linalg.norm(matrix - new_matrix)
|
||||
err = jnp.full(1, err).astype(errs[0].dtype)
|
||||
errs = errs.at[j].set(err)
|
||||
coefs = _qdwh_coefs(lk)
|
||||
return new_matrix, j + 1, coefs, errs, err
|
||||
|
||||
j = jnp.zeros(1, dtype=jnp.int32)
|
||||
err = jnp.full(1, 2 * tol_delta).astype(matrix.real.dtype)
|
||||
matrix, j, coefs, errs, _ = jax.lax.while_loop(
|
||||
_do_qr, _qr_work, (matrix, j, coefs, errs, err))
|
||||
return matrix, j, coefs, errs
|
||||
|
||||
|
||||
@jax.jit
|
||||
def _qdwh_cholesky(matrix, coefs, errs, tol_delta, tol_lk, j0, maxiter):
|
||||
""" Applies the QDWH iteration formulated as
|
||||
|
||||
matrix' = (b / c) matrix + (a - b/c) B,
|
||||
B = (matrix W^-1) W^-H, W = chol(I + c matrix^H matrix).
|
||||
|
||||
to matrix until either ||matrix' - matrix|| < eps * ||matrix'||,
|
||||
or the iteration count exceeds maxiter.
|
||||
"""
|
||||
|
||||
def _do_cholesky(args):
|
||||
_, j, coefs, errs = args
|
||||
lk = coefs[-1]
|
||||
return _unconverged(lk, j, maxiter, errs[j - 1], tol_delta, tol_lk)
|
||||
|
||||
def _cholesky_work(args):
|
||||
matrix, j, coefs, errs = args
|
||||
a, b, c, lk = coefs
|
||||
Z = c * _dot(matrix.T.conj(), matrix)
|
||||
Z = _add_to_diagonal(Z, 1.)
|
||||
W = jsp.linalg.cholesky(Z)
|
||||
B = jsp.linalg.solve_triangular(W.T, matrix.T, lower=True).conj()
|
||||
B = jsp.linalg.solve_triangular(W, B).conj().T
|
||||
new_matrix = (b / c) * matrix + (a - b / c) * B
|
||||
# possible instability if a ~ b / c
|
||||
err = jnp.linalg.norm(new_matrix - matrix).astype(errs[0].dtype)
|
||||
errs = errs.at[j].set(err)
|
||||
coefs = _qdwh_coefs(lk)
|
||||
return new_matrix, j + 1, coefs, errs
|
||||
|
||||
carry = (matrix, j0, coefs, errs)
|
||||
matrix, j_total, coefs, errs = jax.lax.while_loop(
|
||||
_do_cholesky, _cholesky_work, carry)
|
||||
return matrix, j_total - j0, errs
|
237
jax/_src/scipy/eigh.py
Normal file
237
jax/_src/scipy/eigh.py
Normal file
@ -0,0 +1,237 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
|
||||
"""Serial algorithm for eigh."""
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy as jsp
|
||||
from jax import lax
|
||||
|
||||
|
||||
def _similarity_transform(
|
||||
matrix_in, matrix_out, precision=lax.Precision.HIGHEST):
|
||||
""" Returns matrix_out.conj().T @ matrix_in @ matrix_out, done in
|
||||
order from left to right.
|
||||
"""
|
||||
out = jnp.dot(matrix_out.conj().T, matrix_in, precision=precision)
|
||||
return jnp.dot(out, matrix_out, precision=precision)
|
||||
|
||||
|
||||
def _projector_subspace(P, H, rank, maxiter=2):
|
||||
""" Decomposes the `n x n` rank `rank` Hermitian projector `P` into
|
||||
an `n x rank` isometry `Vm` such that `P = Vm @ Vm.conj().T` and
|
||||
an `n x (n - rank)` isometry `Vm` such that -(I - P) = Vp @ Vp.conj().T`.
|
||||
|
||||
The subspaces are computed using the naiive QR eigendecomposition
|
||||
algorithm, which converges very quickly due to the sharp separation
|
||||
between the relevant eigenvalues of the projector.
|
||||
|
||||
Args:
|
||||
P: A rank-`rank` Hermitian projector into the space of `H`'s
|
||||
first `rank` eigenpairs.
|
||||
H: The aforementioned Hermitian matrix, which is used to track
|
||||
convergence.
|
||||
rank: Rank of `P`.
|
||||
maxiter: Maximum number of iterations.
|
||||
Returns:
|
||||
Vm, Vp: Isometries into the eigenspaces described in the docstring.
|
||||
"""
|
||||
# Choose an initial guess: the `rank` largest-norm columns of P.
|
||||
column_norms = jnp.linalg.norm(P, axis=1)
|
||||
sort_idxs = jnp.argsort(column_norms)
|
||||
X = P[:, sort_idxs]
|
||||
X = X[:, :rank]
|
||||
|
||||
H_norm = jnp.linalg.norm(H)
|
||||
thresh = 10 * jnp.finfo(X.dtype).eps * H_norm
|
||||
|
||||
# First iteration skips the matmul.
|
||||
def body_f_after_matmul(X):
|
||||
Q, _ = jnp.linalg.qr(X, mode="complete")
|
||||
V1 = Q[:, :rank]
|
||||
V2 = Q[:, rank:]
|
||||
# TODO: might be able to get away with lower precision here
|
||||
error_matrix = jnp.dot(V2.conj().T, H, precision=lax.Precision.HIGHEST)
|
||||
error_matrix = jnp.dot(error_matrix, V1, precision=lax.Precision.HIGHEST)
|
||||
error = jnp.linalg.norm(error_matrix) / H_norm
|
||||
return V1, V2, error
|
||||
|
||||
def cond_f(args):
|
||||
_, _, j, error = args
|
||||
still_counting = j < maxiter
|
||||
unconverged = error > thresh
|
||||
return jnp.logical_and(still_counting, unconverged)[0]
|
||||
|
||||
def body_f(args):
|
||||
V1, _, j, _ = args
|
||||
X = jnp.dot(P, V1, precision=lax.Precision.HIGHEST)
|
||||
V1, V2, error = body_f_after_matmul(X)
|
||||
return V1, V2, j + 1, error
|
||||
|
||||
V1, V2, error = body_f_after_matmul(X)
|
||||
one = jnp.ones(1, dtype=jnp.int32)
|
||||
V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error))
|
||||
return V1, V2
|
||||
|
||||
|
||||
@jax.partial(jax.jit, static_argnums=(3, 4))
|
||||
def _split_spectrum_jittable(P, H, V0, rank, precision):
|
||||
""" The jittable portion of `split_spectrum`. At this point the sizes of the
|
||||
relavant matrix blocks have been concretized.
|
||||
|
||||
Args:
|
||||
P: Projection matrix.
|
||||
H: Matrix to be projected.
|
||||
V0: Accumulates the isometries into the projected subspaces.
|
||||
rank: Rank of P.
|
||||
precision: The matmul precision.
|
||||
Returns:
|
||||
H1, V1: Projection of H into the column space of P, and the accumulated
|
||||
isometry performing that projection.
|
||||
H2, V2: Projection of H into the null space of P, and the accumulated
|
||||
isometry performing that projection.
|
||||
"""
|
||||
Vm, Vp = _projector_subspace(P, H, rank)
|
||||
Hm = _similarity_transform(H, Vm, precision)
|
||||
Hp = _similarity_transform(H, Vp, precision)
|
||||
if V0 is not None:
|
||||
Vm = jnp.dot(V0, Vm, precision=precision)
|
||||
Vp = jnp.dot(V0, Vp, precision=precision)
|
||||
return Hm, Vm, Hp, Vp
|
||||
|
||||
|
||||
def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST):
|
||||
""" The Hermitian matrix `H` is split into two matrices `Hm`
|
||||
`Hp`, respectively sharing its eigenspaces beneath and above
|
||||
its `split_point`th eigenvalue.
|
||||
|
||||
Returns, in addition, `Vm` and `Vp`, isometries such that
|
||||
`Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
|
||||
returned instead; this allows the overall isometries mapping from
|
||||
an initial input matrix to progressively smaller blocks to be formed.
|
||||
|
||||
Args:
|
||||
H: The Hermitian matrix to split.
|
||||
split_point: The eigenvalue to split along.
|
||||
V0: Matrix of isometries to be updated.
|
||||
precision: TPU matmul precision.
|
||||
Returns:
|
||||
Hm: A Hermitian matrix sharing the eigenvalues of `H` beneath
|
||||
`split_point`.
|
||||
Vm: An isometry from the input space of `V0` to `Hm`.
|
||||
Hp: A Hermitian matrix sharing the eigenvalues of `H` above
|
||||
`split_point`.
|
||||
Vp: An isometry from the input space of `V0` to `Hp`.
|
||||
"""
|
||||
def _fill_diagonal(X, vals):
|
||||
return jax.ops.index_update(X, jnp.diag_indices(X.shape[0]), vals)
|
||||
|
||||
H_shift = _fill_diagonal(H, H.diagonal() - split_point)
|
||||
U, _ = jsp.linalg.polar_unitary(H_shift)
|
||||
P = -0.5 * _fill_diagonal(U, U.diagonal() - 1.)
|
||||
rank = jnp.round(jnp.trace(P)).astype(jnp.int32)
|
||||
rank = int(rank)
|
||||
return _split_spectrum_jittable(P, H, V0, rank, precision)
|
||||
|
||||
|
||||
def _eigh_work(
|
||||
H, V=None, precision=lax.Precision.HIGHEST, termination_size=128):
|
||||
""" The main work loop performing the symmetric eigendecomposition of H.
|
||||
Each step recursively computes a projector into the space of eigenvalues
|
||||
above jnp.mean(jnp.diag(H)). The result of the projections into and out of
|
||||
that space, along with the isometries accomplishing these, are then computed.
|
||||
This is performed recursively until the projections have size 1, and thus
|
||||
store an eigenvalue of the original input; the corresponding isometry is
|
||||
the related eigenvector. The results are then composed.
|
||||
|
||||
This function cannot be Jitted because the internal split_spectrum cannot
|
||||
be.
|
||||
|
||||
Args:
|
||||
H: The Hermitian input.
|
||||
V: Stores the isometries projecting H into its subspaces.
|
||||
precision: The matmul precision.
|
||||
|
||||
Returns:
|
||||
H, V: The result of the projection.
|
||||
"""
|
||||
if H.shape[0] <= termination_size:
|
||||
evals, evecs = jnp.linalg.eigh(H)
|
||||
if V is not None:
|
||||
evecs = jnp.dot(V, evecs, precision=precision)
|
||||
return evals, evecs
|
||||
|
||||
split_point = jnp.median(jnp.diag(H)) # TODO: Improve this?
|
||||
Hm, Vm, Hp, Vp = split_spectrum(H, split_point, V0=V, precision=precision)
|
||||
Hm, Vm = _eigh_work(
|
||||
Hm, V=Vm, precision=precision, termination_size=termination_size)
|
||||
Hp, Vp = _eigh_work(
|
||||
Hp, V=Vp, precision=precision, termination_size=termination_size)
|
||||
|
||||
if Hm.ndim != 1 or Hp.ndim != 1:
|
||||
raise ValueError(f"One of Hm.ndim={Hm.ndim} or Hp.ndim={Hp.ndim} != 1 ",
|
||||
"indicating recursion terminated unexpectedly.")
|
||||
|
||||
evals = jnp.hstack((Hm, Hp))
|
||||
evecs = jnp.hstack((Vm, Vp))
|
||||
return evals, evecs
|
||||
|
||||
|
||||
def eigh(
|
||||
H, precision=lax.Precision.HIGHEST, symmetrize=True, termination_size=128):
|
||||
""" Computes the eigendecomposition of the symmetric/Hermitian matrix H.
|
||||
|
||||
Args:
|
||||
H: The `n x n` Hermitian input.
|
||||
precision: The matmul precision.
|
||||
symmetrize: If True, `0.5 * (H + H.conj().T)` rather than `H` is used.
|
||||
termination_size: Recursion ends once the blocks reach this linear size.
|
||||
Returns:
|
||||
vals: The `n` eigenvalues of `H`, sorted from lowest to higest.
|
||||
vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
|
||||
of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
|
||||
to numerical error.
|
||||
"""
|
||||
nrows, ncols = H.shape
|
||||
if nrows != ncols:
|
||||
raise TypeError(f"Input H of shape {H.shape} must be square.")
|
||||
|
||||
if ncols <= termination_size:
|
||||
return jnp.linalg.eigh(H)
|
||||
|
||||
evals, evecs = _eigh_work(H, precision=precision)
|
||||
sort_idxs = jnp.argsort(evals)
|
||||
evals = evals[sort_idxs]
|
||||
evecs = evecs[:, sort_idxs]
|
||||
return evals, evecs
|
||||
|
||||
|
||||
def svd(A, precision=lax.Precision.HIGHEST):
|
||||
""" Computes an SVD of `A`.
|
||||
|
||||
Args:
|
||||
A: The `m` by `n` input matrix.
|
||||
precision: TPU matmul precision.
|
||||
Returns:
|
||||
U: An `m` by `m` unitary matrix of `A`'s left singular vectors.
|
||||
S: A length-`min(m, n)` vector of `A`'s singular values.
|
||||
V_dag: An `n` by `n` unitary matrix of `A`'s conjugate transposed
|
||||
right singular vectors.
|
||||
"""
|
||||
Up, H, _ = jsp.linalg.polar(A)
|
||||
S, V = eigh(H, precision=precision)
|
||||
U = jnp.dot(Up, V, precision=precision)
|
||||
return U, S, V.conj().T
|
@ -35,3 +35,8 @@ from jax._src.scipy.linalg import (
|
||||
tril,
|
||||
triu,
|
||||
)
|
||||
|
||||
from jax._src.lax.polar import (
|
||||
polar,
|
||||
polar_unitary
|
||||
)
|
||||
|
@ -27,8 +27,11 @@ import scipy.special as osp_special
|
||||
|
||||
from jax._src import api
|
||||
from jax import numpy as jnp
|
||||
from jax import lax
|
||||
from jax import scipy as jsp
|
||||
from jax import test_util as jtu
|
||||
from jax.scipy import special as lsp_special
|
||||
import jax._src.scipy.eigh
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -44,6 +47,50 @@ float_dtypes = jtu.dtypes.floating
|
||||
complex_dtypes = jtu.dtypes.complex
|
||||
int_dtypes = jtu.dtypes.integer
|
||||
|
||||
# Params for the polar tests.
|
||||
polar_shapes = [(16, 12), (12, 16), (128, 128)]
|
||||
n_zero_svs = [0, 4]
|
||||
degeneracies = [0, 4]
|
||||
geometric_spectra = [False, True]
|
||||
max_svs = [0.1, 10.]
|
||||
nonzero_condition_numbers = [0.1, 100000]
|
||||
sides = ["right", "left"]
|
||||
methods = ["qdwh", "svd"]
|
||||
seeds = [1, 10]
|
||||
|
||||
linear_sizes = [16, 128, 256]
|
||||
|
||||
|
||||
def _initialize_polar_test(shape, n_zero_svs, degeneracy, geometric_spectrum,
|
||||
max_sv, nonzero_condition_number, dtype):
|
||||
|
||||
n_rows, n_cols = shape
|
||||
min_dim = min(shape)
|
||||
left_vecs = np.random.randn(n_rows, min_dim).astype(np.float64)
|
||||
left_vecs, _ = np.linalg.qr(left_vecs)
|
||||
right_vecs = np.random.randn(n_cols, min_dim).astype(np.float64)
|
||||
right_vecs, _ = np.linalg.qr(right_vecs)
|
||||
|
||||
min_nonzero_sv = max_sv / nonzero_condition_number
|
||||
num_nonzero_svs = min_dim - n_zero_svs
|
||||
if geometric_spectrum:
|
||||
nonzero_svs = np.geomspace(min_nonzero_sv, max_sv, num=num_nonzero_svs,
|
||||
dtype=np.float64)
|
||||
else:
|
||||
nonzero_svs = np.linspace(min_nonzero_sv, max_sv, num=num_nonzero_svs,
|
||||
dtype=np.float64)
|
||||
half_point = n_zero_svs // 2
|
||||
for i in range(half_point, half_point + degeneracy):
|
||||
nonzero_svs[i] = nonzero_svs[half_point]
|
||||
svs = np.zeros(min(shape), dtype=np.float64)
|
||||
svs[n_zero_svs:] = nonzero_svs
|
||||
svs = svs[::-1]
|
||||
|
||||
result = np.dot(left_vecs * svs, right_vecs.conj().T)
|
||||
result = jnp.array(result).astype(dtype)
|
||||
spectrum = jnp.array(svs).astype(dtype)
|
||||
return result, spectrum
|
||||
|
||||
OpRecord = collections.namedtuple(
|
||||
"OpRecord",
|
||||
["name", "nargs", "dtypes", "rng_factory", "test_autodiff", "nondiff_argnums", "test_name"])
|
||||
@ -99,6 +146,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
def _GetArgsMaker(self, rng, shapes, dtypes):
|
||||
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
|
||||
@ -405,6 +453,151 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name':
|
||||
'_shape={}'
|
||||
'_n_zero_sv={}_degeneracy={}_geometric_spectrum={}'
|
||||
'_max_sv={}_method={}_side={}'
|
||||
'_nonzero_condition_number={}_seed={}'.format(
|
||||
jtu.format_shape_dtype_string(
|
||||
shape, jnp.dtype(dtype).name).replace(" ", ""),
|
||||
n_zero_sv, degeneracy, geometric_spectrum, max_sv,
|
||||
method, side, nonzero_condition_number, seed
|
||||
),
|
||||
'n_zero_sv': n_zero_sv, 'degeneracy': degeneracy,
|
||||
'geometric_spectrum': geometric_spectrum,
|
||||
'max_sv': max_sv, 'shape': shape, 'method': method,
|
||||
'side': side, 'nonzero_condition_number': nonzero_condition_number,
|
||||
'dtype': dtype, 'seed': seed}
|
||||
for n_zero_sv in n_zero_svs
|
||||
for degeneracy in degeneracies
|
||||
for geometric_spectrum in geometric_spectra
|
||||
for max_sv in max_svs
|
||||
for shape in polar_shapes
|
||||
for method in methods
|
||||
for side in sides
|
||||
for nonzero_condition_number in nonzero_condition_numbers
|
||||
for dtype in jtu.dtypes.floating
|
||||
for seed in seeds))
|
||||
def testPolar(
|
||||
self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method,
|
||||
side, nonzero_condition_number, dtype, seed):
|
||||
""" Tests jax.scipy.linalg.polar."""
|
||||
if jtu.device_under_test() != "cpu":
|
||||
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
|
||||
raise unittest.SkipTest("Skip half precision off CPU.")
|
||||
if method == "svd":
|
||||
raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.")
|
||||
|
||||
np.random.seed(seed)
|
||||
matrix, _ = _initialize_polar_test(
|
||||
shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
|
||||
nonzero_condition_number, dtype)
|
||||
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
|
||||
self.assertRaises(
|
||||
NotImplementedError, jsp.linalg.polar, matrix, method=method,
|
||||
side=side)
|
||||
return
|
||||
|
||||
unitary, posdef, info = jsp.linalg.polar(matrix, method=method, side=side)
|
||||
if shape[0] >= shape[1]:
|
||||
should_be_eye = np.matmul(unitary.conj().T, unitary)
|
||||
else:
|
||||
should_be_eye = np.matmul(unitary, unitary.conj().T)
|
||||
tol = 10 * jnp.finfo(matrix.dtype).eps
|
||||
eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
|
||||
with self.subTest('Test unitarity.'):
|
||||
self.assertAllClose(
|
||||
eye_mat, should_be_eye, atol=tol * min(shape))
|
||||
|
||||
with self.subTest('Test Hermiticity.'):
|
||||
self.assertAllClose(
|
||||
posdef, posdef.conj().T, atol=tol * jnp.linalg.norm(posdef))
|
||||
|
||||
ev, _ = np.linalg.eigh(posdef)
|
||||
ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)]
|
||||
negative_ev = jnp.sum(ev < 0.)
|
||||
with self.subTest('Test positive definiteness.'):
|
||||
assert negative_ev == 0.
|
||||
|
||||
if side == "right":
|
||||
recon = jnp.matmul(unitary, posdef, precision=lax.Precision.HIGHEST)
|
||||
elif side == "left":
|
||||
recon = jnp.matmul(posdef, unitary, precision=lax.Precision.HIGHEST)
|
||||
with self.subTest('Test reconstruction.'):
|
||||
self.assertAllClose(
|
||||
matrix, recon, atol=tol * jnp.linalg.norm(matrix))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name':
|
||||
'_linear_size_={}_seed={}_dtype={}'.format(
|
||||
linear_size, seed, jnp.dtype(dtype).name
|
||||
),
|
||||
'linear_size': linear_size, 'seed': seed, 'dtype': dtype}
|
||||
for linear_size in linear_sizes
|
||||
for seed in seeds
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def test_spectral_dac_eigh(self, linear_size, seed, dtype):
|
||||
if jtu.device_under_test != "cpu":
|
||||
raise unittest.SkipTest("Skip eigh off CPU for now.")
|
||||
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
|
||||
if jtu.device_under_test() != "cpu":
|
||||
raise unittest.SkipTest("Skip half precision off CPU.")
|
||||
|
||||
np.random.seed(seed)
|
||||
H = np.random.randn(linear_size, linear_size)
|
||||
H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
|
||||
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
|
||||
self.assertRaises(
|
||||
NotImplementedError, jax._src.scipy.eigh.eigh, H)
|
||||
return
|
||||
evs, V = jax._src.scipy.eigh.eigh(H)
|
||||
ev_exp, eV_exp = jnp.linalg.eigh(H)
|
||||
HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
|
||||
vV = evs * V
|
||||
eps = jnp.finfo(H.dtype).eps
|
||||
atol = jnp.linalg.norm(H) * eps
|
||||
self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
|
||||
self.assertAllClose(HV, vV, atol=30 * atol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name':
|
||||
'_linear_size_={}_seed={}_dtype={}'.format(
|
||||
linear_size, seed, jnp.dtype(dtype).name
|
||||
),
|
||||
'linear_size': linear_size, 'seed': seed, 'dtype': dtype}
|
||||
for linear_size in linear_sizes
|
||||
for seed in seeds
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def test_spectral_dac_svd(self, linear_size, seed, dtype):
|
||||
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
|
||||
if jtu.device_under_test() != "cpu":
|
||||
raise unittest.SkipTest("Skip half precision off CPU.")
|
||||
|
||||
np.random.seed(seed)
|
||||
A = np.random.randn(linear_size, linear_size).astype(dtype)
|
||||
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
|
||||
self.assertRaises(
|
||||
NotImplementedError, jax._src.scipy.eigh.svd, A)
|
||||
return
|
||||
S_expected = np.linalg.svd(A, compute_uv=False)
|
||||
U, S, V = jax._src.scipy.eigh.svd(A)
|
||||
recon = jnp.dot((U * S), V, precision=lax.Precision.HIGHEST)
|
||||
eps = jnp.finfo(dtype).eps
|
||||
eps = eps * jnp.linalg.norm(A) * 10
|
||||
self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps)
|
||||
self.assertAllClose(A, recon, atol=eps)
|
||||
|
||||
# U is unitary.
|
||||
u_unitary_delta = jnp.dot(U.conj().T, U, precision=lax.Precision.HIGHEST)
|
||||
u_eye = jnp.eye(u_unitary_delta.shape[0], dtype=dtype)
|
||||
self.assertAllClose(u_unitary_delta, u_eye, atol=eps)
|
||||
|
||||
# V is unitary.
|
||||
v_unitary_delta = jnp.dot(V.conj().T, V, precision=lax.Precision.HIGHEST)
|
||||
v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype)
|
||||
self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user