Issue1635 expm frechet (#2062)

* Implement Frechet derivatives for expm.

* Update expm to use the current custom gradients API.

Make some stylistic fixes.

Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
Sri Hari Krishna Narayanan 2020-06-28 12:11:12 -04:00 committed by GitHub
parent 1f2025e12f
commit 7b57dc8c80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 353 additions and 95 deletions

View File

@ -16,6 +16,7 @@ jax.scipy.linalg
det
eigh
expm
expm_frechet
inv
lu
lu_factor

View File

@ -19,6 +19,7 @@ import scipy.linalg
import textwrap
from jax import jit, vmap
from .. import api
from .. import lax
from .. import lax_linalg
from ..numpy._util import _wraps
@ -232,120 +233,338 @@ def tril(m, k=0):
def triu(m, k=0):
return jnp.triu(m, k)
@_wraps(scipy.linalg.expm, lax_description=textwrap.dedent("""\
In addition to the original NumPy argument(s) listed below,
also supports the optional boolean argument ``upper_triangular``
to specify whether the ``A`` matrix is upper triangular.
"""))
def expm(A, *, upper_triangular=False):
return _expm(A, upper_triangular)
_expm_description = textwrap.dedent("""
In addition to the original NumPy argument(s) listed below,
also supports the optional boolean argument ``upper_triangular``
to specify whether the ``A`` matrix is upper triangular.
""")
def _expm(A, upper_triangular=False):
P,Q,n_squarings = _calc_P_Q(A)
R = _solve_P_Q(P, Q, upper_triangular)
R = _squaring(R, n_squarings)
return R
@_wraps(scipy.linalg.expm, lax_description=_expm_description)
def expm(A, *, upper_triangular=False):
return _expm(A, upper_triangular)
@partial(api.custom_jvp, nondiff_argnums=(1,))
def _expm(A, upper_triangular):
P, Q, n_squarings = _calc_P_Q(A)
R = _solve_P_Q(P, Q, upper_triangular)
R = _squaring(R, n_squarings)
return R
@jit
def _calc_P_Q(A):
A = jnp.asarray(A)
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError('expected A to be a square matrix')
A_L1 = np_linalg.norm(A,1)
n_squarings = 0
if A.dtype == 'float64' or A.dtype == 'complex128':
U3,V3 = _pade3(A)
U5,V5 = _pade5(A)
U7,V7 = _pade7(A)
U9,V9 = _pade9(A)
maxnorm = 5.371920351148152
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
A = A / 2**n_squarings
U13,V13 = _pade13(A)
conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000])
U = jnp.select((maxnorm<conds),(U3,U5,U7,U9),U13)
V = jnp.select((maxnorm<conds),(V3,V5,V7,V9),V13)
elif A.dtype == 'float32' or A.dtype == 'complex64':
U3,V3 = _pade3(A)
U5,V5 = _pade5(A)
maxnorm = 3.925724783138660
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
A = A / 2**n_squarings
U7,V7 = _pade7(A)
conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
U = jnp.select((maxnorm<conds),(U3,U5),U7)
V = jnp.select((maxnorm<conds),(V3,V5),V7)
else:
raise TypeError("A.dtype={} is not supported.".format(A.dtype))
P = U + V # p_m(A) : numerator
Q = -U + V # q_m(A) : denominator
return P,Q,n_squarings
A = jnp.asarray(A)
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError('expected A to be a square matrix')
A_L1 = np_linalg.norm(A,1)
n_squarings = 0
if A.dtype == 'float64' or A.dtype == 'complex128':
U3, V3 = _pade3(A)
U5, V5 = _pade5(A)
U7, V7 = _pade7(A)
U9, V9 = _pade9(A)
maxnorm = 5.371920351148152
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
A = A / 2**n_squarings
U13, V13 = _pade13(A)
conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001,
9.504178996162932e-001, 2.097847961257068e+000])
U = jnp.select((maxnorm<conds), (U3, U5, U7, U9), U13)
V = jnp.select((maxnorm<conds), (V3, V5, V7, V9), V13)
elif A.dtype == 'float32' or A.dtype == 'complex64':
U3,V3 = _pade3(A)
U5,V5 = _pade5(A)
maxnorm = 3.925724783138660
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
A = A / 2**n_squarings
U7,V7 = _pade7(A)
conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
U = jnp.select((maxnorm<conds), (U3, U5), U7)
V = jnp.select((maxnorm<conds), (V3, V5), V7)
else:
raise TypeError("A.dtype={} is not supported.".format(A.dtype))
P = U + V # p_m(A) : numerator
Q = -U + V # q_m(A) : denominator
return P, Q, n_squarings
def _solve_P_Q(P, Q, upper_triangular=False):
if upper_triangular:
return solve_triangular(Q, P)
else:
return np_linalg.solve(Q,P)
if upper_triangular:
return solve_triangular(Q, P)
else:
return np_linalg.solve(Q, P)
@jit
def _squaring(R, n_squarings):
# squaring step to undo scaling
def my_body_fun(i,R):
return jnp.dot(R,R)
lower = jnp.zeros(1, dtype=n_squarings.dtype)
R = lax.fori_loop(lower[0],n_squarings,my_body_fun,R)
return R
# squaring step to undo scaling
def my_body_fun(i,R):
return jnp.dot(R,R)
lower = jnp.zeros(1, dtype=n_squarings.dtype)
R = lax.fori_loop(lower[0], n_squarings, my_body_fun, R)
return R
def _pade3(A):
b = (120., 60., 12., 1.)
ident = jnp.eye(*A.shape, dtype=A.dtype)
A2 = jnp.dot(A,A)
U = jnp.dot(A , (b[3]*A2 + b[1]*ident))
V = b[2]*A2 + b[0]*ident
return U,V
b = (120., 60., 12., 1.)
ident = jnp.eye(*A.shape, dtype=A.dtype)
A2 = jnp.dot(A, A)
U = jnp.dot(A, (b[3]*A2 + b[1]*ident))
V = b[2]*A2 + b[0]*ident
return U, V
def _pade5(A):
b = (30240., 15120., 3360., 420., 30., 1.)
ident = jnp.eye(*A.shape, dtype=A.dtype)
A2 = jnp.dot(A,A)
A4 = jnp.dot(A2,A2)
U = jnp.dot(A, b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
b = (30240., 15120., 3360., 420., 30., 1.)
ident = jnp.eye(*A.shape, dtype=A.dtype)
A2 = jnp.dot(A, A)
A4 = jnp.dot(A2, A2)
U = jnp.dot(A, b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[4]*A4 + b[2]*A2 + b[0]*ident
return U, V
def _pade7(A):
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
ident = jnp.eye(*A.shape, dtype=A.dtype)
A2 = jnp.dot(A,A)
A4 = jnp.dot(A2,A2)
A6 = jnp.dot(A4,A2)
U = jnp.dot(A, b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
ident = jnp.eye(*A.shape, dtype=A.dtype)
A2 = jnp.dot(A, A)
A4 = jnp.dot(A2, A2)
A6 = jnp.dot(A4, A2)
U = jnp.dot(A, b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
def _pade9(A):
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
2162160., 110880., 3960., 90., 1.)
ident = jnp.eye(*A.shape, dtype=A.dtype)
A2 = jnp.dot(A,A)
A4 = jnp.dot(A2,A2)
A6 = jnp.dot(A4,A2)
A8 = jnp.dot(A6,A2)
U = jnp.dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
2162160., 110880., 3960., 90., 1.)
ident = jnp.eye(*A.shape, dtype=A.dtype)
A2 = jnp.dot(A, A)
A4 = jnp.dot(A2, A2)
A6 = jnp.dot(A4, A2)
A8 = jnp.dot(A6, A2)
U = jnp.dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
def _pade13(A):
b = (64764752532480000., 32382376266240000., 7771770303897600.,
1187353796428800., 129060195264000., 10559470521600., 670442572800.,
33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
ident = jnp.eye(*A.shape, dtype=A.dtype)
A2 = jnp.dot(A,A)
A4 = jnp.dot(A2,A2)
A6 = jnp.dot(A4,A2)
U = jnp.dot(A,jnp.dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = jnp.dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
b = (64764752532480000., 32382376266240000., 7771770303897600.,
1187353796428800., 129060195264000., 10559470521600., 670442572800.,
33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
ident = jnp.eye(*A.shape, dtype=A.dtype)
A2 = jnp.dot(A, A)
A4 = jnp.dot(A2, A2)
A6 = jnp.dot(A4, A2)
U = jnp.dot(A, jnp.dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = jnp.dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
return U,V
_expm_frechet_description = textwrap.dedent("""
Does not currently support the Scipy argument ``jax.numpy.asarray_chkfinite``,
because `jax.numpy.asarray_chkfinite` does not exist at the moment. Does not
support the ``method='blockEnlarge'`` argument.
""")
@_wraps(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description)
def expm_frechet(A, E, *, method=None, compute_expm=True):
return _expm_frechet(A, E, method, compute_expm)
def _expm_frechet(A, E, method=None, compute_expm=True):
A = jnp.asarray(A)
E = jnp.asarray(E)
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError('expected A to be a square matrix')
if E.ndim != 2 or E.shape[0] != E.shape[1]:
raise ValueError('expected E to be a square matrix')
if A.shape != E.shape:
raise ValueError('expected A and E to be the same shape')
if method is None:
method = 'SPS'
if method == 'SPS':
expm_A, expm_frechet_AE = expm_frechet_algo_64(A, E)
else:
raise ValueError('only method=\'SPS\' is supported')
expm_A, expm_frechet_AE = expm_frechet_algo_64(A, E)
if compute_expm:
return expm_A, expm_frechet_AE
else:
return expm_frechet_AE
"""
Maximal values ell_m of ||2**-s A|| such that the backward error bound
does not exceed 2**-53.
"""
ell_table_61 = (
None,
# 1
2.11e-8,
3.56e-4,
1.08e-2,
6.49e-2,
2.00e-1,
4.37e-1,
7.83e-1,
1.23e0,
1.78e0,
2.42e0,
# 11
3.13e0,
3.90e0,
4.74e0,
5.63e0,
6.56e0,
7.52e0,
8.53e0,
9.56e0,
1.06e1,
1.17e1,
)
@jit
def expm_frechet_algo_64(A, E):
ident = jnp.eye(*A.shape, dtype=A.dtype)
A_norm_1 = np_linalg.norm(A, 1)
"""
Subset of the Maximal values ell_m of ||2**-s A||
such that the backward error bound does not exceed 2**-53.
"""
args = (A, E, ident, A_norm_1)
U3579, V3579, Lu3579, Lv3579, s3579 = lax.cond(A_norm_1 <= ell_table_61[3],
args, lambda args: _diff_pade3(args),
args, lambda args: lax.cond(A_norm_1 <= ell_table_61[5],
args, lambda args: _diff_pade5(args),
args, lambda args: lax.cond(A_norm_1 <= ell_table_61[7],
args, lambda args: _diff_pade7(args),
args, lambda args: _diff_pade9(args))))
U13, V13, Lu13, Lv13, s13 = _diff_pade13(args)
# Must be of minimum length 2 for np.select to be used
ell_table_61_local99 = jnp.array([ell_table_61[9], ell_table_61[9]])
U = jnp.select((A_norm_1<=ell_table_61_local99), (U3579, U3579), U13)
V = jnp.select((A_norm_1<=ell_table_61_local99), (V3579, V3579), V13)
Lu = jnp.select((A_norm_1<=ell_table_61_local99), (Lu3579, Lu3579), Lu13)
Lv = jnp.select((A_norm_1<=ell_table_61_local99), (Lv3579, Lv3579), Lv13)
s = jnp.select((A_norm_1<=ell_table_61_local99), (s3579, s3579), s13)
lu_piv = lu_factor(-U + V)
R = lu_solve(lu_piv, U + V)
L = lu_solve(lu_piv, Lu + Lv + jnp.dot((Lu - Lv), R))
# squaring
def my_body_fun(i,my_arg):
R, L = my_arg
L = jnp.dot(R, L) + jnp.dot(L, R)
R = jnp.dot(R, R)
return R, L
lower = jnp.zeros(1, dtype=s.dtype)
R, L = lax.fori_loop(lower[0], s, my_body_fun, (R, L))
return R, L
"""
# The b vectors and U and V are those from
# scipy.sparse.linalg.matfuncs.py.
# M, Lu, Lv follow (6.11), (6.12), (6.13), (3.3)
"""
@jit
def _diff_pade3(args):
A,E,ident,_ = args
s = 0
b = (120., 60., 12., 1.)
A2 = A.dot(A)
M2 = jnp.dot(A, E) + jnp.dot(E, A)
U = A.dot(b[3]*A2 + b[1]*ident)
V = b[2]*A2 + b[0]*ident
Lu = A.dot(b[3]*M2) + E.dot(b[3]*A2 + b[1]*ident)
Lv = b[2]*M2
return U, V, Lu, Lv, s
@jit
def _diff_pade5(args):
A,E,ident,_ = args
s = 0
b = (30240., 15120., 3360., 420., 30., 1.)
A2 = A.dot(A)
M2 = jnp.dot(A, E) + jnp.dot(E, A)
A4 = jnp.dot(A2, A2)
M4 = jnp.dot(A2, M2) + jnp.dot(M2, A2)
U = A.dot(b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[4]*A4 + b[2]*A2 + b[0]*ident
Lu = (A.dot(b[5]*M4 + b[3]*M2) +
E.dot(b[5]*A4 + b[3]*A2 + b[1]*ident))
Lv = b[4]*M4 + b[2]*M2
return U, V, Lu, Lv, s
@jit
def _diff_pade7(args):
A, E, ident, _ = args
s = 0
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
A2 = A.dot(A)
M2 = jnp.dot(A, E) + jnp.dot(E, A)
A4 = jnp.dot(A2, A2)
M4 = jnp.dot(A2, M2) + jnp.dot(M2, A2)
A6 = jnp.dot(A2, A4)
M6 = jnp.dot(A4, M2) + jnp.dot(M4, A2)
U = A.dot(b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
Lu = (A.dot(b[7]*M6 + b[5]*M4 + b[3]*M2) +
E.dot(b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident))
Lv = b[6]*M6 + b[4]*M4 + b[2]*M2
return U, V, Lu, Lv, s
@jit
def _diff_pade9(args):
A,E,ident,_ = args
s = 0
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240., 2162160.,
110880., 3960., 90., 1.)
A2 = A.dot(A)
M2 = jnp.dot(A, E) + jnp.dot(E, A)
A4 = jnp.dot(A2, A2)
M4 = jnp.dot(A2, M2) + jnp.dot(M2, A2)
A6 = jnp.dot(A2, A4)
M6 = jnp.dot(A4, M2) + jnp.dot(M4, A2)
A8 = jnp.dot(A4, A4)
M8 = jnp.dot(A4, M4) + jnp.dot(M4, A4)
U = A.dot(b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
Lu = (A.dot(b[9]*M8 + b[7]*M6 + b[5]*M4 + b[3]*M2) +
E.dot(b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident))
Lv = b[8]*M8 + b[6]*M6 + b[4]*M4 + b[2]*M2
return U, V, Lu, Lv, s
@jit
def _diff_pade13(args):
A,E,ident,A_norm_1 = args
s = jnp.maximum(0, jnp.floor_divide(lax.ceil(jnp.log2(A_norm_1 / ell_table_61[13])), 1))
two = jnp.array([2.0],A.dtype)
A = A * two[0]**-s
E = E * two[0]**-s
# pade order 13
A2 = jnp.dot(A, A)
M2 = jnp.dot(A, E) + jnp.dot(E, A)
A4 = jnp.dot(A2, A2)
M4 = jnp.dot(A2, M2) + jnp.dot(M2, A2)
A6 = jnp.dot(A2, A4)
M6 = jnp.dot(A4, M2) + jnp.dot(M4, A2)
b = (64764752532480000., 32382376266240000., 7771770303897600.,
1187353796428800., 129060195264000., 10559470521600.,
670442572800., 33522128640., 1323241920., 40840800., 960960.,
16380., 182., 1.)
W1 = b[13]*A6 + b[11]*A4 + b[9]*A2
W2 = b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident
Z1 = b[12]*A6 + b[10]*A4 + b[8]*A2
Z2 = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
W = jnp.dot(A6, W1) + W2
U = jnp.dot(A, W)
V = jnp.dot(A6, Z1) + Z2
Lw1 = b[13]*M6 + b[11]*M4 + b[9]*M2
Lw2 = b[7]*M6 + b[5]*M4 + b[3]*M2
Lz1 = b[12]*M6 + b[10]*M4 + b[8]*M2
Lz2 = b[6]*M6 + b[4]*M4 + b[2]*M2
Lw = jnp.dot(A6, Lw1) + jnp.dot(M6, W1) + Lw2
Lu = jnp.dot(A, Lw) + jnp.dot(E, W)
Lv = jnp.dot(A6, Lz1) + jnp.dot(M6, Z1) + Lz2
return U, V, Lu, Lv, s
@_expm.defjvp
def _expm_jvp(upper_triangular, primals, tangents):
matrix, = primals
g, = tangents
return expm_frechet(matrix, g, compute_expm=True)
@_wraps(scipy.linalg.block_diag)

View File

@ -1282,5 +1282,43 @@ class ScipyLinalgTest(jtu.JaxTestCase):
args_maker, tol=1e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
"n": n, "dtype": dtype, "rng_factory": rng_factory}
for n in [1, 4, 5, 20, 50, 100]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_small]))
def testExpmFrechet(self, n, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng((n, n), dtype), rng((n, n), dtype),]
#compute_expm is True
osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=True)
jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=True)
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
check_dtypes=False)
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False)
#compute_expm is False
osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=False)
jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=False)
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
check_dtypes=False)
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
"n": n, "dtype": dtype, "rng_factory": rng_factory}
for n in [1, 4, 5, 20, 50]
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_small]))
def testExpmGrad(self, n, dtype, rng_factory):
rng = rng_factory(self.rng())
_skip_if_unsupported_type(dtype)
a = rng((n, n), dtype)
jtu.check_grads(jsp.linalg.expm, (a,), modes=["fwd"], order=1)
if __name__ == "__main__":
absltest.main()