mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Issue1635 expm frechet (#2062)
* Implement Frechet derivatives for expm. * Update expm to use the current custom gradients API. Make some stylistic fixes. Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
parent
1f2025e12f
commit
7b57dc8c80
@ -16,6 +16,7 @@ jax.scipy.linalg
|
||||
det
|
||||
eigh
|
||||
expm
|
||||
expm_frechet
|
||||
inv
|
||||
lu
|
||||
lu_factor
|
||||
|
@ -19,6 +19,7 @@ import scipy.linalg
|
||||
import textwrap
|
||||
|
||||
from jax import jit, vmap
|
||||
from .. import api
|
||||
from .. import lax
|
||||
from .. import lax_linalg
|
||||
from ..numpy._util import _wraps
|
||||
@ -232,120 +233,338 @@ def tril(m, k=0):
|
||||
def triu(m, k=0):
|
||||
return jnp.triu(m, k)
|
||||
|
||||
@_wraps(scipy.linalg.expm, lax_description=textwrap.dedent("""\
|
||||
In addition to the original NumPy argument(s) listed below,
|
||||
also supports the optional boolean argument ``upper_triangular``
|
||||
to specify whether the ``A`` matrix is upper triangular.
|
||||
"""))
|
||||
def expm(A, *, upper_triangular=False):
|
||||
return _expm(A, upper_triangular)
|
||||
_expm_description = textwrap.dedent("""
|
||||
In addition to the original NumPy argument(s) listed below,
|
||||
also supports the optional boolean argument ``upper_triangular``
|
||||
to specify whether the ``A`` matrix is upper triangular.
|
||||
""")
|
||||
|
||||
def _expm(A, upper_triangular=False):
|
||||
P,Q,n_squarings = _calc_P_Q(A)
|
||||
R = _solve_P_Q(P, Q, upper_triangular)
|
||||
R = _squaring(R, n_squarings)
|
||||
return R
|
||||
@_wraps(scipy.linalg.expm, lax_description=_expm_description)
|
||||
def expm(A, *, upper_triangular=False):
|
||||
return _expm(A, upper_triangular)
|
||||
|
||||
@partial(api.custom_jvp, nondiff_argnums=(1,))
|
||||
def _expm(A, upper_triangular):
|
||||
P, Q, n_squarings = _calc_P_Q(A)
|
||||
R = _solve_P_Q(P, Q, upper_triangular)
|
||||
R = _squaring(R, n_squarings)
|
||||
return R
|
||||
|
||||
@jit
|
||||
def _calc_P_Q(A):
|
||||
A = jnp.asarray(A)
|
||||
if A.ndim != 2 or A.shape[0] != A.shape[1]:
|
||||
raise ValueError('expected A to be a square matrix')
|
||||
A_L1 = np_linalg.norm(A,1)
|
||||
n_squarings = 0
|
||||
if A.dtype == 'float64' or A.dtype == 'complex128':
|
||||
U3,V3 = _pade3(A)
|
||||
U5,V5 = _pade5(A)
|
||||
U7,V7 = _pade7(A)
|
||||
U9,V9 = _pade9(A)
|
||||
maxnorm = 5.371920351148152
|
||||
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
|
||||
A = A / 2**n_squarings
|
||||
U13,V13 = _pade13(A)
|
||||
conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000])
|
||||
U = jnp.select((maxnorm<conds),(U3,U5,U7,U9),U13)
|
||||
V = jnp.select((maxnorm<conds),(V3,V5,V7,V9),V13)
|
||||
elif A.dtype == 'float32' or A.dtype == 'complex64':
|
||||
U3,V3 = _pade3(A)
|
||||
U5,V5 = _pade5(A)
|
||||
maxnorm = 3.925724783138660
|
||||
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
|
||||
A = A / 2**n_squarings
|
||||
U7,V7 = _pade7(A)
|
||||
conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
|
||||
U = jnp.select((maxnorm<conds),(U3,U5),U7)
|
||||
V = jnp.select((maxnorm<conds),(V3,V5),V7)
|
||||
else:
|
||||
raise TypeError("A.dtype={} is not supported.".format(A.dtype))
|
||||
P = U + V # p_m(A) : numerator
|
||||
Q = -U + V # q_m(A) : denominator
|
||||
return P,Q,n_squarings
|
||||
A = jnp.asarray(A)
|
||||
if A.ndim != 2 or A.shape[0] != A.shape[1]:
|
||||
raise ValueError('expected A to be a square matrix')
|
||||
A_L1 = np_linalg.norm(A,1)
|
||||
n_squarings = 0
|
||||
if A.dtype == 'float64' or A.dtype == 'complex128':
|
||||
U3, V3 = _pade3(A)
|
||||
U5, V5 = _pade5(A)
|
||||
U7, V7 = _pade7(A)
|
||||
U9, V9 = _pade9(A)
|
||||
maxnorm = 5.371920351148152
|
||||
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
|
||||
A = A / 2**n_squarings
|
||||
U13, V13 = _pade13(A)
|
||||
conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001,
|
||||
9.504178996162932e-001, 2.097847961257068e+000])
|
||||
U = jnp.select((maxnorm<conds), (U3, U5, U7, U9), U13)
|
||||
V = jnp.select((maxnorm<conds), (V3, V5, V7, V9), V13)
|
||||
elif A.dtype == 'float32' or A.dtype == 'complex64':
|
||||
U3,V3 = _pade3(A)
|
||||
U5,V5 = _pade5(A)
|
||||
maxnorm = 3.925724783138660
|
||||
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
|
||||
A = A / 2**n_squarings
|
||||
U7,V7 = _pade7(A)
|
||||
conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
|
||||
U = jnp.select((maxnorm<conds), (U3, U5), U7)
|
||||
V = jnp.select((maxnorm<conds), (V3, V5), V7)
|
||||
else:
|
||||
raise TypeError("A.dtype={} is not supported.".format(A.dtype))
|
||||
P = U + V # p_m(A) : numerator
|
||||
Q = -U + V # q_m(A) : denominator
|
||||
return P, Q, n_squarings
|
||||
|
||||
def _solve_P_Q(P, Q, upper_triangular=False):
|
||||
if upper_triangular:
|
||||
return solve_triangular(Q, P)
|
||||
else:
|
||||
return np_linalg.solve(Q,P)
|
||||
if upper_triangular:
|
||||
return solve_triangular(Q, P)
|
||||
else:
|
||||
return np_linalg.solve(Q, P)
|
||||
|
||||
@jit
|
||||
def _squaring(R, n_squarings):
|
||||
# squaring step to undo scaling
|
||||
def my_body_fun(i,R):
|
||||
return jnp.dot(R,R)
|
||||
lower = jnp.zeros(1, dtype=n_squarings.dtype)
|
||||
R = lax.fori_loop(lower[0],n_squarings,my_body_fun,R)
|
||||
return R
|
||||
# squaring step to undo scaling
|
||||
def my_body_fun(i,R):
|
||||
return jnp.dot(R,R)
|
||||
lower = jnp.zeros(1, dtype=n_squarings.dtype)
|
||||
R = lax.fori_loop(lower[0], n_squarings, my_body_fun, R)
|
||||
return R
|
||||
|
||||
def _pade3(A):
|
||||
b = (120., 60., 12., 1.)
|
||||
ident = jnp.eye(*A.shape, dtype=A.dtype)
|
||||
A2 = jnp.dot(A,A)
|
||||
U = jnp.dot(A , (b[3]*A2 + b[1]*ident))
|
||||
V = b[2]*A2 + b[0]*ident
|
||||
return U,V
|
||||
b = (120., 60., 12., 1.)
|
||||
ident = jnp.eye(*A.shape, dtype=A.dtype)
|
||||
A2 = jnp.dot(A, A)
|
||||
U = jnp.dot(A, (b[3]*A2 + b[1]*ident))
|
||||
V = b[2]*A2 + b[0]*ident
|
||||
return U, V
|
||||
|
||||
def _pade5(A):
|
||||
b = (30240., 15120., 3360., 420., 30., 1.)
|
||||
ident = jnp.eye(*A.shape, dtype=A.dtype)
|
||||
A2 = jnp.dot(A,A)
|
||||
A4 = jnp.dot(A2,A2)
|
||||
U = jnp.dot(A, b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U,V
|
||||
b = (30240., 15120., 3360., 420., 30., 1.)
|
||||
ident = jnp.eye(*A.shape, dtype=A.dtype)
|
||||
A2 = jnp.dot(A, A)
|
||||
A4 = jnp.dot(A2, A2)
|
||||
U = jnp.dot(A, b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U, V
|
||||
|
||||
def _pade7(A):
|
||||
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
|
||||
ident = jnp.eye(*A.shape, dtype=A.dtype)
|
||||
A2 = jnp.dot(A,A)
|
||||
A4 = jnp.dot(A2,A2)
|
||||
A6 = jnp.dot(A4,A2)
|
||||
U = jnp.dot(A, b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U,V
|
||||
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
|
||||
ident = jnp.eye(*A.shape, dtype=A.dtype)
|
||||
A2 = jnp.dot(A, A)
|
||||
A4 = jnp.dot(A2, A2)
|
||||
A6 = jnp.dot(A4, A2)
|
||||
U = jnp.dot(A, b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U,V
|
||||
|
||||
def _pade9(A):
|
||||
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
|
||||
2162160., 110880., 3960., 90., 1.)
|
||||
ident = jnp.eye(*A.shape, dtype=A.dtype)
|
||||
A2 = jnp.dot(A,A)
|
||||
A4 = jnp.dot(A2,A2)
|
||||
A6 = jnp.dot(A4,A2)
|
||||
A8 = jnp.dot(A6,A2)
|
||||
U = jnp.dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U,V
|
||||
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
|
||||
2162160., 110880., 3960., 90., 1.)
|
||||
ident = jnp.eye(*A.shape, dtype=A.dtype)
|
||||
A2 = jnp.dot(A, A)
|
||||
A4 = jnp.dot(A2, A2)
|
||||
A6 = jnp.dot(A4, A2)
|
||||
A8 = jnp.dot(A6, A2)
|
||||
U = jnp.dot(A, b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U,V
|
||||
|
||||
def _pade13(A):
|
||||
b = (64764752532480000., 32382376266240000., 7771770303897600.,
|
||||
1187353796428800., 129060195264000., 10559470521600., 670442572800.,
|
||||
33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
|
||||
ident = jnp.eye(*A.shape, dtype=A.dtype)
|
||||
A2 = jnp.dot(A,A)
|
||||
A4 = jnp.dot(A2,A2)
|
||||
A6 = jnp.dot(A4,A2)
|
||||
U = jnp.dot(A,jnp.dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = jnp.dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U,V
|
||||
b = (64764752532480000., 32382376266240000., 7771770303897600.,
|
||||
1187353796428800., 129060195264000., 10559470521600., 670442572800.,
|
||||
33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.)
|
||||
ident = jnp.eye(*A.shape, dtype=A.dtype)
|
||||
A2 = jnp.dot(A, A)
|
||||
A4 = jnp.dot(A2, A2)
|
||||
A6 = jnp.dot(A4, A2)
|
||||
U = jnp.dot(A, jnp.dot(A6, b[13]*A6 + b[11]*A4 + b[9]*A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = jnp.dot(A6, b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
return U,V
|
||||
|
||||
|
||||
_expm_frechet_description = textwrap.dedent("""
|
||||
Does not currently support the Scipy argument ``jax.numpy.asarray_chkfinite``,
|
||||
because `jax.numpy.asarray_chkfinite` does not exist at the moment. Does not
|
||||
support the ``method='blockEnlarge'`` argument.
|
||||
""")
|
||||
|
||||
@_wraps(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description)
|
||||
def expm_frechet(A, E, *, method=None, compute_expm=True):
|
||||
return _expm_frechet(A, E, method, compute_expm)
|
||||
|
||||
def _expm_frechet(A, E, method=None, compute_expm=True):
|
||||
A = jnp.asarray(A)
|
||||
E = jnp.asarray(E)
|
||||
if A.ndim != 2 or A.shape[0] != A.shape[1]:
|
||||
raise ValueError('expected A to be a square matrix')
|
||||
if E.ndim != 2 or E.shape[0] != E.shape[1]:
|
||||
raise ValueError('expected E to be a square matrix')
|
||||
if A.shape != E.shape:
|
||||
raise ValueError('expected A and E to be the same shape')
|
||||
if method is None:
|
||||
method = 'SPS'
|
||||
if method == 'SPS':
|
||||
expm_A, expm_frechet_AE = expm_frechet_algo_64(A, E)
|
||||
else:
|
||||
raise ValueError('only method=\'SPS\' is supported')
|
||||
expm_A, expm_frechet_AE = expm_frechet_algo_64(A, E)
|
||||
if compute_expm:
|
||||
return expm_A, expm_frechet_AE
|
||||
else:
|
||||
return expm_frechet_AE
|
||||
|
||||
"""
|
||||
Maximal values ell_m of ||2**-s A|| such that the backward error bound
|
||||
does not exceed 2**-53.
|
||||
"""
|
||||
ell_table_61 = (
|
||||
None,
|
||||
# 1
|
||||
2.11e-8,
|
||||
3.56e-4,
|
||||
1.08e-2,
|
||||
6.49e-2,
|
||||
2.00e-1,
|
||||
4.37e-1,
|
||||
7.83e-1,
|
||||
1.23e0,
|
||||
1.78e0,
|
||||
2.42e0,
|
||||
# 11
|
||||
3.13e0,
|
||||
3.90e0,
|
||||
4.74e0,
|
||||
5.63e0,
|
||||
6.56e0,
|
||||
7.52e0,
|
||||
8.53e0,
|
||||
9.56e0,
|
||||
1.06e1,
|
||||
1.17e1,
|
||||
)
|
||||
|
||||
@jit
|
||||
def expm_frechet_algo_64(A, E):
|
||||
ident = jnp.eye(*A.shape, dtype=A.dtype)
|
||||
A_norm_1 = np_linalg.norm(A, 1)
|
||||
"""
|
||||
Subset of the Maximal values ell_m of ||2**-s A||
|
||||
such that the backward error bound does not exceed 2**-53.
|
||||
"""
|
||||
args = (A, E, ident, A_norm_1)
|
||||
U3579, V3579, Lu3579, Lv3579, s3579 = lax.cond(A_norm_1 <= ell_table_61[3],
|
||||
args, lambda args: _diff_pade3(args),
|
||||
args, lambda args: lax.cond(A_norm_1 <= ell_table_61[5],
|
||||
args, lambda args: _diff_pade5(args),
|
||||
args, lambda args: lax.cond(A_norm_1 <= ell_table_61[7],
|
||||
args, lambda args: _diff_pade7(args),
|
||||
args, lambda args: _diff_pade9(args))))
|
||||
U13, V13, Lu13, Lv13, s13 = _diff_pade13(args)
|
||||
|
||||
# Must be of minimum length 2 for np.select to be used
|
||||
ell_table_61_local99 = jnp.array([ell_table_61[9], ell_table_61[9]])
|
||||
U = jnp.select((A_norm_1<=ell_table_61_local99), (U3579, U3579), U13)
|
||||
V = jnp.select((A_norm_1<=ell_table_61_local99), (V3579, V3579), V13)
|
||||
Lu = jnp.select((A_norm_1<=ell_table_61_local99), (Lu3579, Lu3579), Lu13)
|
||||
Lv = jnp.select((A_norm_1<=ell_table_61_local99), (Lv3579, Lv3579), Lv13)
|
||||
s = jnp.select((A_norm_1<=ell_table_61_local99), (s3579, s3579), s13)
|
||||
|
||||
lu_piv = lu_factor(-U + V)
|
||||
R = lu_solve(lu_piv, U + V)
|
||||
L = lu_solve(lu_piv, Lu + Lv + jnp.dot((Lu - Lv), R))
|
||||
# squaring
|
||||
def my_body_fun(i,my_arg):
|
||||
R, L = my_arg
|
||||
L = jnp.dot(R, L) + jnp.dot(L, R)
|
||||
R = jnp.dot(R, R)
|
||||
return R, L
|
||||
lower = jnp.zeros(1, dtype=s.dtype)
|
||||
R, L = lax.fori_loop(lower[0], s, my_body_fun, (R, L))
|
||||
return R, L
|
||||
|
||||
"""
|
||||
# The b vectors and U and V are those from
|
||||
# scipy.sparse.linalg.matfuncs.py.
|
||||
# M, Lu, Lv follow (6.11), (6.12), (6.13), (3.3)
|
||||
"""
|
||||
@jit
|
||||
def _diff_pade3(args):
|
||||
A,E,ident,_ = args
|
||||
s = 0
|
||||
b = (120., 60., 12., 1.)
|
||||
A2 = A.dot(A)
|
||||
M2 = jnp.dot(A, E) + jnp.dot(E, A)
|
||||
U = A.dot(b[3]*A2 + b[1]*ident)
|
||||
V = b[2]*A2 + b[0]*ident
|
||||
Lu = A.dot(b[3]*M2) + E.dot(b[3]*A2 + b[1]*ident)
|
||||
Lv = b[2]*M2
|
||||
return U, V, Lu, Lv, s
|
||||
|
||||
@jit
|
||||
def _diff_pade5(args):
|
||||
A,E,ident,_ = args
|
||||
s = 0
|
||||
b = (30240., 15120., 3360., 420., 30., 1.)
|
||||
A2 = A.dot(A)
|
||||
M2 = jnp.dot(A, E) + jnp.dot(E, A)
|
||||
A4 = jnp.dot(A2, A2)
|
||||
M4 = jnp.dot(A2, M2) + jnp.dot(M2, A2)
|
||||
U = A.dot(b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
Lu = (A.dot(b[5]*M4 + b[3]*M2) +
|
||||
E.dot(b[5]*A4 + b[3]*A2 + b[1]*ident))
|
||||
Lv = b[4]*M4 + b[2]*M2
|
||||
return U, V, Lu, Lv, s
|
||||
|
||||
@jit
|
||||
def _diff_pade7(args):
|
||||
A, E, ident, _ = args
|
||||
s = 0
|
||||
b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
|
||||
A2 = A.dot(A)
|
||||
M2 = jnp.dot(A, E) + jnp.dot(E, A)
|
||||
A4 = jnp.dot(A2, A2)
|
||||
M4 = jnp.dot(A2, M2) + jnp.dot(M2, A2)
|
||||
A6 = jnp.dot(A2, A4)
|
||||
M6 = jnp.dot(A4, M2) + jnp.dot(M4, A2)
|
||||
U = A.dot(b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
Lu = (A.dot(b[7]*M6 + b[5]*M4 + b[3]*M2) +
|
||||
E.dot(b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident))
|
||||
Lv = b[6]*M6 + b[4]*M4 + b[2]*M2
|
||||
return U, V, Lu, Lv, s
|
||||
|
||||
@jit
|
||||
def _diff_pade9(args):
|
||||
A,E,ident,_ = args
|
||||
s = 0
|
||||
b = (17643225600., 8821612800., 2075673600., 302702400., 30270240., 2162160.,
|
||||
110880., 3960., 90., 1.)
|
||||
A2 = A.dot(A)
|
||||
M2 = jnp.dot(A, E) + jnp.dot(E, A)
|
||||
A4 = jnp.dot(A2, A2)
|
||||
M4 = jnp.dot(A2, M2) + jnp.dot(M2, A2)
|
||||
A6 = jnp.dot(A2, A4)
|
||||
M6 = jnp.dot(A4, M2) + jnp.dot(M4, A2)
|
||||
A8 = jnp.dot(A4, A4)
|
||||
M8 = jnp.dot(A4, M4) + jnp.dot(M4, A4)
|
||||
U = A.dot(b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident)
|
||||
V = b[8]*A8 + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
Lu = (A.dot(b[9]*M8 + b[7]*M6 + b[5]*M4 + b[3]*M2) +
|
||||
E.dot(b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident))
|
||||
Lv = b[8]*M8 + b[6]*M6 + b[4]*M4 + b[2]*M2
|
||||
return U, V, Lu, Lv, s
|
||||
|
||||
@jit
|
||||
def _diff_pade13(args):
|
||||
A,E,ident,A_norm_1 = args
|
||||
s = jnp.maximum(0, jnp.floor_divide(lax.ceil(jnp.log2(A_norm_1 / ell_table_61[13])), 1))
|
||||
two = jnp.array([2.0],A.dtype)
|
||||
A = A * two[0]**-s
|
||||
E = E * two[0]**-s
|
||||
# pade order 13
|
||||
A2 = jnp.dot(A, A)
|
||||
M2 = jnp.dot(A, E) + jnp.dot(E, A)
|
||||
A4 = jnp.dot(A2, A2)
|
||||
M4 = jnp.dot(A2, M2) + jnp.dot(M2, A2)
|
||||
A6 = jnp.dot(A2, A4)
|
||||
M6 = jnp.dot(A4, M2) + jnp.dot(M4, A2)
|
||||
b = (64764752532480000., 32382376266240000., 7771770303897600.,
|
||||
1187353796428800., 129060195264000., 10559470521600.,
|
||||
670442572800., 33522128640., 1323241920., 40840800., 960960.,
|
||||
16380., 182., 1.)
|
||||
W1 = b[13]*A6 + b[11]*A4 + b[9]*A2
|
||||
W2 = b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*ident
|
||||
Z1 = b[12]*A6 + b[10]*A4 + b[8]*A2
|
||||
Z2 = b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*ident
|
||||
W = jnp.dot(A6, W1) + W2
|
||||
U = jnp.dot(A, W)
|
||||
V = jnp.dot(A6, Z1) + Z2
|
||||
Lw1 = b[13]*M6 + b[11]*M4 + b[9]*M2
|
||||
Lw2 = b[7]*M6 + b[5]*M4 + b[3]*M2
|
||||
Lz1 = b[12]*M6 + b[10]*M4 + b[8]*M2
|
||||
Lz2 = b[6]*M6 + b[4]*M4 + b[2]*M2
|
||||
Lw = jnp.dot(A6, Lw1) + jnp.dot(M6, W1) + Lw2
|
||||
Lu = jnp.dot(A, Lw) + jnp.dot(E, W)
|
||||
Lv = jnp.dot(A6, Lz1) + jnp.dot(M6, Z1) + Lz2
|
||||
return U, V, Lu, Lv, s
|
||||
|
||||
@_expm.defjvp
|
||||
def _expm_jvp(upper_triangular, primals, tangents):
|
||||
matrix, = primals
|
||||
g, = tangents
|
||||
return expm_frechet(matrix, g, compute_expm=True)
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.block_diag)
|
||||
|
@ -1282,5 +1282,43 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
args_maker, tol=1e-3)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
|
||||
"n": n, "dtype": dtype, "rng_factory": rng_factory}
|
||||
for n in [1, 4, 5, 20, 50, 100]
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testExpmFrechet(self, n, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng((n, n), dtype), rng((n, n), dtype),]
|
||||
|
||||
#compute_expm is True
|
||||
osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=True)
|
||||
jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=True)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False)
|
||||
#compute_expm is False
|
||||
osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=False)
|
||||
jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=False)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
|
||||
"n": n, "dtype": dtype, "rng_factory": rng_factory}
|
||||
for n in [1, 4, 5, 20, 50]
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testExpmGrad(self, n, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
_skip_if_unsupported_type(dtype)
|
||||
a = rng((n, n), dtype)
|
||||
jtu.check_grads(jsp.linalg.expm, (a,), modes=["fwd"], order=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user