mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Use pinv to compute lstsq.
The current implementation of lstsq is equivalent to pinv(A) @ b, with a different order of matrix multiplications. If we write it that way we benefit from a more stable derivative that does not require differentiating through the singular value decomposition. PiperOrigin-RevId: 487903227
This commit is contained in:
parent
047974dd0c
commit
7c3fb81310
@ -1569,13 +1569,14 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
|
|||||||
"Singular value decomposition JVP not implemented for full matrices")
|
"Singular value decomposition JVP not implemented for full matrices")
|
||||||
|
|
||||||
Ut, V = _H(U), _H(Vt)
|
Ut, V = _H(U), _H(Vt)
|
||||||
|
if not compute_uv:
|
||||||
|
ds = jnp.real(jnp.einsum("...ab,...bc,...ca->...a", Ut, dA, V))
|
||||||
|
return (s,), (ds,)
|
||||||
|
|
||||||
s_dim = s[..., None, :]
|
s_dim = s[..., None, :]
|
||||||
dS = Ut @ dA @ V
|
dS = Ut @ dA @ V
|
||||||
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
|
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
|
||||||
|
|
||||||
if not compute_uv:
|
|
||||||
return (s,), (ds,)
|
|
||||||
|
|
||||||
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
|
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
|
||||||
s_diffs_zeros = jnp.eye(s.shape[-1], dtype=s.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else
|
s_diffs_zeros = jnp.eye(s.shape[-1], dtype=s.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else
|
||||||
s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
|
s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
|
||||||
@ -1594,7 +1595,7 @@ def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
|
|||||||
if m > n:
|
if m > n:
|
||||||
dAV = dA @ V
|
dAV = dA @ V
|
||||||
dU = dU + (dAV - U @ (Ut @ dAV)) / s_dim.astype(A.dtype)
|
dU = dU + (dAV - U @ (Ut @ dAV)) / s_dim.astype(A.dtype)
|
||||||
if n > m:
|
elif n > m:
|
||||||
dAHU = _H(dA) @ U
|
dAHU = _H(dA) @ U
|
||||||
dV = dV + (dAHU - V @ (Vt @ dAHU)) / s_dim.astype(A.dtype)
|
dV = dV + (dAHU - V @ (Vt @ dAHU)) / s_dim.astype(A.dtype)
|
||||||
|
|
||||||
|
@ -409,7 +409,6 @@ def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
|
|||||||
return w
|
return w
|
||||||
|
|
||||||
|
|
||||||
@partial(custom_jvp, nondiff_argnums=(1, 2))
|
|
||||||
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
|
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
|
||||||
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
|
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
|
||||||
default `rcond` is `1e-15`. Here the default is
|
default `rcond` is `1e-15`. Here the default is
|
||||||
@ -422,35 +421,45 @@ def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
|
|||||||
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
||||||
_check_arraylike("jnp.linalg.pinv", a)
|
_check_arraylike("jnp.linalg.pinv", a)
|
||||||
arr = jnp.asarray(a)
|
arr = jnp.asarray(a)
|
||||||
m, n = arr.shape[-2:]
|
|
||||||
if m == 0 or n == 0:
|
|
||||||
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
|
|
||||||
arr = jnp.conj(arr)
|
|
||||||
if rcond is None:
|
if rcond is None:
|
||||||
max_rows_cols = max(arr.shape[-2:])
|
max_rows_cols = max(arr.shape[-2:])
|
||||||
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
|
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
|
||||||
rcond = jnp.asarray(rcond)
|
rcond = jnp.asarray(rcond)
|
||||||
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
|
return _pinv(arr, rcond, hermitian, compute_s=False) # type: ignore
|
||||||
|
|
||||||
|
@jax.default_matmul_precision("float32")
|
||||||
|
def _svd_to_pinv(u, s, vh, rcond):
|
||||||
# Singular values less than or equal to ``rcond * largest_singular_value``
|
# Singular values less than or equal to ``rcond * largest_singular_value``
|
||||||
# are set to zero.
|
# are set to zero.
|
||||||
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
|
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
|
||||||
cutoff = rcond * s[..., 0:1]
|
cutoff = rcond.astype(s.dtype) * s[..., 0:1]
|
||||||
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
|
s_truncated = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
|
||||||
res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]),
|
return _T(vh) @ jnp.divide(_T(u), s_truncated[..., jnp.newaxis])
|
||||||
precision=lax.Precision.HIGHEST)
|
|
||||||
return lax.convert_element_type(res, arr.dtype)
|
|
||||||
|
|
||||||
|
@partial(custom_jvp, nondiff_argnums=(1, 2, 3))
|
||||||
|
def _pinv(arr, rcond, hermitian, compute_s):
|
||||||
|
m, n = arr.shape[-2:]
|
||||||
|
if m == 0 or n == 0:
|
||||||
|
out = jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
|
||||||
|
s = jnp.empty(arr.shape[:-2] + (0,), jnp.finfo(arr.dtype).dtype)
|
||||||
|
else:
|
||||||
|
u, s, vh = svd(jnp.conj(arr), full_matrices=False, hermitian=hermitian)
|
||||||
|
out = _svd_to_pinv(u, s, vh, rcond)
|
||||||
|
return (out, s) if compute_s else out
|
||||||
|
|
||||||
@pinv.defjvp
|
@_pinv.defjvp
|
||||||
@jax.default_matmul_precision("float32")
|
@jax.default_matmul_precision("float32")
|
||||||
def _pinv_jvp(rcond, hermitian, primals, tangents):
|
def _pinv_jvp(rcond, hermitian, compute_s, primals, tangents):
|
||||||
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
|
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
|
||||||
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
|
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
|
||||||
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
|
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
|
||||||
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
|
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
|
||||||
a, = primals # m x n
|
a, = primals # m x n
|
||||||
a_dot, = tangents
|
a_dot, = tangents
|
||||||
p = pinv(a, rcond=rcond, hermitian=hermitian) # n x m
|
|
||||||
|
u, s, vh = svd(jnp.conj(a), full_matrices=False, hermitian=hermitian)
|
||||||
|
p = _svd_to_pinv(u, s, vh, rcond)
|
||||||
|
|
||||||
if hermitian:
|
if hermitian:
|
||||||
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
|
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
|
||||||
a = _symmetrize(a)
|
a = _symmetrize(a)
|
||||||
@ -460,13 +469,18 @@ def _pinv_jvp(rcond, hermitian, primals, tangents):
|
|||||||
# supported triangular matrix multiplication.
|
# supported triangular matrix multiplication.
|
||||||
m, n = a.shape[-2:]
|
m, n = a.shape[-2:]
|
||||||
if m >= n:
|
if m >= n:
|
||||||
s = (p @ _H(p)) @ _H(a_dot) # nxm
|
t1 = (p @ _H(p)) @ _H(a_dot) # nxm
|
||||||
t = (_H(a_dot) @ _H(p)) @ p # nxm
|
t2 = (_H(a_dot) @ _H(p)) @ p # nxm
|
||||||
p_dot = -(p @ a_dot) @ p + s - (s @ a) @ p + t - (p @ a) @ t
|
p_dot = -(p @ a_dot) @ p + t1 - (t1 @ a) @ p + t2 - (p @ a) @ t2
|
||||||
else: # m < n
|
else: # m < n
|
||||||
s = p @ (_H(p) @ _H(a_dot))
|
t1 = p @ (_H(p) @ _H(a_dot))
|
||||||
t = _H(a_dot) @ (_H(p) @ p)
|
t2 = _H(a_dot) @ (_H(p) @ p)
|
||||||
p_dot = -p @ (a_dot @ p) + s - s @ (a @ p) + t - p @ (a @ t)
|
p_dot = -p @ (a_dot @ p) + t1 - t1 @ (a @ p) + t2 - p @ (a @ t2)
|
||||||
|
|
||||||
|
if compute_s:
|
||||||
|
ds = jnp.real(jnp.einsum("...ab,...bc,...ca->...a", _H(u), a_dot, _H(vh)))
|
||||||
|
return (p, s), (p_dot, ds)
|
||||||
|
|
||||||
return p, p_dot
|
return p, p_dot
|
||||||
|
|
||||||
|
|
||||||
@ -612,10 +626,10 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array:
|
|||||||
return lax_linalg._solve(a, b)
|
return lax_linalg._solve(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@jax.default_matmul_precision("float32")
|
||||||
def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
|
def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
|
||||||
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
|
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
|
||||||
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
|
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
|
||||||
# TODO: add custom jvp rule for more robust lstsq differentiation
|
|
||||||
a, b = _promote_dtypes_inexact(a, b)
|
a, b = _promote_dtypes_inexact(a, b)
|
||||||
if a.shape[0] != b.shape[0]:
|
if a.shape[0] != b.shape[0]:
|
||||||
raise ValueError("Leading dimensions of input arrays must match")
|
raise ValueError("Leading dimensions of input arrays must match")
|
||||||
@ -630,29 +644,26 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
|
|||||||
f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
|
f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
|
||||||
m, n = a.shape
|
m, n = a.shape
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
if a.size == 0:
|
if rcond is None:
|
||||||
s = jnp.empty(0, dtype=a.dtype)
|
rcond = jnp.finfo(dtype).eps * max(n, m)
|
||||||
rank = jnp.array(0, dtype=int)
|
|
||||||
x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype)
|
|
||||||
else:
|
else:
|
||||||
if rcond is None:
|
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
|
||||||
rcond = jnp.finfo(dtype).eps * max(n, m)
|
|
||||||
else:
|
inv: Array
|
||||||
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
|
s: Array
|
||||||
u, s, vt = svd(a, full_matrices=False)
|
inv, s = _pinv(a, rcond, hermitian=False, compute_s=True) # type: ignore
|
||||||
|
x = inv @ b
|
||||||
|
if s.size > 0:
|
||||||
mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0]
|
mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0]
|
||||||
rank = mask.sum()
|
rank = mask.sum()
|
||||||
safe_s = jnp.where(mask, s, 1).astype(a.dtype)
|
else:
|
||||||
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
|
rank = jnp.array(0, dtype=int)
|
||||||
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
|
|
||||||
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
|
|
||||||
# Numpy returns empty residuals in some cases. To allow compilation, we
|
# Numpy returns empty residuals in some cases. To allow compilation, we
|
||||||
# default to returning full residuals in all cases.
|
# default to returning full residuals in all cases.
|
||||||
if numpy_resid and (rank < n or m <= n):
|
if numpy_resid and (rank < n or m <= n):
|
||||||
resid = jnp.asarray([])
|
resid = jnp.asarray([])
|
||||||
else:
|
else:
|
||||||
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
|
resid = norm(b - (a @ x), axis=0) ** 2
|
||||||
resid = norm(b - b_estimate, axis=0) ** 2
|
|
||||||
if b_orig_ndim == 1:
|
if b_orig_ndim == 1:
|
||||||
x = x.ravel()
|
x = x.ravel()
|
||||||
return x, resid, rank, s
|
return x, resid, rank, s
|
||||||
|
@ -908,6 +908,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
|||||||
((0, 3), (0,)),
|
((0, 3), (0,)),
|
||||||
((3, 0), (3,)),
|
((3, 0), (3,)),
|
||||||
((3, 1), (3, 0)),
|
((3, 1), (3, 0)),
|
||||||
|
((400000, 2), (400000,)),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
rcond=[-1, None, 0.5],
|
rcond=[-1, None, 0.5],
|
||||||
@ -918,7 +919,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
|||||||
np_fun = partial(np.linalg.lstsq, rcond=rcond)
|
np_fun = partial(np.linalg.lstsq, rcond=rcond)
|
||||||
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
|
jnp_fun = partial(jnp.linalg.lstsq, rcond=rcond)
|
||||||
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
|
jnp_fun_numpy_resid = partial(jnp.linalg.lstsq, rcond=rcond, numpy_resid=True)
|
||||||
tol = {np.float32: 1e-4, np.float64: 1e-12,
|
tol = {np.float32: 1e-3, np.float64: 1e-12,
|
||||||
np.complex64: 1e-5, np.complex128: 1e-12}
|
np.complex64: 1e-5, np.complex128: 1e-12}
|
||||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user