From 7c3fb813106de01dd050acda4cd7a9581d69ac80 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Nov 2022 13:28:11 -0800 Subject: [PATCH] Use pinv to compute lstsq. The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition. PiperOrigin-RevId: 487903227 --- jax/_src/lax/linalg.py | 9 +++-- jax/_src/numpy/linalg.py | 83 +++++++++++++++++++++++----------------- tests/linalg_test.py | 3 +- 3 files changed, 54 insertions(+), 41 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 335dfd827..f18c80d69 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1569,13 +1569,14 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv): "Singular value decomposition JVP not implemented for full matrices") Ut, V = _H(U), _H(Vt) + if not compute_uv: + ds = jnp.real(jnp.einsum("...ab,...bc,...ca->...a", Ut, dA, V)) + return (s,), (ds,) + s_dim = s[..., None, :] dS = Ut @ dA @ V ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) - if not compute_uv: - return (s,), (ds,) - s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim)) s_diffs_zeros = jnp.eye(s.shape[-1], dtype=s.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2)) @@ -1594,7 +1595,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv): if m > n: dAV = dA @ V dU = dU + (dAV - U @ (Ut @ dAV)) / s_dim.astype(A.dtype) - if n > m: + elif n > m: dAHU = _H(dA) @ U dV = dV + (dAHU - V @ (Vt @ dAHU)) / s_dim.astype(A.dtype) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index bb4286868..f8373cc50 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -409,7 +409,6 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array: return w -@partial(custom_jvp, nondiff_argnums=(1, 2)) @_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\ It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the default `rcond` is `1e-15`. Here the default is @@ -422,35 +421,45 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None, # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 _check_arraylike("jnp.linalg.pinv", a) arr = jnp.asarray(a) - m, n = arr.shape[-2:] - if m == 0 or n == 0: - return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype) - arr = jnp.conj(arr) if rcond is None: max_rows_cols = max(arr.shape[-2:]) rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps) rcond = jnp.asarray(rcond) - u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian) + return _pinv(arr, rcond, hermitian, compute_s=False) # type: ignore + +@jax.default_matmul_precision("float32") +def _svd_to_pinv(u, s, vh, rcond): # Singular values less than or equal to ``rcond * largest_singular_value`` # are set to zero. rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1)) - cutoff = rcond * s[..., 0:1] - s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype) - res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]), - precision=lax.Precision.HIGHEST) - return lax.convert_element_type(res, arr.dtype) + cutoff = rcond.astype(s.dtype) * s[..., 0:1] + s_truncated = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype) + return _T(vh) @ jnp.divide(_T(u), s_truncated[..., jnp.newaxis]) +@partial(custom_jvp, nondiff_argnums=(1, 2, 3)) +def _pinv(arr, rcond, hermitian, compute_s): + m, n = arr.shape[-2:] + if m == 0 or n == 0: + out = jnp.empty(arr.shape[:-2] + (n, m), arr.dtype) + s = jnp.empty(arr.shape[:-2] + (0,), jnp.finfo(arr.dtype).dtype) + else: + u, s, vh = svd(jnp.conj(arr), full_matrices=False, hermitian=hermitian) + out = _svd_to_pinv(u, s, vh, rcond) + return (out, s) if compute_s else out -@pinv.defjvp +@_pinv.defjvp @jax.default_matmul_precision("float32") -def _pinv_jvp(rcond, hermitian, primals, tangents): +def _pinv_jvp(rcond, hermitian, compute_s, primals, tangents): # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432. # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) a, = primals # m x n a_dot, = tangents - p = pinv(a, rcond=rcond, hermitian=hermitian) # n x m + + u, s, vh = svd(jnp.conj(a), full_matrices=False, hermitian=hermitian) + p = _svd_to_pinv(u, s, vh, rcond) + if hermitian: # svd(..., hermitian=True) symmetrizes its input, and the JVP must match. a = _symmetrize(a) @@ -460,13 +469,18 @@ def _pinv_jvp(rcond, hermitian, primals, tangents): # supported triangular matrix multiplication. m, n = a.shape[-2:] if m >= n: - s = (p @ _H(p)) @ _H(a_dot) # nxm - t = (_H(a_dot) @ _H(p)) @ p # nxm - p_dot = -(p @ a_dot) @ p + s - (s @ a) @ p + t - (p @ a) @ t + t1 = (p @ _H(p)) @ _H(a_dot) # nxm + t2 = (_H(a_dot) @ _H(p)) @ p # nxm + p_dot = -(p @ a_dot) @ p + t1 - (t1 @ a) @ p + t2 - (p @ a) @ t2 else: # m < n - s = p @ (_H(p) @ _H(a_dot)) - t = _H(a_dot) @ (_H(p) @ p) - p_dot = -p @ (a_dot @ p) + s - s @ (a @ p) + t - p @ (a @ t) + t1 = p @ (_H(p) @ _H(a_dot)) + t2 = _H(a_dot) @ (_H(p) @ p) + p_dot = -p @ (a_dot @ p) + t1 - t1 @ (a @ p) + t2 - p @ (a @ t2) + + if compute_s: + ds = jnp.real(jnp.einsum("...ab,...bc,...ca->...a", _H(u), a_dot, _H(vh))) + return (p, s), (p_dot, ds) + return p, p_dot @@ -612,10 +626,10 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array: return lax_linalg._solve(a, b) +@jax.default_matmul_precision("float32") def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *, numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]: # TODO: add lstsq to lax_linalg and implement this function via those wrappers. - # TODO: add custom jvp rule for more robust lstsq differentiation a, b = _promote_dtypes_inexact(a, b) if a.shape[0] != b.shape[0]: raise ValueError("Leading dimensions of input arrays must match") @@ -630,29 +644,26 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *, f"{b.ndim}-dimensional array given. Array must be one or two-dimensional") m, n = a.shape dtype = a.dtype - if a.size == 0: - s = jnp.empty(0, dtype=a.dtype) - rank = jnp.array(0, dtype=int) - x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype) + if rcond is None: + rcond = jnp.finfo(dtype).eps * max(n, m) else: - if rcond is None: - rcond = jnp.finfo(dtype).eps * max(n, m) - else: - rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) - u, s, vt = svd(a, full_matrices=False) + rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) + + inv: Array + s: Array + inv, s = _pinv(a, rcond, hermitian=False, compute_s=True) # type: ignore + x = inv @ b + if s.size > 0: mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0] rank = mask.sum() - safe_s = jnp.where(mask, s, 1).astype(a.dtype) - s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] - uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) - x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) + else: + rank = jnp.array(0, dtype=int) # Numpy returns empty residuals in some cases. To allow compilation, we # default to returning full residuals in all cases. if numpy_resid and (rank < n or m <= n): resid = jnp.asarray([]) else: - b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST) - resid = norm(b - b_estimate, axis=0) ** 2 + resid = norm(b - (a @ x), axis=0) ** 2 if b_orig_ndim == 1: x = x.ravel() return x, resid, rank, s diff --git a/tests/linalg_test.py b/tests/linalg_test.py index dd0174c8d..c2f039b6a 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -908,6 +908,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): ((0, 3), (0,)), ((3, 0), (3,)), ((3, 1), (3, 0)), + ((400000, 2), (400000,)), ] ], rcond=[-1, None, 0.5], @@ -918,7 +919,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): np_fun = partial(np.linalg.lstsq, rcond=rcond) jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond) jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True) - tol = {np.float32: 1e-4, np.float64: 1e-12, + tol = {np.float32: 1e-3, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12} args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]