jax.scipy.linalg.expm: support batched inputs

This commit is contained in:
Jake VanderPlas 2023-03-27 16:39:48 -07:00
parent 670fba3a91
commit ad0fc8979b
2 changed files with 22 additions and 6 deletions

View File

@ -433,6 +433,16 @@ where norm() denotes the L1 norm, and
@_wraps(scipy.linalg.expm, lax_description=_expm_description)
@partial(jit, static_argnames=('upper_triangular', 'max_squarings'))
def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array:
A, = promote_dtypes_inexact(A)
if A.ndim < 2 or A.shape[-1] != A.shape[-2]:
raise ValueError(f"Expected A to be a (batched) square matrix, got {A.shape=}.")
if A.ndim > 2:
return jnp.vectorize(
partial(expm, upper_triangular=upper_triangular, max_squarings=max_squarings),
signature="(n,n)->(n,n)")(A)
P, Q, n_squarings = _calc_P_Q(A)
def _nan(args):

View File

@ -29,12 +29,15 @@ from jax import jit, grad, jvp, vmap
from jax import lax
from jax import numpy as jnp
from jax import scipy as jsp
from jax._src.numpy.util import promote_dtypes_inexact
from jax._src import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
T = lambda x: np.swapaxes(x, -1, -2)
@ -1240,18 +1243,21 @@ class ScipyLinalgTest(jtu.JaxTestCase):
@jtu.sample_product(
n=[1, 4, 5, 20, 50, 100],
dtype=float_types + complex_types,
batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()],
dtype=int_types + float_types + complex_types
)
def testExpm(self, n, dtype):
def testExpm(self, n, batch_size, dtype):
rng = jtu.rand_small(self.rng())
args_maker = lambda: [rng((n, n), dtype)]
args_maker = lambda: [rng((*batch_size, n, n), dtype)]
osp_fun = lambda a: osp.linalg.expm(a)
jsp_fun = lambda a: jsp.linalg.expm(a)
# Compare to numpy with JAX type promotion semantics.
def osp_fun(A):
return osp.linalg.expm(np.array(*promote_dtypes_inexact(A)))
jsp_fun = jsp.linalg.expm
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker)
self._CompileAndCheck(jsp_fun, args_maker)
args_maker_triu = lambda: [np.triu(rng((n, n), dtype))]
args_maker_triu = lambda: [np.triu(rng((*batch_size, n, n), dtype))]
jsp_fun_triu = lambda a: jsp.linalg.expm(a, upper_triangular=True)
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu)
self._CompileAndCheck(jsp_fun_triu, args_maker_triu)