mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #8003 from hawkinsp:jitlinalg
PiperOrigin-RevId: 399179578
This commit is contained in:
commit
f3b9cac75d
@ -1254,8 +1254,8 @@ def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
|
||||
if not compute_uv:
|
||||
return (s,), (ds,)
|
||||
|
||||
s_diffs = jnp.square(s_dim) - jnp.square(_T(s_dim))
|
||||
s_diffs_zeros = jnp.eye(s.shape[-1], dtype=A.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else
|
||||
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
|
||||
s_diffs_zeros = jnp.eye(s.shape[-1], dtype=s.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else
|
||||
F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
|
||||
dSS = s_dim * dS # dS.dot(jnp.diag(s))
|
||||
SdS = _T(s_dim) * dS # jnp.diag(s).dot(dS)
|
||||
|
@ -46,18 +46,21 @@ def _promote_arg_dtypes(*args):
|
||||
|
||||
|
||||
@_wraps(np.linalg.cholesky)
|
||||
@jit
|
||||
def cholesky(a):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
return lax_linalg.cholesky(a)
|
||||
|
||||
|
||||
@_wraps(np.linalg.svd)
|
||||
@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
|
||||
def svd(a, full_matrices=True, compute_uv=True):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
return lax_linalg.svd(a, full_matrices, compute_uv)
|
||||
|
||||
|
||||
@_wraps(np.linalg.matrix_power)
|
||||
@partial(jit, static_argnames=('n',))
|
||||
def matrix_power(a, n):
|
||||
a = _promote_arg_dtypes(jnp.asarray(a))
|
||||
|
||||
@ -95,6 +98,7 @@ def matrix_power(a, n):
|
||||
|
||||
|
||||
@_wraps(np.linalg.matrix_rank)
|
||||
@jit
|
||||
def matrix_rank(M, tol=None):
|
||||
M = _promote_arg_dtypes(jnp.asarray(M))
|
||||
if M.ndim > 2:
|
||||
@ -288,12 +292,14 @@ def eig(a):
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigvals)
|
||||
@jit
|
||||
def eigvals(a):
|
||||
return lax_linalg.eig(a, compute_left_eigenvectors=False,
|
||||
compute_right_eigenvectors=False)[0]
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigh)
|
||||
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
|
||||
def eigh(a, UPLO=None, symmetrize_input=True):
|
||||
if UPLO is None or UPLO == "L":
|
||||
lower = True
|
||||
@ -309,6 +315,7 @@ def eigh(a, UPLO=None, symmetrize_input=True):
|
||||
|
||||
|
||||
@_wraps(np.linalg.eigvalsh)
|
||||
@partial(jit, static_argnames=('UPLO',))
|
||||
def eigvalsh(a, UPLO='L'):
|
||||
w, _ = eigh(a, UPLO)
|
||||
return w
|
||||
@ -320,6 +327,7 @@ def eigvalsh(a, UPLO='L'):
|
||||
default `rcond` is `1e-15`. Here the default is
|
||||
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
|
||||
"""))
|
||||
@jit
|
||||
def pinv(a, rcond=None):
|
||||
# Uses same algorithm as
|
||||
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
|
||||
@ -356,6 +364,7 @@ def _pinv_jvp(rcond, primals, tangents):
|
||||
|
||||
|
||||
@_wraps(np.linalg.inv)
|
||||
@jit
|
||||
def inv(a):
|
||||
if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
|
||||
raise ValueError(
|
||||
@ -364,8 +373,10 @@ def inv(a):
|
||||
a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2]))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2, 3))
|
||||
def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims):
|
||||
@_wraps(np.linalg.norm)
|
||||
@partial(jit, static_argnames=('ord', 'axis', 'keepdims'))
|
||||
def norm(x, ord=None, axis : Union[None, Tuple[int, ...], int] = None,
|
||||
keepdims=False):
|
||||
x = _promote_arg_dtypes(jnp.asarray(x))
|
||||
x_shape = jnp.shape(x)
|
||||
ndim = len(x_shape)
|
||||
@ -451,12 +462,9 @@ def _norm(x, ord, axis: Union[None, Tuple[int, ...], int], keepdims):
|
||||
raise ValueError(
|
||||
"Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
|
||||
|
||||
@_wraps(np.linalg.norm)
|
||||
def norm(x, ord=None, axis=None, keepdims=False):
|
||||
return _norm(x, ord, axis, keepdims)
|
||||
|
||||
|
||||
@_wraps(np.linalg.qr)
|
||||
@partial(jit, static_argnames=('mode',))
|
||||
def qr(a, mode="reduced"):
|
||||
if mode in ("reduced", "r", "full"):
|
||||
full_matrices = False
|
||||
@ -478,20 +486,7 @@ def solve(a, b):
|
||||
return lax_linalg._solve(a, b)
|
||||
|
||||
|
||||
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
|
||||
It has two important differences:
|
||||
|
||||
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
|
||||
the default will be `None`. Here, the default rcond is `None`.
|
||||
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
|
||||
solutions. Here, the residuals are returned in all cases, to make the function
|
||||
compatible with jit. The non-jit compatible numpy behavior can be recovered by
|
||||
passing numpy_resid=True.
|
||||
|
||||
The lstsq function does not currently have a custom JVP rule, so the gradient is
|
||||
poorly behaved for some inputs, particularly for low-rank `a`.
|
||||
"""))
|
||||
def lstsq(a, b, rcond=None, *, numpy_resid=False):
|
||||
def _lstsq(a, b, rcond, *, numpy_resid=False):
|
||||
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
|
||||
# TODO: add custom jvp rule for more robust lstsq differentiation
|
||||
a, b = _promote_arg_dtypes(a, b)
|
||||
@ -510,8 +505,8 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False):
|
||||
dtype = a.dtype
|
||||
if rcond is None:
|
||||
rcond = jnp.finfo(dtype).eps * max(n, m)
|
||||
elif rcond < 0:
|
||||
rcond = jnp.finfo(dtype).eps
|
||||
else:
|
||||
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
|
||||
u, s, vt = svd(a, full_matrices=False)
|
||||
mask = s >= rcond * s[0]
|
||||
rank = mask.sum()
|
||||
@ -529,3 +524,23 @@ def lstsq(a, b, rcond=None, *, numpy_resid=False):
|
||||
if b_orig_ndim == 1:
|
||||
x = x.ravel()
|
||||
return x, resid, rank, s
|
||||
|
||||
_jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
|
||||
|
||||
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
|
||||
It has two important differences:
|
||||
|
||||
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
|
||||
the default will be `None`. Here, the default rcond is `None`.
|
||||
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
|
||||
solutions. Here, the residuals are returned in all cases, to make the function
|
||||
compatible with jit. The non-jit compatible numpy behavior can be recovered by
|
||||
passing numpy_resid=True.
|
||||
|
||||
The lstsq function does not currently have a custom JVP rule, so the gradient is
|
||||
poorly behaved for some inputs, particularly for low-rank `a`.
|
||||
"""))
|
||||
def lstsq(a, b, rcond=None, *, numpy_resid=False):
|
||||
if numpy_resid:
|
||||
return _lstsq(a, b, rcond, numpy_resid=True)
|
||||
return _jit_lstsq(a, b, rcond)
|
||||
|
@ -29,7 +29,7 @@ from jax._src.numpy import linalg as np_linalg
|
||||
|
||||
_T = lambda x: jnp.swapaxes(x, -1, -2)
|
||||
|
||||
@partial(jit, static_argnums=(1,))
|
||||
@partial(jit, static_argnames=('lower',))
|
||||
def _cholesky(a, lower):
|
||||
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
|
||||
l = lax_linalg.cholesky(a if lower else jnp.conj(_T(a)), symmetrize_input=False)
|
||||
@ -44,7 +44,7 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
||||
def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
||||
return (cholesky(a, lower=lower), lower)
|
||||
|
||||
@partial(jit, static_argnums=(2,))
|
||||
@partial(jit, static_argnames=('lower',))
|
||||
def _cho_solve(c, b, lower):
|
||||
c, b = np_linalg._promote_arg_dtypes(jnp.asarray(c), jnp.asarray(b))
|
||||
lax_linalg._check_solve_shapes(c, b)
|
||||
@ -60,13 +60,17 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
||||
c, lower = c_and_lower
|
||||
return _cho_solve(c, b, lower)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
|
||||
def _svd(a, *, full_matrices, compute_uv):
|
||||
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
|
||||
return lax_linalg.svd(a, full_matrices, compute_uv)
|
||||
|
||||
@_wraps(scipy.linalg.svd)
|
||||
def svd(a, full_matrices=True, compute_uv=True, overwrite_a=False,
|
||||
check_finite=True, lapack_driver='gesdd'):
|
||||
del overwrite_a, check_finite, lapack_driver
|
||||
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
|
||||
return lax_linalg.svd(a, full_matrices, compute_uv)
|
||||
|
||||
return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
@_wraps(scipy.linalg.det)
|
||||
def det(a, overwrite_a=False, check_finite=True):
|
||||
@ -74,11 +78,8 @@ def det(a, overwrite_a=False, check_finite=True):
|
||||
return np_linalg.det(a)
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.eigh)
|
||||
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
||||
overwrite_b=False, turbo=True, eigvals=None, type=1,
|
||||
check_finite=True):
|
||||
del overwrite_a, overwrite_b, turbo, check_finite
|
||||
@partial(jit, static_argnames=('lower', 'eigvals_only', 'eigvals', 'type'))
|
||||
def _eigh(a, b, lower, eigvals_only, eigvals, type):
|
||||
if b is not None:
|
||||
raise NotImplementedError("Only the b=None case of eigh is implemented")
|
||||
if type != 1:
|
||||
@ -95,6 +96,14 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
||||
else:
|
||||
return w, v
|
||||
|
||||
@_wraps(scipy.linalg.eigh)
|
||||
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
||||
overwrite_b=False, turbo=True, eigvals=None, type=1,
|
||||
check_finite=True):
|
||||
del overwrite_a, overwrite_b, turbo, check_finite
|
||||
return _eigh(a, b, lower, eigvals_only, eigvals, type)
|
||||
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.inv)
|
||||
def inv(a, overwrite_a=False, check_finite=True):
|
||||
@ -103,6 +112,7 @@ def inv(a, overwrite_a=False, check_finite=True):
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.lu_factor)
|
||||
@partial(jit, static_argnames=('overwrite_a', 'check_finite'))
|
||||
def lu_factor(a, overwrite_a=False, check_finite=True):
|
||||
del overwrite_a, check_finite
|
||||
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
|
||||
@ -111,6 +121,7 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.lu_solve)
|
||||
@partial(jit, static_argnames=('trans', 'overwrite_a', 'check_finite'))
|
||||
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
||||
del overwrite_b, check_finite
|
||||
lu, pivots = lu_and_piv
|
||||
@ -135,11 +146,12 @@ def _lu(a, permute_l):
|
||||
return p, l, u
|
||||
|
||||
@_wraps(scipy.linalg.lu, update_doc=False)
|
||||
@partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite'))
|
||||
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
||||
del overwrite_a, check_finite
|
||||
return _lu(a, permute_l)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
@partial(jit, static_argnames=('mode', 'pivoting'))
|
||||
def _qr(a, mode, pivoting):
|
||||
if pivoting:
|
||||
raise NotImplementedError(
|
||||
@ -163,7 +175,7 @@ def qr(a, overwrite_a=False, lwork=None, mode="full", pivoting=False,
|
||||
return _qr(a, mode, pivoting)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2, 3))
|
||||
@partial(jit, static_argnames=('sym_pos', 'lower'))
|
||||
def _solve(a, b, sym_pos, lower):
|
||||
if not sym_pos:
|
||||
return np_linalg.solve(a, b)
|
||||
@ -193,7 +205,7 @@ def solve(a, b, sym_pos=False, lower=False, overwrite_a=False, overwrite_b=False
|
||||
del overwrite_a, overwrite_b, debug, check_finite
|
||||
return _solve(a, b, sym_pos, lower)
|
||||
|
||||
@partial(jit, static_argnums=(2, 3, 4))
|
||||
@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))
|
||||
def _solve_triangular(a, b, trans, lower, unit_diagonal):
|
||||
if trans == 0 or trans == "N":
|
||||
transpose_a, conjugate_a = False, False
|
||||
@ -252,11 +264,8 @@ where norm() denotes the L1 norm, and
|
||||
""")
|
||||
|
||||
@_wraps(scipy.linalg.expm, lax_description=_expm_description)
|
||||
@partial(jit, static_argnames=('upper_triangular', 'max_squarings'))
|
||||
def expm(A, *, upper_triangular=False, max_squarings=16):
|
||||
return _expm(A, upper_triangular, max_squarings)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def _expm(A, upper_triangular, max_squarings):
|
||||
P, Q, n_squarings = _calc_P_Q(A)
|
||||
|
||||
def _nan(args):
|
||||
@ -391,10 +400,8 @@ support the ``method='blockEnlarge'`` argument.
|
||||
""")
|
||||
|
||||
@_wraps(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description)
|
||||
@partial(jit, static_argnames=('method', 'compute_expm'))
|
||||
def expm_frechet(A, E, *, method=None, compute_expm=True):
|
||||
return _expm_frechet(A, E, method, compute_expm)
|
||||
|
||||
def _expm_frechet(A, E, method=None, compute_expm=True):
|
||||
A = jnp.asarray(A)
|
||||
E = jnp.asarray(E)
|
||||
if A.ndim != 2 or A.shape[0] != A.shape[1]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user