Issue1635 expm (#1940)

* Issue1635 expm
Implemented expm using Pade approximation. The implmentation is
wrapped using custom_transforms. Frechet derivatives are provided
using defvjp.

* Issue1635 expm

Implemented expm using Pade approximation based on tf.linalg.expm.

* Revert "Revert "Merge remote-tracking branch 'origin/Issue1635' into Issue1635""

This reverts commit dd26c6eeeb60fa556f55abc8acb2f5969b64a2f5, reversing
changes made to b63c190c7671ebb9b911a52dcc203285c56a8051.

* Issue1635 expm testing

Add a test that compares numerical output of scipy.linalg.expm against jax.scipy.linalg.expm

* travis build Issue1635 branch

* Issue1635 expm testing

Use rand_small to get numerical agreeming

* Issue1635 expm testing
Use @jit to prevent recompilation

* Issue1635 expm testing

Use rand_small to get numerical agreement

* Revert "travis build Issue1635 branch"

This reverts commit 6139772555e3af79dc0307fce88838a480e42d38.

* Issue1635

Replace construct with  jax.numpy.select

* Issue1635

Restructure to support the docstring from SciPy

* Issue1635

Restructure to support the docstring from SciPy

* Issue1635

Remove the note that sparsity is not exploited because JAX does not support sparsity.

* Issue1635 expm

Support for the case where A is upper triangular. Instead of autodetection, the option is specified explicitly.

* Issue1635

Rename argument, make it positional. Update documentation

Co-authored-by: Jan <j.hueckelheim@imperial.ac.uk>
This commit is contained in:
Sri Hari Krishna Narayanan 2020-01-22 00:11:51 -05:00 committed by Stephan Hoyer
parent f04348ed53
commit 03b2ae6d59
3 changed files with 141 additions and 1 deletions

View File

@ -14,6 +14,7 @@ jax.scipy.linalg
cholesky
det
eigh
expm
inv
lu
lu_factor

View File

@ -19,6 +19,7 @@ from __future__ import print_function
from functools import partial
import scipy.linalg
import textwrap
from jax import jit
from .. import lax
@ -27,7 +28,6 @@ from ..numpy.lax_numpy import _wraps
from ..numpy import lax_numpy as np
from ..numpy import linalg as np_linalg
_T = lambda x: np.swapaxes(x, -1, -2)
@partial(jit, static_argnums=(1,))
@ -264,3 +264,118 @@ def tril(m, k=0):
@_wraps(scipy.linalg.triu)
def triu(m, k=0):
return np.triu(m, k)
@_wraps(scipy.linalg.expm, lax_description=textwrap.dedent("""\
In addition to the original NumPy argument(s) listed below,
also supports the optional boolean argument ``upper_triangular``
to specify whether the ``A`` matrix is upper triangular.
"""))
def expm(A, *, upper_triangular=False):
return _expm(A, upper_triangular)
def _expm(A, upper_triangular=False):
P,Q,n_squarings = _calc_P_Q(A)
R = _solve_P_Q(P, Q, upper_triangular)
R = _squaring(R, n_squarings)
return R
@jit
def _calc_P_Q(A):
A = np.asarray(A)
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError('expected A to be a square matrix')
A_L1 = np_linalg.norm(A,1)
n_squarings = 0
if A.dtype == 'float64' or A.dtype == 'complex128':
U3,V3 = _pade3(A)
U5,V5 = _pade5(A)
U7,V7 = _pade7(A)
U9,V9 = _pade9(A)
maxnorm = 5.371920351148152
n_squarings = np.maximum(0, np.floor_divide(np.log2(A_L1 / maxnorm),1))
A = A / 2**n_squarings
U13,V13 = _pade13(A)
conds=np.array([1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000])
U = np.select((maxnorm<conds),(U3,U5,U7,U9),U13)
V = np.select((maxnorm<conds),(V3,V5,V7,V9),V13)
elif A.dtype == 'float32' or A.dtype == 'complex64':
U3,V3 = _pade3(A)
U5,V5 = _pade5(A)
maxnorm = 3.925724783138660
n_squarings = np.maximum(0, np.floor_divide(np.log2(A_L1 / maxnorm),1))
A = A / 2**n_squarings
U7,V7 = _pade7(A)
conds=np.array([4.258730016922831e-001, 1.880152677804762e+000])
U = np.select((maxnorm<conds),(U3,U5),U7)
V = np.select((maxnorm<conds),(V3,V5),V7)
else:
raise TypeError("A.dtype={} is not supported.".format(A.dtype))
P = U + V # p_m(A) : numerator
Q = -U + V # q_m(A) : denominator
return P,Q,n_squarings
def _solve_P_Q(P, Q, upper_triangular=False):
if upper_triangular:
return solve_triangular(Q, P)
else:
return np_linalg.solve(Q,P)
@jit
def _squaring(R, n_squarings):
# squaring step to undo scaling
def my_body_fun(i,R):
return np.dot(R,R)
lower = np.zeros(1, dtype=n_squarings.dtype)
R = lax.fori_loop(lower[0],n_squarings,my_body_fun,R)
return R
def _pade3(A):
b = (120., 60., 12., 1.)
ident = np.eye(*A.shape, dtype=A.dtype)
A2 = np.dot(A,A)
U = np.dot(A , (b[3]*A2 + b[1]*ident))
V = b[2]*A2 + b[0]*ident
return U,V
def _pade5(A):
b = (30240., 15120., 3360., 420., 30., 1.)
ident = np.eye(*A.shape, dtype=A.dtype)
A2 = np.dot(A,A)
A4 = np.dot(A2,A2)
U = np.dot(A, b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
def _pade7(A):
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
ident = np.eye(*A.shape, dtype=A.dtype)
A2 = np.dot(A,A)
A4 = np.dot(A2,A2)
A6 = np.dot(A4,A2)
U = np.dot(A, b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
def _pade9(A):
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
2162160., 110880., 3960., 90., 1.)
ident = np.eye(*A.shape, dtype=A.dtype)
A2 = np.dot(A,A)
A4 = np.dot(A2,A2)
A6 = np.dot(A4,A2)
A8 = np.dot(A6,A2)
U = np.dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
def _pade13(A):
b = (64764752532480000., 32382376266240000., 7771770303897600.,
1187353796428800., 129060195264000., 10559470521600., 670442572800.,
33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
ident = np.eye(*A.shape, dtype=A.dtype)
A2 = np.dot(A,A)
A4 = np.dot(A2,A2)
A6 = np.dot(A4,A2)
U = np.dot(A,np.dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = np.dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V

View File

@ -896,5 +896,29 @@ class ScipyLinalgTest(jtu.JaxTestCase):
(a, b),
(a, b))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
"n": n, "dtype": dtype, "rng_factory": rng_factory}
for n in [1, 4, 5, 20, 50, 100]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_small]))
def testExpm(self, n, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng((n, n), dtype)]
osp_fun = lambda a: osp.linalg.expm(a)
jsp_fun = lambda a: jsp.linalg.expm(a)
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
check_dtypes=True)
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
args_maker_triu = lambda: [onp.triu(rng((n, n), dtype))]
jsp_fun_triu = lambda a: jsp.linalg.expm(a,upper_triangular=True)
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu,
check_dtypes=True)
self._CompileAndCheck(jsp_fun_triu, args_maker_triu, check_dtypes=True)
if __name__ == "__main__":
absltest.main()