Update test of QDWH to use stricter tolerances and test more shapes and types.

Get rid of comparison with scipy.linalg.polar, since its outputs are significantly less accurate than QDWH. Since the polar decomposition is unique, comparing to a less accurate implementation does not add value.

PiperOrigin-RevId: 642423757
This commit is contained in:
jax authors 2024-06-11 16:03:55 -07:00 committed by jax authors
parent ebe0ab0b51
commit 95e2c17b61
2 changed files with 104 additions and 142 deletions

View File

@ -784,6 +784,7 @@ jax_test(
jax_test(
name = "qdwh_test",
srcs = ["qdwh_test.py"],
shard_count = 10,
)
jax_test(

View File

@ -15,181 +15,145 @@
"""Tests for the library of QDWH-based polar decomposition."""
import functools
from absl.testing import absltest
import jax
import jax.numpy as jnp
import numpy as np
import scipy.linalg as osp_linalg
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lax import qdwh
from absl.testing import absltest
import jax.numpy as jnp
import numpy as np
config.parse_flags_with_absl()
_JAX_ENABLE_X64_QDWH = config.enable_x64.value
# Input matrix data type for QdwhTest.
_QDWH_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64_QDWH else np.float32
# Machine epsilon used by QdwhTest.
_QDWH_TEST_EPS = jnp.finfo(_QDWH_TEST_DTYPE).eps
# Largest log10 value of condition numbers used by QdwhTest.
_MAX_LOG_CONDITION_NUM = np.log10(int(1 / _QDWH_TEST_EPS))
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
def _check_symmetry(x: jax.Array) -> bool:
"""Check if the array is symmetric."""
m, n = x.shape
eps = jnp.finfo(x.dtype).eps
tol = 50.0 * eps
is_hermitian = False
if m == n:
if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol:
is_hermitian = True
return is_hermitian
def _compute_relative_diff(actual, expected):
def _compute_relative_normwise_diff(actual, expected):
"""Computes relative difference between two matrices."""
return np.linalg.norm(actual - expected) / np.linalg.norm(expected)
_dot = functools.partial(jnp.dot, precision="highest")
_dot = functools.partial(jnp.dot, precision='highest')
class QdwhTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(m=m, n=n) for m, n in [(8, 6), (10, 10), (20, 18)]],
log_cond=np.linspace(1, _MAX_LOG_CONDITION_NUM, 4),
)
def testQdwhUnconvergedAfterMaxNumberIterations(
self, m, n, log_cond):
"""Tests unconvergence after maximum number of iterations."""
a = jnp.triu(jnp.ones((m, n)))
u, s, v = jnp.linalg.svd(a, full_matrices=False)
cond = 10**log_cond
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
with jax.numpy_dtype_promotion('standard'):
a = (u * s) @ v
is_hermitian = _check_symmetry(a)
max_iterations = 2
def _testReconstruction(self, a, u, h, tol):
"""Tests that a = u*p."""
with self.subTest('Test reconstruction'):
diff = _compute_relative_normwise_diff(_dot(u, h), a)
self.assertLessEqual(diff, tol)
_, _, actual_num_iterations, is_converged = qdwh.qdwh(
a, is_hermitian=is_hermitian, max_iterations=max_iterations)
with self.subTest('Number of iterations.'):
self.assertEqual(max_iterations, actual_num_iterations)
with self.subTest('Converged.'):
self.assertFalse(is_converged)
@jtu.sample_product(
[dict(m=m, n=n) for m, n in [(8, 6), (10, 10), (20, 18)]],
log_cond=np.linspace(1, _MAX_LOG_CONDITION_NUM, 4),
)
def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond):
"""Tests qdwh with upper triangular input of all ones."""
a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE)
u, s, v = jnp.linalg.svd(a, full_matrices=False)
cond = 10**log_cond
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
a = (u * s) @ v
is_hermitian = _check_symmetry(a)
max_iterations = 10
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian,
max_iterations=max_iterations)
expected_u, expected_h = osp_linalg.polar(a)
# Sets the test tolerance.
rtol = 1E6 * _QDWH_TEST_EPS
with self.subTest('Test u.'):
relative_diff_u = _compute_relative_diff(actual_u, expected_u)
np.testing.assert_almost_equal(relative_diff_u, 1E-6, decimal=5)
with self.subTest('Test h.'):
relative_diff_h = _compute_relative_diff(actual_h, expected_h)
np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5)
with self.subTest('Test u.dot(h).'):
a_round_trip = _dot(actual_u, actual_h)
relative_diff_a = _compute_relative_diff(a_round_trip, a)
np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5)
with self.subTest('Test orthogonality.'):
actual_results = _dot(actual_u.T, actual_u)
expected_results = np.eye(n)
def _testUnitary(self, u, tol):
"""Tests that u is unitary."""
with self.subTest('Test unitary'):
m, n = u.shape
self.assertAllClose(
actual_results, expected_results, rtol=rtol, atol=1E-5)
_dot(u.conj().T, u), np.eye(n, dtype=u.dtype), atol=tol, rtol=tol
)
def _testHermitian(self, h, tol):
"""Tests that h is Hermitian."""
with self.subTest('Test hermitian'):
self.assertAllClose(h, h.conj().T, atol=tol, rtol=tol)
def _testPolarDecomposition(self, a, u, h, tol):
"""Tests that u*h is the polar decomposition of a"""
self._testReconstruction(a, u, h, tol)
self._testUnitary(u, tol)
self._testHermitian(h, tol)
def _testQdwh(self, a, is_hermitian=False):
"""Computes the polar decomposition and tests its basic properties."""
eps = jnp.finfo(a.dtype).eps
u, h, iters, conv = qdwh.qdwh(a, is_hermitian=is_hermitian)
tol = 10 * eps
self._testPolarDecomposition(a, u, h, tol=tol)
@jtu.sample_product(
[dict(m=m, n=n) for m, n in [(6, 6), (8, 4)]],
padding=(None, (3, 2)),
log_cond=np.linspace(1, 4, 4),
shape=[(8, 6), (10, 10), (20, 18)],
dtype=float_types + complex_types,
)
def testQdwhWithRandomMatrix(self, m, n, log_cond, padding):
"""Tests qdwh with random input."""
rng = jtu.rand_uniform(self.rng(), low=0.3, high=0.9)
a = rng((m, n), _QDWH_TEST_DTYPE)
u, s, v = jnp.linalg.svd(a, full_matrices=False)
cond = 10**log_cond
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
a = (u * s) @ v
is_hermitian = _check_symmetry(a)
max_iterations = 10
def testQdwhWithUpperTriangularInputAllOnes(self, shape, dtype):
"""Tests qdwh with upper triangular input of all ones."""
eps = jnp.finfo(dtype).eps
m, n = shape
a = jnp.triu(jnp.ones((m, n))).astype(dtype)
self._testQdwh(a)
@jtu.sample_product(
shape=[(8, 6), (10, 10), (20, 18), (300, 300)],
log_cond=np.linspace(0, 1, 4),
hermitian=[True, False],
dtype=float_types + complex_types,
)
def testQdwhWithRandomMatrix(self, shape, log_cond, hermitian, dtype):
"""Tests qdwh with upper triangular input of all ones."""
eps = jnp.finfo(dtype).eps
m, n = shape
max_cond = np.log10(1.0 / eps)
log_cond = log_cond * max_cond
cond = 10**log_cond
# Generates input matrix with prescribed condition number.
rng = jtu.rand_uniform(self.rng())
a = rng((m, n), dtype)
u, _, v = jnp.linalg.svd(a, full_matrices=False)
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
a = (u * s.astype(u.dtype)) @ v
hermitian = hermitian & (m == n)
if hermitian:
a = (a + a.conj().T) / 2
self._testQdwh(a, is_hermitian=hermitian)
@jtu.sample_product(
[dict(m=m, n=n) for m, n in [(6, 6), (8, 4)]],
padding=(None, (3, 2)),
dtype=float_types + complex_types,
)
def testQdwhJitCompatibility(self, m, n, padding, dtype):
"""Tests JIT compilation of QDWH with and without dynamic shape."""
rng = jtu.rand_uniform(self.rng())
a = rng((m, n), dtype)
def lsp_linalg_fn(a):
if padding is not None:
pm, pn = padding
a = jnp.pad(a, [(0, pm), (0, pn)], constant_values=jnp.nan)
u, h, _, _ = qdwh.qdwh(
a, is_hermitian=is_hermitian, max_iterations=max_iterations,
dynamic_shape=(m, n) if padding else None)
u, h, _, _ = qdwh.qdwh(a, dynamic_shape=(m, n) if padding else None)
if padding is not None:
u = u[:m, :n]
h = h[:n, :n]
return u, h
args_maker = lambda: [a]
# Sets the test tolerance.
rtol = 1E6 * _QDWH_TEST_EPS
with self.subTest('Test JIT compatibility'):
self._CompileAndCheck(lsp_linalg_fn, args_maker)
with self.subTest('Test against numpy.'):
self._CheckAgainstNumpy(osp_linalg.polar, lsp_linalg_fn, args_maker,
rtol=rtol, atol=1E-3)
@jtu.sample_product(
[dict(m=m, n=n, r=r) for m, n, r in [(10, 10, 8), (8, 8, 7), (12, 8, 5)]],
log_cond=np.linspace(1, 4, 4),
[dict(m=m, n=n, r=r) for m, n, r in [(10, 10, 8), (8, 8, 7), (12, 8, 5)]],
log_cond=np.linspace(0, 1, 4),
dtype=float_types + complex_types,
)
def testQdwhOnRankDeficientInput(self, m, n, r, log_cond):
def testQdwhOnRankDeficientInput(self, m, n, r, log_cond, dtype):
"""Tests qdwh on rank-deficient input."""
a = np.triu(np.ones((m, n))).astype(_QDWH_TEST_DTYPE)
eps = jnp.finfo(dtype).eps
a = np.triu(np.ones((m, n))).astype(dtype)
# Generates a rank-deficient input.
# Generates a rank-deficient input with prescribed condition number.
max_cond = np.log10(1.0 / eps)
log_cond = log_cond * max_cond
u, _, vh = np.linalg.svd(a, full_matrices=False)
s = 10**jnp.linspace(log_cond, 0, min(m, n))
print(s)
s = jnp.expand_dims(s.at[r:].set(0), range(u.ndim - 1))
a = (u * s) @ vh
a = (u * s.astype(u.dtype)) @ vh
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=_check_symmetry(a))
_, expected_h = osp_linalg.polar(a)
actual_u, actual_h, _, _ = qdwh.qdwh(a)
with self.subTest('Test h.'):
relative_diff_h = _compute_relative_diff(actual_h, expected_h)
np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5)
with self.subTest('Test u.dot(h).'):
a_round_trip = _dot(actual_u, actual_h)
relative_diff_a = _compute_relative_diff(a_round_trip, a)
np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5)
self._testHermitian(actual_h, 10 * eps)
self._testReconstruction(a, actual_u, actual_h, 60 * eps)
# QDWH gives U_p = U Σₖ V* for input A with SVD A = U Σ V*. For full rank
# input, we expect convergence Σₖ → I, giving the correct polar factor
@ -202,34 +166,31 @@ class QdwhTest(jtu.JaxTestCase):
vr = vh.conj().T[:, :r]
uvr = _dot(actual_u, vr)
actual_results = _dot(uvr.T.conj(), uvr)
expected_results = np.eye(r)
expected_results = np.eye(r, dtype=actual_u.dtype)
self.assertAllClose(
actual_results, expected_results, rtol=_QDWH_TEST_EPS, atol=1e-6
actual_results, expected_results, atol=25 * eps, rtol=25 * eps
)
@jtu.sample_product(
[dict(m=m, n=n, r=r, c=c) for m, n, r, c in [(4, 3, 1, 1), (5, 2, 0, 0)]],
dtype=jtu.dtypes.floating,
[dict(m=m, n=n, r=r, c=c) for m, n, r, c in [(4, 3, 1, 1), (5, 2, 0, 0)]],
dtype=float_types + complex_types,
)
def testQdwhWithTinyElement(self, m, n, r, c, dtype):
"""Tests qdwh on matrix with zeros and close-to-zero entries."""
a = jnp.zeros((m, n), dtype=dtype)
tiny_elem = jnp.finfo(a.dtype).tiny
one = dtype(1.0)
tiny_elem = dtype(jnp.finfo(a.dtype).tiny)
a = a.at[r, c].set(tiny_elem)
is_hermitian = _check_symmetry(a)
max_iterations = 10
@jax.jit
def lsp_linalg_fn(a):
u, h, _, _ = qdwh.qdwh(
a, is_hermitian=is_hermitian, max_iterations=max_iterations)
u, h, _, _ = qdwh.qdwh(a)
return u, h
actual_u, actual_h = lsp_linalg_fn(a)
expected_u = jnp.zeros((m, n), dtype=dtype)
expected_u = expected_u.at[r, c].set(1.0)
expected_u = expected_u.at[r, c].set(one)
with self.subTest('Test u.'):
np.testing.assert_array_equal(expected_u, actual_u)