mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add jax.scipy.linalg.toeplitz
.
This commit is contained in:
parent
440b25bf5d
commit
1ade5f8592
@ -45,6 +45,7 @@ jax.scipy.linalg
|
||||
solve_triangular
|
||||
sqrtm
|
||||
svd
|
||||
toeplitz
|
||||
tril
|
||||
triu
|
||||
|
||||
|
@ -27,6 +27,7 @@ from jax import lax
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.lax import qdwh
|
||||
from jax._src.numpy.lax_numpy import _check_arraylike
|
||||
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact, _promote_dtypes_complex
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import linalg as np_linalg
|
||||
@ -1031,3 +1032,28 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
|
||||
return h, q
|
||||
else:
|
||||
return h
|
||||
|
||||
@_wraps(scipy.linalg.toeplitz)
|
||||
def toeplitz(c: ArrayLike, r: Optional[ArrayLike] = None) -> Array:
|
||||
if r is None:
|
||||
_check_arraylike("toeplitz", c)
|
||||
r = jnp.conjugate(jnp.asarray(c))
|
||||
else:
|
||||
_check_arraylike("toeplitz", c, r)
|
||||
|
||||
c = jnp.asarray(c).flatten()
|
||||
r = jnp.asarray(r).flatten()
|
||||
|
||||
ncols, = c.shape
|
||||
nrows, = r.shape
|
||||
|
||||
if ncols == 0 or nrows == 0:
|
||||
return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype))
|
||||
|
||||
nelems = ncols + nrows - 1
|
||||
elems = jnp.concatenate((c[::-1], r[1:]))
|
||||
patches = lax.conv_general_dilated_patches(
|
||||
elems.reshape((1, nelems, 1)),
|
||||
(nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'),
|
||||
precision=lax.Precision.HIGHEST)[0]
|
||||
return jnp.flip(patches, axis=0)
|
||||
|
@ -36,6 +36,7 @@ from jax._src.scipy.linalg import (
|
||||
solve as solve,
|
||||
solve_triangular as solve_triangular,
|
||||
svd as svd,
|
||||
toeplitz as toeplitz,
|
||||
tril as tril,
|
||||
triu as triu,
|
||||
)
|
||||
|
@ -42,6 +42,7 @@ T = lambda x: np.swapaxes(x, -1, -2)
|
||||
|
||||
float_types = jtu.dtypes.floating
|
||||
complex_types = jtu.dtypes.complex
|
||||
int_types = jtu.dtypes.all_integer
|
||||
|
||||
|
||||
class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@ -1571,6 +1572,63 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(root, expected, check_dtypes=False)
|
||||
|
||||
@jtu.sample_product(
|
||||
cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)],
|
||||
cdtype=float_types + complex_types,
|
||||
rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)],
|
||||
rdtype=float_types + complex_types + int_types)
|
||||
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype):
|
||||
if ((rdtype in [np.float64, np.complex128]
|
||||
or cdtype in [np.float64, np.complex128])
|
||||
and not config.x64_enabled):
|
||||
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
||||
|
||||
int_types_excl_i8 = set(int_types) - {np.int8}
|
||||
if ((rdtype in int_types_excl_i8 or cdtype in int_types_excl_i8)
|
||||
and jtu.device_under_test() == "gpu"):
|
||||
self.skipTest("Integer (except int8) toeplitz is not supported on GPU yet.")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)]
|
||||
with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]):
|
||||
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
|
||||
jsp.linalg.toeplitz, args_maker)
|
||||
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)],
|
||||
dtype=float_types + complex_types + int_types)
|
||||
def testToeplitzSymmetricConstruction(self, shape, dtype):
|
||||
if (dtype in [np.float64, np.complex128]
|
||||
and not config.x64_enabled):
|
||||
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
||||
|
||||
int_types_excl_i8 = set(int_types) - {np.int8}
|
||||
if (dtype in int_types_excl_i8
|
||||
and jtu.device_under_test() == "gpu"):
|
||||
self.skipTest("Integer (except int8) toeplitz is not supported on GPU yet.")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
|
||||
jsp.linalg.toeplitz, args_maker)
|
||||
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
||||
|
||||
def testToeplitzConstructionWithKnownCases(self):
|
||||
# Test with examples taken from SciPy doc for the corresponding function.
|
||||
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.toeplitz.html
|
||||
ret = jsp.linalg.toeplitz(np.array([1.0, 2+3j, 4-1j]))
|
||||
self.assertAllClose(ret, np.array([
|
||||
[ 1.+0.j, 2.-3.j, 4.+1.j],
|
||||
[ 2.+3.j, 1.+0.j, 2.-3.j],
|
||||
[ 4.-1.j, 2.+3.j, 1.+0.j]]))
|
||||
ret = jsp.linalg.toeplitz(np.array([1, 2, 3], dtype=np.float32),
|
||||
np.array([1, 4, 5, 6], dtype=np.float32))
|
||||
self.assertAllClose(ret, np.array([
|
||||
[1, 4, 5, 6],
|
||||
[2, 1, 4, 5],
|
||||
[3, 2, 1, 4]], dtype=np.float32))
|
||||
|
||||
|
||||
class LaxLinalgTest(jtu.JaxTestCase):
|
||||
"""Tests for lax.linalg primitives."""
|
||||
@ -1698,5 +1756,6 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
||||
Ts, Ss = vmap(lax.linalg.schur)(args)
|
||||
self.assertAllClose(reconstruct(Ss, Ts), args, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user