mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add test for QDWH with dynamic shapes.
PiperOrigin-RevId: 643087130
This commit is contained in:
parent
0dc706d79f
commit
dd3b0a6981
@ -64,12 +64,18 @@ class QdwhTest(jtu.JaxTestCase):
|
||||
self._testUnitary(u, tol)
|
||||
self._testHermitian(h, tol)
|
||||
|
||||
def _testQdwh(self, a, is_hermitian=False):
|
||||
def _testQdwh(self, a, is_hermitian=False, dynamic_shape=None):
|
||||
"""Computes the polar decomposition and tests its basic properties."""
|
||||
eps = jnp.finfo(a.dtype).eps
|
||||
|
||||
u, h, iters, conv = qdwh.qdwh(a, is_hermitian=is_hermitian)
|
||||
u, h, iters, conv = qdwh.qdwh(
|
||||
a, is_hermitian=is_hermitian, dynamic_shape=dynamic_shape
|
||||
)
|
||||
tol = 10 * eps
|
||||
if dynamic_shape is not None:
|
||||
m, n = dynamic_shape
|
||||
a = a[:m, :n]
|
||||
u = u[:m, :n]
|
||||
h = h[:n, :n]
|
||||
self._testPolarDecomposition(a, u, h, tol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -83,6 +89,40 @@ class QdwhTest(jtu.JaxTestCase):
|
||||
a = jnp.triu(jnp.ones((m, n))).astype(dtype)
|
||||
self._testQdwh(a)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(2, 2), (5, 5), (8, 5), (10, 10)],
|
||||
is_hermitian=[True, False],
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
def testQdwhWithDynamicShape(self, shape, is_hermitian, dtype):
|
||||
"""Tests qdwh with dynamic shapes."""
|
||||
if is_hermitian & (shape[0] != shape[1]):
|
||||
self.skipTest('Invalid combination')
|
||||
rng = jtu.rand_uniform(self.rng())
|
||||
a = rng((10, 10), dtype)
|
||||
|
||||
if is_hermitian:
|
||||
a = (a + a.conj().T) / 2
|
||||
self._testQdwh(a, is_hermitian=is_hermitian, dynamic_shape=shape)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(8, 6), (5, 5), (10, 10), (20, 18), (300, 300)],
|
||||
log_cond=np.linspace(0, 1, 4),
|
||||
hermitian=[True, False],
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
def testQdwhWithRandomMatrix0(self, shape, log_cond, hermitian, dtype):
|
||||
"""Tests qdwh with upper triangular input of all ones."""
|
||||
m, n = shape
|
||||
rng = jtu.rand_uniform(self.rng())
|
||||
a = rng((m, n), dtype)
|
||||
|
||||
# Generates input matrix with prescribed condition number.
|
||||
hermitian = hermitian & (m == n)
|
||||
if hermitian:
|
||||
a = (a + a.conj().T) / 2
|
||||
self._testQdwh(a, is_hermitian=hermitian)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(8, 6), (10, 10), (20, 18), (300, 300)],
|
||||
log_cond=np.linspace(0, 1, 4),
|
||||
|
Loading…
x
Reference in New Issue
Block a user