Add test for QDWH with dynamic shapes.

PiperOrigin-RevId: 643087130
This commit is contained in:
jax authors 2024-06-13 12:32:37 -07:00 committed by jax authors
parent 0dc706d79f
commit dd3b0a6981

View File

@ -64,12 +64,18 @@ class QdwhTest(jtu.JaxTestCase):
self._testUnitary(u, tol)
self._testHermitian(h, tol)
def _testQdwh(self, a, is_hermitian=False):
def _testQdwh(self, a, is_hermitian=False, dynamic_shape=None):
"""Computes the polar decomposition and tests its basic properties."""
eps = jnp.finfo(a.dtype).eps
u, h, iters, conv = qdwh.qdwh(a, is_hermitian=is_hermitian)
u, h, iters, conv = qdwh.qdwh(
a, is_hermitian=is_hermitian, dynamic_shape=dynamic_shape
)
tol = 10 * eps
if dynamic_shape is not None:
m, n = dynamic_shape
a = a[:m, :n]
u = u[:m, :n]
h = h[:n, :n]
self._testPolarDecomposition(a, u, h, tol=tol)
@jtu.sample_product(
@ -83,6 +89,40 @@ class QdwhTest(jtu.JaxTestCase):
a = jnp.triu(jnp.ones((m, n))).astype(dtype)
self._testQdwh(a)
@jtu.sample_product(
shape=[(2, 2), (5, 5), (8, 5), (10, 10)],
is_hermitian=[True, False],
dtype=float_types + complex_types,
)
def testQdwhWithDynamicShape(self, shape, is_hermitian, dtype):
"""Tests qdwh with dynamic shapes."""
if is_hermitian & (shape[0] != shape[1]):
self.skipTest('Invalid combination')
rng = jtu.rand_uniform(self.rng())
a = rng((10, 10), dtype)
if is_hermitian:
a = (a + a.conj().T) / 2
self._testQdwh(a, is_hermitian=is_hermitian, dynamic_shape=shape)
@jtu.sample_product(
shape=[(8, 6), (5, 5), (10, 10), (20, 18), (300, 300)],
log_cond=np.linspace(0, 1, 4),
hermitian=[True, False],
dtype=float_types + complex_types,
)
def testQdwhWithRandomMatrix0(self, shape, log_cond, hermitian, dtype):
"""Tests qdwh with upper triangular input of all ones."""
m, n = shape
rng = jtu.rand_uniform(self.rng())
a = rng((m, n), dtype)
# Generates input matrix with prescribed condition number.
hermitian = hermitian & (m == n)
if hermitian:
a = (a + a.conj().T) / 2
self._testQdwh(a, is_hermitian=hermitian)
@jtu.sample_product(
shape=[(8, 6), (10, 10), (20, 18), (300, 300)],
log_cond=np.linspace(0, 1, 4),