Add tests for lax.linalg.svd algorithm specification.

PiperOrigin-RevId: 730890907
This commit is contained in:
Dan Foreman-Mackey 2025-02-25 08:11:17 -08:00 committed by jax authors
parent 30348e90e7
commit a3a48af105

View File

@ -841,9 +841,21 @@ class NumpyLinalgTest(jtu.JaxTestCase):
b=[(), (3,), (2, 3)],
dtype=float_types + complex_types,
compute_uv=[False, True],
algorithm=[None, lax.linalg.SvdAlgorithm.QR, lax.linalg.SvdAlgorithm.JACOBI],
)
@jax.default_matmul_precision("float32")
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian, algorithm):
if algorithm is not None:
if hermitian:
self.skipTest("Hermitian SVD doesn't support the algorithm parameter.")
if not jtu.test_device_matches(["cpu", "gpu"]):
self.skipTest("SVD algorithm selection only supported on CPU and GPU.")
# TODO(danfm): Remove this check after 0.5.2 is released.
if jtu.test_device_matches(["cpu"]) and jtu.jaxlib_version() <= (0, 5, 1):
self.skipTest("SVD algorithm selection on CPU requires a newer jaxlib version.")
if jtu.test_device_matches(["cpu"]) and algorithm == lax.linalg.SvdAlgorithm.JACOBI:
self.skipTest("Jacobi SVD not supported on GPU.")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(b + (m, n), dtype)]
@ -862,8 +874,11 @@ class NumpyLinalgTest(jtu.JaxTestCase):
a, = args_maker()
if hermitian:
a = a + np.conj(T(a))
out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv,
hermitian=hermitian)
if algorithm is None:
fun = partial(jnp.linalg.svd, hermitian=hermitian)
else:
fun = partial(lax.linalg.svd, algorithm=algorithm)
out = fun(a, full_matrices=full_matrices, compute_uv=compute_uv)
if compute_uv:
# Check the reconstructed matrices
out = list(out)
@ -906,13 +921,12 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False),
np.asarray(out), atol=1e-4, rtol=1e-4))
self._CompileAndCheck(partial(jnp.linalg.svd, full_matrices=full_matrices,
self._CompileAndCheck(partial(fun, full_matrices=full_matrices,
compute_uv=compute_uv),
args_maker)
if not compute_uv and a.size < 100000:
svd = partial(jnp.linalg.svd, full_matrices=full_matrices,
compute_uv=compute_uv)
svd = partial(fun, full_matrices=full_matrices, compute_uv=compute_uv)
# TODO(phawkins): these tolerances seem very loose.
if dtype == np.complex128:
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=1e-4, atol=1e-4,