diff --git a/tests/qdwh_test.py b/tests/qdwh_test.py index e6d299ae9..67a081e41 100644 --- a/tests/qdwh_test.py +++ b/tests/qdwh_test.py @@ -64,12 +64,18 @@ class QdwhTest(jtu.JaxTestCase): self._testUnitary(u, tol) self._testHermitian(h, tol) - def _testQdwh(self, a, is_hermitian=False): + def _testQdwh(self, a, is_hermitian=False, dynamic_shape=None): """Computes the polar decomposition and tests its basic properties.""" eps = jnp.finfo(a.dtype).eps - - u, h, iters, conv = qdwh.qdwh(a, is_hermitian=is_hermitian) + u, h, iters, conv = qdwh.qdwh( + a, is_hermitian=is_hermitian, dynamic_shape=dynamic_shape + ) tol = 10 * eps + if dynamic_shape is not None: + m, n = dynamic_shape + a = a[:m, :n] + u = u[:m, :n] + h = h[:n, :n] self._testPolarDecomposition(a, u, h, tol=tol) @jtu.sample_product( @@ -83,6 +89,40 @@ class QdwhTest(jtu.JaxTestCase): a = jnp.triu(jnp.ones((m, n))).astype(dtype) self._testQdwh(a) + @jtu.sample_product( + shape=[(2, 2), (5, 5), (8, 5), (10, 10)], + is_hermitian=[True, False], + dtype=float_types + complex_types, + ) + def testQdwhWithDynamicShape(self, shape, is_hermitian, dtype): + """Tests qdwh with dynamic shapes.""" + if is_hermitian & (shape[0] != shape[1]): + self.skipTest('Invalid combination') + rng = jtu.rand_uniform(self.rng()) + a = rng((10, 10), dtype) + + if is_hermitian: + a = (a + a.conj().T) / 2 + self._testQdwh(a, is_hermitian=is_hermitian, dynamic_shape=shape) + + @jtu.sample_product( + shape=[(8, 6), (5, 5), (10, 10), (20, 18), (300, 300)], + log_cond=np.linspace(0, 1, 4), + hermitian=[True, False], + dtype=float_types + complex_types, + ) + def testQdwhWithRandomMatrix0(self, shape, log_cond, hermitian, dtype): + """Tests qdwh with upper triangular input of all ones.""" + m, n = shape + rng = jtu.rand_uniform(self.rng()) + a = rng((m, n), dtype) + + # Generates input matrix with prescribed condition number. + hermitian = hermitian & (m == n) + if hermitian: + a = (a + a.conj().T) / 2 + self._testQdwh(a, is_hermitian=hermitian) + @jtu.sample_product( shape=[(8, 6), (10, 10), (20, 18), (300, 300)], log_cond=np.linspace(0, 1, 4),