mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add tests for lax.linalg.svd algorithm specification.
PiperOrigin-RevId: 730890907
This commit is contained in:
parent
30348e90e7
commit
a3a48af105
@ -841,9 +841,21 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
b=[(), (3,), (2, 3)],
|
||||
dtype=float_types + complex_types,
|
||||
compute_uv=[False, True],
|
||||
algorithm=[None, lax.linalg.SvdAlgorithm.QR, lax.linalg.SvdAlgorithm.JACOBI],
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
|
||||
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian, algorithm):
|
||||
if algorithm is not None:
|
||||
if hermitian:
|
||||
self.skipTest("Hermitian SVD doesn't support the algorithm parameter.")
|
||||
if not jtu.test_device_matches(["cpu", "gpu"]):
|
||||
self.skipTest("SVD algorithm selection only supported on CPU and GPU.")
|
||||
# TODO(danfm): Remove this check after 0.5.2 is released.
|
||||
if jtu.test_device_matches(["cpu"]) and jtu.jaxlib_version() <= (0, 5, 1):
|
||||
self.skipTest("SVD algorithm selection on CPU requires a newer jaxlib version.")
|
||||
if jtu.test_device_matches(["cpu"]) and algorithm == lax.linalg.SvdAlgorithm.JACOBI:
|
||||
self.skipTest("Jacobi SVD not supported on GPU.")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(b + (m, n), dtype)]
|
||||
|
||||
@ -862,8 +874,11 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
a, = args_maker()
|
||||
if hermitian:
|
||||
a = a + np.conj(T(a))
|
||||
out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv,
|
||||
hermitian=hermitian)
|
||||
if algorithm is None:
|
||||
fun = partial(jnp.linalg.svd, hermitian=hermitian)
|
||||
else:
|
||||
fun = partial(lax.linalg.svd, algorithm=algorithm)
|
||||
out = fun(a, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
if compute_uv:
|
||||
# Check the reconstructed matrices
|
||||
out = list(out)
|
||||
@ -906,13 +921,12 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False),
|
||||
np.asarray(out), atol=1e-4, rtol=1e-4))
|
||||
|
||||
self._CompileAndCheck(partial(jnp.linalg.svd, full_matrices=full_matrices,
|
||||
self._CompileAndCheck(partial(fun, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv),
|
||||
args_maker)
|
||||
|
||||
if not compute_uv and a.size < 100000:
|
||||
svd = partial(jnp.linalg.svd, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
svd = partial(fun, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
# TODO(phawkins): these tolerances seem very loose.
|
||||
if dtype == np.complex128:
|
||||
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=1e-4, atol=1e-4,
|
||||
|
Loading…
x
Reference in New Issue
Block a user