From 52d7f4911c5c63ec8bc15ee3c3561f5b5e8a6364 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Thu, 16 Nov 2023 01:15:48 +0000 Subject: [PATCH] Prefer `expand_dims` over `reshape` --- jax/_src/numpy/ufunc_api.py | 2 +- jax/_src/prng.py | 2 +- jax/_src/scipy/signal.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 42c14aef1..ae287b106 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -327,7 +327,7 @@ class ufunc: ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) def loop_body(i, out): return _where((i > ind_start) & (i < ind_end), - self._call(out, take(a, i.reshape(1), axis=axis)), + self._call(out, take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), out) return jax.lax.fori_loop(0, a.shape[axis], loop_body, out) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 76e9b1dc8..67670517a 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -987,7 +987,7 @@ def _threefry_seed(seed: typing.Array) -> typing.Array: raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.") if not np.issubdtype(seed.dtype, np.integer): raise TypeError(f"PRNG key seed must be an integer; got {seed!r}") - convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1]) + convert = lambda k: lax.expand_dims(lax.convert_element_type(k, np.uint32), [0]) k1 = convert( lax.shift_right_logical(seed, lax_internal._const(seed, 32))) with config.numpy_dtype_promotion('standard'): diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index d9e721340..f3d918e25 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -678,7 +678,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', xsubs *= win.sum() # This takes care of the 'spectrum' scaling # make win broadcastable over xsubs - win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1,)) + win = lax.expand_dims(win, (*range(xsubs.ndim - 2), -1)) x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep) win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1) norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep)