diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b2587c3f..5421e3e42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,9 @@ Remember to align the itemized text with the first line of an item within a list * Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`. It currently is converted to NaN, and in the future will raise a {obj}`TypeError`. - + * Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by + keyword arguments has been deprecated, to match `numpy.where`. + ## jaxlib 0.4.21 diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index bf3c9adc8..e948a6488 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1062,25 +1062,27 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return jitted_interp(x, xp, fp, left, right, period) -@overload -def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *, - size: int | None = None, +@overload # type: ignore[no-overload-impl] +def where(condition: ArrayLike, x: Literal[None] = None, + y: Literal[None] = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> tuple[Array, ...]: ... @overload -def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *, +def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, / ,*, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> Array: ... @overload def where(condition: ArrayLike, x: ArrayLike | None = None, - y: ArrayLike | None = None, *, size: int | None = None, + y: ArrayLike | None = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> Array | tuple[Array, ...]: ... -@util._wraps(np.where, +_DEPRECATED_WHERE_ARG = object() + +@util._wraps(np.where, # type: ignore[no-redef] lax_description=_dedent(""" At present, JAX does not support JIT-compilation of the single-argument form of :py:func:`jax.numpy.where` because its output shape is data-dependent. The @@ -1104,18 +1106,43 @@ def where(condition: ArrayLike, x: ArrayLike | None = None, fill_value : array_like, optional When ``size`` is specified and there are fewer than the indicated number of elements, the remaining elements will be filled with ``fill_value``, which defaults to zero.""")) -def where(condition: ArrayLike, x: ArrayLike | None = None, - y: ArrayLike | None = None, *, size: int | None = None, - fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None - ) -> Array | tuple[Array, ...]: - if x is None and y is None: - util.check_arraylike("where", condition) - return nonzero(condition, size=size, fill_value=fill_value) +def where( + acondition = None, if_true = None, if_false = None, /, *, + size=None, fill_value=None, + # Deprecated keyword-only names. + condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG, + y = _DEPRECATED_WHERE_ARG +) -> Array | tuple[Array, ...]: + if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG + or y is not _DEPRECATED_WHERE_ARG): + # TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires. + warnings.warn( + "Passing condition, x, or y to jax.numpy.where via keyword arguments " + "is deprecated.", + DeprecationWarning, + stacklevel=2, + ) + if condition is not _DEPRECATED_WHERE_ARG: + if acondition is not None: + raise ValueError("condition should be a positional-only argument") + acondition = condition + if x is not _DEPRECATED_WHERE_ARG: + if if_true is not None: + raise ValueError("x should be a positional-only argument") + if_true = x + if y is not _DEPRECATED_WHERE_ARG: + if if_false is not None: + raise ValueError("y should be a positional-only argument") + if_false = y + + if if_true is None and if_false is None: + util.check_arraylike("where", acondition) + return nonzero(acondition, size=size, fill_value=fill_value) else: - util.check_arraylike("where", condition, x, y) + util.check_arraylike("where", acondition, if_true, if_false) if size is not None or fill_value is not None: raise ValueError("size and fill_value arguments cannot be used in three-term where function.") - return util._where(condition, x, y) + return util._where(acondition, if_true, if_false) @util._wraps(np.select) diff --git a/jax/_src/scipy/sparse/linalg.py b/jax/_src/scipy/sparse/linalg.py index 8614dfa38..9fd11f6d2 100644 --- a/jax/_src/scipy/sparse/linalg.py +++ b/jax/_src/scipy/sparse/linalg.py @@ -704,7 +704,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve) failed = jnp.isnan(_norm(x)) - info = jnp.where(failed, x=-1, y=0) + info = jnp.where(failed, -1, 0) return x, info diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index e63cc12be..876d76448 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -831,19 +831,19 @@ def vstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]], @overload def where(condition: ArrayLike, x: Literal[None] = ..., y: Literal[None] = ..., - *, size: Optional[int] = ..., + /, *, size: Optional[int] = ..., fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... ) -> tuple[Array, ...]: ... @overload -def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *, +def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /, *, size: Optional[int] = ..., fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... ) -> Array: ... @overload def where(condition: ArrayLike, x: Optional[ArrayLike] = ..., - y: Optional[ArrayLike] = ..., *, size: Optional[int] = ..., + y: Optional[ArrayLike] = ..., /, *, size: Optional[int] = ..., fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... ) -> Union[Array, tuple[Array, ...]]: ...