jnp.nonzero: deprecate zero-dimensional inputs

This commit is contained in:
Jake VanderPlas 2023-12-06 12:57:25 -08:00
parent fe6e195a45
commit 51960048f0
2 changed files with 32 additions and 26 deletions

View File

@ -1404,8 +1404,13 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> tuple[Array, ...]:
util.check_arraylike("nonzero", a)
arr = atleast_1d(a)
arr = asarray(a)
del a
if ndim(arr) == 0:
# Added 2023 Dec 6
warnings.warn("Calling nonzero on 0d arrays is deprecated. Use `atleast_1d(arr).nonzero()",
DeprecationWarning, stacklevel=2)
arr = atleast_1d(arr)
mask = arr if arr.dtype == bool else (arr != 0)
calculated_size = mask.sum() if size is None else size
calculated_size = core.concrete_dim_or_error(calculated_size,

View File

@ -280,13 +280,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
def testNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = lambda x: np.nonzero(x)
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np_fun)
jnp_fun = lambda x: jnp.nonzero(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
@ -299,7 +296,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testNonzeroSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
def np_fun(x):
result = np.nonzero(x)
if size <= len(result[0]):
@ -309,8 +305,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
for fval, arg in safe_zip(fillvals, result))
jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
def testFlatNonzero(self, shape, dtype):
@ -350,17 +348,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
def testArgWhere(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np.argwhere)
jnp_fun = jnp.argwhere
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
# JIT compilation requires specifying a size statically. Full test of this
# behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
@ -373,7 +371,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testArgWhereSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
def np_fun(x):
result = np.argwhere(x)
if size <= len(result):
@ -383,8 +380,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
return np.empty((size, 0), dtype=int) if np.ndim(x) == 0 else np.stack([np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
for fval, arg in safe_zip(fillvals, result.T)]).T
jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name),
@ -4086,18 +4086,19 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
)
def testWhereOneArgument(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = lambda x: np.where(x)
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np_fun)
jnp_fun = lambda x: jnp.where(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
# JIT compilation requires specifying a size statically. Full test of
# this behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
shapes=filter(_shapes_are_broadcast_compatible,