From c5f73b3d8e84acbb7613c33c97b7317aad5907e9 Mon Sep 17 00:00:00 2001 From: Tianjian Lu Date: Fri, 29 Oct 2021 14:44:27 -0700 Subject: [PATCH] [JAX] Added `jax.lax.linalg.qdwh`. PiperOrigin-RevId: 406453671 --- CHANGELOG.md | 2 + docs/jax.lax.rst | 1 + jax/_src/lax/qdwh.py | 185 ++++++++++++++++++++++++++++++++++++++++ jax/lax/linalg.py | 5 ++ tests/qdwh_test.py | 195 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 388 insertions(+) create mode 100644 jax/_src/lax/qdwh.py create mode 100644 tests/qdwh_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 114a6352e..efac0c79b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * Moved `jax.experimental.stax` to `jax.example_libraries.stax` * Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers` +* New features: + * Added `jax.lax.linalg.qdwh`. ## jax 0.2.24 (Oct 19, 2021) * [GitHub diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 1faf8e7ff..aed44a3b0 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -199,6 +199,7 @@ Linear algebra operators (jax.lax.linalg) eig eigh lu + qdwh qr svd triangular_solve diff --git a/jax/_src/lax/qdwh.py b/jax/_src/lax/qdwh.py new file mode 100644 index 000000000..41f0c65b2 --- /dev/null +++ b/jax/_src/lax/qdwh.py @@ -0,0 +1,185 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""A JIT-compatible library for QDWH-based polar decomposition. + +QDWH is short for QR-based dynamically weighted Halley iteration. The Halley +iteration implemented through QR decmopositions does not require matrix +inversion. This is desirable for multicore and heterogeneous computing systems. + +Reference: Nakatsukasa, Yuji, Zhaojun Bai, and François Gygi. +"Optimizing Halley's iteration for computing the matrix polar decomposition." +SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720. +https://epubs.siam.org/doi/abs/10.1137/090774999 +""" + +import functools + +import jax +from jax import core +import jax.numpy as jnp +from jax._src.lax import linalg as lax_linalg + + +def _use_qr(u, params): + """Uses QR decomposition.""" + a, b, c = params + m, n = u.shape + y = jnp.concatenate([jnp.sqrt(c) * u, jnp.eye(n)]) + q, _ = jnp.linalg.qr(y) + q1 = q[:m, :] + q2 = (q[m:, :]).T.conj() + e = b / c + u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2)) + return u + + +def _use_cholesky(u, params): + """Uses Cholesky decomposition.""" + a, b, c = params + _, n = u.shape + x = c * u.T.conj() @ u + jnp.eye(n) + + # `y` is lower triangular. + y = lax_linalg.cholesky(x, symmetrize_input=False) + + z = lax_linalg.triangular_solve( + y, u.T, left_side=True, lower=True, conjugate_a=True).conj() + + z = lax_linalg.triangular_solve(y, z, left_side=True, lower=True, + transpose_a=True, conjugate_a=True).T.conj() + + e = b / c + u = e * u + (a - e) * z + return u + + +@functools.partial(jax.jit, static_argnums=(1, 2, 3)) +def _qdwh(x, is_symmetric, max_iterations): + """QR-based dynamically weighted Halley iteration for polar decomposition.""" + + # Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of + # norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for + # the smallest singular value of x. + eps = jnp.finfo(x.dtype).eps + alpha = jnp.sqrt(jnp.linalg.norm(x, ord=1) * jnp.linalg.norm(x, ord=jnp.inf)) + l = eps + + u = x / alpha + + # Iteration tolerances. + tol_l = 10.0 * eps / 2.0 + tol_norm = jnp.cbrt(tol_l) + + def cond_fun(state): + _, _, _, is_unconverged, is_not_max_iteration = state + return jnp.logical_and(is_unconverged, is_not_max_iteration) + + def body_fun(state): + u, l, iter_idx, _, _ = state + + u_prev = u + + # Computes parameters. + l2 = l**2 + dd = jnp.cbrt(4.0 * (1.0 / l2 - 1.0) / l2) + sqd = jnp.sqrt(1.0 + dd) + a = (sqd + jnp.sqrt(8.0 - 4.0 * dd + 8.0 * (2.0 - l2) / (l2 * sqd)) / 2) + a = jnp.real(a) + b = (a - 1.0)**2 / 4.0 + c = a + b - 1.0 + + # Updates l. + l = l * (a + b * l2) / (1.0 + c * l2) + + # Uses QR or Cholesky decomposition. + def true_fn(u): + return _use_qr(u, params=(a, b, c)) + + def false_fn(u): + return _use_cholesky(u, params=(a, b, c)) + + u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u)) + + if is_symmetric: + u = (u + u.T.conj()) / 2.0 + + # Checks convergence. + iterating_l = jnp.abs(1.0 - l) > tol_l + iterating_u = jnp.linalg.norm((u-u_prev)) > tol_norm + is_unconverged = jnp.logical_or(iterating_l, iterating_u) + + is_not_max_iteration = iter_idx < max_iterations + + return u, l, iter_idx + 1, is_unconverged, is_not_max_iteration + + iter_idx = 1 + is_unconverged = True + is_not_max_iteration = True + u, _, num_iters, is_unconverged, _ = jax.lax.while_loop( + cond_fun=cond_fun, body_fun=body_fun, + init_val=(u, l, iter_idx, is_unconverged, is_not_max_iteration)) + + # Applies Newton-Schulz refinement for better accuracy. + u = 1.5 * u - 0.5 * u @ (u.T.conj() @ u) + + h = u.T.conj() @ x + h = (h + h.T.conj()) / 2.0 + + # Converged within the maximum number of iterations. + is_converged = jnp.logical_not(is_unconverged) + + return u, h, num_iters - 1, is_converged + + +# TODO: Add pivoting. +def qdwh(x, is_symmetric, max_iterations=10): + """QR-based dynamically weighted Halley iteration for polar decomposition. + + Args: + x: A full-rank matrix of shape `m x n` with `m >= n`. + is_symmetric: True if `x` is symmetric. + max_iterations: The predefined maximum number of iterations. + + Returns: + A four-tuple of (u, h, num_iters, is_converged) containing the + polar decomposition of `x = u * h`, the number of iterations to compute `u`, + and `is_converged`, whose value is `True` when the convergence is achieved + within the maximum number of iterations. + """ + m, n = x.shape + + if m < n: + raise ValueError('The input matrix of shape m x n must have m >= n.') + + max_iterations = core.concrete_or_error( + int, max_iterations, 'The `max_iterations` argument must be statically ' + 'specified to use `qdwh` within JAX transformations.') + + is_symmetric = core.concrete_or_error( + bool, is_symmetric, 'The `is_symmetric` argument must be statically ' + 'specified to use `qdwh` within JAX transformations.') + + if is_symmetric: + eps = jnp.finfo(x.dtype).eps + tol = 50.0 * eps + relative_diff = jnp.linalg.norm(x - x.T.conj()) / jnp.linalg.norm(x) + if relative_diff > tol: + raise ValueError('The input `x` is NOT symmetric because ' + '`norm(x-x.H) / norm(x)` is {}, which is greater than ' + 'the tolerance {}.'.format(relative_diff, tol)) + + u, h, num_iters, is_converged = _qdwh(x, is_symmetric, max_iterations) + + return u, h, num_iters, is_converged diff --git a/jax/lax/linalg.py b/jax/lax/linalg.py index 391ec0e26..f310318e1 100644 --- a/jax/lax/linalg.py +++ b/jax/lax/linalg.py @@ -34,3 +34,8 @@ from jax._src.lax.linalg import ( schur, schur_p ) + + +from jax._src.lax.qdwh import ( + qdwh as qdwh +) diff --git a/tests/qdwh_test.py b/tests/qdwh_test.py new file mode 100644 index 000000000..0bfd7af8e --- /dev/null +++ b/tests/qdwh_test.py @@ -0,0 +1,195 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Tests for the library of QDWH-based polar decomposition.""" + +from jax import test_util as jtu +from jax.config import config +import jax.numpy as jnp +import numpy as np +import scipy.linalg as osp_linalg +from jax._src.lax import qdwh + +from absl.testing import absltest +from absl.testing import parameterized + + +config.parse_flags_with_absl() +_JAX_ENABLE_X64 = config.x64_enabled + +# Input matrix data type for PolarTest. +_POLAR_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64 else np.float32 + +# Machine epsilon used by PolarTest. +_POLAR_TEST_EPS = jnp.finfo(_POLAR_TEST_DTYPE).eps + +# Largest log10 value of condition numbers used by PolarTest. +_MAX_LOG_CONDITION_NUM = np.log10(int(1 / _POLAR_TEST_EPS)) + + +def _check_symmetry(x: jnp.ndarray) -> bool: + """Check if the array is symmetric.""" + m, n = x.shape + eps = jnp.finfo(x.dtype).eps + tol = 50.0 * eps + is_symmetric = False + if m == n: + if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol: + is_symmetric = True + + return is_symmetric + + +class PolarTest(jtu.JaxTestCase): + + @parameterized.named_parameters(jtu.cases_from_list( + { # pylint:disable=g-complex-comprehension + 'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond), + 'm': m, 'n': n, 'log_cond': log_cond} + for m, n in zip([8, 10, 20], [6, 10, 18]) + for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4))) + def testQdwhUnconvergedAfterMaxNumberIterations( + self, m, n, log_cond): + """Tests unconvergence after maximum number of iterations.""" + a = jnp.triu(jnp.ones((m, n))) + u, s, v = jnp.linalg.svd(a, full_matrices=False) + cond = 10**log_cond + s = jnp.linspace(cond, 1, min(m, n)) + a = (u * s) @ v + is_symmetric = _check_symmetry(a) + max_iterations = 2 + + _, _, actual_num_iterations, is_converged = qdwh.qdwh( + a, is_symmetric, max_iterations) + + with self.subTest('Number of iterations.'): + self.assertEqual(max_iterations, actual_num_iterations) + + with self.subTest('Converged.'): + self.assertFalse(is_converged) + + @parameterized.named_parameters(jtu.cases_from_list( + { # pylint:disable=g-complex-comprehension + 'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond), + 'm': m, 'n': n, 'log_cond': log_cond} + for m, n in zip([8, 10, 20], [6, 10, 18]) + for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4))) + def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond): + """Tests qdwh with upper triangular input of all ones.""" + a = jnp.triu(jnp.ones((m, n))) + u, s, v = jnp.linalg.svd(a, full_matrices=False) + cond = 10**log_cond + s = jnp.linspace(cond, 1, min(m, n)) + a = (u * s) @ v + is_symmetric = _check_symmetry(a) + max_iterations = 10 + actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations) + expected_u, expected_h = osp_linalg.polar(a) + + # Sets the test tolerance. + rtol = 1E6 * _POLAR_TEST_EPS + + with self.subTest('Test u.'): + self.assertAllClose(actual_u, expected_u, rtol=rtol) + + with self.subTest('Test h.'): + self.assertAllClose(actual_h, expected_h, rtol=rtol) + + with self.subTest('Test u.dot(h).'): + a_round_trip = actual_u.dot(actual_h) + self.assertAllClose(a_round_trip, a, rtol=rtol) + + with self.subTest('Test orthogonality.'): + actual_results = actual_u.T.dot(actual_u) + expected_results = np.eye(n) + self.assertAllClose( + actual_results, expected_results, rtol=rtol, atol=1E-4) + + @parameterized.named_parameters(jtu.cases_from_list( + { # pylint:disable=g-complex-comprehension + 'testcase_name': '_m={}_by_n={}_log_cond={}'.format( + m, n, log_cond), + 'm': m, 'n': n, 'log_cond': log_cond} + for m, n in zip([6, 8], [6, 4]) + for log_cond in np.linspace(1, 4, 4))) + def testQdwhWithRandomMatrix(self, m, n, log_cond): + """Tests qdwh with random input.""" + + a = np.random.uniform( + low=0.3, high=0.9, size=(m, n)).astype(_POLAR_TEST_DTYPE) + u, s, v = jnp.linalg.svd(a, full_matrices=False) + cond = 10**log_cond + s = jnp.linspace(cond, 1, min(m, n)) + a = (u * s) @ v + is_symmetric = _check_symmetry(a) + max_iterations = 10 + + def lsp_linalg_fn(a): + u, h, _, _ = qdwh.qdwh( + a, is_symmetric=is_symmetric, max_iterations=max_iterations) + return u, h + + args_maker = lambda: [a] + + # Sets the test tolerance. + rtol = 1E6 * _POLAR_TEST_EPS + + with self.subTest('Test JIT compatibility'): + self._CompileAndCheck(lsp_linalg_fn, args_maker) + + with self.subTest('Test against numpy.'): + self._CheckAgainstNumpy(osp_linalg.polar, lsp_linalg_fn, args_maker, + rtol=rtol, atol=1E-3) + + @parameterized.named_parameters(jtu.cases_from_list( + { # pylint:disable=g-complex-comprehension + 'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond), + 'm': m, 'n': n, 'log_cond': log_cond} + for m, n in zip([10, 12], [10, 12]) + for log_cond in np.linspace(1, 4, 4))) + def testQdwhWithOnRankDeficientInput(self, m, n, log_cond): + """Tests qdwh with rank-deficient input.""" + a = jnp.triu(jnp.ones((m, n))).astype(_POLAR_TEST_DTYPE) + + # Generates a rank-deficient input. + u, s, v = jnp.linalg.svd(a, full_matrices=False) + cond = 10**log_cond + s = jnp.linspace(cond, 1, min(m, n)) + s = s.at[-1].set(0) + a = (u * s) @ v + + is_symmetric = _check_symmetry(a) + max_iterations = 10 + actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations) + _, expected_h = osp_linalg.polar(a) + + # Sets the test tolerance. + rtol = 1E6 * _POLAR_TEST_EPS + + # For rank-deficient matrix, `u` is not unique. + with self.subTest('Test h.'): + self.assertAllClose(actual_h, expected_h, rtol=rtol) + + with self.subTest('Test u.dot(h).'): + a_round_trip = actual_u.dot(actual_h) + self.assertAllClose(a_round_trip, a, rtol=rtol) + + with self.subTest('Test orthogonality.'): + actual_results = actual_u.T.dot(actual_u) + expected_results = np.eye(n) + self.assertAllClose( + actual_results, expected_results, rtol=rtol, atol=1E-5) + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader())