Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.

PiperOrigin-RevId: 584377134
This commit is contained in:
Peter Hawkins 2023-11-21 11:09:28 -08:00 committed by jax authors
parent 038879248d
commit 84c1e825c0
4 changed files with 49 additions and 20 deletions

View File

@ -21,7 +21,9 @@ Remember to align the itemized text with the first line of an item within a list
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
keyword arguments has been deprecated, to match `numpy.where`.
## jaxlib 0.4.21

View File

@ -1062,25 +1062,27 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
return jitted_interp(x, xp, fp, left, right, period)
@overload
def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *,
size: int | None = None,
@overload # type: ignore[no-overload-impl]
def where(condition: ArrayLike, x: Literal[None] = None,
y: Literal[None] = None, /, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> tuple[Array, ...]: ...
@overload
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, / ,*,
size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> Array: ...
@overload
def where(condition: ArrayLike, x: ArrayLike | None = None,
y: ArrayLike | None = None, *, size: int | None = None,
y: ArrayLike | None = None, /, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> Array | tuple[Array, ...]: ...
@util._wraps(np.where,
_DEPRECATED_WHERE_ARG = object()
@util._wraps(np.where, # type: ignore[no-redef]
lax_description=_dedent("""
At present, JAX does not support JIT-compilation of the single-argument form
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
@ -1104,18 +1106,43 @@ def where(condition: ArrayLike, x: ArrayLike | None = None,
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def where(condition: ArrayLike, x: ArrayLike | None = None,
y: ArrayLike | None = None, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> Array | tuple[Array, ...]:
if x is None and y is None:
util.check_arraylike("where", condition)
return nonzero(condition, size=size, fill_value=fill_value)
def where(
acondition = None, if_true = None, if_false = None, /, *,
size=None, fill_value=None,
# Deprecated keyword-only names.
condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG,
y = _DEPRECATED_WHERE_ARG
) -> Array | tuple[Array, ...]:
if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG
or y is not _DEPRECATED_WHERE_ARG):
# TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires.
warnings.warn(
"Passing condition, x, or y to jax.numpy.where via keyword arguments "
"is deprecated.",
DeprecationWarning,
stacklevel=2,
)
if condition is not _DEPRECATED_WHERE_ARG:
if acondition is not None:
raise ValueError("condition should be a positional-only argument")
acondition = condition
if x is not _DEPRECATED_WHERE_ARG:
if if_true is not None:
raise ValueError("x should be a positional-only argument")
if_true = x
if y is not _DEPRECATED_WHERE_ARG:
if if_false is not None:
raise ValueError("y should be a positional-only argument")
if_false = y
if if_true is None and if_false is None:
util.check_arraylike("where", acondition)
return nonzero(acondition, size=size, fill_value=fill_value)
else:
util.check_arraylike("where", condition, x, y)
util.check_arraylike("where", acondition, if_true, if_false)
if size is not None or fill_value is not None:
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
return util._where(condition, x, y)
return util._where(acondition, if_true, if_false)
@util._wraps(np.select)

View File

@ -704,7 +704,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
failed = jnp.isnan(_norm(x))
info = jnp.where(failed, x=-1, y=0)
info = jnp.where(failed, -1, 0)
return x, info

View File

@ -831,19 +831,19 @@ def vstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]],
@overload
def where(condition: ArrayLike, x: Literal[None] = ..., y: Literal[None] = ...,
*, size: Optional[int] = ...,
/, *, size: Optional[int] = ...,
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
) -> tuple[Array, ...]: ...
@overload
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /, *,
size: Optional[int] = ...,
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
) -> Array: ...
@overload
def where(condition: ArrayLike, x: Optional[ArrayLike] = ...,
y: Optional[ArrayLike] = ..., *, size: Optional[int] = ...,
y: Optional[ArrayLike] = ..., /, *, size: Optional[int] = ...,
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
) -> Union[Array, tuple[Array, ...]]: ...