mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.scipy.linalg.toeplitz: support implicit batching
This commit is contained in:
parent
6892e628fb
commit
3f98c57f7b
@ -40,6 +40,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
`platforms` instead.
|
`platforms` instead.
|
||||||
* Hashing of tracers, which has been deprecated since version 0.4.30, now
|
* Hashing of tracers, which has been deprecated since version 0.4.30, now
|
||||||
results in a `TypeError`.
|
results in a `TypeError`.
|
||||||
|
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
|
||||||
|
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
|
||||||
|
on the function inputs.
|
||||||
|
|
||||||
* New Features
|
* New Features
|
||||||
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
|
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
|
||||||
|
@ -2004,7 +2004,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
|
|||||||
|
|
||||||
|
|
||||||
def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
|
def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
|
||||||
r"""Construct a Toeplitz matrix
|
r"""Construct a Toeplitz matrix.
|
||||||
|
|
||||||
JAX implementation of :func:`scipy.linalg.toeplitz`.
|
JAX implementation of :func:`scipy.linalg.toeplitz`.
|
||||||
|
|
||||||
@ -2023,13 +2023,13 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
|
|||||||
Notice this implies that :math:`r_0` is ignored.
|
Notice this implies that :math:`r_0` is ignored.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
c: array specifying the first column. Will be flattened
|
c: array of shape ``(..., N)`` specifying the first column.
|
||||||
if not 1-dimensional.
|
r: (optional) array of shape ``(..., M)`` specifying the first row. Leading
|
||||||
r: (optional) array specifying the first row. If not specified, defaults
|
dimensions must be broadcast-compatible with those of ``c``. If not specified,
|
||||||
to ``conj(c)``. Will be flattened if not 1-dimensional.
|
``r`` defaults to ``conj(c)``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
toeplitz matrix of shape ``(c.size, r.size)``.
|
A Toeplitz matrix of shape ``(... N, M)``.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
Specifying ``c`` only:
|
Specifying ``c`` only:
|
||||||
@ -2059,32 +2059,40 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array:
|
|||||||
[1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64)
|
[1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64)
|
||||||
>>> print("M is Hermitian:", jnp.all(M == M.conj().T))
|
>>> print("M is Hermitian:", jnp.all(M == M.conj().T))
|
||||||
M is Hermitian: True
|
M is Hermitian: True
|
||||||
|
|
||||||
|
For N-dimensional ``c`` and/or ``r``, the result is a batch of Toeplitz matrices:
|
||||||
|
|
||||||
|
>>> c = jnp.array([[1, 2, 3], [4, 5, 6]])
|
||||||
|
>>> jax.scipy.linalg.toeplitz(c)
|
||||||
|
Array([[[1, 2, 3],
|
||||||
|
[2, 1, 2],
|
||||||
|
[3, 2, 1]],
|
||||||
|
<BLANKLINE>
|
||||||
|
[[4, 5, 6],
|
||||||
|
[5, 4, 5],
|
||||||
|
[6, 5, 4]]], dtype=int32)
|
||||||
"""
|
"""
|
||||||
if r is None:
|
if r is None:
|
||||||
check_arraylike("toeplitz", c)
|
check_arraylike("toeplitz", c)
|
||||||
r = jnp.conjugate(jnp.asarray(c))
|
r = jnp.conjugate(jnp.asarray(c))
|
||||||
else:
|
else:
|
||||||
check_arraylike("toeplitz", c, r)
|
check_arraylike("toeplitz", c, r)
|
||||||
|
return _toeplitz(jnp.atleast_1d(jnp.asarray(c)), jnp.atleast_1d(jnp.asarray(r)))
|
||||||
|
|
||||||
c_arr = jnp.asarray(c).flatten()
|
@partial(jnp.vectorize, signature="(m),(n)->(m,n)")
|
||||||
r_arr = jnp.asarray(r).flatten()
|
def _toeplitz(c: Array, r: Array) -> Array:
|
||||||
|
ncols, = c.shape
|
||||||
ncols, = c_arr.shape
|
nrows, = r.shape
|
||||||
nrows, = r_arr.shape
|
|
||||||
|
|
||||||
if ncols == 0 or nrows == 0:
|
if ncols == 0 or nrows == 0:
|
||||||
return jnp.empty((ncols, nrows),
|
return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype))
|
||||||
dtype=jnp.promote_types(c_arr.dtype, r_arr.dtype))
|
|
||||||
|
|
||||||
nelems = ncols + nrows - 1
|
nelems = ncols + nrows - 1
|
||||||
elems = jnp.concatenate((c_arr[::-1], r_arr[1:]))
|
elems = jnp.concatenate((c[::-1], r[1:]))
|
||||||
patches = lax.conv_general_dilated_patches(
|
patches = lax.conv_general_dilated_patches(
|
||||||
elems.reshape((1, nelems, 1)),
|
elems.reshape((1, nelems, 1)),
|
||||||
(nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'),
|
(nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'),
|
||||||
precision=lax.Precision.HIGHEST)[0]
|
precision=lax.Precision.HIGHEST)[0]
|
||||||
return jnp.flip(patches, axis=0)
|
return jnp.flip(patches, axis=0)
|
||||||
|
|
||||||
|
|
||||||
@partial(jit, static_argnames=("n",))
|
@partial(jit, static_argnames=("n",))
|
||||||
def hilbert(n: int) -> Array:
|
def hilbert(n: int) -> Array:
|
||||||
r"""Create a Hilbert matrix of order n.
|
r"""Create a Hilbert matrix of order n.
|
||||||
|
@ -53,6 +53,22 @@ def _is_required_cuda_version_satisfied(cuda_version):
|
|||||||
else:
|
else:
|
||||||
return int(version.split()[-1]) >= cuda_version
|
return int(version.split()[-1]) >= cuda_version
|
||||||
|
|
||||||
|
|
||||||
|
def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray:
|
||||||
|
"""scipy.linalg.toeplitz with v1.17+ batching semantics."""
|
||||||
|
if scipy_version >= (1, 17, 0):
|
||||||
|
return scipy.linalg.toeplitz(c, r)
|
||||||
|
elif r is None:
|
||||||
|
c = np.atleast_1d(c)
|
||||||
|
return np.vectorize(
|
||||||
|
scipy.linalg.toeplitz, signature="(m)->(m,m)", otypes=(c.dtype,))(c)
|
||||||
|
else:
|
||||||
|
c = np.atleast_1d(c)
|
||||||
|
r = np.atleast_1d(r)
|
||||||
|
return np.vectorize(
|
||||||
|
scipy.linalg.toeplitz, signature="(m),(n)->(m,n)", otypes=(np.result_type(c, r),))(c, r)
|
||||||
|
|
||||||
|
|
||||||
class NumpyLinalgTest(jtu.JaxTestCase):
|
class NumpyLinalgTest(jtu.JaxTestCase):
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
@ -1990,11 +2006,11 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
|||||||
self.assertAllClose(root, expected, check_dtypes=False)
|
self.assertAllClose(root, expected, check_dtypes=False)
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)],
|
cshape=[(), (4,), (8,), (4, 7), (2, 1, 5)],
|
||||||
cdtype=float_types + complex_types,
|
cdtype=float_types + complex_types,
|
||||||
rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)],
|
rshape=[(), (3,), (7,), (4, 4), (2, 4, 0)],
|
||||||
rdtype=float_types + complex_types + int_types)
|
rdtype=float_types + complex_types + int_types)
|
||||||
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype):
|
def testToeplitzConstruction(self, rshape, rdtype, cshape, cdtype):
|
||||||
if ((rdtype in [np.float64, np.complex128]
|
if ((rdtype in [np.float64, np.complex128]
|
||||||
or cdtype in [np.float64, np.complex128])
|
or cdtype in [np.float64, np.complex128])
|
||||||
and not config.enable_x64.value):
|
and not config.enable_x64.value):
|
||||||
@ -2007,10 +2023,11 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)]
|
args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)]
|
||||||
with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]):
|
with jax.numpy_rank_promotion("allow"):
|
||||||
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
|
with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]):
|
||||||
jsp.linalg.toeplitz, args_maker)
|
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp_linalg_toeplitz),
|
||||||
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
jsp.linalg.toeplitz, args_maker)
|
||||||
|
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)],
|
shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)],
|
||||||
@ -2028,8 +2045,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz),
|
self._CheckAgainstNumpy(osp_linalg_toeplitz, jsp.linalg.toeplitz, args_maker)
|
||||||
jsp.linalg.toeplitz, args_maker)
|
|
||||||
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
self._CompileAndCheck(jsp.linalg.toeplitz, args_maker)
|
||||||
|
|
||||||
def testToeplitzConstructionWithKnownCases(self):
|
def testToeplitzConstructionWithKnownCases(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user