From 845f8df8376f8421907603caf22230e49c05dc1c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 7 Nov 2022 12:40:16 -0500 Subject: [PATCH] Avoid forming identity matrix in SVD JVP. Set the default matmul precision in the SVD JVP, and use @ to express matmuls. Also fix a flaky test failure in QR test on Mac ARM. --- jax/_src/lax/linalg.py | 16 +++++++++------- tests/linalg_test.py | 8 +------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index a0c77bec2..cc959738f 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -21,6 +21,7 @@ import warnings import numpy as np from typing_extensions import Literal +import jax from jax._src.numpy import lax_numpy as jnp from jax._src.numpy.vectorize import vectorize from jax._src import ad_util @@ -1551,6 +1552,7 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv): else: raise NotImplementedError +@jax.default_matmul_precision("float32") def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv): A, = primals dA, = tangents @@ -1563,7 +1565,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv): Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] - dS = jnp.matmul(jnp.matmul(Ut, dA), V) + dS = Ut @ dA @ V ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) if not compute_uv: @@ -1580,16 +1582,16 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv): s_inv = 1 / (s + s_zeros) - s_zeros s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv) dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat.astype(A.dtype) - dU = jnp.matmul(U, F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag) - dV = jnp.matmul(V, F.astype(A.dtype) * (SdS + _H(SdS))) + dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag) + dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS))) m, n = A.shape[-2:] if m > n: - I = lax.expand_dims(jnp.eye(m, dtype=A.dtype), range(U.ndim - 2)) - dU = dU + jnp.matmul(I - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim.astype(A.dtype) + dAV = dA @ V + dU = dU + (dAV - U @ Ut @ dAV) / s_dim.astype(A.dtype) if n > m: - I = lax.expand_dims(jnp.eye(n, dtype=A.dtype), range(V.ndim - 2)) - dV = dV + jnp.matmul(I - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim.astype(A.dtype) + dAHU = _H(dA) @ U + dV = dV + (dAHU - V @ Vt @ dAHU) / s_dim.astype(A.dtype) return (s, U, Vt), (ds, dU, _H(dV)) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 7767f4619..2721ce46d 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -15,7 +15,6 @@ """Tests for the LAPAX linear algebra module.""" from functools import partial -import unittest import numpy as np import scipy @@ -478,11 +477,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): ) def testNorm(self, shape, dtype, ord, axis, keepdims): rng = jtu.rand_default(self.rng()) - if (ord in ('nuc', 2, -2) and ( - jtu.device_under_test() != "cpu" or - (isinstance(axis, tuple) and len(axis) == 2))): - raise unittest.SkipTest("No adequate SVD implementation available") - args_maker = lambda: [rng(shape, dtype)] np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims) jnp_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims) @@ -663,7 +657,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios)) q1 *= phases nm = norm(q1 - q2) - self.assertTrue(np.all(nm < 140), msg=f"norm={np.amax(nm)}") + self.assertTrue(np.all(nm < 160), msg=f"norm={np.amax(nm)}") # Check a ~= qr self.assertTrue(np.all(norm(a - np.matmul(lq, lr)) < 40))