diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 35f608e9e..0031341ab 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -11573,164 +11573,6 @@ def _canonicalize_tuple_index(arr_ndim, idx): idx = tuple(idx) + colons return idx - -@export -def blackman(M: int) -> Array: - """Return a Blackman window of size M. - - JAX implementation of :func:`numpy.blackman`. - - Args: - M: The window size. - - Returns: - An array of size M containing the Blackman window. - - Examples: - >>> with jnp.printoptions(precision=2, suppress=True): - ... print(jnp.blackman(4)) - [-0. 0.63 0.63 -0. ] - - See also: - - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. - - :func:`jax.numpy.hamming`: return a Hamming window of size M. - - :func:`jax.numpy.hanning`: return a Hanning window of size M. - - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. - """ - M = core.concrete_or_error(int, M, "M argument of jnp.blackman") - dtype = dtypes.canonicalize_dtype(dtypes.float_) - if M <= 1: - return ones(M, dtype) - n = lax.iota(dtype, M) - return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1)) - - -@export -def bartlett(M: int) -> Array: - """Return a Bartlett window of size M. - - JAX implementation of :func:`numpy.bartlett`. - - Args: - M: The window size. - - Returns: - An array of size M containing the Bartlett window. - - Examples: - >>> with jnp.printoptions(precision=2, suppress=True): - ... print(jnp.bartlett(4)) - [0. 0.67 0.67 0. ] - - See also: - - :func:`jax.numpy.blackman`: return a Blackman window of size M. - - :func:`jax.numpy.hamming`: return a Hamming window of size M. - - :func:`jax.numpy.hanning`: return a Hanning window of size M. - - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. - """ - M = core.concrete_or_error(int, M, "M argument of jnp.bartlett") - dtype = dtypes.canonicalize_dtype(dtypes.float_) - if M <= 1: - return ones(M, dtype) - n = lax.iota(dtype, M) - return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1) - - -@export -def hamming(M: int) -> Array: - """Return a Hamming window of size M. - - JAX implementation of :func:`numpy.hamming`. - - Args: - M: The window size. - - Returns: - An array of size M containing the Hamming window. - - Examples: - >>> with jnp.printoptions(precision=2, suppress=True): - ... print(jnp.hamming(4)) - [0.08 0.77 0.77 0.08] - - See also: - - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. - - :func:`jax.numpy.blackman`: return a Blackman window of size M. - - :func:`jax.numpy.hanning`: return a Hanning window of size M. - - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. - """ - M = core.concrete_or_error(int, M, "M argument of jnp.hamming") - dtype = dtypes.canonicalize_dtype(dtypes.float_) - if M <= 1: - return ones(M, dtype) - n = lax.iota(dtype, M) - return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1)) - - -@export -def hanning(M: int) -> Array: - """Return a Hanning window of size M. - - JAX implementation of :func:`numpy.hanning`. - - Args: - M: The window size. - - Returns: - An array of size M containing the Hanning window. - - Examples: - >>> with jnp.printoptions(precision=2, suppress=True): - ... print(jnp.hanning(4)) - [0. 0.75 0.75 0. ] - - See also: - - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. - - :func:`jax.numpy.blackman`: return a Blackman window of size M. - - :func:`jax.numpy.hamming`: return a Hamming window of size M. - - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. - """ - M = core.concrete_or_error(int, M, "M argument of jnp.hanning") - dtype = dtypes.canonicalize_dtype(dtypes.float_) - if M <= 1: - return ones(M, dtype) - n = lax.iota(dtype, M) - return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1))) - - -@export -def kaiser(M: int, beta: ArrayLike) -> Array: - """Return a Kaiser window of size M. - - JAX implementation of :func:`numpy.kaiser`. - - Args: - M: The window size. - beta: The Kaiser window parameter. - - Returns: - An array of size M containing the Kaiser window. - - Examples: - >>> with jnp.printoptions(precision=2, suppress=True): - ... print(jnp.kaiser(4, 1.5)) - [0.61 0.95 0.95 0.61] - - See also: - - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. - - :func:`jax.numpy.blackman`: return a Blackman window of size M. - - :func:`jax.numpy.hamming`: return a Hamming window of size M. - - :func:`jax.numpy.hanning`: return a Hanning window of size M. - """ - M = core.concrete_or_error(int, M, "M argument of jnp.kaiser") - dtype = dtypes.canonicalize_dtype(dtypes.float_) - if M <= 1: - return ones(M, dtype) - n = lax.iota(dtype, M) - alpha = 0.5 * (M - 1) - return i0(beta * ufuncs.sqrt(1 - ((n - alpha) / alpha) ** 2)) / i0(beta) - - def _gcd_cond_fn(xs: tuple[Array, Array]) -> Array: x1, x2 = xs return reductions.any(x2 != 0) diff --git a/jax/_src/numpy/window_functions.py b/jax/_src/numpy/window_functions.py new file mode 100644 index 000000000..96a15db77 --- /dev/null +++ b/jax/_src/numpy/window_functions.py @@ -0,0 +1,182 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from jax._src import core +from jax._src import dtypes +from jax._src.numpy import lax_numpy +from jax._src.numpy import ufuncs +from jax._src.typing import Array, ArrayLike +from jax._src.util import set_module +from jax import lax + +export = set_module('jax.numpy') + + +@export +def blackman(M: int) -> Array: + """Return a Blackman window of size M. + + JAX implementation of :func:`numpy.blackman`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Blackman window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.blackman(4)) + [-0. 0.63 0.63 -0. ] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ + M = core.concrete_or_error(int, M, "M argument of jnp.blackman") + dtype = dtypes.canonicalize_dtype(dtypes.float_) + if M <= 1: + return lax.full((M,), 1, dtype) + n = lax.iota(dtype, M) + return 0.42 - 0.5 * ufuncs.cos(2 * np.pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * np.pi * n / (M - 1)) + + +@export +def bartlett(M: int) -> Array: + """Return a Bartlett window of size M. + + JAX implementation of :func:`numpy.bartlett`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Bartlett window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.bartlett(4)) + [0. 0.67 0.67 0. ] + + See also: + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ + M = core.concrete_or_error(int, M, "M argument of jnp.bartlett") + dtype = dtypes.canonicalize_dtype(dtypes.float_) + if M <= 1: + return lax.full((M,), 1, dtype) + n = lax.iota(dtype, M) + return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1) + + +@export +def hamming(M: int) -> Array: + """Return a Hamming window of size M. + + JAX implementation of :func:`numpy.hamming`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Hamming window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.hamming(4)) + [0.08 0.77 0.77 0.08] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ + M = core.concrete_or_error(int, M, "M argument of jnp.hamming") + dtype = dtypes.canonicalize_dtype(dtypes.float_) + if M <= 1: + return lax.full((M,), 1, dtype) + n = lax.iota(dtype, M) + return 0.54 - 0.46 * ufuncs.cos(2 * np.pi * n / (M - 1)) + + +@export +def hanning(M: int) -> Array: + """Return a Hanning window of size M. + + JAX implementation of :func:`numpy.hanning`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Hanning window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.hanning(4)) + [0. 0.75 0.75 0. ] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ + M = core.concrete_or_error(int, M, "M argument of jnp.hanning") + dtype = dtypes.canonicalize_dtype(dtypes.float_) + if M <= 1: + return lax.full((M,), 1, dtype) + n = lax.iota(dtype, M) + return 0.5 * (1 - ufuncs.cos(2 * np.pi * n / (M - 1))) + + +@export +def kaiser(M: int, beta: ArrayLike) -> Array: + """Return a Kaiser window of size M. + + JAX implementation of :func:`numpy.kaiser`. + + Args: + M: The window size. + beta: The Kaiser window parameter. + + Returns: + An array of size M containing the Kaiser window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.kaiser(4, 1.5)) + [0.61 0.95 0.95 0.61] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + """ + M = core.concrete_or_error(int, M, "M argument of jnp.kaiser") + dtype = dtypes.canonicalize_dtype(dtypes.float_) + if M <= 1: + return lax.full((M,), 1, dtype) + n = lax.iota(dtype, M) + alpha = 0.5 * (M - 1) + return lax_numpy.i0(beta * ufuncs.sqrt(1 - ((n - alpha) / alpha) ** 2)) / lax_numpy.i0(beta) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 41badcacc..cefc5e0b3 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -47,9 +47,7 @@ from jax._src.numpy.lax_numpy import ( atleast_1d as atleast_1d, atleast_2d as atleast_2d, atleast_3d as atleast_3d, - bartlett as bartlett, bincount as bincount, - blackman as blackman, block as block, broadcast_arrays as broadcast_arrays, broadcast_shapes as broadcast_shapes, @@ -106,8 +104,6 @@ from jax._src.numpy.lax_numpy import ( geomspace as geomspace, get_printoptions as get_printoptions, gradient as gradient, - hamming as hamming, - hanning as hanning, histogram as histogram, histogram_bin_edges as histogram_bin_edges, histogram2d as histogram2d, @@ -131,7 +127,6 @@ from jax._src.numpy.lax_numpy import ( issubdtype as issubdtype, iterable as iterable, ix_ as ix_, - kaiser as kaiser, kron as kron, lcm as lcm, linspace as linspace, @@ -265,6 +260,14 @@ from jax._src.numpy.sorting import ( sort_complex as sort_complex, ) +from jax._src.numpy.window_functions import ( + bartlett as bartlett, + blackman as blackman, + hamming as hamming, + hanning as hanning, + kaiser as kaiser, +) + # NumPy generic scalar types: from numpy import ( character as character,