[JAX] Added jit-able singular value decomposition.

PiperOrigin-RevId: 426193395
This commit is contained in:
Tianjian Lu 2022-02-03 11:16:18 -08:00 committed by jax authors
parent b7dcc4ce01
commit 5a012d5e7b
2 changed files with 277 additions and 0 deletions

143
jax/_src/lax/svd.py Normal file
View File

@ -0,0 +1,143 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""A JIT-compatible library for QDWH-based SVD decomposition.
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
iteration implemented through QR decmopositions is numerically stable and does
not require solving a linear system involving the iteration matrix or
computing its inversion. This is desirable for multicore and heterogeneous
computing systems.
References:
Nakatsukasa, Yuji, and Nicholas J. Higham.
"Stable and efficient spectral divide and conquer algorithms for the symmetric
eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing 35,
no. 3 (2013): A1325-A1349.
https://epubs.siam.org/doi/abs/10.1137/120876605
Nakatsukasa, Yuji, Zhaojun Bai, and François Gygi.
"Optimizing Halley's iteration for computing the matrix polar decomposition."
SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720.
https://epubs.siam.org/doi/abs/10.1137/090774999
"""
import functools
from typing import Sequence
import jax
from jax import core
from jax import lax
import jax.numpy as jnp
@functools.partial(jax.jit, static_argnums=(1, 2))
def _svd(a: jnp.ndarray,
is_hermitian: bool,
max_iterations: int) -> Sequence[jnp.ndarray]:
"""Singular value decomposition for m x n matrix and m >= n.
Args:
a: A matrix of shape `m x n` with `m >= n`.
is_hermitian: True if `a` is Hermitian.
max_iterations: The predefined maximum number of iterations of QDWH.
Returns:
A 3-tuple (`u`, `s`, `v`), where `u` is a unitary matrix of shape `m x n`,
`s` is vector of length `n` containing the singular values in the descending
order, `v` is a unitary matrix of shape `n x n`, and
`a = (u * s) @ v.T.conj()`.
"""
u, h, _, _ = lax.linalg.qdwh(a, is_hermitian, max_iterations)
v, s = lax.linalg.eigh(h)
# Flips the singular values in descending order.
s_out = jnp.flip(s)
# Reorders eigenvectors.
v_out = jnp.fliplr(v)
u_out = u @ v_out
# Makes correction if computed `u` from qdwh is not unitary.
# Section 5.5 of Nakatsukasa, Yuji, and Nicholas J. Higham. "Stable and
# efficient spectral divide and conquer algorithms for the symmetric
# eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing
# 35, no. 3 (2013): A1325-A1349.
def correct_rank_deficiency(u_out):
u_out, r = lax.linalg.qr(u_out, full_matrices=False)
u_out = u_out @ jnp.diag(lax.sign(jnp.diag(r)))
return u_out
eps = jnp.finfo(a.dtype).eps
u_out = lax.cond(s[0] < a.shape[1] * eps * s_out[0],
correct_rank_deficiency,
lambda u_out: u_out,
operand=(u_out))
return (u_out, s_out, v_out)
@functools.partial(jax.jit, static_argnums=(1, 2))
def svd(a: jnp.ndarray,
is_hermitian: bool = False,
max_iterations: int = 10) -> Sequence[jnp.ndarray]:
"""Singular value decomposition.
Args:
a: A matrix of shape `m x n`.
is_hermitian: True if `a` is Hermitian.
max_iterations: The predefined maximum number of iterations of QDWH.
Returns:
A 3-tuple (`u`, `s`, `vh`), where `u` is a unitary matrix of shape `m x k`,
`s` is vector of length `k` containing the singular values in the descending
order, `vh` is a unitary matrix of shape `k x n`, `k = min(m, n)`, and
`a = (u * s) @ vh`.
"""
is_hermitian = core.concrete_or_error(
bool, is_hermitian, 'The `is_hermitian` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
max_iterations = core.concrete_or_error(
int, max_iterations, 'The `max_iterations` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
m, n = a.shape
is_flip = False
if m < n:
a = a.T.conj()
m, n = a.shape
is_flip = True
reduce_to_square = False
if m > 1.15 * n:
m = n
q, a = lax.linalg.qr(a, full_matrices=False)
reduce_to_square = True
u_out, s_out, v_out = _svd(a, is_hermitian, max_iterations)
if reduce_to_square:
u_out = q @ u_out
if is_flip:
return(v_out, s_out, u_out.T.conj())
return (u_out, s_out, v_out.T.conj())

134
tests/svd_test.py Normal file
View File

@ -0,0 +1,134 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""Tests for the library of QDWH-based SVD decomposition."""
import functools
import jax
from jax import test_util as jtu
from jax.config import config
import jax.numpy as jnp
import numpy as np
import scipy.linalg as osp_linalg
from jax._src.lax import svd
from absl.testing import absltest
from absl.testing import parameterized
config.parse_flags_with_absl()
_JAX_ENABLE_X64 = config.x64_enabled
# Input matrix data type for SvdTest.
_SVD_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64 else np.float32
# Machine epsilon used by SvdTest.
_SVD_TEST_EPS = jnp.finfo(_SVD_TEST_DTYPE).eps
# SvdTest relative tolerance.
_SVD_RTOL = 1E-6 if _JAX_ENABLE_X64 else 1E-2
_MAX_LOG_CONDITION_NUM = 9 if _JAX_ENABLE_X64 else 4
@jtu.with_config(jax_numpy_rank_promotion='allow')
class SvdTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{ # pylint:disable=g-complex-comprehension
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond),
'm': m, 'n': n, 'log_cond': log_cond}
for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
def testSvdWithRectangularInput(self, m, n, log_cond):
"""Tests SVD with rectangular input."""
with jax.default_matmul_precision('float32'):
a = np.random.uniform(
low=0.3, high=0.9, size=(m, n)).astype(_SVD_TEST_DTYPE)
u, s, v = jnp.linalg.svd(a, full_matrices=False)
cond = 10**log_cond
s = jnp.linspace(cond, 1, min(m, n))
a = (u * s) @ v
a = a + 1j * a
osp_linalg_fn = functools.partial(osp_linalg.svd, full_matrices=False)
actual_u, actual_s, actual_v = svd.svd(a)
k = min(m, n)
if m > n:
unitary_u = jnp.abs(actual_u.T.conj() @ actual_u)
unitary_v = jnp.abs(actual_v.T.conj() @ actual_v)
else:
unitary_u = jnp.abs(actual_u @ actual_u.T.conj())
unitary_v = jnp.abs(actual_v @ actual_v.T.conj())
_, expected_s, _ = osp_linalg_fn(a)
args_maker = lambda: [a]
with self.subTest('Test JIT compatibility'):
self._CompileAndCheck(svd.svd, args_maker)
with self.subTest('Test unitary u.'):
self.assertAllClose(np.eye(k), unitary_u, rtol=_SVD_RTOL, atol=2E-3)
with self.subTest('Test unitary v.'):
self.assertAllClose(np.eye(k), unitary_v, rtol=_SVD_RTOL, atol=2E-3)
with self.subTest('Test s.'):
self.assertAllClose(
expected_s, jnp.real(actual_s), rtol=_SVD_RTOL, atol=1E-6)
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name': '_m={}_by_n={}'.format(m, n), 'm': m, 'n': n}
for m, n in zip([50, 6], [3, 60])))
def testSvdWithSkinnyTallInput(self, m, n):
"""Tests SVD with skinny and tall input."""
# Generates a skinny and tall input
with jax.default_matmul_precision('float32'):
np.random.seed(1235)
a = np.random.randn(m, n).astype(_SVD_TEST_DTYPE)
u, s, v = svd.svd(a, is_hermitian=False)
relative_diff = np.linalg.norm(a - (u * s) @ v) / np.linalg.norm(a)
np.testing.assert_almost_equal(relative_diff, 1E-6, decimal=6)
@parameterized.named_parameters(jtu.cases_from_list(
{ # pylint:disable=g-complex-comprehension
'testcase_name': '_m={}_r={}_log_cond={}'.format(m, r, log_cond),
'm': m, 'r': r, 'log_cond': log_cond}
for m, r in zip([8, 8, 8, 10], [3, 5, 7, 9])
for log_cond in np.linspace(1, 3, 3)))
def testSvdWithOnRankDeficientInput(self, m, r, log_cond):
"""Tests SVD with rank-deficient input."""
with jax.default_matmul_precision('float32'):
a = jnp.triu(jnp.ones((m, m))).astype(_SVD_TEST_DTYPE)
# Generates a rank-deficient input.
u, s, v = jnp.linalg.svd(a, full_matrices=False)
cond = 10**log_cond
s = jnp.linspace(cond, 1, m)
s = s.at[r:m].set(jnp.zeros((m-r,)))
a = (u * s) @ v
with jax.default_matmul_precision('float32'):
u, s, v = svd.svd(a, is_hermitian=False)
diff = np.linalg.norm(a - (u * s) @ v)
np.testing.assert_almost_equal(diff, 1E-4, decimal=2)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())