mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Finalize deprecation of zero-dimensional inputs to jnp.nonzero
PiperOrigin-RevId: 626299531
This commit is contained in:
parent
837f0bbf6f
commit
41fa67c2dc
@ -46,6 +46,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
|
||||
passing complex-valued inputs to it. This will raise an error when the
|
||||
deprecation is completed.
|
||||
* Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
|
||||
related functions now raise an error, following a similar change in NumPy.
|
||||
|
||||
## jaxlib 0.4.27
|
||||
|
||||
|
@ -1454,10 +1454,8 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
|
||||
arr = asarray(a)
|
||||
del a
|
||||
if ndim(arr) == 0:
|
||||
# Added 2023 Dec 6
|
||||
warnings.warn("Calling nonzero on 0d arrays is deprecated. Use `atleast_1d(arr).nonzero()",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
arr = atleast_1d(arr)
|
||||
raise ValueError("Calling nonzero on 0d arrays is not allowed. "
|
||||
"Use jnp.atleast_1d(scalar).nonzero() instead.")
|
||||
mask = arr if arr.dtype == bool else (arr != 0)
|
||||
calculated_size = mask.sum() if size is None else size
|
||||
calculated_size = core.concrete_dim_or_error(calculated_size,
|
||||
|
@ -323,17 +323,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
||||
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
|
||||
def testNonzero(self, shape, dtype):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*"):
|
||||
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
|
||||
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, fill_value=fill_value)
|
||||
for shape in nonempty_array_shapes
|
||||
for shape in nonempty_nonscalar_array_shapes
|
||||
for fill_value in [None, -1, shape or (1,)]
|
||||
],
|
||||
dtype=all_dtypes,
|
||||
@ -351,17 +349,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
|
||||
for fval, arg in safe_zip(fillvals, result))
|
||||
jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value)
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*"):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
||||
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
|
||||
def testFlatNonzero(self, shape, dtype):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
np_fun = jtu.ignore_warning(
|
||||
category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*")(np.flatnonzero)
|
||||
np_fun = np.flatnonzero
|
||||
jnp_fun = jnp.flatnonzero
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
@ -371,7 +365,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=nonempty_array_shapes,
|
||||
shape=nonempty_nonscalar_array_shapes,
|
||||
dtype=all_dtypes,
|
||||
fill_value=[None, -1, 10, (-1,), (10,)],
|
||||
size=[1, 5, 10],
|
||||
@ -379,7 +373,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testFlatNonzeroSize(self, shape, dtype, size, fill_value):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
|
||||
def np_fun(x):
|
||||
result = np.flatnonzero(x)
|
||||
if size <= len(result):
|
||||
@ -391,24 +384,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
||||
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
|
||||
def testArgWhere(self, shape, dtype):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*"):
|
||||
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
|
||||
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
|
||||
|
||||
# JIT compilation requires specifying a size statically. Full test of this
|
||||
# behavior is in testNonzeroSize().
|
||||
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*"):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, fill_value=fill_value)
|
||||
for shape in nonempty_array_shapes
|
||||
for shape in nonempty_nonscalar_array_shapes
|
||||
for fill_value in [None, -1, shape or (1,)]
|
||||
],
|
||||
dtype=all_dtypes,
|
||||
@ -427,10 +416,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
for fval, arg in safe_zip(fillvals, result.T)]).T
|
||||
jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value)
|
||||
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*"):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name),
|
||||
@ -4490,24 +4477,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
shape=nonzerodim_shapes,
|
||||
dtype=all_dtypes,
|
||||
)
|
||||
def testWhereOneArgument(self, shape, dtype):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*"):
|
||||
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
|
||||
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
|
||||
|
||||
# JIT compilation requires specifying a size statically. Full test of
|
||||
# this behavior is in testNonzeroSize().
|
||||
jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2)
|
||||
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*"):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shapes=filter(_shapes_are_broadcast_compatible,
|
||||
|
Loading…
x
Reference in New Issue
Block a user