jax.scipy.linalg.toeplitz: support implicit batching

This commit is contained in:
Jake VanderPlas 2024-11-11 15:32:43 -08:00
parent 6892e628fb
commit 3f98c57f7b
3 changed files with 53 additions and 26 deletions

View File

@ -40,6 +40,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
`platforms` instead.
* Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a `TypeError`.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
* New Features
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for

View File

@ -2004,7 +2004,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
r"""Construct a Toeplitz matrix
r"""Construct a Toeplitz matrix.
JAX implementation of :func:`scipy.linalg.toeplitz`.
@ -2023,13 +2023,13 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
Notice this implies that :math:`r_0` is ignored.
Args:
c: array specifying the first column. Will be flattened
if not 1-dimensional.
r: (optional) array specifying the first row. If not specified, defaults
to ``conj(c)``. Will be flattened if not 1-dimensional.
c: array of shape ``(..., N)`` specifying the first column.
r: (optional) array of shape ``(..., M)`` specifying the first row. Leading
dimensions must be broadcast-compatible with those of ``c``. If not specified,
``r`` defaults to ``conj(c)``.
Returns:
toeplitz matrix of shape ``(c.size, r.size)``.
A Toeplitz matrix of shape ``(... N, M)``.
Examples:
Specifying ``c`` only:
@ -2059,32 +2059,40 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
[1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64)
>>> print("M is Hermitian:", jnp.all(M == M.conj().T))
M is Hermitian: True
For N-dimensional ``c`` and/or ``r``, the result is a batch of Toeplitz matrices:
>>> c = jnp.array([[1, 2, 3], [4, 5, 6]])
>>> jax.scipy.linalg.toeplitz(c)
Array([[[1, 2, 3],
[2, 1, 2],
[3, 2, 1]],
<BLANKLINE>
[[4, 5, 6],
[5, 4, 5],
[6, 5, 4]]], dtype=int32)
"""
if r is None:
check_arraylike("toeplitz", c)
r = jnp.conjugate(jnp.asarray(c))
else:
check_arraylike("toeplitz", c, r)
return _toeplitz(jnp.atleast_1d(jnp.asarray(c)), jnp.atleast_1d(jnp.asarray(r)))
c_arr = jnp.asarray(c).flatten()
r_arr = jnp.asarray(r).flatten()
ncols, = c_arr.shape
nrows, = r_arr.shape
@partial(jnp.vectorize, signature="(m),(n)->(m,n)")
def _toeplitz(c: Array, r: Array) -> Array:
ncols, = c.shape
nrows, = r.shape
if ncols == 0 or nrows == 0:
return jnp.empty((ncols, nrows),
dtype=jnp.promote_types(c_arr.dtype, r_arr.dtype))
return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype))
nelems = ncols + nrows - 1
elems = jnp.concatenate((c_arr[::-1], r_arr[1:]))
elems = jnp.concatenate((c[::-1], r[1:]))
patches = lax.conv_general_dilated_patches(
elems.reshape((1, nelems, 1)),
(nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'),
precision=lax.Precision.HIGHEST)[0]
return jnp.flip(patches, axis=0)
@partial(jit, static_argnames=("n",))
def hilbert(n: int) -> Array:
r"""Create a Hilbert matrix of order n.

View File

@ -53,6 +53,22 @@ def _is_required_cuda_version_satisfied(cuda_version):
else:
return int(version.split()[-1]) >= cuda_version
def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray:
"""scipy.linalg.toeplitz with v1.17+ batching semantics."""
if scipy_version >= (1, 17, 0):
return scipy.linalg.toeplitz(c, r)
elif r is None:
c = np.atleast_1d(c)
return np.vectorize(
scipy.linalg.toeplitz, signature="(m)->(m,m)", otypes=(c.dtype,))(c)
else:
c = np.atleast_1d(c)
r = np.atleast_1d(r)
return np.vectorize(
scipy.linalg.toeplitz, signature="(m),(n)->(m,n)", otypes=(np.result_type(c, r),))(c, r)
class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.sample_product(
@ -1990,11 +2006,11 @@ class ScipyLinalgTest(jtu.JaxTestCase):
self.assertAllClose(root, expected, check_dtypes=False)
@jtu.sample_product(
cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)],
cshape=[(), (4,), (8,), (4, 7), (2, 1, 5)],
cdtype=float_types + complex_types,
rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)],
rshape=[(), (3,), (7,), (4, 4), (2, 4, 0)],
rdtype=float_types + complex_types + int_types)
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype):
def testToeplitzConstruction(self, rshape, rdtype, cshape, cdtype):
if ((rdtype in [np.float64, np.complex128]
or cdtype in [np.float64, np.complex128])
and not config.enable_x64.value):
@ -2007,10 +2023,11 @@ class ScipyLinalgTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)]
with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]):
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
jsp.linalg.toeplitz, args_maker)
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
with jax.numpy_rank_promotion("allow"):
with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]):
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp_linalg_toeplitz),
jsp.linalg.toeplitz, args_maker)
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
@jtu.sample_product(
shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)],
@ -2028,8 +2045,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
jsp.linalg.toeplitz, args_maker)
self._CheckAgainstNumpy(osp_linalg_toeplitz, jsp.linalg.toeplitz, args_maker)
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
def testToeplitzConstructionWithKnownCases(self):