mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Remove use of int casting in STFT collapse of batch dimensions.
PiperOrigin-RevId: 524115535
This commit is contained in:
parent
3e93833ed8
commit
0fd5b2ca61
@ -233,7 +233,7 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array],
|
||||
else:
|
||||
step = nperseg - noverlap
|
||||
batch_shape = list(batch_shape)
|
||||
x = x.reshape((math.prod(batch_shape)), signal_length)[..., np.newaxis]
|
||||
x = x.reshape((math.prod(batch_shape), signal_length, 1))
|
||||
result = jax.lax.conv_general_dilated_patches(
|
||||
x, (nperseg,), (step,),
|
||||
'VALID',
|
||||
|
Loading…
x
Reference in New Issue
Block a user