From 0fd5b2ca616f8956f1e5dee4f0c44aae95f5aa28 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 13 Apr 2023 15:14:37 -0700 Subject: [PATCH] Remove use of int casting in STFT collapse of batch dimensions. PiperOrigin-RevId: 524115535 --- jax/_src/scipy/signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 70b9f6084..f7424cfce 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -233,7 +233,7 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array], else: step = nperseg - noverlap batch_shape = list(batch_shape) - x = x.reshape((math.prod(batch_shape)), signal_length)[..., np.newaxis] + x = x.reshape((math.prod(batch_shape), signal_length, 1)) result = jax.lax.conv_general_dilated_patches( x, (nperseg,), (step,), 'VALID',