mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Added jax.lax.linalg.qdwh
.
PiperOrigin-RevId: 406453671
This commit is contained in:
parent
d403b6d03c
commit
c5f73b3d8e
@ -20,6 +20,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
|
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
|
||||||
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`
|
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`
|
||||||
|
|
||||||
|
* New features:
|
||||||
|
* Added `jax.lax.linalg.qdwh`.
|
||||||
|
|
||||||
## jax 0.2.24 (Oct 19, 2021)
|
## jax 0.2.24 (Oct 19, 2021)
|
||||||
* [GitHub
|
* [GitHub
|
||||||
|
@ -199,6 +199,7 @@ Linear algebra operators (jax.lax.linalg)
|
|||||||
eig
|
eig
|
||||||
eigh
|
eigh
|
||||||
lu
|
lu
|
||||||
|
qdwh
|
||||||
qr
|
qr
|
||||||
svd
|
svd
|
||||||
triangular_solve
|
triangular_solve
|
||||||
|
185
jax/_src/lax/qdwh.py
Normal file
185
jax/_src/lax/qdwh.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License
|
||||||
|
|
||||||
|
"""A JIT-compatible library for QDWH-based polar decomposition.
|
||||||
|
|
||||||
|
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
|
||||||
|
iteration implemented through QR decmopositions does not require matrix
|
||||||
|
inversion. This is desirable for multicore and heterogeneous computing systems.
|
||||||
|
|
||||||
|
Reference: Nakatsukasa, Yuji, Zhaojun Bai, and François Gygi.
|
||||||
|
"Optimizing Halley's iteration for computing the matrix polar decomposition."
|
||||||
|
SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720.
|
||||||
|
https://epubs.siam.org/doi/abs/10.1137/090774999
|
||||||
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import core
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax._src.lax import linalg as lax_linalg
|
||||||
|
|
||||||
|
|
||||||
|
def _use_qr(u, params):
|
||||||
|
"""Uses QR decomposition."""
|
||||||
|
a, b, c = params
|
||||||
|
m, n = u.shape
|
||||||
|
y = jnp.concatenate([jnp.sqrt(c) * u, jnp.eye(n)])
|
||||||
|
q, _ = jnp.linalg.qr(y)
|
||||||
|
q1 = q[:m, :]
|
||||||
|
q2 = (q[m:, :]).T.conj()
|
||||||
|
e = b / c
|
||||||
|
u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2))
|
||||||
|
return u
|
||||||
|
|
||||||
|
|
||||||
|
def _use_cholesky(u, params):
|
||||||
|
"""Uses Cholesky decomposition."""
|
||||||
|
a, b, c = params
|
||||||
|
_, n = u.shape
|
||||||
|
x = c * u.T.conj() @ u + jnp.eye(n)
|
||||||
|
|
||||||
|
# `y` is lower triangular.
|
||||||
|
y = lax_linalg.cholesky(x, symmetrize_input=False)
|
||||||
|
|
||||||
|
z = lax_linalg.triangular_solve(
|
||||||
|
y, u.T, left_side=True, lower=True, conjugate_a=True).conj()
|
||||||
|
|
||||||
|
z = lax_linalg.triangular_solve(y, z, left_side=True, lower=True,
|
||||||
|
transpose_a=True, conjugate_a=True).T.conj()
|
||||||
|
|
||||||
|
e = b / c
|
||||||
|
u = e * u + (a - e) * z
|
||||||
|
return u
|
||||||
|
|
||||||
|
|
||||||
|
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
|
||||||
|
def _qdwh(x, is_symmetric, max_iterations):
|
||||||
|
"""QR-based dynamically weighted Halley iteration for polar decomposition."""
|
||||||
|
|
||||||
|
# Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of
|
||||||
|
# norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for
|
||||||
|
# the smallest singular value of x.
|
||||||
|
eps = jnp.finfo(x.dtype).eps
|
||||||
|
alpha = jnp.sqrt(jnp.linalg.norm(x, ord=1) * jnp.linalg.norm(x, ord=jnp.inf))
|
||||||
|
l = eps
|
||||||
|
|
||||||
|
u = x / alpha
|
||||||
|
|
||||||
|
# Iteration tolerances.
|
||||||
|
tol_l = 10.0 * eps / 2.0
|
||||||
|
tol_norm = jnp.cbrt(tol_l)
|
||||||
|
|
||||||
|
def cond_fun(state):
|
||||||
|
_, _, _, is_unconverged, is_not_max_iteration = state
|
||||||
|
return jnp.logical_and(is_unconverged, is_not_max_iteration)
|
||||||
|
|
||||||
|
def body_fun(state):
|
||||||
|
u, l, iter_idx, _, _ = state
|
||||||
|
|
||||||
|
u_prev = u
|
||||||
|
|
||||||
|
# Computes parameters.
|
||||||
|
l2 = l**2
|
||||||
|
dd = jnp.cbrt(4.0 * (1.0 / l2 - 1.0) / l2)
|
||||||
|
sqd = jnp.sqrt(1.0 + dd)
|
||||||
|
a = (sqd + jnp.sqrt(8.0 - 4.0 * dd + 8.0 * (2.0 - l2) / (l2 * sqd)) / 2)
|
||||||
|
a = jnp.real(a)
|
||||||
|
b = (a - 1.0)**2 / 4.0
|
||||||
|
c = a + b - 1.0
|
||||||
|
|
||||||
|
# Updates l.
|
||||||
|
l = l * (a + b * l2) / (1.0 + c * l2)
|
||||||
|
|
||||||
|
# Uses QR or Cholesky decomposition.
|
||||||
|
def true_fn(u):
|
||||||
|
return _use_qr(u, params=(a, b, c))
|
||||||
|
|
||||||
|
def false_fn(u):
|
||||||
|
return _use_cholesky(u, params=(a, b, c))
|
||||||
|
|
||||||
|
u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u))
|
||||||
|
|
||||||
|
if is_symmetric:
|
||||||
|
u = (u + u.T.conj()) / 2.0
|
||||||
|
|
||||||
|
# Checks convergence.
|
||||||
|
iterating_l = jnp.abs(1.0 - l) > tol_l
|
||||||
|
iterating_u = jnp.linalg.norm((u-u_prev)) > tol_norm
|
||||||
|
is_unconverged = jnp.logical_or(iterating_l, iterating_u)
|
||||||
|
|
||||||
|
is_not_max_iteration = iter_idx < max_iterations
|
||||||
|
|
||||||
|
return u, l, iter_idx + 1, is_unconverged, is_not_max_iteration
|
||||||
|
|
||||||
|
iter_idx = 1
|
||||||
|
is_unconverged = True
|
||||||
|
is_not_max_iteration = True
|
||||||
|
u, _, num_iters, is_unconverged, _ = jax.lax.while_loop(
|
||||||
|
cond_fun=cond_fun, body_fun=body_fun,
|
||||||
|
init_val=(u, l, iter_idx, is_unconverged, is_not_max_iteration))
|
||||||
|
|
||||||
|
# Applies Newton-Schulz refinement for better accuracy.
|
||||||
|
u = 1.5 * u - 0.5 * u @ (u.T.conj() @ u)
|
||||||
|
|
||||||
|
h = u.T.conj() @ x
|
||||||
|
h = (h + h.T.conj()) / 2.0
|
||||||
|
|
||||||
|
# Converged within the maximum number of iterations.
|
||||||
|
is_converged = jnp.logical_not(is_unconverged)
|
||||||
|
|
||||||
|
return u, h, num_iters - 1, is_converged
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Add pivoting.
|
||||||
|
def qdwh(x, is_symmetric, max_iterations=10):
|
||||||
|
"""QR-based dynamically weighted Halley iteration for polar decomposition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: A full-rank matrix of shape `m x n` with `m >= n`.
|
||||||
|
is_symmetric: True if `x` is symmetric.
|
||||||
|
max_iterations: The predefined maximum number of iterations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A four-tuple of (u, h, num_iters, is_converged) containing the
|
||||||
|
polar decomposition of `x = u * h`, the number of iterations to compute `u`,
|
||||||
|
and `is_converged`, whose value is `True` when the convergence is achieved
|
||||||
|
within the maximum number of iterations.
|
||||||
|
"""
|
||||||
|
m, n = x.shape
|
||||||
|
|
||||||
|
if m < n:
|
||||||
|
raise ValueError('The input matrix of shape m x n must have m >= n.')
|
||||||
|
|
||||||
|
max_iterations = core.concrete_or_error(
|
||||||
|
int, max_iterations, 'The `max_iterations` argument must be statically '
|
||||||
|
'specified to use `qdwh` within JAX transformations.')
|
||||||
|
|
||||||
|
is_symmetric = core.concrete_or_error(
|
||||||
|
bool, is_symmetric, 'The `is_symmetric` argument must be statically '
|
||||||
|
'specified to use `qdwh` within JAX transformations.')
|
||||||
|
|
||||||
|
if is_symmetric:
|
||||||
|
eps = jnp.finfo(x.dtype).eps
|
||||||
|
tol = 50.0 * eps
|
||||||
|
relative_diff = jnp.linalg.norm(x - x.T.conj()) / jnp.linalg.norm(x)
|
||||||
|
if relative_diff > tol:
|
||||||
|
raise ValueError('The input `x` is NOT symmetric because '
|
||||||
|
'`norm(x-x.H) / norm(x)` is {}, which is greater than '
|
||||||
|
'the tolerance {}.'.format(relative_diff, tol))
|
||||||
|
|
||||||
|
u, h, num_iters, is_converged = _qdwh(x, is_symmetric, max_iterations)
|
||||||
|
|
||||||
|
return u, h, num_iters, is_converged
|
@ -34,3 +34,8 @@ from jax._src.lax.linalg import (
|
|||||||
schur,
|
schur,
|
||||||
schur_p
|
schur_p
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from jax._src.lax.qdwh import (
|
||||||
|
qdwh as qdwh
|
||||||
|
)
|
||||||
|
195
tests/qdwh_test.py
Normal file
195
tests/qdwh_test.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
# Copyright 2021 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License
|
||||||
|
|
||||||
|
"""Tests for the library of QDWH-based polar decomposition."""
|
||||||
|
|
||||||
|
from jax import test_util as jtu
|
||||||
|
from jax.config import config
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import scipy.linalg as osp_linalg
|
||||||
|
from jax._src.lax import qdwh
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
|
||||||
|
config.parse_flags_with_absl()
|
||||||
|
_JAX_ENABLE_X64 = config.x64_enabled
|
||||||
|
|
||||||
|
# Input matrix data type for PolarTest.
|
||||||
|
_POLAR_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64 else np.float32
|
||||||
|
|
||||||
|
# Machine epsilon used by PolarTest.
|
||||||
|
_POLAR_TEST_EPS = jnp.finfo(_POLAR_TEST_DTYPE).eps
|
||||||
|
|
||||||
|
# Largest log10 value of condition numbers used by PolarTest.
|
||||||
|
_MAX_LOG_CONDITION_NUM = np.log10(int(1 / _POLAR_TEST_EPS))
|
||||||
|
|
||||||
|
|
||||||
|
def _check_symmetry(x: jnp.ndarray) -> bool:
|
||||||
|
"""Check if the array is symmetric."""
|
||||||
|
m, n = x.shape
|
||||||
|
eps = jnp.finfo(x.dtype).eps
|
||||||
|
tol = 50.0 * eps
|
||||||
|
is_symmetric = False
|
||||||
|
if m == n:
|
||||||
|
if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol:
|
||||||
|
is_symmetric = True
|
||||||
|
|
||||||
|
return is_symmetric
|
||||||
|
|
||||||
|
|
||||||
|
class PolarTest(jtu.JaxTestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
{ # pylint:disable=g-complex-comprehension
|
||||||
|
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond),
|
||||||
|
'm': m, 'n': n, 'log_cond': log_cond}
|
||||||
|
for m, n in zip([8, 10, 20], [6, 10, 18])
|
||||||
|
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
|
||||||
|
def testQdwhUnconvergedAfterMaxNumberIterations(
|
||||||
|
self, m, n, log_cond):
|
||||||
|
"""Tests unconvergence after maximum number of iterations."""
|
||||||
|
a = jnp.triu(jnp.ones((m, n)))
|
||||||
|
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||||||
|
cond = 10**log_cond
|
||||||
|
s = jnp.linspace(cond, 1, min(m, n))
|
||||||
|
a = (u * s) @ v
|
||||||
|
is_symmetric = _check_symmetry(a)
|
||||||
|
max_iterations = 2
|
||||||
|
|
||||||
|
_, _, actual_num_iterations, is_converged = qdwh.qdwh(
|
||||||
|
a, is_symmetric, max_iterations)
|
||||||
|
|
||||||
|
with self.subTest('Number of iterations.'):
|
||||||
|
self.assertEqual(max_iterations, actual_num_iterations)
|
||||||
|
|
||||||
|
with self.subTest('Converged.'):
|
||||||
|
self.assertFalse(is_converged)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
{ # pylint:disable=g-complex-comprehension
|
||||||
|
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond),
|
||||||
|
'm': m, 'n': n, 'log_cond': log_cond}
|
||||||
|
for m, n in zip([8, 10, 20], [6, 10, 18])
|
||||||
|
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
|
||||||
|
def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond):
|
||||||
|
"""Tests qdwh with upper triangular input of all ones."""
|
||||||
|
a = jnp.triu(jnp.ones((m, n)))
|
||||||
|
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||||||
|
cond = 10**log_cond
|
||||||
|
s = jnp.linspace(cond, 1, min(m, n))
|
||||||
|
a = (u * s) @ v
|
||||||
|
is_symmetric = _check_symmetry(a)
|
||||||
|
max_iterations = 10
|
||||||
|
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations)
|
||||||
|
expected_u, expected_h = osp_linalg.polar(a)
|
||||||
|
|
||||||
|
# Sets the test tolerance.
|
||||||
|
rtol = 1E6 * _POLAR_TEST_EPS
|
||||||
|
|
||||||
|
with self.subTest('Test u.'):
|
||||||
|
self.assertAllClose(actual_u, expected_u, rtol=rtol)
|
||||||
|
|
||||||
|
with self.subTest('Test h.'):
|
||||||
|
self.assertAllClose(actual_h, expected_h, rtol=rtol)
|
||||||
|
|
||||||
|
with self.subTest('Test u.dot(h).'):
|
||||||
|
a_round_trip = actual_u.dot(actual_h)
|
||||||
|
self.assertAllClose(a_round_trip, a, rtol=rtol)
|
||||||
|
|
||||||
|
with self.subTest('Test orthogonality.'):
|
||||||
|
actual_results = actual_u.T.dot(actual_u)
|
||||||
|
expected_results = np.eye(n)
|
||||||
|
self.assertAllClose(
|
||||||
|
actual_results, expected_results, rtol=rtol, atol=1E-4)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
{ # pylint:disable=g-complex-comprehension
|
||||||
|
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(
|
||||||
|
m, n, log_cond),
|
||||||
|
'm': m, 'n': n, 'log_cond': log_cond}
|
||||||
|
for m, n in zip([6, 8], [6, 4])
|
||||||
|
for log_cond in np.linspace(1, 4, 4)))
|
||||||
|
def testQdwhWithRandomMatrix(self, m, n, log_cond):
|
||||||
|
"""Tests qdwh with random input."""
|
||||||
|
|
||||||
|
a = np.random.uniform(
|
||||||
|
low=0.3, high=0.9, size=(m, n)).astype(_POLAR_TEST_DTYPE)
|
||||||
|
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||||||
|
cond = 10**log_cond
|
||||||
|
s = jnp.linspace(cond, 1, min(m, n))
|
||||||
|
a = (u * s) @ v
|
||||||
|
is_symmetric = _check_symmetry(a)
|
||||||
|
max_iterations = 10
|
||||||
|
|
||||||
|
def lsp_linalg_fn(a):
|
||||||
|
u, h, _, _ = qdwh.qdwh(
|
||||||
|
a, is_symmetric=is_symmetric, max_iterations=max_iterations)
|
||||||
|
return u, h
|
||||||
|
|
||||||
|
args_maker = lambda: [a]
|
||||||
|
|
||||||
|
# Sets the test tolerance.
|
||||||
|
rtol = 1E6 * _POLAR_TEST_EPS
|
||||||
|
|
||||||
|
with self.subTest('Test JIT compatibility'):
|
||||||
|
self._CompileAndCheck(lsp_linalg_fn, args_maker)
|
||||||
|
|
||||||
|
with self.subTest('Test against numpy.'):
|
||||||
|
self._CheckAgainstNumpy(osp_linalg.polar, lsp_linalg_fn, args_maker,
|
||||||
|
rtol=rtol, atol=1E-3)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
{ # pylint:disable=g-complex-comprehension
|
||||||
|
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond),
|
||||||
|
'm': m, 'n': n, 'log_cond': log_cond}
|
||||||
|
for m, n in zip([10, 12], [10, 12])
|
||||||
|
for log_cond in np.linspace(1, 4, 4)))
|
||||||
|
def testQdwhWithOnRankDeficientInput(self, m, n, log_cond):
|
||||||
|
"""Tests qdwh with rank-deficient input."""
|
||||||
|
a = jnp.triu(jnp.ones((m, n))).astype(_POLAR_TEST_DTYPE)
|
||||||
|
|
||||||
|
# Generates a rank-deficient input.
|
||||||
|
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||||||
|
cond = 10**log_cond
|
||||||
|
s = jnp.linspace(cond, 1, min(m, n))
|
||||||
|
s = s.at[-1].set(0)
|
||||||
|
a = (u * s) @ v
|
||||||
|
|
||||||
|
is_symmetric = _check_symmetry(a)
|
||||||
|
max_iterations = 10
|
||||||
|
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations)
|
||||||
|
_, expected_h = osp_linalg.polar(a)
|
||||||
|
|
||||||
|
# Sets the test tolerance.
|
||||||
|
rtol = 1E6 * _POLAR_TEST_EPS
|
||||||
|
|
||||||
|
# For rank-deficient matrix, `u` is not unique.
|
||||||
|
with self.subTest('Test h.'):
|
||||||
|
self.assertAllClose(actual_h, expected_h, rtol=rtol)
|
||||||
|
|
||||||
|
with self.subTest('Test u.dot(h).'):
|
||||||
|
a_round_trip = actual_u.dot(actual_h)
|
||||||
|
self.assertAllClose(a_round_trip, a, rtol=rtol)
|
||||||
|
|
||||||
|
with self.subTest('Test orthogonality.'):
|
||||||
|
actual_results = actual_u.T.dot(actual_u)
|
||||||
|
expected_results = np.eye(n)
|
||||||
|
self.assertAllClose(
|
||||||
|
actual_results, expected_results, rtol=rtol, atol=1E-5)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user