mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
239 lines
8.4 KiB
Python
239 lines
8.4 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License
|
|
|
|
|
|
"""Serial algorithm for eigh."""
|
|
|
|
import functools
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import jax.scipy as jsp
|
|
from jax import lax
|
|
|
|
|
|
def _similarity_transform(
|
|
matrix_in, matrix_out, precision=lax.Precision.HIGHEST):
|
|
""" Returns matrix_out.conj().T @ matrix_in @ matrix_out, done in
|
|
order from left to right.
|
|
"""
|
|
out = jnp.dot(matrix_out.conj().T, matrix_in, precision=precision)
|
|
return jnp.dot(out, matrix_out, precision=precision)
|
|
|
|
|
|
def _projector_subspace(P, H, rank, maxiter=2):
|
|
""" Decomposes the `n x n` rank `rank` Hermitian projector `P` into
|
|
an `n x rank` isometry `Vm` such that `P = Vm @ Vm.conj().T` and
|
|
an `n x (n - rank)` isometry `Vm` such that -(I - P) = Vp @ Vp.conj().T`.
|
|
|
|
The subspaces are computed using the naiive QR eigendecomposition
|
|
algorithm, which converges very quickly due to the sharp separation
|
|
between the relevant eigenvalues of the projector.
|
|
|
|
Args:
|
|
P: A rank-`rank` Hermitian projector into the space of `H`'s
|
|
first `rank` eigenpairs.
|
|
H: The aforementioned Hermitian matrix, which is used to track
|
|
convergence.
|
|
rank: Rank of `P`.
|
|
maxiter: Maximum number of iterations.
|
|
Returns:
|
|
Vm, Vp: Isometries into the eigenspaces described in the docstring.
|
|
"""
|
|
# Choose an initial guess: the `rank` largest-norm columns of P.
|
|
column_norms = jnp.linalg.norm(P, axis=1)
|
|
sort_idxs = jnp.argsort(column_norms)
|
|
X = P[:, sort_idxs]
|
|
X = X[:, :rank]
|
|
|
|
H_norm = jnp.linalg.norm(H)
|
|
thresh = 10 * jnp.finfo(X.dtype).eps * H_norm
|
|
|
|
# First iteration skips the matmul.
|
|
def body_f_after_matmul(X):
|
|
Q, _ = jnp.linalg.qr(X, mode="complete")
|
|
V1 = Q[:, :rank]
|
|
V2 = Q[:, rank:]
|
|
# TODO: might be able to get away with lower precision here
|
|
error_matrix = jnp.dot(V2.conj().T, H, precision=lax.Precision.HIGHEST)
|
|
error_matrix = jnp.dot(error_matrix, V1, precision=lax.Precision.HIGHEST)
|
|
error = jnp.linalg.norm(error_matrix) / H_norm
|
|
return V1, V2, error
|
|
|
|
def cond_f(args):
|
|
_, _, j, error = args
|
|
still_counting = j < maxiter
|
|
unconverged = error > thresh
|
|
return jnp.logical_and(still_counting, unconverged)[0]
|
|
|
|
def body_f(args):
|
|
V1, _, j, _ = args
|
|
X = jnp.dot(P, V1, precision=lax.Precision.HIGHEST)
|
|
V1, V2, error = body_f_after_matmul(X)
|
|
return V1, V2, j + 1, error
|
|
|
|
V1, V2, error = body_f_after_matmul(X)
|
|
one = jnp.ones(1, dtype=jnp.int32)
|
|
V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error))
|
|
return V1, V2
|
|
|
|
|
|
@functools.partial(jax.jit, static_argnums=(3, 4))
|
|
def _split_spectrum_jittable(P, H, V0, rank, precision):
|
|
""" The jittable portion of `split_spectrum`. At this point the sizes of the
|
|
relevant matrix blocks have been concretized.
|
|
|
|
Args:
|
|
P: Projection matrix.
|
|
H: Matrix to be projected.
|
|
V0: Accumulates the isometries into the projected subspaces.
|
|
rank: Rank of P.
|
|
precision: The matmul precision.
|
|
Returns:
|
|
H1, V1: Projection of H into the column space of P, and the accumulated
|
|
isometry performing that projection.
|
|
H2, V2: Projection of H into the null space of P, and the accumulated
|
|
isometry performing that projection.
|
|
"""
|
|
Vm, Vp = _projector_subspace(P, H, rank)
|
|
Hm = _similarity_transform(H, Vm, precision)
|
|
Hp = _similarity_transform(H, Vp, precision)
|
|
if V0 is not None:
|
|
Vm = jnp.dot(V0, Vm, precision=precision)
|
|
Vp = jnp.dot(V0, Vp, precision=precision)
|
|
return Hm, Vm, Hp, Vp
|
|
|
|
|
|
def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST):
|
|
""" The Hermitian matrix `H` is split into two matrices `Hm`
|
|
`Hp`, respectively sharing its eigenspaces beneath and above
|
|
its `split_point`th eigenvalue.
|
|
|
|
Returns, in addition, `Vm` and `Vp`, isometries such that
|
|
`Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
|
|
returned instead; this allows the overall isometries mapping from
|
|
an initial input matrix to progressively smaller blocks to be formed.
|
|
|
|
Args:
|
|
H: The Hermitian matrix to split.
|
|
split_point: The eigenvalue to split along.
|
|
V0: Matrix of isometries to be updated.
|
|
precision: TPU matmul precision.
|
|
Returns:
|
|
Hm: A Hermitian matrix sharing the eigenvalues of `H` beneath
|
|
`split_point`.
|
|
Vm: An isometry from the input space of `V0` to `Hm`.
|
|
Hp: A Hermitian matrix sharing the eigenvalues of `H` above
|
|
`split_point`.
|
|
Vp: An isometry from the input space of `V0` to `Hp`.
|
|
"""
|
|
def _fill_diagonal(X, vals):
|
|
return X.at[jnp.diag_indices(X.shape[0])].set(vals)
|
|
|
|
H_shift = _fill_diagonal(H, H.diagonal() - split_point)
|
|
U, _ = jsp.linalg.polar_unitary(H_shift)
|
|
P = -0.5 * _fill_diagonal(U, U.diagonal() - 1.)
|
|
rank = jnp.round(jnp.trace(P)).astype(jnp.int32)
|
|
rank = int(rank)
|
|
return _split_spectrum_jittable(P, H, V0, rank, precision)
|
|
|
|
|
|
def _eigh_work(
|
|
H, V=None, precision=lax.Precision.HIGHEST, termination_size=128):
|
|
""" The main work loop performing the symmetric eigendecomposition of H.
|
|
Each step recursively computes a projector into the space of eigenvalues
|
|
above jnp.mean(jnp.diag(H)). The result of the projections into and out of
|
|
that space, along with the isometries accomplishing these, are then computed.
|
|
This is performed recursively until the projections have size 1, and thus
|
|
store an eigenvalue of the original input; the corresponding isometry is
|
|
the related eigenvector. The results are then composed.
|
|
|
|
This function cannot be Jitted because the internal split_spectrum cannot
|
|
be.
|
|
|
|
Args:
|
|
H: The Hermitian input.
|
|
V: Stores the isometries projecting H into its subspaces.
|
|
precision: The matmul precision.
|
|
|
|
Returns:
|
|
H, V: The result of the projection.
|
|
"""
|
|
if H.shape[0] <= termination_size:
|
|
evals, evecs = jnp.linalg.eigh(H)
|
|
if V is not None:
|
|
evecs = jnp.dot(V, evecs, precision=precision)
|
|
return evals, evecs
|
|
|
|
split_point = jnp.median(jnp.diag(H)) # TODO: Improve this?
|
|
Hm, Vm, Hp, Vp = split_spectrum(H, split_point, V0=V, precision=precision)
|
|
Hm, Vm = _eigh_work(
|
|
Hm, V=Vm, precision=precision, termination_size=termination_size)
|
|
Hp, Vp = _eigh_work(
|
|
Hp, V=Vp, precision=precision, termination_size=termination_size)
|
|
|
|
if Hm.ndim != 1 or Hp.ndim != 1:
|
|
raise ValueError(f"One of Hm.ndim={Hm.ndim} or Hp.ndim={Hp.ndim} != 1 ",
|
|
"indicating recursion terminated unexpectedly.")
|
|
|
|
evals = jnp.hstack((Hm, Hp))
|
|
evecs = jnp.hstack((Vm, Vp))
|
|
return evals, evecs
|
|
|
|
|
|
def eigh(
|
|
H, precision=lax.Precision.HIGHEST, symmetrize=True, termination_size=128):
|
|
""" Computes the eigendecomposition of the symmetric/Hermitian matrix H.
|
|
|
|
Args:
|
|
H: The `n x n` Hermitian input.
|
|
precision: The matmul precision.
|
|
symmetrize: If True, `0.5 * (H + H.conj().T)` rather than `H` is used.
|
|
termination_size: Recursion ends once the blocks reach this linear size.
|
|
Returns:
|
|
vals: The `n` eigenvalues of `H`, sorted from lowest to highest.
|
|
vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
|
|
of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
|
|
to numerical error.
|
|
"""
|
|
nrows, ncols = H.shape
|
|
if nrows != ncols:
|
|
raise TypeError(f"Input H of shape {H.shape} must be square.")
|
|
|
|
if ncols <= termination_size:
|
|
return jnp.linalg.eigh(H)
|
|
|
|
evals, evecs = _eigh_work(H, precision=precision)
|
|
sort_idxs = jnp.argsort(evals)
|
|
evals = evals[sort_idxs]
|
|
evecs = evecs[:, sort_idxs]
|
|
return evals, evecs
|
|
|
|
|
|
def svd(A, precision=lax.Precision.HIGHEST):
|
|
""" Computes an SVD of `A`.
|
|
|
|
Args:
|
|
A: The `m` by `n` input matrix.
|
|
precision: TPU matmul precision.
|
|
Returns:
|
|
U: An `m` by `m` unitary matrix of `A`'s left singular vectors.
|
|
S: A length-`min(m, n)` vector of `A`'s singular values.
|
|
V_dag: An `n` by `n` unitary matrix of `A`'s conjugate transposed
|
|
right singular vectors.
|
|
"""
|
|
Up, H = jsp.linalg.polar(A)
|
|
S, V = eigh(H, precision=precision)
|
|
U = jnp.dot(Up, V, precision=precision)
|
|
return U, S, V.conj().T
|