mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Finalize deprecation of jnp.where keyword arguments
PiperOrigin-RevId: 629086639
This commit is contained in:
parent
dfc17187ab
commit
ba540ca735
@ -58,6 +58,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have
|
||||
been removed after being deprecated in JAX v0.4.22. Instead use
|
||||
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
|
||||
* The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now
|
||||
positional-only, following deprecation of the keywords in JAX v0.4.21.
|
||||
|
||||
* Bug fixes
|
||||
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
|
||||
|
@ -1083,13 +1083,8 @@ def where(condition: ArrayLike, x: ArrayLike | None = None,
|
||||
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
||||
) -> Array | tuple[Array, ...]: ...
|
||||
|
||||
def where(
|
||||
acondition = None, if_true = None, if_false = None, /, *,
|
||||
size=None, fill_value=None,
|
||||
# Deprecated keyword-only names.
|
||||
condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG,
|
||||
y = _DEPRECATED_WHERE_ARG
|
||||
) -> Array | tuple[Array, ...]:
|
||||
|
||||
def where(condition, x=None, y=None, /, *, size=None, fill_value=None):
|
||||
"""Select elements from two arrays based on a condition.
|
||||
|
||||
JAX implementation of :func:`numpy.where`.
|
||||
@ -1149,41 +1144,19 @@ def where(
|
||||
>>> jnp.where(x > 4, x, 0)
|
||||
Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)
|
||||
"""
|
||||
if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG
|
||||
or y is not _DEPRECATED_WHERE_ARG):
|
||||
# TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires.
|
||||
warnings.warn(
|
||||
"Passing condition, x, or y to jax.numpy.where via keyword arguments "
|
||||
"is deprecated.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if condition is not _DEPRECATED_WHERE_ARG:
|
||||
if acondition is not None:
|
||||
raise ValueError("condition should be a positional-only argument")
|
||||
acondition = condition
|
||||
if x is not _DEPRECATED_WHERE_ARG:
|
||||
if if_true is not None:
|
||||
raise ValueError("x should be a positional-only argument")
|
||||
if_true = x
|
||||
if y is not _DEPRECATED_WHERE_ARG:
|
||||
if if_false is not None:
|
||||
raise ValueError("y should be a positional-only argument")
|
||||
if_false = y
|
||||
|
||||
if if_true is None and if_false is None:
|
||||
util.check_arraylike("where", acondition)
|
||||
return nonzero(acondition, size=size, fill_value=fill_value)
|
||||
if x is None and y is None:
|
||||
util.check_arraylike("where", condition)
|
||||
return nonzero(condition, size=size, fill_value=fill_value)
|
||||
else:
|
||||
util.check_arraylike("where", acondition, if_true, if_false)
|
||||
util.check_arraylike("where", condition, x, y)
|
||||
if size is not None or fill_value is not None:
|
||||
raise ValueError("size and fill_value arguments cannot be used in "
|
||||
"three-term where function.")
|
||||
if if_true is None or if_false is None:
|
||||
if x is None or y is None:
|
||||
raise ValueError("Either both or neither of the x and y arguments "
|
||||
"should be provided to jax.numpy.where, got "
|
||||
f"{if_true} and {if_false}.")
|
||||
return util._where(acondition, if_true, if_false)
|
||||
f"{x} and {y}.")
|
||||
return util._where(condition, x, y)
|
||||
|
||||
|
||||
@util.implements(np.select)
|
||||
|
Loading…
x
Reference in New Issue
Block a user