From 03b2ae6d5907e76472c0d43e5d7793e80523bb95 Mon Sep 17 00:00:00 2001 From: Sri Hari Krishna Narayanan Date: Wed, 22 Jan 2020 00:11:51 -0500 Subject: [PATCH] Issue1635 expm (#1940) * Issue1635 expm Implemented expm using Pade approximation. The implmentation is wrapped using custom_transforms. Frechet derivatives are provided using defvjp. * Issue1635 expm Implemented expm using Pade approximation based on tf.linalg.expm. * Revert "Revert "Merge remote-tracking branch 'origin/Issue1635' into Issue1635"" This reverts commit dd26c6eeeb60fa556f55abc8acb2f5969b64a2f5, reversing changes made to b63c190c7671ebb9b911a52dcc203285c56a8051. * Issue1635 expm testing Add a test that compares numerical output of scipy.linalg.expm against jax.scipy.linalg.expm * travis build Issue1635 branch * Issue1635 expm testing Use rand_small to get numerical agreeming * Issue1635 expm testing Use @jit to prevent recompilation * Issue1635 expm testing Use rand_small to get numerical agreement * Revert "travis build Issue1635 branch" This reverts commit 6139772555e3af79dc0307fce88838a480e42d38. * Issue1635 Replace construct with jax.numpy.select * Issue1635 Restructure to support the docstring from SciPy * Issue1635 Restructure to support the docstring from SciPy * Issue1635 Remove the note that sparsity is not exploited because JAX does not support sparsity. * Issue1635 expm Support for the case where A is upper triangular. Instead of autodetection, the option is specified explicitly. * Issue1635 Rename argument, make it positional. Update documentation Co-authored-by: Jan --- docs/jax.scipy.rst | 1 + jax/scipy/linalg.py | 117 ++++++++++++++++++++++++++++++++++++++++++- tests/linalg_test.py | 24 +++++++++ 3 files changed, 141 insertions(+), 1 deletion(-) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 693b21804..9b3f07a59 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -14,6 +14,7 @@ jax.scipy.linalg cholesky det eigh + expm inv lu lu_factor diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 9fef2b3ab..b8e82357b 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -19,6 +19,7 @@ from __future__ import print_function from functools import partial import scipy.linalg +import textwrap from jax import jit from .. import lax @@ -27,7 +28,6 @@ from ..numpy.lax_numpy import _wraps from ..numpy import lax_numpy as np from ..numpy import linalg as np_linalg - _T = lambda x: np.swapaxes(x, -1, -2) @partial(jit, static_argnums=(1,)) @@ -264,3 +264,118 @@ def tril(m, k=0): @_wraps(scipy.linalg.triu) def triu(m, k=0): return np.triu(m, k) + +@_wraps(scipy.linalg.expm, lax_description=textwrap.dedent("""\ + In addition to the original NumPy argument(s) listed below, + also supports the optional boolean argument ``upper_triangular`` + to specify whether the ``A`` matrix is upper triangular. + """)) +def expm(A, *, upper_triangular=False): + return _expm(A, upper_triangular) + +def _expm(A, upper_triangular=False): + P,Q,n_squarings = _calc_P_Q(A) + R = _solve_P_Q(P, Q, upper_triangular) + R = _squaring(R, n_squarings) + return R + +@jit +def _calc_P_Q(A): + A = np.asarray(A) + if A.ndim != 2 or A.shape[0] != A.shape[1]: + raise ValueError('expected A to be a square matrix') + A_L1 = np_linalg.norm(A,1) + n_squarings = 0 + if A.dtype == 'float64' or A.dtype == 'complex128': + U3,V3 = _pade3(A) + U5,V5 = _pade5(A) + U7,V7 = _pade7(A) + U9,V9 = _pade9(A) + maxnorm = 5.371920351148152 + n_squarings = np.maximum(0, np.floor_divide(np.log2(A_L1 / maxnorm),1)) + A = A / 2**n_squarings + U13,V13 = _pade13(A) + conds=np.array([1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000]) + U = np.select((maxnorm