mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
196 lines
6.7 KiB
Python
196 lines
6.7 KiB
Python
![]() |
# Copyright 2021 Google LLC
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License
|
||
|
|
||
|
"""Tests for the library of QDWH-based polar decomposition."""
|
||
|
|
||
|
from jax import test_util as jtu
|
||
|
from jax.config import config
|
||
|
import jax.numpy as jnp
|
||
|
import numpy as np
|
||
|
import scipy.linalg as osp_linalg
|
||
|
from jax._src.lax import qdwh
|
||
|
|
||
|
from absl.testing import absltest
|
||
|
from absl.testing import parameterized
|
||
|
|
||
|
|
||
|
config.parse_flags_with_absl()
|
||
|
_JAX_ENABLE_X64 = config.x64_enabled
|
||
|
|
||
|
# Input matrix data type for PolarTest.
|
||
|
_POLAR_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64 else np.float32
|
||
|
|
||
|
# Machine epsilon used by PolarTest.
|
||
|
_POLAR_TEST_EPS = jnp.finfo(_POLAR_TEST_DTYPE).eps
|
||
|
|
||
|
# Largest log10 value of condition numbers used by PolarTest.
|
||
|
_MAX_LOG_CONDITION_NUM = np.log10(int(1 / _POLAR_TEST_EPS))
|
||
|
|
||
|
|
||
|
def _check_symmetry(x: jnp.ndarray) -> bool:
|
||
|
"""Check if the array is symmetric."""
|
||
|
m, n = x.shape
|
||
|
eps = jnp.finfo(x.dtype).eps
|
||
|
tol = 50.0 * eps
|
||
|
is_symmetric = False
|
||
|
if m == n:
|
||
|
if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol:
|
||
|
is_symmetric = True
|
||
|
|
||
|
return is_symmetric
|
||
|
|
||
|
|
||
|
class PolarTest(jtu.JaxTestCase):
|
||
|
|
||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||
|
{ # pylint:disable=g-complex-comprehension
|
||
|
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond),
|
||
|
'm': m, 'n': n, 'log_cond': log_cond}
|
||
|
for m, n in zip([8, 10, 20], [6, 10, 18])
|
||
|
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
|
||
|
def testQdwhUnconvergedAfterMaxNumberIterations(
|
||
|
self, m, n, log_cond):
|
||
|
"""Tests unconvergence after maximum number of iterations."""
|
||
|
a = jnp.triu(jnp.ones((m, n)))
|
||
|
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||
|
cond = 10**log_cond
|
||
|
s = jnp.linspace(cond, 1, min(m, n))
|
||
|
a = (u * s) @ v
|
||
|
is_symmetric = _check_symmetry(a)
|
||
|
max_iterations = 2
|
||
|
|
||
|
_, _, actual_num_iterations, is_converged = qdwh.qdwh(
|
||
|
a, is_symmetric, max_iterations)
|
||
|
|
||
|
with self.subTest('Number of iterations.'):
|
||
|
self.assertEqual(max_iterations, actual_num_iterations)
|
||
|
|
||
|
with self.subTest('Converged.'):
|
||
|
self.assertFalse(is_converged)
|
||
|
|
||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||
|
{ # pylint:disable=g-complex-comprehension
|
||
|
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond),
|
||
|
'm': m, 'n': n, 'log_cond': log_cond}
|
||
|
for m, n in zip([8, 10, 20], [6, 10, 18])
|
||
|
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
|
||
|
def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond):
|
||
|
"""Tests qdwh with upper triangular input of all ones."""
|
||
|
a = jnp.triu(jnp.ones((m, n)))
|
||
|
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||
|
cond = 10**log_cond
|
||
|
s = jnp.linspace(cond, 1, min(m, n))
|
||
|
a = (u * s) @ v
|
||
|
is_symmetric = _check_symmetry(a)
|
||
|
max_iterations = 10
|
||
|
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations)
|
||
|
expected_u, expected_h = osp_linalg.polar(a)
|
||
|
|
||
|
# Sets the test tolerance.
|
||
|
rtol = 1E6 * _POLAR_TEST_EPS
|
||
|
|
||
|
with self.subTest('Test u.'):
|
||
|
self.assertAllClose(actual_u, expected_u, rtol=rtol)
|
||
|
|
||
|
with self.subTest('Test h.'):
|
||
|
self.assertAllClose(actual_h, expected_h, rtol=rtol)
|
||
|
|
||
|
with self.subTest('Test u.dot(h).'):
|
||
|
a_round_trip = actual_u.dot(actual_h)
|
||
|
self.assertAllClose(a_round_trip, a, rtol=rtol)
|
||
|
|
||
|
with self.subTest('Test orthogonality.'):
|
||
|
actual_results = actual_u.T.dot(actual_u)
|
||
|
expected_results = np.eye(n)
|
||
|
self.assertAllClose(
|
||
|
actual_results, expected_results, rtol=rtol, atol=1E-4)
|
||
|
|
||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||
|
{ # pylint:disable=g-complex-comprehension
|
||
|
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(
|
||
|
m, n, log_cond),
|
||
|
'm': m, 'n': n, 'log_cond': log_cond}
|
||
|
for m, n in zip([6, 8], [6, 4])
|
||
|
for log_cond in np.linspace(1, 4, 4)))
|
||
|
def testQdwhWithRandomMatrix(self, m, n, log_cond):
|
||
|
"""Tests qdwh with random input."""
|
||
|
|
||
|
a = np.random.uniform(
|
||
|
low=0.3, high=0.9, size=(m, n)).astype(_POLAR_TEST_DTYPE)
|
||
|
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||
|
cond = 10**log_cond
|
||
|
s = jnp.linspace(cond, 1, min(m, n))
|
||
|
a = (u * s) @ v
|
||
|
is_symmetric = _check_symmetry(a)
|
||
|
max_iterations = 10
|
||
|
|
||
|
def lsp_linalg_fn(a):
|
||
|
u, h, _, _ = qdwh.qdwh(
|
||
|
a, is_symmetric=is_symmetric, max_iterations=max_iterations)
|
||
|
return u, h
|
||
|
|
||
|
args_maker = lambda: [a]
|
||
|
|
||
|
# Sets the test tolerance.
|
||
|
rtol = 1E6 * _POLAR_TEST_EPS
|
||
|
|
||
|
with self.subTest('Test JIT compatibility'):
|
||
|
self._CompileAndCheck(lsp_linalg_fn, args_maker)
|
||
|
|
||
|
with self.subTest('Test against numpy.'):
|
||
|
self._CheckAgainstNumpy(osp_linalg.polar, lsp_linalg_fn, args_maker,
|
||
|
rtol=rtol, atol=1E-3)
|
||
|
|
||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||
|
{ # pylint:disable=g-complex-comprehension
|
||
|
'testcase_name': '_m={}_by_n={}_log_cond={}'.format(m, n, log_cond),
|
||
|
'm': m, 'n': n, 'log_cond': log_cond}
|
||
|
for m, n in zip([10, 12], [10, 12])
|
||
|
for log_cond in np.linspace(1, 4, 4)))
|
||
|
def testQdwhWithOnRankDeficientInput(self, m, n, log_cond):
|
||
|
"""Tests qdwh with rank-deficient input."""
|
||
|
a = jnp.triu(jnp.ones((m, n))).astype(_POLAR_TEST_DTYPE)
|
||
|
|
||
|
# Generates a rank-deficient input.
|
||
|
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||
|
cond = 10**log_cond
|
||
|
s = jnp.linspace(cond, 1, min(m, n))
|
||
|
s = s.at[-1].set(0)
|
||
|
a = (u * s) @ v
|
||
|
|
||
|
is_symmetric = _check_symmetry(a)
|
||
|
max_iterations = 10
|
||
|
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_symmetric, max_iterations)
|
||
|
_, expected_h = osp_linalg.polar(a)
|
||
|
|
||
|
# Sets the test tolerance.
|
||
|
rtol = 1E6 * _POLAR_TEST_EPS
|
||
|
|
||
|
# For rank-deficient matrix, `u` is not unique.
|
||
|
with self.subTest('Test h.'):
|
||
|
self.assertAllClose(actual_h, expected_h, rtol=rtol)
|
||
|
|
||
|
with self.subTest('Test u.dot(h).'):
|
||
|
a_round_trip = actual_u.dot(actual_h)
|
||
|
self.assertAllClose(a_round_trip, a, rtol=rtol)
|
||
|
|
||
|
with self.subTest('Test orthogonality.'):
|
||
|
actual_results = actual_u.T.dot(actual_u)
|
||
|
expected_results = np.eye(n)
|
||
|
self.assertAllClose(
|
||
|
actual_results, expected_results, rtol=rtol, atol=1E-5)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
absltest.main(testLoader=jtu.JaxTestLoader())
|