Add jax.scipy.linalg.toeplitz.

This commit is contained in:
Yotaro Kubo 2022-11-15 18:40:52 +09:00
parent 440b25bf5d
commit 1ade5f8592
4 changed files with 87 additions and 0 deletions

View File

@ -45,6 +45,7 @@ jax.scipy.linalg
solve_triangular
sqrtm
svd
toeplitz
tril
triu

View File

@ -27,6 +27,7 @@ from jax import lax
from jax._src import dtypes
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import qdwh
from jax._src.numpy.lax_numpy import _check_arraylike
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _promote_dtypes_complex
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg as np_linalg
@ -1031,3 +1032,28 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
return h, q
else:
return h
@_wraps(scipy.linalg.toeplitz)
def toeplitz(c: ArrayLike, r: Optional[ArrayLike] = None) -> Array:
if r is None:
_check_arraylike("toeplitz", c)
r = jnp.conjugate(jnp.asarray(c))
else:
_check_arraylike("toeplitz", c, r)
c = jnp.asarray(c).flatten()
r = jnp.asarray(r).flatten()
ncols, = c.shape
nrows, = r.shape
if ncols == 0 or nrows == 0:
return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype))
nelems = ncols + nrows - 1
elems = jnp.concatenate((c[::-1], r[1:]))
patches = lax.conv_general_dilated_patches(
elems.reshape((1, nelems, 1)),
(nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'),
precision=lax.Precision.HIGHEST)[0]
return jnp.flip(patches, axis=0)

View File

@ -36,6 +36,7 @@ from jax._src.scipy.linalg import (
solve as solve,
solve_triangular as solve_triangular,
svd as svd,
toeplitz as toeplitz,
tril as tril,
triu as triu,
)

View File

@ -42,6 +42,7 @@ T = lambda x: np.swapaxes(x, -1, -2)
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
int_types = jtu.dtypes.all_integer
class NumpyLinalgTest(jtu.JaxTestCase):
@ -1571,6 +1572,63 @@ class ScipyLinalgTest(jtu.JaxTestCase):
self.assertAllClose(root, expected, check_dtypes=False)
@jtu.sample_product(
cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)],
cdtype=float_types + complex_types,
rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)],
rdtype=float_types + complex_types + int_types)
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype):
if ((rdtype in [np.float64, np.complex128]
or cdtype in [np.float64, np.complex128])
and not config.x64_enabled):
self.skipTest("Only run float64 testcase when float64 is enabled.")
int_types_excl_i8 = set(int_types) - {np.int8}
if ((rdtype in int_types_excl_i8 or cdtype in int_types_excl_i8)
and jtu.device_under_test() == "gpu"):
self.skipTest("Integer (except int8) toeplitz is not supported on GPU yet.")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)]
with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]):
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
jsp.linalg.toeplitz, args_maker)
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
@jtu.sample_product(
shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)],
dtype=float_types + complex_types + int_types)
def testToeplitzSymmetricConstruction(self, shape, dtype):
if (dtype in [np.float64, np.complex128]
and not config.x64_enabled):
self.skipTest("Only run float64 testcase when float64 is enabled.")
int_types_excl_i8 = set(int_types) - {np.int8}
if (dtype in int_types_excl_i8
and jtu.device_under_test() == "gpu"):
self.skipTest("Integer (except int8) toeplitz is not supported on GPU yet.")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
jsp.linalg.toeplitz, args_maker)
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
def testToeplitzConstructionWithKnownCases(self):
# Test with examples taken from SciPy doc for the corresponding function.
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.toeplitz.html
ret = jsp.linalg.toeplitz(np.array([1.0, 2+3j, 4-1j]))
self.assertAllClose(ret, np.array([
[ 1.+0.j, 2.-3.j, 4.+1.j],
[ 2.+3.j, 1.+0.j, 2.-3.j],
[ 4.-1.j, 2.+3.j, 1.+0.j]]))
ret = jsp.linalg.toeplitz(np.array([1, 2, 3], dtype=np.float32),
np.array([1, 4, 5, 6], dtype=np.float32))
self.assertAllClose(ret, np.array([
[1, 4, 5, 6],
[2, 1, 4, 5],
[3, 2, 1, 4]], dtype=np.float32))
class LaxLinalgTest(jtu.JaxTestCase):
"""Tests for lax.linalg primitives."""
@ -1698,5 +1756,6 @@ class LaxLinalgTest(jtu.JaxTestCase):
Ts, Ss = vmap(lax.linalg.schur)(args)
self.assertAllClose(reconstruct(Ss, Ts), args, atol=1e-4)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())