mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
This commit is contained in:
parent
038879248d
commit
84c1e825c0
@ -21,7 +21,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
|
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
|
||||||
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
|
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
|
||||||
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
|
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
|
||||||
|
* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
|
||||||
|
keyword arguments has been deprecated, to match `numpy.where`.
|
||||||
|
|
||||||
|
|
||||||
## jaxlib 0.4.21
|
## jaxlib 0.4.21
|
||||||
|
|
||||||
|
@ -1062,25 +1062,27 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
|
|||||||
return jitted_interp(x, xp, fp, left, right, period)
|
return jitted_interp(x, xp, fp, left, right, period)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload # type: ignore[no-overload-impl]
|
||||||
def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *,
|
def where(condition: ArrayLike, x: Literal[None] = None,
|
||||||
size: int | None = None,
|
y: Literal[None] = None, /, *, size: int | None = None,
|
||||||
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
||||||
) -> tuple[Array, ...]: ...
|
) -> tuple[Array, ...]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
|
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, / ,*,
|
||||||
size: int | None = None,
|
size: int | None = None,
|
||||||
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
||||||
) -> Array: ...
|
) -> Array: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def where(condition: ArrayLike, x: ArrayLike | None = None,
|
def where(condition: ArrayLike, x: ArrayLike | None = None,
|
||||||
y: ArrayLike | None = None, *, size: int | None = None,
|
y: ArrayLike | None = None, /, *, size: int | None = None,
|
||||||
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
||||||
) -> Array | tuple[Array, ...]: ...
|
) -> Array | tuple[Array, ...]: ...
|
||||||
|
|
||||||
@util._wraps(np.where,
|
_DEPRECATED_WHERE_ARG = object()
|
||||||
|
|
||||||
|
@util._wraps(np.where, # type: ignore[no-redef]
|
||||||
lax_description=_dedent("""
|
lax_description=_dedent("""
|
||||||
At present, JAX does not support JIT-compilation of the single-argument form
|
At present, JAX does not support JIT-compilation of the single-argument form
|
||||||
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
|
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
|
||||||
@ -1104,18 +1106,43 @@ def where(condition: ArrayLike, x: ArrayLike | None = None,
|
|||||||
fill_value : array_like, optional
|
fill_value : array_like, optional
|
||||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||||
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
|
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
|
||||||
def where(condition: ArrayLike, x: ArrayLike | None = None,
|
def where(
|
||||||
y: ArrayLike | None = None, *, size: int | None = None,
|
acondition = None, if_true = None, if_false = None, /, *,
|
||||||
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
size=None, fill_value=None,
|
||||||
) -> Array | tuple[Array, ...]:
|
# Deprecated keyword-only names.
|
||||||
if x is None and y is None:
|
condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG,
|
||||||
util.check_arraylike("where", condition)
|
y = _DEPRECATED_WHERE_ARG
|
||||||
return nonzero(condition, size=size, fill_value=fill_value)
|
) -> Array | tuple[Array, ...]:
|
||||||
|
if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG
|
||||||
|
or y is not _DEPRECATED_WHERE_ARG):
|
||||||
|
# TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires.
|
||||||
|
warnings.warn(
|
||||||
|
"Passing condition, x, or y to jax.numpy.where via keyword arguments "
|
||||||
|
"is deprecated.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
if condition is not _DEPRECATED_WHERE_ARG:
|
||||||
|
if acondition is not None:
|
||||||
|
raise ValueError("condition should be a positional-only argument")
|
||||||
|
acondition = condition
|
||||||
|
if x is not _DEPRECATED_WHERE_ARG:
|
||||||
|
if if_true is not None:
|
||||||
|
raise ValueError("x should be a positional-only argument")
|
||||||
|
if_true = x
|
||||||
|
if y is not _DEPRECATED_WHERE_ARG:
|
||||||
|
if if_false is not None:
|
||||||
|
raise ValueError("y should be a positional-only argument")
|
||||||
|
if_false = y
|
||||||
|
|
||||||
|
if if_true is None and if_false is None:
|
||||||
|
util.check_arraylike("where", acondition)
|
||||||
|
return nonzero(acondition, size=size, fill_value=fill_value)
|
||||||
else:
|
else:
|
||||||
util.check_arraylike("where", condition, x, y)
|
util.check_arraylike("where", acondition, if_true, if_false)
|
||||||
if size is not None or fill_value is not None:
|
if size is not None or fill_value is not None:
|
||||||
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
|
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
|
||||||
return util._where(condition, x, y)
|
return util._where(acondition, if_true, if_false)
|
||||||
|
|
||||||
|
|
||||||
@util._wraps(np.select)
|
@util._wraps(np.select)
|
||||||
|
@ -704,7 +704,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
|||||||
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
|
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
|
||||||
|
|
||||||
failed = jnp.isnan(_norm(x))
|
failed = jnp.isnan(_norm(x))
|
||||||
info = jnp.where(failed, x=-1, y=0)
|
info = jnp.where(failed, -1, 0)
|
||||||
return x, info
|
return x, info
|
||||||
|
|
||||||
|
|
||||||
|
@ -831,19 +831,19 @@ def vstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]],
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
def where(condition: ArrayLike, x: Literal[None] = ..., y: Literal[None] = ...,
|
def where(condition: ArrayLike, x: Literal[None] = ..., y: Literal[None] = ...,
|
||||||
*, size: Optional[int] = ...,
|
/, *, size: Optional[int] = ...,
|
||||||
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
|
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
|
||||||
) -> tuple[Array, ...]: ...
|
) -> tuple[Array, ...]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
|
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /, *,
|
||||||
size: Optional[int] = ...,
|
size: Optional[int] = ...,
|
||||||
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
|
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
|
||||||
) -> Array: ...
|
) -> Array: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def where(condition: ArrayLike, x: Optional[ArrayLike] = ...,
|
def where(condition: ArrayLike, x: Optional[ArrayLike] = ...,
|
||||||
y: Optional[ArrayLike] = ..., *, size: Optional[int] = ...,
|
y: Optional[ArrayLike] = ..., /, *, size: Optional[int] = ...,
|
||||||
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
|
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
|
||||||
) -> Union[Array, tuple[Array, ...]]: ...
|
) -> Union[Array, tuple[Array, ...]]: ...
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user