mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
jax.scipy.linalg.expm: support batched inputs
This commit is contained in:
parent
670fba3a91
commit
ad0fc8979b
@ -433,6 +433,16 @@ where norm() denotes the L1 norm, and
|
||||
@_wraps(scipy.linalg.expm, lax_description=_expm_description)
|
||||
@partial(jit, static_argnames=('upper_triangular', 'max_squarings'))
|
||||
def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array:
|
||||
A, = promote_dtypes_inexact(A)
|
||||
|
||||
if A.ndim < 2 or A.shape[-1] != A.shape[-2]:
|
||||
raise ValueError(f"Expected A to be a (batched) square matrix, got {A.shape=}.")
|
||||
|
||||
if A.ndim > 2:
|
||||
return jnp.vectorize(
|
||||
partial(expm, upper_triangular=upper_triangular, max_squarings=max_squarings),
|
||||
signature="(n,n)->(n,n)")(A)
|
||||
|
||||
P, Q, n_squarings = _calc_P_Q(A)
|
||||
|
||||
def _nan(args):
|
||||
|
@ -29,12 +29,15 @@ from jax import jit, grad, jvp, vmap
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import scipy as jsp
|
||||
from jax._src.numpy.util import promote_dtypes_inexact
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
|
||||
|
||||
T = lambda x: np.swapaxes(x, -1, -2)
|
||||
|
||||
|
||||
@ -1240,18 +1243,21 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
n=[1, 4, 5, 20, 50, 100],
|
||||
dtype=float_types + complex_types,
|
||||
batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()],
|
||||
dtype=int_types + float_types + complex_types
|
||||
)
|
||||
def testExpm(self, n, dtype):
|
||||
def testExpm(self, n, batch_size, dtype):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
args_maker = lambda: [rng((*batch_size, n, n), dtype)]
|
||||
|
||||
osp_fun = lambda a: osp.linalg.expm(a)
|
||||
jsp_fun = lambda a: jsp.linalg.expm(a)
|
||||
# Compare to numpy with JAX type promotion semantics.
|
||||
def osp_fun(A):
|
||||
return osp.linalg.expm(np.array(*promote_dtypes_inexact(A)))
|
||||
jsp_fun = jsp.linalg.expm
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker)
|
||||
self._CompileAndCheck(jsp_fun, args_maker)
|
||||
|
||||
args_maker_triu = lambda: [np.triu(rng((n, n), dtype))]
|
||||
args_maker_triu = lambda: [np.triu(rng((*batch_size, n, n), dtype))]
|
||||
jsp_fun_triu = lambda a: jsp.linalg.expm(a, upper_triangular=True)
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun_triu, args_maker_triu)
|
||||
self._CompileAndCheck(jsp_fun_triu, args_maker_triu)
|
||||
|
Loading…
x
Reference in New Issue
Block a user