rocm_jax/jax/_src/lax/svd.py
rajasekharporeddy b93da3873b Fix Typos
2024-06-17 13:55:46 +05:30

244 lines
8.6 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""A JIT-compatible library for QDWH-based singular value decomposition.
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
iteration implemented through QR decompositions is numerically stable and does
not require solving a linear system involving the iteration matrix or
computing its inversion. This is desirable for multicore and heterogeneous
computing systems.
References:
Nakatsukasa, Yuji, and Nicholas J. Higham.
"Stable and efficient spectral divide and conquer algorithms for the symmetric
eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing 35,
no. 3 (2013): A1325-A1349.
https://epubs.siam.org/doi/abs/10.1137/120876605
Nakatsukasa, Yuji, Zhaojun Bai, and François Gygi.
"Optimizing Halley's iteration for computing the matrix polar decomposition."
SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720.
https://epubs.siam.org/doi/abs/10.1137/090774999
"""
from __future__ import annotations
from collections.abc import Sequence
import functools
import operator
from typing import Any
import jax
from jax import lax
from jax._src import core
import jax.numpy as jnp
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
def _svd_tall_and_square_input(
a: Any,
hermitian: bool,
compute_uv: bool,
max_iterations: int,
subset_by_index: tuple[int, int] | None = None,
) -> Any | Sequence[Any]:
"""Singular value decomposition for m x n matrix and m >= n.
Args:
a: A matrix of shape `m x n` with `m >= n`.
hermitian: True if `a` is Hermitian.
compute_uv: Whether to also compute `u` and `v` in addition to `s`.
max_iterations: The predefined maximum number of iterations of QDWH.
Returns:
A 3-tuple (`u`, `s`, `v`), where `u` is a unitary matrix of shape `m x n`,
`s` is vector of length `n` containing the singular values in the descending
order, `v` is a unitary matrix of shape `n x n`, and
`a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned.
"""
u_p, h, _, _ = lax.linalg.qdwh(
a, is_hermitian=hermitian, max_iterations=max_iterations
)
# TODO: Uses `eigvals_only=True` if `compute_uv=False`.
v, s = lax.linalg.eigh(
h, subset_by_index=subset_by_index, sort_eigenvalues=False
)
# Singular values are non-negative by definition. But eigh could return small
# negative values, so we clamp them to zero.
s = jnp.maximum(s, 0.0)
# Sort or reorder singular values to be in descending order.
sort_idx = jnp.argsort(s, descending=True)
s_out = s[sort_idx]
if not compute_uv:
return s_out
# Reorders eigenvectors.
v_out = v[:, sort_idx]
u_out = u_p @ v_out
# Makes correction if computed `u` from qdwh is not unitary.
# Section 5.5 of Nakatsukasa, Yuji, and Nicholas J. Higham. "Stable and
# efficient spectral divide and conquer algorithms for the symmetric
# eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing
# 35, no. 3 (2013): A1325-A1349.
def correct_rank_deficiency(u_out):
u_out, r = lax.linalg.qr(u_out, full_matrices=False)
u_out = u_out @ jnp.diag(jnp.where(jnp.diag(r) >= 0, 1, -1))
return u_out
eps = float(jnp.finfo(a.dtype).eps)
do_correction = s_out[-1] <= a.shape[1] * eps * s_out[0]
cond_f = lambda args: args[1]
body_f = lambda args: (correct_rank_deficiency(args[0]), False)
u_out, _ = lax.while_loop(cond_f, body_f, (u_out, do_correction))
return (u_out, s_out, v_out)
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5))
def svd(
a: Any,
full_matrices: bool,
compute_uv: bool = True,
hermitian: bool = False,
max_iterations: int = 10,
subset_by_index: tuple[int, int] | None = None,
) -> Any | Sequence[Any]:
"""Singular value decomposition.
Args:
a: A matrix of shape `m x n`.
full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`,
respectively. If False, the shapes are `m x k` and `k x n`, respectively,
where `k = min(m, n)`.
compute_uv: Whether to also compute `u` and `v` in addition to `s`.
hermitian: True if `a` is Hermitian.
max_iterations: The predefined maximum number of iterations of QDWH.
subset_by_index: Optional 2-tuple [start, end] indicating the range of
indices of singular components to compute. For example, if
``subset_by_index`` = [0,2], then ``svd`` computes the two largest
singular values (and their singular vectors if `compute_uv` is true.
Returns:
A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices,
`s` is vector of length `k` containing the singular values in the
non-increasing order, and `k = min(m, n)`. The shapes of `u` and `vh`
depend on the value of `full_matrices`. For `compute_uv=False`,
only `s` is returned.
"""
full_matrices = core.concrete_or_error(
bool, full_matrices, 'The `full_matrices` argument must be statically '
'specified to use `svd` within JAX transformations.')
compute_uv = core.concrete_or_error(
bool, compute_uv, 'The `compute_uv` argument must be statically '
'specified to use `svd` within JAX transformations.')
hermitian = core.concrete_or_error(
bool,
hermitian,
'The `hermitian` argument must be statically '
'specified to use `svd` within JAX transformations.',
)
max_iterations = core.concrete_or_error(
int,
max_iterations,
'The `max_iterations` argument must be statically '
'specified to use `svd` within JAX transformations.',
)
if subset_by_index is not None:
if len(subset_by_index) != 2:
raise ValueError('subset_by_index must be a tuple of size 2.')
# Make sure subset_by_index is a concrete tuple.
subset_by_index = (
operator.index(subset_by_index[0]),
operator.index(subset_by_index[1]),
)
if subset_by_index[0] >= subset_by_index[1]:
raise ValueError('Got empty index range in subset_by_index.')
if subset_by_index[0] < 0:
raise ValueError('Indices in subset_by_index must be non-negative.')
m, n = a.shape
rank = n if n < m else m
if subset_by_index[1] > rank:
raise ValueError('Index in subset_by_index[1] exceeds matrix size.')
if full_matrices and subset_by_index != (0, rank):
raise ValueError(
'full_matrices and subset_by_index cannot be both be set.'
)
# By convention, eigenvalues are numbered in non-decreasing order, while
# singular values are numbered non-increasing order, so change
# subset_by_index accordingly.
subset_by_index = (rank - subset_by_index[1], rank - subset_by_index[0])
m, n = a.shape
is_flip = False
if m < n:
a = a.T.conj()
m, n = a.shape
is_flip = True
reduce_to_square = False
if full_matrices:
q_full, a_full = lax.linalg.qr(a, full_matrices=True)
q = q_full[:, :n]
u_out_null = q_full[:, n:]
a = a_full[:n, :]
reduce_to_square = True
else:
# The constant `1.15` comes from Yuji Nakatsukasa's implementation
# https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav
if m > 1.15 * n:
q, a = lax.linalg.qr(a, full_matrices=False)
reduce_to_square = True
if not compute_uv:
with jax.default_matmul_precision('float32'):
return _svd_tall_and_square_input(
a, hermitian, compute_uv, max_iterations, subset_by_index
)
with jax.default_matmul_precision('float32'):
u_out, s_out, v_out = _svd_tall_and_square_input(
a, hermitian, compute_uv, max_iterations, subset_by_index
)
if reduce_to_square:
u_out = q @ u_out
if full_matrices:
u_out = jnp.hstack((u_out, u_out_null))
is_finite = jnp.all(jnp.isfinite(a))
cond_f = lambda args: jnp.logical_not(args[0])
body_f = lambda args: (
jnp.array(True),
jnp.full_like(u_out, jnp.nan),
jnp.full_like(s_out, jnp.nan),
jnp.full_like(v_out, jnp.nan),
)
_, u_out, s_out, v_out = lax.while_loop(
cond_f, body_f, (is_finite, u_out, s_out, v_out)
)
if is_flip:
return (v_out, s_out, u_out.T.conj())
return (u_out, s_out, v_out.T.conj())