mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.scipy.linalg.toeplitz: support implicit batching
This commit is contained in:
parent
6892e628fb
commit
3f98c57f7b
@ -40,6 +40,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
`platforms` instead.
|
||||
* Hashing of tracers, which has been deprecated since version 0.4.30, now
|
||||
results in a `TypeError`.
|
||||
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
|
||||
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
|
||||
on the function inputs.
|
||||
|
||||
* New Features
|
||||
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
|
||||
|
@ -2004,7 +2004,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
|
||||
|
||||
|
||||
def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
|
||||
r"""Construct a Toeplitz matrix
|
||||
r"""Construct a Toeplitz matrix.
|
||||
|
||||
JAX implementation of :func:`scipy.linalg.toeplitz`.
|
||||
|
||||
@ -2023,13 +2023,13 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
|
||||
Notice this implies that :math:`r_0` is ignored.
|
||||
|
||||
Args:
|
||||
c: array specifying the first column. Will be flattened
|
||||
if not 1-dimensional.
|
||||
r: (optional) array specifying the first row. If not specified, defaults
|
||||
to ``conj(c)``. Will be flattened if not 1-dimensional.
|
||||
c: array of shape ``(..., N)`` specifying the first column.
|
||||
r: (optional) array of shape ``(..., M)`` specifying the first row. Leading
|
||||
dimensions must be broadcast-compatible with those of ``c``. If not specified,
|
||||
``r`` defaults to ``conj(c)``.
|
||||
|
||||
Returns:
|
||||
toeplitz matrix of shape ``(c.size, r.size)``.
|
||||
A Toeplitz matrix of shape ``(... N, M)``.
|
||||
|
||||
Examples:
|
||||
Specifying ``c`` only:
|
||||
@ -2059,32 +2059,40 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
|
||||
[1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64)
|
||||
>>> print("M is Hermitian:", jnp.all(M == M.conj().T))
|
||||
M is Hermitian: True
|
||||
|
||||
For N-dimensional ``c`` and/or ``r``, the result is a batch of Toeplitz matrices:
|
||||
|
||||
>>> c = jnp.array([[1, 2, 3], [4, 5, 6]])
|
||||
>>> jax.scipy.linalg.toeplitz(c)
|
||||
Array([[[1, 2, 3],
|
||||
[2, 1, 2],
|
||||
[3, 2, 1]],
|
||||
<BLANKLINE>
|
||||
[[4, 5, 6],
|
||||
[5, 4, 5],
|
||||
[6, 5, 4]]], dtype=int32)
|
||||
"""
|
||||
if r is None:
|
||||
check_arraylike("toeplitz", c)
|
||||
r = jnp.conjugate(jnp.asarray(c))
|
||||
else:
|
||||
check_arraylike("toeplitz", c, r)
|
||||
return _toeplitz(jnp.atleast_1d(jnp.asarray(c)), jnp.atleast_1d(jnp.asarray(r)))
|
||||
|
||||
c_arr = jnp.asarray(c).flatten()
|
||||
r_arr = jnp.asarray(r).flatten()
|
||||
|
||||
ncols, = c_arr.shape
|
||||
nrows, = r_arr.shape
|
||||
|
||||
@partial(jnp.vectorize, signature="(m),(n)->(m,n)")
|
||||
def _toeplitz(c: Array, r: Array) -> Array:
|
||||
ncols, = c.shape
|
||||
nrows, = r.shape
|
||||
if ncols == 0 or nrows == 0:
|
||||
return jnp.empty((ncols, nrows),
|
||||
dtype=jnp.promote_types(c_arr.dtype, r_arr.dtype))
|
||||
|
||||
return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype))
|
||||
nelems = ncols + nrows - 1
|
||||
elems = jnp.concatenate((c_arr[::-1], r_arr[1:]))
|
||||
elems = jnp.concatenate((c[::-1], r[1:]))
|
||||
patches = lax.conv_general_dilated_patches(
|
||||
elems.reshape((1, nelems, 1)),
|
||||
(nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'),
|
||||
precision=lax.Precision.HIGHEST)[0]
|
||||
return jnp.flip(patches, axis=0)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=("n",))
|
||||
def hilbert(n: int) -> Array:
|
||||
r"""Create a Hilbert matrix of order n.
|
||||
|
@ -53,6 +53,22 @@ def _is_required_cuda_version_satisfied(cuda_version):
|
||||
else:
|
||||
return int(version.split()[-1]) >= cuda_version
|
||||
|
||||
|
||||
def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray:
|
||||
"""scipy.linalg.toeplitz with v1.17+ batching semantics."""
|
||||
if scipy_version >= (1, 17, 0):
|
||||
return scipy.linalg.toeplitz(c, r)
|
||||
elif r is None:
|
||||
c = np.atleast_1d(c)
|
||||
return np.vectorize(
|
||||
scipy.linalg.toeplitz, signature="(m)->(m,m)", otypes=(c.dtype,))(c)
|
||||
else:
|
||||
c = np.atleast_1d(c)
|
||||
r = np.atleast_1d(r)
|
||||
return np.vectorize(
|
||||
scipy.linalg.toeplitz, signature="(m),(n)->(m,n)", otypes=(np.result_type(c, r),))(c, r)
|
||||
|
||||
|
||||
class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -1990,11 +2006,11 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(root, expected, check_dtypes=False)
|
||||
|
||||
@jtu.sample_product(
|
||||
cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)],
|
||||
cshape=[(), (4,), (8,), (4, 7), (2, 1, 5)],
|
||||
cdtype=float_types + complex_types,
|
||||
rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)],
|
||||
rshape=[(), (3,), (7,), (4, 4), (2, 4, 0)],
|
||||
rdtype=float_types + complex_types + int_types)
|
||||
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype):
|
||||
def testToeplitzConstruction(self, rshape, rdtype, cshape, cdtype):
|
||||
if ((rdtype in [np.float64, np.complex128]
|
||||
or cdtype in [np.float64, np.complex128])
|
||||
and not config.enable_x64.value):
|
||||
@ -2007,8 +2023,9 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)]
|
||||
with jax.numpy_rank_promotion("allow"):
|
||||
with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]):
|
||||
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
|
||||
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp_linalg_toeplitz),
|
||||
jsp.linalg.toeplitz, args_maker)
|
||||
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
||||
|
||||
@ -2028,8 +2045,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
|
||||
jsp.linalg.toeplitz, args_maker)
|
||||
self._CheckAgainstNumpy(osp_linalg_toeplitz, jsp.linalg.toeplitz, args_maker)
|
||||
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
||||
|
||||
def testToeplitzConstructionWithKnownCases(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user