Use pinv to compute lstsq.

The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition.

PiperOrigin-RevId: 487903227
This commit is contained in:
Peter Hawkins 2022-11-11 13:28:11 -08:00 committed by jax authors
parent 047974dd0c
commit 7c3fb81310
3 changed files with 54 additions and 41 deletions

View File

@ -1569,13 +1569,14 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
"Singular value decomposition JVP not implemented for full matrices")
Ut, V = _H(U), _H(Vt)
if not compute_uv:
ds = jnp.real(jnp.einsum("...ab,...bc,...ca->...a", Ut, dA, V))
return (s,), (ds,)
s_dim = s[..., None, :]
dS = Ut @ dA @ V
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
if not compute_uv:
return (s,), (ds,)
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
s_diffs_zeros = jnp.eye(s.shape[-1], dtype=s.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else
s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
@ -1594,7 +1595,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
if m > n:
dAV = dA @ V
dU = dU + (dAV - U @ (Ut @ dAV)) / s_dim.astype(A.dtype)
if n > m:
elif n > m:
dAHU = _H(dA) @ U
dV = dV + (dAHU - V @ (Vt @ dAHU)) / s_dim.astype(A.dtype)

View File

@ -409,7 +409,6 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
return w
@partial(custom_jvp, nondiff_argnums=(1, 2))
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
default `rcond` is `1e-15`. Here the default is
@ -422,35 +421,45 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
_check_arraylike("jnp.linalg.pinv", a)
arr = jnp.asarray(a)
m, n = arr.shape[-2:]
if m == 0 or n == 0:
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
arr = jnp.conj(arr)
if rcond is None:
max_rows_cols = max(arr.shape[-2:])
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
rcond = jnp.asarray(rcond)
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
return _pinv(arr, rcond, hermitian, compute_s=False) # type: ignore
@jax.default_matmul_precision("float32")
def _svd_to_pinv(u, s, vh, rcond):
# Singular values less than or equal to ``rcond * largest_singular_value``
# are set to zero.
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
cutoff = rcond * s[..., 0:1]
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]),
precision=lax.Precision.HIGHEST)
return lax.convert_element_type(res, arr.dtype)
cutoff = rcond.astype(s.dtype) * s[..., 0:1]
s_truncated = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
return _T(vh) @ jnp.divide(_T(u), s_truncated[..., jnp.newaxis])
@partial(custom_jvp, nondiff_argnums=(1, 2, 3))
def _pinv(arr, rcond, hermitian, compute_s):
m, n = arr.shape[-2:]
if m == 0 or n == 0:
out = jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
s = jnp.empty(arr.shape[:-2] + (0,), jnp.finfo(arr.dtype).dtype)
else:
u, s, vh = svd(jnp.conj(arr), full_matrices=False, hermitian=hermitian)
out = _svd_to_pinv(u, s, vh, rcond)
return (out, s) if compute_s else out
@pinv.defjvp
@_pinv.defjvp
@jax.default_matmul_precision("float32")
def _pinv_jvp(rcond, hermitian, primals, tangents):
def _pinv_jvp(rcond, hermitian, compute_s, primals, tangents):
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
a, = primals # m x n
a_dot, = tangents
p = pinv(a, rcond=rcond, hermitian=hermitian) # n x m
u, s, vh = svd(jnp.conj(a), full_matrices=False, hermitian=hermitian)
p = _svd_to_pinv(u, s, vh, rcond)
if hermitian:
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
a = _symmetrize(a)
@ -460,13 +469,18 @@ def _pinv_jvp(rcond, hermitian, primals, tangents):
# supported triangular matrix multiplication.
m, n = a.shape[-2:]
if m >= n:
s = (p @ _H(p)) @ _H(a_dot) # nxm
t = (_H(a_dot) @ _H(p)) @ p # nxm
p_dot = -(p @ a_dot) @ p + s - (s @ a) @ p + t - (p @ a) @ t
t1 = (p @ _H(p)) @ _H(a_dot) # nxm
t2 = (_H(a_dot) @ _H(p)) @ p # nxm
p_dot = -(p @ a_dot) @ p + t1 - (t1 @ a) @ p + t2 - (p @ a) @ t2
else: # m < n
s = p @ (_H(p) @ _H(a_dot))
t = _H(a_dot) @ (_H(p) @ p)
p_dot = -p @ (a_dot @ p) + s - s @ (a @ p) + t - p @ (a @ t)
t1 = p @ (_H(p) @ _H(a_dot))
t2 = _H(a_dot) @ (_H(p) @ p)
p_dot = -p @ (a_dot @ p) + t1 - t1 @ (a @ p) + t2 - p @ (a @ t2)
if compute_s:
ds = jnp.real(jnp.einsum("...ab,...bc,...ca->...a", _H(u), a_dot, _H(vh)))
return (p, s), (p_dot, ds)
return p, p_dot
@ -612,10 +626,10 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array:
return lax_linalg._solve(a, b)
@jax.default_matmul_precision("float32")
def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
# TODO: add custom jvp rule for more robust lstsq differentiation
a, b = _promote_dtypes_inexact(a, b)
if a.shape[0] != b.shape[0]:
raise ValueError("Leading dimensions of input arrays must match")
@ -630,29 +644,26 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
m, n = a.shape
dtype = a.dtype
if a.size == 0:
s = jnp.empty(0, dtype=a.dtype)
rank = jnp.array(0, dtype=int)
x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype)
if rcond is None:
rcond = jnp.finfo(dtype).eps * max(n, m)
else:
if rcond is None:
rcond = jnp.finfo(dtype).eps * max(n, m)
else:
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
u, s, vt = svd(a, full_matrices=False)
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
inv: Array
s: Array
inv, s = _pinv(a, rcond, hermitian=False, compute_s=True) # type: ignore
x = inv @ b
if s.size > 0:
mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0]
rank = mask.sum()
safe_s = jnp.where(mask, s, 1).astype(a.dtype)
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
else:
rank = jnp.array(0, dtype=int)
# Numpy returns empty residuals in some cases. To allow compilation, we
# default to returning full residuals in all cases.
if numpy_resid and (rank < n or m <= n):
resid = jnp.asarray([])
else:
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
resid = norm(b - b_estimate, axis=0) ** 2
resid = norm(b - (a @ x), axis=0) ** 2
if b_orig_ndim == 1:
x = x.ravel()
return x, resid, rank, s

View File

@ -908,6 +908,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
((0, 3), (0,)),
((3, 0), (3,)),
((3, 1), (3, 0)),
((400000, 2), (400000,)),
]
],
rcond=[-1, None, 0.5],
@ -918,7 +919,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
np_fun = partial(np.linalg.lstsq, rcond=rcond)
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
tol = {np.float32: 1e-4, np.float64: 1e-12,
tol = {np.float32: 1e-3, np.float64: 1e-12,
np.complex64: 1e-5, np.complex128: 1e-12}
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]