mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Update test of QDWH to use stricter tolerances and test more shapes and types.
Get rid of comparison with scipy.linalg.polar, since its outputs are significantly less accurate than QDWH. Since the polar decomposition is unique, comparing to a less accurate implementation does not add value. PiperOrigin-RevId: 642423757
This commit is contained in:
parent
ebe0ab0b51
commit
95e2c17b61
@ -784,6 +784,7 @@ jax_test(
|
||||
jax_test(
|
||||
name = "qdwh_test",
|
||||
srcs = ["qdwh_test.py"],
|
||||
shard_count = 10,
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
@ -15,181 +15,145 @@
|
||||
"""Tests for the library of QDWH-based polar decomposition."""
|
||||
import functools
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import scipy.linalg as osp_linalg
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import qdwh
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
_JAX_ENABLE_X64_QDWH = config.enable_x64.value
|
||||
|
||||
# Input matrix data type for QdwhTest.
|
||||
_QDWH_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64_QDWH else np.float32
|
||||
|
||||
# Machine epsilon used by QdwhTest.
|
||||
_QDWH_TEST_EPS = jnp.finfo(_QDWH_TEST_DTYPE).eps
|
||||
|
||||
# Largest log10 value of condition numbers used by QdwhTest.
|
||||
_MAX_LOG_CONDITION_NUM = np.log10(int(1 / _QDWH_TEST_EPS))
|
||||
float_types = jtu.dtypes.floating
|
||||
complex_types = jtu.dtypes.complex
|
||||
|
||||
|
||||
def _check_symmetry(x: jax.Array) -> bool:
|
||||
"""Check if the array is symmetric."""
|
||||
m, n = x.shape
|
||||
eps = jnp.finfo(x.dtype).eps
|
||||
tol = 50.0 * eps
|
||||
is_hermitian = False
|
||||
if m == n:
|
||||
if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol:
|
||||
is_hermitian = True
|
||||
|
||||
return is_hermitian
|
||||
|
||||
def _compute_relative_diff(actual, expected):
|
||||
def _compute_relative_normwise_diff(actual, expected):
|
||||
"""Computes relative difference between two matrices."""
|
||||
return np.linalg.norm(actual - expected) / np.linalg.norm(expected)
|
||||
|
||||
_dot = functools.partial(jnp.dot, precision="highest")
|
||||
|
||||
_dot = functools.partial(jnp.dot, precision='highest')
|
||||
|
||||
|
||||
class QdwhTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(m=m, n=n) for m, n in [(8, 6), (10, 10), (20, 18)]],
|
||||
log_cond=np.linspace(1, _MAX_LOG_CONDITION_NUM, 4),
|
||||
)
|
||||
def testQdwhUnconvergedAfterMaxNumberIterations(
|
||||
self, m, n, log_cond):
|
||||
"""Tests unconvergence after maximum number of iterations."""
|
||||
a = jnp.triu(jnp.ones((m, n)))
|
||||
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||||
cond = 10**log_cond
|
||||
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
|
||||
with jax.numpy_dtype_promotion('standard'):
|
||||
a = (u * s) @ v
|
||||
is_hermitian = _check_symmetry(a)
|
||||
max_iterations = 2
|
||||
def _testReconstruction(self, a, u, h, tol):
|
||||
"""Tests that a = u*p."""
|
||||
with self.subTest('Test reconstruction'):
|
||||
diff = _compute_relative_normwise_diff(_dot(u, h), a)
|
||||
self.assertLessEqual(diff, tol)
|
||||
|
||||
_, _, actual_num_iterations, is_converged = qdwh.qdwh(
|
||||
a, is_hermitian=is_hermitian, max_iterations=max_iterations)
|
||||
|
||||
with self.subTest('Number of iterations.'):
|
||||
self.assertEqual(max_iterations, actual_num_iterations)
|
||||
|
||||
with self.subTest('Converged.'):
|
||||
self.assertFalse(is_converged)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(m=m, n=n) for m, n in [(8, 6), (10, 10), (20, 18)]],
|
||||
log_cond=np.linspace(1, _MAX_LOG_CONDITION_NUM, 4),
|
||||
)
|
||||
def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond):
|
||||
"""Tests qdwh with upper triangular input of all ones."""
|
||||
a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE)
|
||||
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||||
cond = 10**log_cond
|
||||
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
|
||||
a = (u * s) @ v
|
||||
is_hermitian = _check_symmetry(a)
|
||||
max_iterations = 10
|
||||
|
||||
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian,
|
||||
max_iterations=max_iterations)
|
||||
expected_u, expected_h = osp_linalg.polar(a)
|
||||
|
||||
# Sets the test tolerance.
|
||||
rtol = 1E6 * _QDWH_TEST_EPS
|
||||
|
||||
with self.subTest('Test u.'):
|
||||
relative_diff_u = _compute_relative_diff(actual_u, expected_u)
|
||||
np.testing.assert_almost_equal(relative_diff_u, 1E-6, decimal=5)
|
||||
|
||||
with self.subTest('Test h.'):
|
||||
relative_diff_h = _compute_relative_diff(actual_h, expected_h)
|
||||
np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5)
|
||||
|
||||
with self.subTest('Test u.dot(h).'):
|
||||
a_round_trip = _dot(actual_u, actual_h)
|
||||
relative_diff_a = _compute_relative_diff(a_round_trip, a)
|
||||
np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5)
|
||||
|
||||
with self.subTest('Test orthogonality.'):
|
||||
actual_results = _dot(actual_u.T, actual_u)
|
||||
expected_results = np.eye(n)
|
||||
def _testUnitary(self, u, tol):
|
||||
"""Tests that u is unitary."""
|
||||
with self.subTest('Test unitary'):
|
||||
m, n = u.shape
|
||||
self.assertAllClose(
|
||||
actual_results, expected_results, rtol=rtol, atol=1E-5)
|
||||
_dot(u.conj().T, u), np.eye(n, dtype=u.dtype), atol=tol, rtol=tol
|
||||
)
|
||||
|
||||
def _testHermitian(self, h, tol):
|
||||
"""Tests that h is Hermitian."""
|
||||
with self.subTest('Test hermitian'):
|
||||
self.assertAllClose(h, h.conj().T, atol=tol, rtol=tol)
|
||||
|
||||
def _testPolarDecomposition(self, a, u, h, tol):
|
||||
"""Tests that u*h is the polar decomposition of a"""
|
||||
self._testReconstruction(a, u, h, tol)
|
||||
self._testUnitary(u, tol)
|
||||
self._testHermitian(h, tol)
|
||||
|
||||
def _testQdwh(self, a, is_hermitian=False):
|
||||
"""Computes the polar decomposition and tests its basic properties."""
|
||||
eps = jnp.finfo(a.dtype).eps
|
||||
|
||||
u, h, iters, conv = qdwh.qdwh(a, is_hermitian=is_hermitian)
|
||||
tol = 10 * eps
|
||||
self._testPolarDecomposition(a, u, h, tol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(m=m, n=n) for m, n in [(6, 6), (8, 4)]],
|
||||
padding=(None, (3, 2)),
|
||||
log_cond=np.linspace(1, 4, 4),
|
||||
shape=[(8, 6), (10, 10), (20, 18)],
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
def testQdwhWithRandomMatrix(self, m, n, log_cond, padding):
|
||||
"""Tests qdwh with random input."""
|
||||
rng = jtu.rand_uniform(self.rng(), low=0.3, high=0.9)
|
||||
a = rng((m, n), _QDWH_TEST_DTYPE)
|
||||
u, s, v = jnp.linalg.svd(a, full_matrices=False)
|
||||
cond = 10**log_cond
|
||||
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
|
||||
a = (u * s) @ v
|
||||
is_hermitian = _check_symmetry(a)
|
||||
max_iterations = 10
|
||||
def testQdwhWithUpperTriangularInputAllOnes(self, shape, dtype):
|
||||
"""Tests qdwh with upper triangular input of all ones."""
|
||||
eps = jnp.finfo(dtype).eps
|
||||
m, n = shape
|
||||
a = jnp.triu(jnp.ones((m, n))).astype(dtype)
|
||||
self._testQdwh(a)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(8, 6), (10, 10), (20, 18), (300, 300)],
|
||||
log_cond=np.linspace(0, 1, 4),
|
||||
hermitian=[True, False],
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
def testQdwhWithRandomMatrix(self, shape, log_cond, hermitian, dtype):
|
||||
"""Tests qdwh with upper triangular input of all ones."""
|
||||
eps = jnp.finfo(dtype).eps
|
||||
m, n = shape
|
||||
max_cond = np.log10(1.0 / eps)
|
||||
log_cond = log_cond * max_cond
|
||||
cond = 10**log_cond
|
||||
|
||||
# Generates input matrix with prescribed condition number.
|
||||
rng = jtu.rand_uniform(self.rng())
|
||||
a = rng((m, n), dtype)
|
||||
u, _, v = jnp.linalg.svd(a, full_matrices=False)
|
||||
s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
|
||||
a = (u * s.astype(u.dtype)) @ v
|
||||
hermitian = hermitian & (m == n)
|
||||
if hermitian:
|
||||
a = (a + a.conj().T) / 2
|
||||
self._testQdwh(a, is_hermitian=hermitian)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(m=m, n=n) for m, n in [(6, 6), (8, 4)]],
|
||||
padding=(None, (3, 2)),
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
def testQdwhJitCompatibility(self, m, n, padding, dtype):
|
||||
"""Tests JIT compilation of QDWH with and without dynamic shape."""
|
||||
rng = jtu.rand_uniform(self.rng())
|
||||
a = rng((m, n), dtype)
|
||||
def lsp_linalg_fn(a):
|
||||
if padding is not None:
|
||||
pm, pn = padding
|
||||
a = jnp.pad(a, [(0, pm), (0, pn)], constant_values=jnp.nan)
|
||||
u, h, _, _ = qdwh.qdwh(
|
||||
a, is_hermitian=is_hermitian, max_iterations=max_iterations,
|
||||
dynamic_shape=(m, n) if padding else None)
|
||||
u, h, _, _ = qdwh.qdwh(a, dynamic_shape=(m, n) if padding else None)
|
||||
if padding is not None:
|
||||
u = u[:m, :n]
|
||||
h = h[:n, :n]
|
||||
return u, h
|
||||
|
||||
args_maker = lambda: [a]
|
||||
|
||||
# Sets the test tolerance.
|
||||
rtol = 1E6 * _QDWH_TEST_EPS
|
||||
|
||||
with self.subTest('Test JIT compatibility'):
|
||||
self._CompileAndCheck(lsp_linalg_fn, args_maker)
|
||||
|
||||
with self.subTest('Test against numpy.'):
|
||||
self._CheckAgainstNumpy(osp_linalg.polar, lsp_linalg_fn, args_maker,
|
||||
rtol=rtol, atol=1E-3)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(m=m, n=n, r=r) for m, n, r in [(10, 10, 8), (8, 8, 7), (12, 8, 5)]],
|
||||
log_cond=np.linspace(1, 4, 4),
|
||||
[dict(m=m, n=n, r=r) for m, n, r in [(10, 10, 8), (8, 8, 7), (12, 8, 5)]],
|
||||
log_cond=np.linspace(0, 1, 4),
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
def testQdwhOnRankDeficientInput(self, m, n, r, log_cond):
|
||||
def testQdwhOnRankDeficientInput(self, m, n, r, log_cond, dtype):
|
||||
"""Tests qdwh on rank-deficient input."""
|
||||
a = np.triu(np.ones((m, n))).astype(_QDWH_TEST_DTYPE)
|
||||
eps = jnp.finfo(dtype).eps
|
||||
a = np.triu(np.ones((m, n))).astype(dtype)
|
||||
|
||||
# Generates a rank-deficient input.
|
||||
# Generates a rank-deficient input with prescribed condition number.
|
||||
max_cond = np.log10(1.0 / eps)
|
||||
log_cond = log_cond * max_cond
|
||||
u, _, vh = np.linalg.svd(a, full_matrices=False)
|
||||
s = 10**jnp.linspace(log_cond, 0, min(m, n))
|
||||
print(s)
|
||||
s = jnp.expand_dims(s.at[r:].set(0), range(u.ndim - 1))
|
||||
a = (u * s) @ vh
|
||||
a = (u * s.astype(u.dtype)) @ vh
|
||||
|
||||
actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=_check_symmetry(a))
|
||||
_, expected_h = osp_linalg.polar(a)
|
||||
actual_u, actual_h, _, _ = qdwh.qdwh(a)
|
||||
|
||||
with self.subTest('Test h.'):
|
||||
relative_diff_h = _compute_relative_diff(actual_h, expected_h)
|
||||
np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5)
|
||||
|
||||
with self.subTest('Test u.dot(h).'):
|
||||
a_round_trip = _dot(actual_u, actual_h)
|
||||
relative_diff_a = _compute_relative_diff(a_round_trip, a)
|
||||
np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5)
|
||||
self._testHermitian(actual_h, 10 * eps)
|
||||
self._testReconstruction(a, actual_u, actual_h, 60 * eps)
|
||||
|
||||
# QDWH gives U_p = U Σₖ V* for input A with SVD A = U Σ V*. For full rank
|
||||
# input, we expect convergence Σₖ → I, giving the correct polar factor
|
||||
@ -202,34 +166,31 @@ class QdwhTest(jtu.JaxTestCase):
|
||||
vr = vh.conj().T[:, :r]
|
||||
uvr = _dot(actual_u, vr)
|
||||
actual_results = _dot(uvr.T.conj(), uvr)
|
||||
expected_results = np.eye(r)
|
||||
expected_results = np.eye(r, dtype=actual_u.dtype)
|
||||
self.assertAllClose(
|
||||
actual_results, expected_results, rtol=_QDWH_TEST_EPS, atol=1e-6
|
||||
actual_results, expected_results, atol=25 * eps, rtol=25 * eps
|
||||
)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(m=m, n=n, r=r, c=c) for m, n, r, c in [(4, 3, 1, 1), (5, 2, 0, 0)]],
|
||||
dtype=jtu.dtypes.floating,
|
||||
[dict(m=m, n=n, r=r, c=c) for m, n, r, c in [(4, 3, 1, 1), (5, 2, 0, 0)]],
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
def testQdwhWithTinyElement(self, m, n, r, c, dtype):
|
||||
"""Tests qdwh on matrix with zeros and close-to-zero entries."""
|
||||
a = jnp.zeros((m, n), dtype=dtype)
|
||||
tiny_elem = jnp.finfo(a.dtype).tiny
|
||||
one = dtype(1.0)
|
||||
tiny_elem = dtype(jnp.finfo(a.dtype).tiny)
|
||||
a = a.at[r, c].set(tiny_elem)
|
||||
|
||||
is_hermitian = _check_symmetry(a)
|
||||
max_iterations = 10
|
||||
|
||||
@jax.jit
|
||||
def lsp_linalg_fn(a):
|
||||
u, h, _, _ = qdwh.qdwh(
|
||||
a, is_hermitian=is_hermitian, max_iterations=max_iterations)
|
||||
u, h, _, _ = qdwh.qdwh(a)
|
||||
return u, h
|
||||
|
||||
actual_u, actual_h = lsp_linalg_fn(a)
|
||||
|
||||
expected_u = jnp.zeros((m, n), dtype=dtype)
|
||||
expected_u = expected_u.at[r, c].set(1.0)
|
||||
expected_u = expected_u.at[r, c].set(one)
|
||||
with self.subTest('Test u.'):
|
||||
np.testing.assert_array_equal(expected_u, actual_u)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user