Don't broadcast scalar conditions in the jnp.where implementation().

The underlying lax primitive is perfectly happy to accept scalar conditions with the other arguments being non-scalar.
This commit is contained in:
Peter Hawkins 2024-07-24 11:25:25 -04:00
parent 0e17d26b6d
commit 7527101672

View File

@ -447,9 +447,13 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
if not np.issubdtype(_dtype(condition), np.bool_):
condition = lax.ne(condition, lax._zero(condition))
x, y = promote_dtypes(x, y)
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
if np.ndim(condition) == 0:
# lax.select() handles scalar conditions without broadcasting.
x_arr, y_arr = _broadcast_arrays(x, y)
else:
condition, x_arr, y_arr = _broadcast_arrays(condition, x, y)
try:
is_always_empty = core.is_empty_shape(x_arr.shape)
except:
is_always_empty = False # can fail with dynamic shapes
return lax.select(condition_arr, x_arr, y_arr) if not is_always_empty else x_arr
return lax.select(condition, x_arr, y_arr) if not is_always_empty else x_arr