diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 9a64de81b..1ce64d3b2 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -16,6 +16,7 @@ jax.scipy.linalg det eigh expm + expm_frechet inv lu lu_factor diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index cb1a2751b..f2d147774 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -19,6 +19,7 @@ import scipy.linalg import textwrap from jax import jit, vmap +from .. import api from .. import lax from .. import lax_linalg from ..numpy._util import _wraps @@ -232,120 +233,338 @@ def tril(m, k=0): def triu(m, k=0): return jnp.triu(m, k) -@_wraps(scipy.linalg.expm, lax_description=textwrap.dedent("""\ - In addition to the original NumPy argument(s) listed below, - also supports the optional boolean argument ``upper_triangular`` - to specify whether the ``A`` matrix is upper triangular. - """)) -def expm(A, *, upper_triangular=False): - return _expm(A, upper_triangular) +_expm_description = textwrap.dedent(""" +In addition to the original NumPy argument(s) listed below, +also supports the optional boolean argument ``upper_triangular`` +to specify whether the ``A`` matrix is upper triangular. +""") -def _expm(A, upper_triangular=False): - P,Q,n_squarings = _calc_P_Q(A) - R = _solve_P_Q(P, Q, upper_triangular) - R = _squaring(R, n_squarings) - return R +@_wraps(scipy.linalg.expm, lax_description=_expm_description) +def expm(A, *, upper_triangular=False): + return _expm(A, upper_triangular) + +@partial(api.custom_jvp, nondiff_argnums=(1,)) +def _expm(A, upper_triangular): + P, Q, n_squarings = _calc_P_Q(A) + R = _solve_P_Q(P, Q, upper_triangular) + R = _squaring(R, n_squarings) + return R @jit def _calc_P_Q(A): - A = jnp.asarray(A) - if A.ndim != 2 or A.shape[0] != A.shape[1]: - raise ValueError('expected A to be a square matrix') - A_L1 = np_linalg.norm(A,1) - n_squarings = 0 - if A.dtype == 'float64' or A.dtype == 'complex128': - U3,V3 = _pade3(A) - U5,V5 = _pade5(A) - U7,V7 = _pade7(A) - U9,V9 = _pade9(A) - maxnorm = 5.371920351148152 - n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) - A = A / 2**n_squarings - U13,V13 = _pade13(A) - conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000]) - U = jnp.select((maxnorm