mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Use pinv to compute lstsq.
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition. PiperOrigin-RevId: 487903227
This commit is contained in:
parent
047974dd0c
commit
7c3fb81310
@ -1569,13 +1569,14 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
|
||||
"Singular value decomposition JVP not implemented for full matrices")
|
||||
|
||||
Ut, V = _H(U), _H(Vt)
|
||||
if not compute_uv:
|
||||
ds = jnp.real(jnp.einsum("...ab,...bc,...ca->...a", Ut, dA, V))
|
||||
return (s,), (ds,)
|
||||
|
||||
s_dim = s[..., None, :]
|
||||
dS = Ut @ dA @ V
|
||||
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
|
||||
|
||||
if not compute_uv:
|
||||
return (s,), (ds,)
|
||||
|
||||
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
|
||||
s_diffs_zeros = jnp.eye(s.shape[-1], dtype=s.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else
|
||||
s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
|
||||
@ -1594,7 +1595,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
|
||||
if m > n:
|
||||
dAV = dA @ V
|
||||
dU = dU + (dAV - U @ (Ut @ dAV)) / s_dim.astype(A.dtype)
|
||||
if n > m:
|
||||
elif n > m:
|
||||
dAHU = _H(dA) @ U
|
||||
dV = dV + (dAHU - V @ (Vt @ dAHU)) / s_dim.astype(A.dtype)
|
||||
|
||||
|
@ -409,7 +409,6 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
|
||||
return w
|
||||
|
||||
|
||||
@partial(custom_jvp, nondiff_argnums=(1, 2))
|
||||
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
|
||||
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
|
||||
default `rcond` is `1e-15`. Here the default is
|
||||
@ -422,35 +421,45 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
|
||||
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
||||
_check_arraylike("jnp.linalg.pinv", a)
|
||||
arr = jnp.asarray(a)
|
||||
m, n = arr.shape[-2:]
|
||||
if m == 0 or n == 0:
|
||||
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
|
||||
arr = jnp.conj(arr)
|
||||
if rcond is None:
|
||||
max_rows_cols = max(arr.shape[-2:])
|
||||
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
|
||||
rcond = jnp.asarray(rcond)
|
||||
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
|
||||
return _pinv(arr, rcond, hermitian, compute_s=False) # type: ignore
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def _svd_to_pinv(u, s, vh, rcond):
|
||||
# Singular values less than or equal to ``rcond * largest_singular_value``
|
||||
# are set to zero.
|
||||
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
|
||||
cutoff = rcond * s[..., 0:1]
|
||||
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
|
||||
res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]),
|
||||
precision=lax.Precision.HIGHEST)
|
||||
return lax.convert_element_type(res, arr.dtype)
|
||||
cutoff = rcond.astype(s.dtype) * s[..., 0:1]
|
||||
s_truncated = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
|
||||
return _T(vh) @ jnp.divide(_T(u), s_truncated[..., jnp.newaxis])
|
||||
|
||||
@partial(custom_jvp, nondiff_argnums=(1, 2, 3))
|
||||
def _pinv(arr, rcond, hermitian, compute_s):
|
||||
m, n = arr.shape[-2:]
|
||||
if m == 0 or n == 0:
|
||||
out = jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
|
||||
s = jnp.empty(arr.shape[:-2] + (0,), jnp.finfo(arr.dtype).dtype)
|
||||
else:
|
||||
u, s, vh = svd(jnp.conj(arr), full_matrices=False, hermitian=hermitian)
|
||||
out = _svd_to_pinv(u, s, vh, rcond)
|
||||
return (out, s) if compute_s else out
|
||||
|
||||
@pinv.defjvp
|
||||
@_pinv.defjvp
|
||||
@jax.default_matmul_precision("float32")
|
||||
def _pinv_jvp(rcond, hermitian, primals, tangents):
|
||||
def _pinv_jvp(rcond, hermitian, compute_s, primals, tangents):
|
||||
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
|
||||
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
|
||||
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
|
||||
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
|
||||
a, = primals # m x n
|
||||
a_dot, = tangents
|
||||
p = pinv(a, rcond=rcond, hermitian=hermitian) # n x m
|
||||
|
||||
u, s, vh = svd(jnp.conj(a), full_matrices=False, hermitian=hermitian)
|
||||
p = _svd_to_pinv(u, s, vh, rcond)
|
||||
|
||||
if hermitian:
|
||||
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
|
||||
a = _symmetrize(a)
|
||||
@ -460,13 +469,18 @@ def _pinv_jvp(rcond, hermitian, primals, tangents):
|
||||
# supported triangular matrix multiplication.
|
||||
m, n = a.shape[-2:]
|
||||
if m >= n:
|
||||
s = (p @ _H(p)) @ _H(a_dot) # nxm
|
||||
t = (_H(a_dot) @ _H(p)) @ p # nxm
|
||||
p_dot = -(p @ a_dot) @ p + s - (s @ a) @ p + t - (p @ a) @ t
|
||||
t1 = (p @ _H(p)) @ _H(a_dot) # nxm
|
||||
t2 = (_H(a_dot) @ _H(p)) @ p # nxm
|
||||
p_dot = -(p @ a_dot) @ p + t1 - (t1 @ a) @ p + t2 - (p @ a) @ t2
|
||||
else: # m < n
|
||||
s = p @ (_H(p) @ _H(a_dot))
|
||||
t = _H(a_dot) @ (_H(p) @ p)
|
||||
p_dot = -p @ (a_dot @ p) + s - s @ (a @ p) + t - p @ (a @ t)
|
||||
t1 = p @ (_H(p) @ _H(a_dot))
|
||||
t2 = _H(a_dot) @ (_H(p) @ p)
|
||||
p_dot = -p @ (a_dot @ p) + t1 - t1 @ (a @ p) + t2 - p @ (a @ t2)
|
||||
|
||||
if compute_s:
|
||||
ds = jnp.real(jnp.einsum("...ab,...bc,...ca->...a", _H(u), a_dot, _H(vh)))
|
||||
return (p, s), (p_dot, ds)
|
||||
|
||||
return p, p_dot
|
||||
|
||||
|
||||
@ -612,10 +626,10 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array:
|
||||
return lax_linalg._solve(a, b)
|
||||
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
|
||||
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
|
||||
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
|
||||
# TODO: add custom jvp rule for more robust lstsq differentiation
|
||||
a, b = _promote_dtypes_inexact(a, b)
|
||||
if a.shape[0] != b.shape[0]:
|
||||
raise ValueError("Leading dimensions of input arrays must match")
|
||||
@ -630,29 +644,26 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
|
||||
f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
|
||||
m, n = a.shape
|
||||
dtype = a.dtype
|
||||
if a.size == 0:
|
||||
s = jnp.empty(0, dtype=a.dtype)
|
||||
rank = jnp.array(0, dtype=int)
|
||||
x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype)
|
||||
if rcond is None:
|
||||
rcond = jnp.finfo(dtype).eps * max(n, m)
|
||||
else:
|
||||
if rcond is None:
|
||||
rcond = jnp.finfo(dtype).eps * max(n, m)
|
||||
else:
|
||||
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
|
||||
u, s, vt = svd(a, full_matrices=False)
|
||||
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
|
||||
|
||||
inv: Array
|
||||
s: Array
|
||||
inv, s = _pinv(a, rcond, hermitian=False, compute_s=True) # type: ignore
|
||||
x = inv @ b
|
||||
if s.size > 0:
|
||||
mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0]
|
||||
rank = mask.sum()
|
||||
safe_s = jnp.where(mask, s, 1).astype(a.dtype)
|
||||
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
|
||||
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
|
||||
else:
|
||||
rank = jnp.array(0, dtype=int)
|
||||
# Numpy returns empty residuals in some cases. To allow compilation, we
|
||||
# default to returning full residuals in all cases.
|
||||
if numpy_resid and (rank < n or m <= n):
|
||||
resid = jnp.asarray([])
|
||||
else:
|
||||
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
|
||||
resid = norm(b - b_estimate, axis=0) ** 2
|
||||
resid = norm(b - (a @ x), axis=0) ** 2
|
||||
if b_orig_ndim == 1:
|
||||
x = x.ravel()
|
||||
return x, resid, rank, s
|
||||
|
@ -908,6 +908,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
((0, 3), (0,)),
|
||||
((3, 0), (3,)),
|
||||
((3, 1), (3, 0)),
|
||||
((400000, 2), (400000,)),
|
||||
]
|
||||
],
|
||||
rcond=[-1, None, 0.5],
|
||||
@ -918,7 +919,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
np_fun = partial(np.linalg.lstsq, rcond=rcond)
|
||||
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
|
||||
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
|
||||
tol = {np.float32: 1e-4, np.float64: 1e-12,
|
||||
tol = {np.float32: 1e-3, np.float64: 1e-12,
|
||||
np.complex64: 1e-5, np.complex128: 1e-12}
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user