mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Don't broadcast scalar conditions in the jnp.where implementation().
The underlying lax primitive is perfectly happy to accept scalar conditions with the other arguments being non-scalar.
This commit is contained in:
parent
0e17d26b6d
commit
7527101672
@ -447,9 +447,13 @@ def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
|
||||
if not np.issubdtype(_dtype(condition), np.bool_):
|
||||
condition = lax.ne(condition, lax._zero(condition))
|
||||
x, y = promote_dtypes(x, y)
|
||||
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
|
||||
if np.ndim(condition) == 0:
|
||||
# lax.select() handles scalar conditions without broadcasting.
|
||||
x_arr, y_arr = _broadcast_arrays(x, y)
|
||||
else:
|
||||
condition, x_arr, y_arr = _broadcast_arrays(condition, x, y)
|
||||
try:
|
||||
is_always_empty = core.is_empty_shape(x_arr.shape)
|
||||
except:
|
||||
is_always_empty = False # can fail with dynamic shapes
|
||||
return lax.select(condition_arr, x_arr, y_arr) if not is_always_empty else x_arr
|
||||
return lax.select(condition, x_arr, y_arr) if not is_always_empty else x_arr
|
||||
|
Loading…
x
Reference in New Issue
Block a user