diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ab334c15..d2a45c377 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `platforms` instead. * Hashing of tracers, which has been deprecated since version 0.4.30, now results in a `TypeError`. + * {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional + inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel` + on the function inputs. * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index d014e5ceb..1c5eba988 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -2004,7 +2004,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: - r"""Construct a Toeplitz matrix + r"""Construct a Toeplitz matrix. JAX implementation of :func:`scipy.linalg.toeplitz`. @@ -2023,13 +2023,13 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: Notice this implies that :math:`r_0` is ignored. Args: - c: array specifying the first column. Will be flattened - if not 1-dimensional. - r: (optional) array specifying the first row. If not specified, defaults - to ``conj(c)``. Will be flattened if not 1-dimensional. + c: array of shape ``(..., N)`` specifying the first column. + r: (optional) array of shape ``(..., M)`` specifying the first row. Leading + dimensions must be broadcast-compatible with those of ``c``. If not specified, + ``r`` defaults to ``conj(c)``. Returns: - toeplitz matrix of shape ``(c.size, r.size)``. + A Toeplitz matrix of shape ``(... N, M)``. Examples: Specifying ``c`` only: @@ -2059,32 +2059,40 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: [1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64) >>> print("M is Hermitian:", jnp.all(M == M.conj().T)) M is Hermitian: True + + For N-dimensional ``c`` and/or ``r``, the result is a batch of Toeplitz matrices: + + >>> c = jnp.array([[1, 2, 3], [4, 5, 6]]) + >>> jax.scipy.linalg.toeplitz(c) + Array([[[1, 2, 3], + [2, 1, 2], + [3, 2, 1]], + + [[4, 5, 6], + [5, 4, 5], + [6, 5, 4]]], dtype=int32) """ if r is None: check_arraylike("toeplitz", c) r = jnp.conjugate(jnp.asarray(c)) else: check_arraylike("toeplitz", c, r) + return _toeplitz(jnp.atleast_1d(jnp.asarray(c)), jnp.atleast_1d(jnp.asarray(r))) - c_arr = jnp.asarray(c).flatten() - r_arr = jnp.asarray(r).flatten() - - ncols, = c_arr.shape - nrows, = r_arr.shape - +@partial(jnp.vectorize, signature="(m),(n)->(m,n)") +def _toeplitz(c: Array, r: Array) -> Array: + ncols, = c.shape + nrows, = r.shape if ncols == 0 or nrows == 0: - return jnp.empty((ncols, nrows), - dtype=jnp.promote_types(c_arr.dtype, r_arr.dtype)) - + return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype)) nelems = ncols + nrows - 1 - elems = jnp.concatenate((c_arr[::-1], r_arr[1:])) + elems = jnp.concatenate((c[::-1], r[1:])) patches = lax.conv_general_dilated_patches( elems.reshape((1, nelems, 1)), (nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'), precision=lax.Precision.HIGHEST)[0] return jnp.flip(patches, axis=0) - @partial(jit, static_argnames=("n",)) def hilbert(n: int) -> Array: r"""Create a Hilbert matrix of order n. diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5ace4b5ec..d3fe8f476 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -53,6 +53,22 @@ def _is_required_cuda_version_satisfied(cuda_version): else: return int(version.split()[-1]) >= cuda_version + +def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray: + """scipy.linalg.toeplitz with v1.17+ batching semantics.""" + if scipy_version >= (1, 17, 0): + return scipy.linalg.toeplitz(c, r) + elif r is None: + c = np.atleast_1d(c) + return np.vectorize( + scipy.linalg.toeplitz, signature="(m)->(m,m)", otypes=(c.dtype,))(c) + else: + c = np.atleast_1d(c) + r = np.atleast_1d(r) + return np.vectorize( + scipy.linalg.toeplitz, signature="(m),(n)->(m,n)", otypes=(np.result_type(c, r),))(c, r) + + class NumpyLinalgTest(jtu.JaxTestCase): @jtu.sample_product( @@ -1990,11 +2006,11 @@ class ScipyLinalgTest(jtu.JaxTestCase): self.assertAllClose(root, expected, check_dtypes=False) @jtu.sample_product( - cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)], + cshape=[(), (4,), (8,), (4, 7), (2, 1, 5)], cdtype=float_types + complex_types, - rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)], + rshape=[(), (3,), (7,), (4, 4), (2, 4, 0)], rdtype=float_types + complex_types + int_types) - def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype): + def testToeplitzConstruction(self, rshape, rdtype, cshape, cdtype): if ((rdtype in [np.float64, np.complex128] or cdtype in [np.float64, np.complex128]) and not config.enable_x64.value): @@ -2007,10 +2023,11 @@ class ScipyLinalgTest(jtu.JaxTestCase): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)] - with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]): - self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz), - jsp.linalg.toeplitz, args_maker) - self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) + with jax.numpy_rank_promotion("allow"): + with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]): + self._CheckAgainstNumpy(jtu.promote_like_jnp(osp_linalg_toeplitz), + jsp.linalg.toeplitz, args_maker) + self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) @jtu.sample_product( shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)], @@ -2028,8 +2045,7 @@ class ScipyLinalgTest(jtu.JaxTestCase): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz), - jsp.linalg.toeplitz, args_maker) + self._CheckAgainstNumpy(osp_linalg_toeplitz, jsp.linalg.toeplitz, args_maker) self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) def testToeplitzConstructionWithKnownCases(self):