From 75271016727fa7ea590c9eb43a9fa8295c7c6cb1 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 24 Jul 2024 11:25:25 -0400 Subject: [PATCH] Don't broadcast scalar conditions in the jnp.where implementation(). The underlying lax primitive is perfectly happy to accept scalar conditions with the other arguments being non-scalar. --- jax/_src/numpy/util.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 21b96deea..5133a1589 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -447,9 +447,13 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array: if not np.issubdtype(_dtype(condition), np.bool_): condition = lax.ne(condition, lax._zero(condition)) x, y = promote_dtypes(x, y) - condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y) + if np.ndim(condition) == 0: + # lax.select() handles scalar conditions without broadcasting. + x_arr, y_arr = _broadcast_arrays(x, y) + else: + condition, x_arr, y_arr = _broadcast_arrays(condition, x, y) try: is_always_empty = core.is_empty_shape(x_arr.shape) except: is_always_empty = False # can fail with dynamic shapes - return lax.select(condition_arr, x_arr, y_arr) if not is_always_empty else x_arr + return lax.select(condition, x_arr, y_arr) if not is_always_empty else x_arr