mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[JAX] Added jit-able singular value decomposition.
PiperOrigin-RevId: 426193395
This commit is contained in:
parent
b7dcc4ce01
commit
5a012d5e7b
143
jax/_src/lax/svd.py
Normal file
143
jax/_src/lax/svd.py
Normal file
@ -0,0 +1,143 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
"""A JIT-compatible library for QDWH-based SVD decomposition.
|
||||
|
||||
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
|
||||
iteration implemented through QR decmopositions is numerically stable and does
|
||||
not require solving a linear system involving the iteration matrix or
|
||||
computing its inversion. This is desirable for multicore and heterogeneous
|
||||
computing systems.
|
||||
|
||||
References:
|
||||
Nakatsukasa, Yuji, and Nicholas J. Higham.
|
||||
"Stable and efficient spectral divide and conquer algorithms for the symmetric
|
||||
eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing 35,
|
||||
no. 3 (2013): A1325-A1349.
|
||||
https://epubs.siam.org/doi/abs/10.1137/120876605
|
||||
|
||||
Nakatsukasa, Yuji, Zhaojun Bai, and François Gygi.
|
||||
"Optimizing Halley's iteration for computing the matrix polar decomposition."
|
||||
SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720.
|
||||
https://epubs.siam.org/doi/abs/10.1137/090774999
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2))
|
||||
def _svd(a: jnp.ndarray,
|
||||
is_hermitian: bool,
|
||||
max_iterations: int) -> Sequence[jnp.ndarray]:
|
||||
"""Singular value decomposition for m x n matrix and m >= n.
|
||||
|
||||
Args:
|
||||
a: A matrix of shape `m x n` with `m >= n`.
|
||||
is_hermitian: True if `a` is Hermitian.
|
||||
max_iterations: The predefined maximum number of iterations of QDWH.
|
||||
|
||||
Returns:
|
||||
A 3-tuple (`u`, `s`, `v`), where `u` is a unitary matrix of shape `m x n`,
|
||||
`s` is vector of length `n` containing the singular values in the descending
|
||||
order, `v` is a unitary matrix of shape `n x n`, and
|
||||
`a = (u * s) @ v.T.conj()`.
|
||||
"""
|
||||
|
||||
u, h, _, _ = lax.linalg.qdwh(a, is_hermitian, max_iterations)
|
||||
|
||||
v, s = lax.linalg.eigh(h)
|
||||
|
||||
# Flips the singular values in descending order.
|
||||
s_out = jnp.flip(s)
|
||||
|
||||
# Reorders eigenvectors.
|
||||
v_out = jnp.fliplr(v)
|
||||
|
||||
u_out = u @ v_out
|
||||
|
||||
# Makes correction if computed `u` from qdwh is not unitary.
|
||||
# Section 5.5 of Nakatsukasa, Yuji, and Nicholas J. Higham. "Stable and
|
||||
# efficient spectral divide and conquer algorithms for the symmetric
|
||||
# eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing
|
||||
# 35, no. 3 (2013): A1325-A1349.
|
||||
def correct_rank_deficiency(u_out):
|
||||
u_out, r = lax.linalg.qr(u_out, full_matrices=False)
|
||||
u_out = u_out @ jnp.diag(lax.sign(jnp.diag(r)))
|
||||
return u_out
|
||||
|
||||
eps = jnp.finfo(a.dtype).eps
|
||||
u_out = lax.cond(s[0] < a.shape[1] * eps * s_out[0],
|
||||
correct_rank_deficiency,
|
||||
lambda u_out: u_out,
|
||||
operand=(u_out))
|
||||
|
||||
return (u_out, s_out, v_out)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2))
|
||||
def svd(a: jnp.ndarray,
|
||||
is_hermitian: bool = False,
|
||||
max_iterations: int = 10) -> Sequence[jnp.ndarray]:
|
||||
"""Singular value decomposition.
|
||||
|
||||
Args:
|
||||
a: A matrix of shape `m x n`.
|
||||
is_hermitian: True if `a` is Hermitian.
|
||||
max_iterations: The predefined maximum number of iterations of QDWH.
|
||||
|
||||
Returns:
|
||||
A 3-tuple (`u`, `s`, `vh`), where `u` is a unitary matrix of shape `m x k`,
|
||||
`s` is vector of length `k` containing the singular values in the descending
|
||||
order, `vh` is a unitary matrix of shape `k x n`, `k = min(m, n)`, and
|
||||
`a = (u * s) @ vh`.
|
||||
"""
|
||||
|
||||
is_hermitian = core.concrete_or_error(
|
||||
bool, is_hermitian, 'The `is_hermitian` argument must be statically '
|
||||
'specified to use `qdwh` within JAX transformations.')
|
||||
|
||||
max_iterations = core.concrete_or_error(
|
||||
int, max_iterations, 'The `max_iterations` argument must be statically '
|
||||
'specified to use `qdwh` within JAX transformations.')
|
||||
|
||||
m, n = a.shape
|
||||
|
||||
is_flip = False
|
||||
if m < n:
|
||||
a = a.T.conj()
|
||||
m, n = a.shape
|
||||
is_flip = True
|
||||
|
||||
reduce_to_square = False
|
||||
if m > 1.15 * n:
|
||||
m = n
|
||||
q, a = lax.linalg.qr(a, full_matrices=False)
|
||||
reduce_to_square = True
|
||||
|
||||
u_out, s_out, v_out = _svd(a, is_hermitian, max_iterations)
|
||||
|
||||
if reduce_to_square:
|
||||
u_out = q @ u_out
|
||||
|
||||
if is_flip:
|
||||
return(v_out, s_out, u_out.T.conj())
|
||||
|
||||
return (u_out, s_out, v_out.T.conj())
|
134
tests/svd_test.py
Normal file
134
tests/svd_test.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
"""Tests for the library of QDWH-based SVD decomposition."""
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import scipy.linalg as osp_linalg
|
||||
from jax._src.lax import svd
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
_JAX_ENABLE_X64 = config.x64_enabled
|
||||
|
||||
# Input matrix data type for SvdTest.
|
||||
_SVD_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64 else np.float32
|
||||
|
||||
# Machine epsilon used by SvdTest.
|
||||
_SVD_TEST_EPS = jnp.finfo(_SVD_TEST_DTYPE).eps
|
||||
|
||||
# SvdTest relative tolerance.
|
||||
_SVD_RTOL = 1E-6 if _JAX_ENABLE_X64 else 1E-2
|
||||
|
||||
_MAX_LOG_CONDITION_NUM = 9 if _JAX_ENABLE_X64 else 4
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion='allow')
|
||||
class SvdTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{ # pylint:disable=g-complex-comprehension
|
||||
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond),
|
||||
'm': m, 'n': n, 'log_cond': log_cond}
|
||||
for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])
|
||||
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
|
||||
def testSvdWithRectangularInput(self, m, n, log_cond):
|
||||
"""Tests SVD with rectangular input."""
|
||||
with jax.default_matmul_precision('float32'):
|
||||
a = np.random.uniform(
|
||||
low=0.3, high=0.9, size=(m, n)).astype(_SVD_TEST_DTYPE)
|
||||
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||||
cond = 10**log_cond
|
||||
s = jnp.linspace(cond, 1, min(m, n))
|
||||
a = (u * s) @ v
|
||||
a = a + 1j * a
|
||||
|
||||
osp_linalg_fn = functools.partial(osp_linalg.svd, full_matrices=False)
|
||||
actual_u, actual_s, actual_v = svd.svd(a)
|
||||
|
||||
k = min(m, n)
|
||||
if m > n:
|
||||
unitary_u = jnp.abs(actual_u.T.conj() @ actual_u)
|
||||
unitary_v = jnp.abs(actual_v.T.conj() @ actual_v)
|
||||
else:
|
||||
unitary_u = jnp.abs(actual_u @ actual_u.T.conj())
|
||||
unitary_v = jnp.abs(actual_v @ actual_v.T.conj())
|
||||
|
||||
_, expected_s, _ = osp_linalg_fn(a)
|
||||
|
||||
args_maker = lambda: [a]
|
||||
|
||||
with self.subTest('Test JIT compatibility'):
|
||||
self._CompileAndCheck(svd.svd, args_maker)
|
||||
|
||||
with self.subTest('Test unitary u.'):
|
||||
self.assertAllClose(np.eye(k), unitary_u, rtol=_SVD_RTOL, atol=2E-3)
|
||||
|
||||
with self.subTest('Test unitary v.'):
|
||||
self.assertAllClose(np.eye(k), unitary_v, rtol=_SVD_RTOL, atol=2E-3)
|
||||
|
||||
with self.subTest('Test s.'):
|
||||
self.assertAllClose(
|
||||
expected_s, jnp.real(actual_s), rtol=_SVD_RTOL, atol=1E-6)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name': '_m={}_by_n={}'.format(m, n), 'm': m, 'n': n}
|
||||
for m, n in zip([50, 6], [3, 60])))
|
||||
def testSvdWithSkinnyTallInput(self, m, n):
|
||||
"""Tests SVD with skinny and tall input."""
|
||||
# Generates a skinny and tall input
|
||||
with jax.default_matmul_precision('float32'):
|
||||
np.random.seed(1235)
|
||||
a = np.random.randn(m, n).astype(_SVD_TEST_DTYPE)
|
||||
u, s, v = svd.svd(a, is_hermitian=False)
|
||||
|
||||
relative_diff = np.linalg.norm(a - (u * s) @ v) / np.linalg.norm(a)
|
||||
|
||||
np.testing.assert_almost_equal(relative_diff, 1E-6, decimal=6)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{ # pylint:disable=g-complex-comprehension
|
||||
'testcase_name': '_m={}_r={}_log_cond={}'.format(m, r, log_cond),
|
||||
'm': m, 'r': r, 'log_cond': log_cond}
|
||||
for m, r in zip([8, 8, 8, 10], [3, 5, 7, 9])
|
||||
for log_cond in np.linspace(1, 3, 3)))
|
||||
def testSvdWithOnRankDeficientInput(self, m, r, log_cond):
|
||||
"""Tests SVD with rank-deficient input."""
|
||||
with jax.default_matmul_precision('float32'):
|
||||
a = jnp.triu(jnp.ones((m, m))).astype(_SVD_TEST_DTYPE)
|
||||
|
||||
# Generates a rank-deficient input.
|
||||
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||||
cond = 10**log_cond
|
||||
s = jnp.linspace(cond, 1, m)
|
||||
s = s.at[r:m].set(jnp.zeros((m-r,)))
|
||||
a = (u * s) @ v
|
||||
|
||||
with jax.default_matmul_precision('float32'):
|
||||
u, s, v = svd.svd(a, is_hermitian=False)
|
||||
diff = np.linalg.norm(a - (u * s) @ v)
|
||||
|
||||
np.testing.assert_almost_equal(diff, 1E-4, decimal=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user