mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Prefer expand_dims
over reshape
This commit is contained in:
parent
234be736c4
commit
52d7f4911c
@ -327,7 +327,7 @@ class ufunc:
|
||||
ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis)
|
||||
def loop_body(i, out):
|
||||
return _where((i > ind_start) & (i < ind_end),
|
||||
self._call(out, take(a, i.reshape(1), axis=axis)),
|
||||
self._call(out, take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
|
||||
out)
|
||||
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)
|
||||
|
||||
|
@ -987,7 +987,7 @@ def _threefry_seed(seed: typing.Array) -> typing.Array:
|
||||
raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
|
||||
if not np.issubdtype(seed.dtype, np.integer):
|
||||
raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")
|
||||
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
|
||||
convert = lambda k: lax.expand_dims(lax.convert_element_type(k, np.uint32), [0])
|
||||
k1 = convert(
|
||||
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
|
||||
with config.numpy_dtype_promotion('standard'):
|
||||
|
@ -678,7 +678,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
|
||||
xsubs *= win.sum() # This takes care of the 'spectrum' scaling
|
||||
|
||||
# make win broadcastable over xsubs
|
||||
win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1,))
|
||||
win = lax.expand_dims(win, (*range(xsubs.ndim - 2), -1))
|
||||
x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep)
|
||||
win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1)
|
||||
norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep)
|
||||
|
Loading…
x
Reference in New Issue
Block a user