Merge pull request #8003 from hawkinsp:jitlinalg

PiperOrigin-RevId: 399179578
This commit is contained in:
jax authors 2021-09-27 06:57:08 -07:00
commit f3b9cac75d
3 changed files with 66 additions and 44 deletions

View File

@ -1254,8 +1254,8 @@ def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
if not compute_uv:
return (s,), (ds,)
s_diffs = jnp.square(s_dim) - jnp.square(_T(s_dim))
s_diffs_zeros = jnp.eye(s.shape[-1], dtype=A.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
s_diffs_zeros = jnp.eye(s.shape[-1], dtype=s.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else
F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
dSS = s_dim * dS # dS.dot(jnp.diag(s))
SdS = _T(s_dim) * dS # jnp.diag(s).dot(dS)

View File

@ -46,18 +46,21 @@ def _promote_arg_dtypes(*args):
@_wraps(np.linalg.cholesky)
@jit
def cholesky(a):
a = _promote_arg_dtypes(jnp.asarray(a))
return lax_linalg.cholesky(a)
@_wraps(np.linalg.svd)
@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
def svd(a, full_matrices=True, compute_uv=True):
a = _promote_arg_dtypes(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices, compute_uv)
@_wraps(np.linalg.matrix_power)
@partial(jit, static_argnames=('n',))
def matrix_power(a, n):
a = _promote_arg_dtypes(jnp.asarray(a))
@ -95,6 +98,7 @@ def matrix_power(a, n):
@_wraps(np.linalg.matrix_rank)
@jit
def matrix_rank(M, tol=None):
M = _promote_arg_dtypes(jnp.asarray(M))
if M.ndim > 2:
@ -288,12 +292,14 @@ def eig(a):
@_wraps(np.linalg.eigvals)
@jit
def eigvals(a):
return lax_linalg.eig(a, compute_left_eigenvectors=False,
compute_right_eigenvectors=False)[0]
@_wraps(np.linalg.eigh)
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
def eigh(a, UPLO=None, symmetrize_input=True):
if UPLO is None or UPLO == "L":
lower = True
@ -309,6 +315,7 @@ def eigh(a, UPLO=None, symmetrize_input=True):
@_wraps(np.linalg.eigvalsh)
@partial(jit, static_argnames=('UPLO',))
def eigvalsh(a, UPLO='L'):
w, _ = eigh(a, UPLO)
return w
@ -320,6 +327,7 @@ def eigvalsh(a, UPLO='L'):
default `rcond` is `1e-15`. Here the default is
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
"""))
@jit
def pinv(a, rcond=None):
# Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
@ -356,6 +364,7 @@ def _pinv_jvp(rcond, primals, tangents):
@_wraps(np.linalg.inv)
@jit
def inv(a):
if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
raise ValueError(
@ -364,8 +373,10 @@ def inv(a):
a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2]))
@partial(jit, static_argnums=(1, 2, 3))
def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims):
@_wraps(np.linalg.norm)
@partial(jit, static_argnames=('ord', 'axis', 'keepdims'))
def norm(x, ord=None, axis : Union[None, Tuple[int, ...], int] = None,
keepdims=False):
x = _promote_arg_dtypes(jnp.asarray(x))
x_shape = jnp.shape(x)
ndim = len(x_shape)
@ -451,12 +462,9 @@ def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims):
raise ValueError(
"Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
@_wraps(np.linalg.norm)
def norm(x, ord=None, axis=None, keepdims=False):
return _norm(x, ord, axis, keepdims)
@_wraps(np.linalg.qr)
@partial(jit, static_argnames=('mode',))
def qr(a, mode="reduced"):
if mode in ("reduced", "r", "full"):
full_matrices = False
@ -478,20 +486,7 @@ def solve(a, b):
return lax_linalg._solve(a, b)
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
the default will be `None`. Here, the default rcond is `None`.
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
solutions. Here, the residuals are returned in all cases, to make the function
compatible with jit. The non-jit compatible numpy behavior can be recovered by
passing numpy_resid=True.
The lstsq function does not currently have a custom JVP rule, so the gradient is
poorly behaved for some inputs, particularly for low-rank `a`.
"""))
def lstsq(a, b, rcond=None, *, numpy_resid=False):
def _lstsq(a, b, rcond, *, numpy_resid=False):
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
# TODO: add custom jvp rule for more robust lstsq differentiation
a, b = _promote_arg_dtypes(a, b)
@ -510,8 +505,8 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False):
dtype = a.dtype
if rcond is None:
rcond = jnp.finfo(dtype).eps * max(n, m)
elif rcond < 0:
rcond = jnp.finfo(dtype).eps
else:
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
u, s, vt = svd(a, full_matrices=False)
mask = s >= rcond * s[0]
rank = mask.sum()
@ -529,3 +524,23 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False):
if b_orig_ndim == 1:
x = x.ravel()
return x, resid, rank, s
_jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
the default will be `None`. Here, the default rcond is `None`.
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
solutions. Here, the residuals are returned in all cases, to make the function
compatible with jit. The non-jit compatible numpy behavior can be recovered by
passing numpy_resid=True.
The lstsq function does not currently have a custom JVP rule, so the gradient is
poorly behaved for some inputs, particularly for low-rank `a`.
"""))
def lstsq(a, b, rcond=None, *, numpy_resid=False):
if numpy_resid:
return _lstsq(a, b, rcond, numpy_resid=True)
return _jit_lstsq(a, b, rcond)

View File

@ -29,7 +29,7 @@ from jax._src.numpy import linalg as np_linalg
_T = lambda x: jnp.swapaxes(x, -1, -2)
@partial(jit, static_argnums=(1,))
@partial(jit, static_argnames=('lower',))
def _cholesky(a, lower):
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
l = lax_linalg.cholesky(a if lower else jnp.conj(_T(a)), symmetrize_input=False)
@ -44,7 +44,7 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
return (cholesky(a, lower=lower), lower)
@partial(jit, static_argnums=(2,))
@partial(jit, static_argnames=('lower',))
def _cho_solve(c, b, lower):
c, b = np_linalg._promote_arg_dtypes(jnp.asarray(c), jnp.asarray(b))
lax_linalg._check_solve_shapes(c, b)
@ -60,13 +60,17 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
c, lower = c_and_lower
return _cho_solve(c, b, lower)
@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
def _svd(a, *, full_matrices, compute_uv):
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices, compute_uv)
@_wraps(scipy.linalg.svd)
def svd(a, full_matrices=True, compute_uv=True, overwrite_a=False,
check_finite=True, lapack_driver='gesdd'):
del overwrite_a, check_finite, lapack_driver
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices, compute_uv)
return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
@_wraps(scipy.linalg.det)
def det(a, overwrite_a=False, check_finite=True):
@ -74,11 +78,8 @@ def det(a, overwrite_a=False, check_finite=True):
return np_linalg.det(a)
@_wraps(scipy.linalg.eigh)
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
overwrite_b=False, turbo=True, eigvals=None, type=1,
check_finite=True):
del overwrite_a, overwrite_b, turbo, check_finite
@partial(jit, static_argnames=('lower', 'eigvals_only', 'eigvals', 'type'))
def _eigh(a, b, lower, eigvals_only, eigvals, type):
if b is not None:
raise NotImplementedError("Only the b=None case of eigh is implemented")
if type != 1:
@ -95,6 +96,14 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
else:
return w, v
@_wraps(scipy.linalg.eigh)
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
overwrite_b=False, turbo=True, eigvals=None, type=1,
check_finite=True):
del overwrite_a, overwrite_b, turbo, check_finite
return _eigh(a, b, lower, eigvals_only, eigvals, type)
@_wraps(scipy.linalg.inv)
def inv(a, overwrite_a=False, check_finite=True):
@ -103,6 +112,7 @@ def inv(a, overwrite_a=False, check_finite=True):
@_wraps(scipy.linalg.lu_factor)
@partial(jit, static_argnames=('overwrite_a', 'check_finite'))
def lu_factor(a, overwrite_a=False, check_finite=True):
del overwrite_a, check_finite
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
@ -111,6 +121,7 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
@_wraps(scipy.linalg.lu_solve)
@partial(jit, static_argnames=('trans', 'overwrite_a', 'check_finite'))
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
del overwrite_b, check_finite
lu, pivots = lu_and_piv
@ -135,11 +146,12 @@ def _lu(a, permute_l):
return p, l, u
@_wraps(scipy.linalg.lu, update_doc=False)
@partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite'))
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
del overwrite_a, check_finite
return _lu(a, permute_l)
@partial(jit, static_argnums=(1, 2))
@partial(jit, static_argnames=('mode', 'pivoting'))
def _qr(a, mode, pivoting):
if pivoting:
raise NotImplementedError(
@ -163,7 +175,7 @@ def qr(a, overwrite_a=False, lwork=None, mode="full", pivoting=False,
return _qr(a, mode, pivoting)
@partial(jit, static_argnums=(2, 3))
@partial(jit, static_argnames=('sym_pos', 'lower'))
def _solve(a, b, sym_pos, lower):
if not sym_pos:
return np_linalg.solve(a, b)
@ -193,7 +205,7 @@ def solve(a, b, sym_pos=False, lower=False, overwrite_a=False, overwrite_b=False
del overwrite_a, overwrite_b, debug, check_finite
return _solve(a, b, sym_pos, lower)
@partial(jit, static_argnums=(2, 3, 4))
@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))
def _solve_triangular(a, b, trans, lower, unit_diagonal):
if trans == 0 or trans == "N":
transpose_a, conjugate_a = False, False
@ -252,11 +264,8 @@ where norm() denotes the L1 norm, and
""")
@_wraps(scipy.linalg.expm, lax_description=_expm_description)
@partial(jit, static_argnames=('upper_triangular', 'max_squarings'))
def expm(A, *, upper_triangular=False, max_squarings=16):
return _expm(A, upper_triangular, max_squarings)
@partial(jit, static_argnums=(1, 2))
def _expm(A, upper_triangular, max_squarings):
P, Q, n_squarings = _calc_P_Q(A)
def _nan(args):
@ -391,10 +400,8 @@ support the ``method='blockEnlarge'`` argument.
""")
@_wraps(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description)
@partial(jit, static_argnames=('method', 'compute_expm'))
def expm_frechet(A, E, *, method=None, compute_expm=True):
return _expm_frechet(A, E, method, compute_expm)
def _expm_frechet(A, E, method=None, compute_expm=True):
A = jnp.asarray(A)
E = jnp.asarray(E)
if A.ndim != 2 or A.shape[0] != A.shape[1]: