diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index f18c80d69..335dfd827 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1569,14 +1569,13 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv): "Singular value decomposition JVP not implemented for full matrices") Ut, V = _H(U), _H(Vt) - if not compute_uv: - ds = jnp.real(jnp.einsum("...ab,...bc,...ca->...a", Ut, dA, V)) - return (s,), (ds,) - s_dim = s[..., None, :] dS = Ut @ dA @ V ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) + if not compute_uv: + return (s,), (ds,) + s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim)) s_diffs_zeros = jnp.eye(s.shape[-1], dtype=s.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2)) @@ -1595,7 +1594,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv): if m > n: dAV = dA @ V dU = dU + (dAV - U @ (Ut @ dAV)) / s_dim.astype(A.dtype) - elif n > m: + if n > m: dAHU = _H(dA) @ U dV = dV + (dAHU - V @ (Vt @ dAHU)) / s_dim.astype(A.dtype) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index f8373cc50..bb4286868 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -409,6 +409,7 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array: return w +@partial(custom_jvp, nondiff_argnums=(1, 2)) @_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\ It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the default `rcond` is `1e-15`. Here the default is @@ -421,45 +422,35 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None, # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 _check_arraylike("jnp.linalg.pinv", a) arr = jnp.asarray(a) + m, n = arr.shape[-2:] + if m == 0 or n == 0: + return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype) + arr = jnp.conj(arr) if rcond is None: max_rows_cols = max(arr.shape[-2:]) rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps) rcond = jnp.asarray(rcond) - return _pinv(arr, rcond, hermitian, compute_s=False) # type: ignore - -@jax.default_matmul_precision("float32") -def _svd_to_pinv(u, s, vh, rcond): + u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian) # Singular values less than or equal to ``rcond * largest_singular_value`` # are set to zero. rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1)) - cutoff = rcond.astype(s.dtype) * s[..., 0:1] - s_truncated = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype) - return _T(vh) @ jnp.divide(_T(u), s_truncated[..., jnp.newaxis]) + cutoff = rcond * s[..., 0:1] + s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype) + res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]), + precision=lax.Precision.HIGHEST) + return lax.convert_element_type(res, arr.dtype) -@partial(custom_jvp, nondiff_argnums=(1, 2, 3)) -def _pinv(arr, rcond, hermitian, compute_s): - m, n = arr.shape[-2:] - if m == 0 or n == 0: - out = jnp.empty(arr.shape[:-2] + (n, m), arr.dtype) - s = jnp.empty(arr.shape[:-2] + (0,), jnp.finfo(arr.dtype).dtype) - else: - u, s, vh = svd(jnp.conj(arr), full_matrices=False, hermitian=hermitian) - out = _svd_to_pinv(u, s, vh, rcond) - return (out, s) if compute_s else out -@_pinv.defjvp +@pinv.defjvp @jax.default_matmul_precision("float32") -def _pinv_jvp(rcond, hermitian, compute_s, primals, tangents): +def _pinv_jvp(rcond, hermitian, primals, tangents): # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432. # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) a, = primals # m x n a_dot, = tangents - - u, s, vh = svd(jnp.conj(a), full_matrices=False, hermitian=hermitian) - p = _svd_to_pinv(u, s, vh, rcond) - + p = pinv(a, rcond=rcond, hermitian=hermitian) # n x m if hermitian: # svd(..., hermitian=True) symmetrizes its input, and the JVP must match. a = _symmetrize(a) @@ -469,18 +460,13 @@ def _pinv_jvp(rcond, hermitian, compute_s, primals, tangents): # supported triangular matrix multiplication. m, n = a.shape[-2:] if m >= n: - t1 = (p @ _H(p)) @ _H(a_dot) # nxm - t2 = (_H(a_dot) @ _H(p)) @ p # nxm - p_dot = -(p @ a_dot) @ p + t1 - (t1 @ a) @ p + t2 - (p @ a) @ t2 + s = (p @ _H(p)) @ _H(a_dot) # nxm + t = (_H(a_dot) @ _H(p)) @ p # nxm + p_dot = -(p @ a_dot) @ p + s - (s @ a) @ p + t - (p @ a) @ t else: # m < n - t1 = p @ (_H(p) @ _H(a_dot)) - t2 = _H(a_dot) @ (_H(p) @ p) - p_dot = -p @ (a_dot @ p) + t1 - t1 @ (a @ p) + t2 - p @ (a @ t2) - - if compute_s: - ds = jnp.real(jnp.einsum("...ab,...bc,...ca->...a", _H(u), a_dot, _H(vh))) - return (p, s), (p_dot, ds) - + s = p @ (_H(p) @ _H(a_dot)) + t = _H(a_dot) @ (_H(p) @ p) + p_dot = -p @ (a_dot @ p) + s - s @ (a @ p) + t - p @ (a @ t) return p, p_dot @@ -626,10 +612,10 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array: return lax_linalg._solve(a, b) -@jax.default_matmul_precision("float32") def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *, numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]: # TODO: add lstsq to lax_linalg and implement this function via those wrappers. + # TODO: add custom jvp rule for more robust lstsq differentiation a, b = _promote_dtypes_inexact(a, b) if a.shape[0] != b.shape[0]: raise ValueError("Leading dimensions of input arrays must match") @@ -644,26 +630,29 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *, f"{b.ndim}-dimensional array given. Array must be one or two-dimensional") m, n = a.shape dtype = a.dtype - if rcond is None: - rcond = jnp.finfo(dtype).eps * max(n, m) + if a.size == 0: + s = jnp.empty(0, dtype=a.dtype) + rank = jnp.array(0, dtype=int) + x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype) else: - rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) - - inv: Array - s: Array - inv, s = _pinv(a, rcond, hermitian=False, compute_s=True) # type: ignore - x = inv @ b - if s.size > 0: + if rcond is None: + rcond = jnp.finfo(dtype).eps * max(n, m) + else: + rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) + u, s, vt = svd(a, full_matrices=False) mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0] rank = mask.sum() - else: - rank = jnp.array(0, dtype=int) + safe_s = jnp.where(mask, s, 1).astype(a.dtype) + s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] + uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) + x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) # Numpy returns empty residuals in some cases. To allow compilation, we # default to returning full residuals in all cases. if numpy_resid and (rank < n or m <= n): resid = jnp.asarray([]) else: - resid = norm(b - (a @ x), axis=0) ** 2 + b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST) + resid = norm(b - b_estimate, axis=0) ** 2 if b_orig_ndim == 1: x = x.ravel() return x, resid, rank, s diff --git a/tests/linalg_test.py b/tests/linalg_test.py index c2f039b6a..dd0174c8d 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -908,7 +908,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): ((0, 3), (0,)), ((3, 0), (3,)), ((3, 1), (3, 0)), - ((400000, 2), (400000,)), ] ], rcond=[-1, None, 0.5], @@ -919,7 +918,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): np_fun = partial(np.linalg.lstsq, rcond=rcond) jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond) jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True) - tol = {np.float32: 1e-3, np.float64: 1e-12, + tol = {np.float32: 1e-4, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12} args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]