Finalize deprecation of jnp.where keyword arguments

PiperOrigin-RevId: 629086639
This commit is contained in:
Jake VanderPlas 2024-04-29 09:09:13 -07:00 committed by jax authors
parent dfc17187ab
commit ba540ca735
2 changed files with 11 additions and 36 deletions

View File

@ -58,6 +58,8 @@ Remember to align the itemized text with the first line of an item within a list
* The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have
been removed after being deprecated in JAX v0.4.22. Instead use
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
* The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now
positional-only, following deprecation of the keywords in JAX v0.4.21.
* Bug fixes
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.

View File

@ -1083,13 +1083,8 @@ def where(condition: ArrayLike, x: ArrayLike | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> Array | tuple[Array, ...]: ...
def where(
acondition = None, if_true = None, if_false = None, /, *,
size=None, fill_value=None,
# Deprecated keyword-only names.
condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG,
y = _DEPRECATED_WHERE_ARG
) -> Array | tuple[Array, ...]:
def where(condition, x=None, y=None, /, *, size=None, fill_value=None):
"""Select elements from two arrays based on a condition.
JAX implementation of :func:`numpy.where`.
@ -1149,41 +1144,19 @@ def where(
>>> jnp.where(x > 4, x, 0)
Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)
"""
if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG
or y is not _DEPRECATED_WHERE_ARG):
# TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires.
warnings.warn(
"Passing condition, x, or y to jax.numpy.where via keyword arguments "
"is deprecated.",
DeprecationWarning,
stacklevel=2,
)
if condition is not _DEPRECATED_WHERE_ARG:
if acondition is not None:
raise ValueError("condition should be a positional-only argument")
acondition = condition
if x is not _DEPRECATED_WHERE_ARG:
if if_true is not None:
raise ValueError("x should be a positional-only argument")
if_true = x
if y is not _DEPRECATED_WHERE_ARG:
if if_false is not None:
raise ValueError("y should be a positional-only argument")
if_false = y
if if_true is None and if_false is None:
util.check_arraylike("where", acondition)
return nonzero(acondition, size=size, fill_value=fill_value)
if x is None and y is None:
util.check_arraylike("where", condition)
return nonzero(condition, size=size, fill_value=fill_value)
else:
util.check_arraylike("where", acondition, if_true, if_false)
util.check_arraylike("where", condition, x, y)
if size is not None or fill_value is not None:
raise ValueError("size and fill_value arguments cannot be used in "
"three-term where function.")
if if_true is None or if_false is None:
if x is None or y is None:
raise ValueError("Either both or neither of the x and y arguments "
"should be provided to jax.numpy.where, got "
f"{if_true} and {if_false}.")
return util._where(acondition, if_true, if_false)
f"{x} and {y}.")
return util._where(condition, x, y)
@util.implements(np.select)