Prefer expand_dims over reshape

This commit is contained in:
Lukas Geiger 2023-11-16 01:15:48 +00:00
parent 234be736c4
commit 52d7f4911c
3 changed files with 3 additions and 3 deletions

View File

@ -327,7 +327,7 @@ class ufunc:
ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis)
def loop_body(i, out):
return _where((i > ind_start) & (i < ind_end),
self._call(out, take(a, i.reshape(1), axis=axis)),
self._call(out, take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
out)
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)

View File

@ -987,7 +987,7 @@ def _threefry_seed(seed: typing.Array) -> typing.Array:
raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
if not np.issubdtype(seed.dtype, np.integer):
raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
convert = lambda k: lax.expand_dims(lax.convert_element_type(k, np.uint32), [0])
k1 = convert(
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
with config.numpy_dtype_promotion('standard'):

View File

@ -678,7 +678,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
xsubs *= win.sum() # This takes care of the 'spectrum' scaling
# make win broadcastable over xsubs
win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1,))
win = lax.expand_dims(win, (*range(xsubs.ndim - 2), -1))
x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep)
win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1)
norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep)