Remove use of int casting in STFT collapse of batch dimensions.

PiperOrigin-RevId: 524115535
This commit is contained in:
jax authors 2023-04-13 15:14:37 -07:00
parent 3e93833ed8
commit 0fd5b2ca61

View File

@ -233,7 +233,7 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array],
else:
step = nperseg - noverlap
batch_shape = list(batch_shape)
x = x.reshape((math.prod(batch_shape)), signal_length)[..., np.newaxis]
x = x.reshape((math.prod(batch_shape), signal_length, 1))
result = jax.lax.conv_general_dilated_patches(
x, (nperseg,), (step,),
'VALID',