mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Finalize deprecation of jnp.where keyword arguments
PiperOrigin-RevId: 629086639
This commit is contained in:
parent
dfc17187ab
commit
ba540ca735
@ -58,6 +58,8 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have
|
* The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have
|
||||||
been removed after being deprecated in JAX v0.4.22. Instead use
|
been removed after being deprecated in JAX v0.4.22. Instead use
|
||||||
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
|
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
|
||||||
|
* The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now
|
||||||
|
positional-only, following deprecation of the keywords in JAX v0.4.21.
|
||||||
|
|
||||||
* Bug fixes
|
* Bug fixes
|
||||||
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
|
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
|
||||||
|
@ -1083,13 +1083,8 @@ def where(condition: ArrayLike, x: ArrayLike | None = None,
|
|||||||
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
||||||
) -> Array | tuple[Array, ...]: ...
|
) -> Array | tuple[Array, ...]: ...
|
||||||
|
|
||||||
def where(
|
|
||||||
acondition = None, if_true = None, if_false = None, /, *,
|
def where(condition, x=None, y=None, /, *, size=None, fill_value=None):
|
||||||
size=None, fill_value=None,
|
|
||||||
# Deprecated keyword-only names.
|
|
||||||
condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG,
|
|
||||||
y = _DEPRECATED_WHERE_ARG
|
|
||||||
) -> Array | tuple[Array, ...]:
|
|
||||||
"""Select elements from two arrays based on a condition.
|
"""Select elements from two arrays based on a condition.
|
||||||
|
|
||||||
JAX implementation of :func:`numpy.where`.
|
JAX implementation of :func:`numpy.where`.
|
||||||
@ -1149,41 +1144,19 @@ def where(
|
|||||||
>>> jnp.where(x > 4, x, 0)
|
>>> jnp.where(x > 4, x, 0)
|
||||||
Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)
|
Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)
|
||||||
"""
|
"""
|
||||||
if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG
|
if x is None and y is None:
|
||||||
or y is not _DEPRECATED_WHERE_ARG):
|
util.check_arraylike("where", condition)
|
||||||
# TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires.
|
return nonzero(condition, size=size, fill_value=fill_value)
|
||||||
warnings.warn(
|
|
||||||
"Passing condition, x, or y to jax.numpy.where via keyword arguments "
|
|
||||||
"is deprecated.",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
if condition is not _DEPRECATED_WHERE_ARG:
|
|
||||||
if acondition is not None:
|
|
||||||
raise ValueError("condition should be a positional-only argument")
|
|
||||||
acondition = condition
|
|
||||||
if x is not _DEPRECATED_WHERE_ARG:
|
|
||||||
if if_true is not None:
|
|
||||||
raise ValueError("x should be a positional-only argument")
|
|
||||||
if_true = x
|
|
||||||
if y is not _DEPRECATED_WHERE_ARG:
|
|
||||||
if if_false is not None:
|
|
||||||
raise ValueError("y should be a positional-only argument")
|
|
||||||
if_false = y
|
|
||||||
|
|
||||||
if if_true is None and if_false is None:
|
|
||||||
util.check_arraylike("where", acondition)
|
|
||||||
return nonzero(acondition, size=size, fill_value=fill_value)
|
|
||||||
else:
|
else:
|
||||||
util.check_arraylike("where", acondition, if_true, if_false)
|
util.check_arraylike("where", condition, x, y)
|
||||||
if size is not None or fill_value is not None:
|
if size is not None or fill_value is not None:
|
||||||
raise ValueError("size and fill_value arguments cannot be used in "
|
raise ValueError("size and fill_value arguments cannot be used in "
|
||||||
"three-term where function.")
|
"three-term where function.")
|
||||||
if if_true is None or if_false is None:
|
if x is None or y is None:
|
||||||
raise ValueError("Either both or neither of the x and y arguments "
|
raise ValueError("Either both or neither of the x and y arguments "
|
||||||
"should be provided to jax.numpy.where, got "
|
"should be provided to jax.numpy.where, got "
|
||||||
f"{if_true} and {if_false}.")
|
f"{x} and {y}.")
|
||||||
return util._where(acondition, if_true, if_false)
|
return util._where(condition, x, y)
|
||||||
|
|
||||||
|
|
||||||
@util.implements(np.select)
|
@util.implements(np.select)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user