jax.scipy.linalg.toeplitz: support implicit batching

This commit is contained in:
Jake VanderPlas 2024-11-11 15:32:43 -08:00
parent 6892e628fb
commit 3f98c57f7b
3 changed files with 53 additions and 26 deletions

View File

@ -40,6 +40,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
`platforms` instead. `platforms` instead.
* Hashing of tracers, which has been deprecated since version 0.4.30, now * Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a `TypeError`. results in a `TypeError`.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
* New Features * New Features
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for

View File

@ -2004,7 +2004,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
r"""Construct a Toeplitz matrix r"""Construct a Toeplitz matrix.
JAX implementation of :func:`scipy.linalg.toeplitz`. JAX implementation of :func:`scipy.linalg.toeplitz`.
@ -2023,13 +2023,13 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
Notice this implies that :math:`r_0` is ignored. Notice this implies that :math:`r_0` is ignored.
Args: Args:
c: array specifying the first column. Will be flattened c: array of shape ``(..., N)`` specifying the first column.
if not 1-dimensional. r: (optional) array of shape ``(..., M)`` specifying the first row. Leading
r: (optional) array specifying the first row. If not specified, defaults dimensions must be broadcast-compatible with those of ``c``. If not specified,
to ``conj(c)``. Will be flattened if not 1-dimensional. ``r`` defaults to ``conj(c)``.
Returns: Returns:
toeplitz matrix of shape ``(c.size, r.size)``. A Toeplitz matrix of shape ``(... N, M)``.
Examples: Examples:
Specifying ``c`` only: Specifying ``c`` only:
@ -2059,32 +2059,40 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
[1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64) [1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64)
>>> print("M is Hermitian:", jnp.all(M == M.conj().T)) >>> print("M is Hermitian:", jnp.all(M == M.conj().T))
M is Hermitian: True M is Hermitian: True
For N-dimensional ``c`` and/or ``r``, the result is a batch of Toeplitz matrices:
>>> c = jnp.array([[1, 2, 3], [4, 5, 6]])
>>> jax.scipy.linalg.toeplitz(c)
Array([[[1, 2, 3],
[2, 1, 2],
[3, 2, 1]],
<BLANKLINE>
[[4, 5, 6],
[5, 4, 5],
[6, 5, 4]]], dtype=int32)
""" """
if r is None: if r is None:
check_arraylike("toeplitz", c) check_arraylike("toeplitz", c)
r = jnp.conjugate(jnp.asarray(c)) r = jnp.conjugate(jnp.asarray(c))
else: else:
check_arraylike("toeplitz", c, r) check_arraylike("toeplitz", c, r)
return _toeplitz(jnp.atleast_1d(jnp.asarray(c)), jnp.atleast_1d(jnp.asarray(r)))
c_arr = jnp.asarray(c).flatten() @partial(jnp.vectorize, signature="(m),(n)->(m,n)")
r_arr = jnp.asarray(r).flatten() def _toeplitz(c: Array, r: Array) -> Array:
ncols, = c.shape
ncols, = c_arr.shape nrows, = r.shape
nrows, = r_arr.shape
if ncols == 0 or nrows == 0: if ncols == 0 or nrows == 0:
return jnp.empty((ncols, nrows), return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype))
dtype=jnp.promote_types(c_arr.dtype, r_arr.dtype))
nelems = ncols + nrows - 1 nelems = ncols + nrows - 1
elems = jnp.concatenate((c_arr[::-1], r_arr[1:])) elems = jnp.concatenate((c[::-1], r[1:]))
patches = lax.conv_general_dilated_patches( patches = lax.conv_general_dilated_patches(
elems.reshape((1, nelems, 1)), elems.reshape((1, nelems, 1)),
(nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'), (nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'),
precision=lax.Precision.HIGHEST)[0] precision=lax.Precision.HIGHEST)[0]
return jnp.flip(patches, axis=0) return jnp.flip(patches, axis=0)
@partial(jit, static_argnames=("n",)) @partial(jit, static_argnames=("n",))
def hilbert(n: int) -> Array: def hilbert(n: int) -> Array:
r"""Create a Hilbert matrix of order n. r"""Create a Hilbert matrix of order n.

View File

@ -53,6 +53,22 @@ def _is_required_cuda_version_satisfied(cuda_version):
else: else:
return int(version.split()[-1]) >= cuda_version return int(version.split()[-1]) >= cuda_version
def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray:
"""scipy.linalg.toeplitz with v1.17+ batching semantics."""
if scipy_version >= (1, 17, 0):
return scipy.linalg.toeplitz(c, r)
elif r is None:
c = np.atleast_1d(c)
return np.vectorize(
scipy.linalg.toeplitz, signature="(m)->(m,m)", otypes=(c.dtype,))(c)
else:
c = np.atleast_1d(c)
r = np.atleast_1d(r)
return np.vectorize(
scipy.linalg.toeplitz, signature="(m),(n)->(m,n)", otypes=(np.result_type(c, r),))(c, r)
class NumpyLinalgTest(jtu.JaxTestCase): class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.sample_product( @jtu.sample_product(
@ -1990,11 +2006,11 @@ class ScipyLinalgTest(jtu.JaxTestCase):
self.assertAllClose(root, expected, check_dtypes=False) self.assertAllClose(root, expected, check_dtypes=False)
@jtu.sample_product( @jtu.sample_product(
cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)], cshape=[(), (4,), (8,), (4, 7), (2, 1, 5)],
cdtype=float_types + complex_types, cdtype=float_types + complex_types,
rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)], rshape=[(), (3,), (7,), (4, 4), (2, 4, 0)],
rdtype=float_types + complex_types + int_types) rdtype=float_types + complex_types + int_types)
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype): def testToeplitzConstruction(self, rshape, rdtype, cshape, cdtype):
if ((rdtype in [np.float64, np.complex128] if ((rdtype in [np.float64, np.complex128]
or cdtype in [np.float64, np.complex128]) or cdtype in [np.float64, np.complex128])
and not config.enable_x64.value): and not config.enable_x64.value):
@ -2007,10 +2023,11 @@ class ScipyLinalgTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)] args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)]
with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]): with jax.numpy_rank_promotion("allow"):
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz), with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]):
jsp.linalg.toeplitz, args_maker) self._CheckAgainstNumpy(jtu.promote_like_jnp(osp_linalg_toeplitz),
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) jsp.linalg.toeplitz, args_maker)
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
@jtu.sample_product( @jtu.sample_product(
shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)], shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)],
@ -2028,8 +2045,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)] args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz), self._CheckAgainstNumpy(osp_linalg_toeplitz, jsp.linalg.toeplitz, args_maker)
jsp.linalg.toeplitz, args_maker)
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
def testToeplitzConstructionWithKnownCases(self): def testToeplitzConstructionWithKnownCases(self):