mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv
This commit is contained in:
parent
c855bb0371
commit
13dd5e42cc
@ -23,6 +23,10 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
|
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
|
||||||
* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
|
* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
|
||||||
keyword arguments has been deprecated, to match `numpy.where`.
|
keyword arguments has been deprecated, to match `numpy.where`.
|
||||||
|
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
|
||||||
|
that cannot be converted to a JAX array is deprecated and now raises a
|
||||||
|
{obj}`DeprecationWaning`. Currently the functions return False, in the future this
|
||||||
|
will raise an exception.
|
||||||
|
|
||||||
|
|
||||||
## jaxlib 0.4.21
|
## jaxlib 0.4.21
|
||||||
|
@ -2298,7 +2298,12 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore
|
|||||||
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
|
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
|
||||||
try:
|
try:
|
||||||
a1, a2 = asarray(a1), asarray(a2)
|
a1, a2 = asarray(a1), asarray(a2)
|
||||||
except Exception:
|
except Exception as err:
|
||||||
|
# TODO(jakevdp): Deprecated 2023-11-23; change to error.
|
||||||
|
warnings.warn("Inputs to array_equal() cannot be coerced to array. "
|
||||||
|
"Returning False; in the future this will raise an exception.\n"
|
||||||
|
f"{err!r}",
|
||||||
|
DeprecationWarning, stacklevel=2)
|
||||||
return bool_(False)
|
return bool_(False)
|
||||||
if shape(a1) != shape(a2):
|
if shape(a1) != shape(a2):
|
||||||
return bool_(False)
|
return bool_(False)
|
||||||
@ -2312,7 +2317,12 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
|
|||||||
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:
|
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:
|
||||||
try:
|
try:
|
||||||
a1, a2 = asarray(a1), asarray(a2)
|
a1, a2 = asarray(a1), asarray(a2)
|
||||||
except Exception:
|
except Exception as err:
|
||||||
|
# TODO(jakevdp): Deprecated 2023-11-23; change to error.
|
||||||
|
warnings.warn("Inputs to array_equiv() cannot be coerced to array. "
|
||||||
|
"Returning False; in the future this will raise an exception.\n"
|
||||||
|
f"{err!r}",
|
||||||
|
DeprecationWarning, stacklevel=2)
|
||||||
return bool_(False)
|
return bool_(False)
|
||||||
try:
|
try:
|
||||||
eq = ufuncs.equal(a1, a2)
|
eq = ufuncs.equal(a1, a2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user