mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
This commit is contained in:
parent
038879248d
commit
84c1e825c0
@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
|
||||
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
|
||||
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
|
||||
* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
|
||||
keyword arguments has been deprecated, to match `numpy.where`.
|
||||
|
||||
|
||||
## jaxlib 0.4.21
|
||||
|
@ -1062,25 +1062,27 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
|
||||
return jitted_interp(x, xp, fp, left, right, period)
|
||||
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *,
|
||||
size: int | None = None,
|
||||
@overload # type: ignore[no-overload-impl]
|
||||
def where(condition: ArrayLike, x: Literal[None] = None,
|
||||
y: Literal[None] = None, /, *, size: int | None = None,
|
||||
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
||||
) -> tuple[Array, ...]: ...
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
|
||||
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, / ,*,
|
||||
size: int | None = None,
|
||||
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
||||
) -> Array: ...
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: ArrayLike | None = None,
|
||||
y: ArrayLike | None = None, *, size: int | None = None,
|
||||
y: ArrayLike | None = None, /, *, size: int | None = None,
|
||||
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
||||
) -> Array | tuple[Array, ...]: ...
|
||||
|
||||
@util._wraps(np.where,
|
||||
_DEPRECATED_WHERE_ARG = object()
|
||||
|
||||
@util._wraps(np.where, # type: ignore[no-redef]
|
||||
lax_description=_dedent("""
|
||||
At present, JAX does not support JIT-compilation of the single-argument form
|
||||
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
|
||||
@ -1104,18 +1106,43 @@ def where(condition: ArrayLike, x: ArrayLike | None = None,
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
|
||||
def where(condition: ArrayLike, x: ArrayLike | None = None,
|
||||
y: ArrayLike | None = None, *, size: int | None = None,
|
||||
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
|
||||
) -> Array | tuple[Array, ...]:
|
||||
if x is None and y is None:
|
||||
util.check_arraylike("where", condition)
|
||||
return nonzero(condition, size=size, fill_value=fill_value)
|
||||
def where(
|
||||
acondition = None, if_true = None, if_false = None, /, *,
|
||||
size=None, fill_value=None,
|
||||
# Deprecated keyword-only names.
|
||||
condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG,
|
||||
y = _DEPRECATED_WHERE_ARG
|
||||
) -> Array | tuple[Array, ...]:
|
||||
if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG
|
||||
or y is not _DEPRECATED_WHERE_ARG):
|
||||
# TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires.
|
||||
warnings.warn(
|
||||
"Passing condition, x, or y to jax.numpy.where via keyword arguments "
|
||||
"is deprecated.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if condition is not _DEPRECATED_WHERE_ARG:
|
||||
if acondition is not None:
|
||||
raise ValueError("condition should be a positional-only argument")
|
||||
acondition = condition
|
||||
if x is not _DEPRECATED_WHERE_ARG:
|
||||
if if_true is not None:
|
||||
raise ValueError("x should be a positional-only argument")
|
||||
if_true = x
|
||||
if y is not _DEPRECATED_WHERE_ARG:
|
||||
if if_false is not None:
|
||||
raise ValueError("y should be a positional-only argument")
|
||||
if_false = y
|
||||
|
||||
if if_true is None and if_false is None:
|
||||
util.check_arraylike("where", acondition)
|
||||
return nonzero(acondition, size=size, fill_value=fill_value)
|
||||
else:
|
||||
util.check_arraylike("where", condition, x, y)
|
||||
util.check_arraylike("where", acondition, if_true, if_false)
|
||||
if size is not None or fill_value is not None:
|
||||
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
|
||||
return util._where(condition, x, y)
|
||||
return util._where(acondition, if_true, if_false)
|
||||
|
||||
|
||||
@util._wraps(np.select)
|
||||
|
@ -704,7 +704,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
|
||||
|
||||
failed = jnp.isnan(_norm(x))
|
||||
info = jnp.where(failed, x=-1, y=0)
|
||||
info = jnp.where(failed, -1, 0)
|
||||
return x, info
|
||||
|
||||
|
||||
|
@ -831,19 +831,19 @@ def vstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]],
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: Literal[None] = ..., y: Literal[None] = ...,
|
||||
*, size: Optional[int] = ...,
|
||||
/, *, size: Optional[int] = ...,
|
||||
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
|
||||
) -> tuple[Array, ...]: ...
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
|
||||
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /, *,
|
||||
size: Optional[int] = ...,
|
||||
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
|
||||
) -> Array: ...
|
||||
|
||||
@overload
|
||||
def where(condition: ArrayLike, x: Optional[ArrayLike] = ...,
|
||||
y: Optional[ArrayLike] = ..., *, size: Optional[int] = ...,
|
||||
y: Optional[ArrayLike] = ..., /, *, size: Optional[int] = ...,
|
||||
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
|
||||
) -> Union[Array, tuple[Array, ...]]: ...
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user