Merge pull request #26411 from jakevdp:jnp-window-functions

PiperOrigin-RevId: 725195238
This commit is contained in:
jax authors 2025-02-10 06:46:07 -08:00
commit 260a879bbf
3 changed files with 190 additions and 163 deletions

View File

@ -11573,164 +11573,6 @@ def _canonicalize_tuple_index(arr_ndim, idx):
idx = tuple(idx) + colons
return idx
@export
def blackman(M: int) -> Array:
"""Return a Blackman window of size M.
JAX implementation of :func:`numpy.blackman`.
Args:
M: The window size.
Returns:
An array of size M containing the Blackman window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.blackman(4))
[-0. 0.63 0.63 -0. ]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.blackman")
dtype = dtypes.canonicalize_dtype(dtypes.float_)
if M <= 1:
return ones(M, dtype)
n = lax.iota(dtype, M)
return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1))
@export
def bartlett(M: int) -> Array:
"""Return a Bartlett window of size M.
JAX implementation of :func:`numpy.bartlett`.
Args:
M: The window size.
Returns:
An array of size M containing the Bartlett window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.bartlett(4))
[0. 0.67 0.67 0. ]
See also:
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.bartlett")
dtype = dtypes.canonicalize_dtype(dtypes.float_)
if M <= 1:
return ones(M, dtype)
n = lax.iota(dtype, M)
return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1)
@export
def hamming(M: int) -> Array:
"""Return a Hamming window of size M.
JAX implementation of :func:`numpy.hamming`.
Args:
M: The window size.
Returns:
An array of size M containing the Hamming window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.hamming(4))
[0.08 0.77 0.77 0.08]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.hamming")
dtype = dtypes.canonicalize_dtype(dtypes.float_)
if M <= 1:
return ones(M, dtype)
n = lax.iota(dtype, M)
return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1))
@export
def hanning(M: int) -> Array:
"""Return a Hanning window of size M.
JAX implementation of :func:`numpy.hanning`.
Args:
M: The window size.
Returns:
An array of size M containing the Hanning window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.hanning(4))
[0. 0.75 0.75 0. ]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.hanning")
dtype = dtypes.canonicalize_dtype(dtypes.float_)
if M <= 1:
return ones(M, dtype)
n = lax.iota(dtype, M)
return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1)))
@export
def kaiser(M: int, beta: ArrayLike) -> Array:
"""Return a Kaiser window of size M.
JAX implementation of :func:`numpy.kaiser`.
Args:
M: The window size.
beta: The Kaiser window parameter.
Returns:
An array of size M containing the Kaiser window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.kaiser(4, 1.5))
[0.61 0.95 0.95 0.61]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.kaiser")
dtype = dtypes.canonicalize_dtype(dtypes.float_)
if M <= 1:
return ones(M, dtype)
n = lax.iota(dtype, M)
alpha = 0.5 * (M - 1)
return i0(beta * ufuncs.sqrt(1 - ((n - alpha) / alpha) ** 2)) / i0(beta)
def _gcd_cond_fn(xs: tuple[Array, Array]) -> Array:
x1, x2 = xs
return reductions.any(x2 != 0)

View File

@ -0,0 +1,182 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import core
from jax._src import dtypes
from jax._src.numpy import lax_numpy
from jax._src.numpy import ufuncs
from jax._src.typing import Array, ArrayLike
from jax._src.util import set_module
from jax import lax
export = set_module('jax.numpy')
@export
def blackman(M: int) -> Array:
"""Return a Blackman window of size M.
JAX implementation of :func:`numpy.blackman`.
Args:
M: The window size.
Returns:
An array of size M containing the Blackman window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.blackman(4))
[-0. 0.63 0.63 -0. ]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.blackman")
dtype = dtypes.canonicalize_dtype(dtypes.float_)
if M <= 1:
return lax.full((M,), 1, dtype)
n = lax.iota(dtype, M)
return 0.42 - 0.5 * ufuncs.cos(2 * np.pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * np.pi * n / (M - 1))
@export
def bartlett(M: int) -> Array:
"""Return a Bartlett window of size M.
JAX implementation of :func:`numpy.bartlett`.
Args:
M: The window size.
Returns:
An array of size M containing the Bartlett window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.bartlett(4))
[0. 0.67 0.67 0. ]
See also:
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.bartlett")
dtype = dtypes.canonicalize_dtype(dtypes.float_)
if M <= 1:
return lax.full((M,), 1, dtype)
n = lax.iota(dtype, M)
return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1)
@export
def hamming(M: int) -> Array:
"""Return a Hamming window of size M.
JAX implementation of :func:`numpy.hamming`.
Args:
M: The window size.
Returns:
An array of size M containing the Hamming window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.hamming(4))
[0.08 0.77 0.77 0.08]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.hamming")
dtype = dtypes.canonicalize_dtype(dtypes.float_)
if M <= 1:
return lax.full((M,), 1, dtype)
n = lax.iota(dtype, M)
return 0.54 - 0.46 * ufuncs.cos(2 * np.pi * n / (M - 1))
@export
def hanning(M: int) -> Array:
"""Return a Hanning window of size M.
JAX implementation of :func:`numpy.hanning`.
Args:
M: The window size.
Returns:
An array of size M containing the Hanning window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.hanning(4))
[0. 0.75 0.75 0. ]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.hanning")
dtype = dtypes.canonicalize_dtype(dtypes.float_)
if M <= 1:
return lax.full((M,), 1, dtype)
n = lax.iota(dtype, M)
return 0.5 * (1 - ufuncs.cos(2 * np.pi * n / (M - 1)))
@export
def kaiser(M: int, beta: ArrayLike) -> Array:
"""Return a Kaiser window of size M.
JAX implementation of :func:`numpy.kaiser`.
Args:
M: The window size.
beta: The Kaiser window parameter.
Returns:
An array of size M containing the Kaiser window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.kaiser(4, 1.5))
[0.61 0.95 0.95 0.61]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.kaiser")
dtype = dtypes.canonicalize_dtype(dtypes.float_)
if M <= 1:
return lax.full((M,), 1, dtype)
n = lax.iota(dtype, M)
alpha = 0.5 * (M - 1)
return lax_numpy.i0(beta * ufuncs.sqrt(1 - ((n - alpha) / alpha) ** 2)) / lax_numpy.i0(beta)

View File

@ -47,9 +47,7 @@ from jax._src.numpy.lax_numpy import (
atleast_1d as atleast_1d,
atleast_2d as atleast_2d,
atleast_3d as atleast_3d,
bartlett as bartlett,
bincount as bincount,
blackman as blackman,
block as block,
broadcast_arrays as broadcast_arrays,
broadcast_shapes as broadcast_shapes,
@ -106,8 +104,6 @@ from jax._src.numpy.lax_numpy import (
geomspace as geomspace,
get_printoptions as get_printoptions,
gradient as gradient,
hamming as hamming,
hanning as hanning,
histogram as histogram,
histogram_bin_edges as histogram_bin_edges,
histogram2d as histogram2d,
@ -131,7 +127,6 @@ from jax._src.numpy.lax_numpy import (
issubdtype as issubdtype,
iterable as iterable,
ix_ as ix_,
kaiser as kaiser,
kron as kron,
lcm as lcm,
linspace as linspace,
@ -265,6 +260,14 @@ from jax._src.numpy.sorting import (
sort_complex as sort_complex,
)
from jax._src.numpy.window_functions import (
bartlett as bartlett,
blackman as blackman,
hamming as hamming,
hanning as hanning,
kaiser as kaiser,
)
# NumPy generic scalar types:
from numpy import (
character as character,