mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Finalize deprecation of zero-dimensional inputs to jnp.nonzero
PiperOrigin-RevId: 626299531
This commit is contained in:
parent
837f0bbf6f
commit
41fa67c2dc
@ -46,6 +46,8 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
|
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
|
||||||
passing complex-valued inputs to it. This will raise an error when the
|
passing complex-valued inputs to it. This will raise an error when the
|
||||||
deprecation is completed.
|
deprecation is completed.
|
||||||
|
* Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
|
||||||
|
related functions now raise an error, following a similar change in NumPy.
|
||||||
|
|
||||||
## jaxlib 0.4.27
|
## jaxlib 0.4.27
|
||||||
|
|
||||||
|
@ -1454,10 +1454,8 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
|
|||||||
arr = asarray(a)
|
arr = asarray(a)
|
||||||
del a
|
del a
|
||||||
if ndim(arr) == 0:
|
if ndim(arr) == 0:
|
||||||
# Added 2023 Dec 6
|
raise ValueError("Calling nonzero on 0d arrays is not allowed. "
|
||||||
warnings.warn("Calling nonzero on 0d arrays is deprecated. Use `atleast_1d(arr).nonzero()",
|
"Use jnp.atleast_1d(scalar).nonzero() instead.")
|
||||||
DeprecationWarning, stacklevel=2)
|
|
||||||
arr = atleast_1d(arr)
|
|
||||||
mask = arr if arr.dtype == bool else (arr != 0)
|
mask = arr if arr.dtype == bool else (arr != 0)
|
||||||
calculated_size = mask.sum() if size is None else size
|
calculated_size = mask.sum() if size is None else size
|
||||||
calculated_size = core.concrete_dim_or_error(calculated_size,
|
calculated_size = core.concrete_dim_or_error(calculated_size,
|
||||||
|
@ -323,17 +323,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
|
||||||
def testNonzero(self, shape, dtype):
|
def testNonzero(self, shape, dtype):
|
||||||
rng = jtu.rand_some_zero(self.rng())
|
rng = jtu.rand_some_zero(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
with jtu.ignore_warning(category=DeprecationWarning,
|
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
|
||||||
message="Calling nonzero on 0d arrays.*"):
|
|
||||||
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
|
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
[dict(shape=shape, fill_value=fill_value)
|
[dict(shape=shape, fill_value=fill_value)
|
||||||
for shape in nonempty_array_shapes
|
for shape in nonempty_nonscalar_array_shapes
|
||||||
for fill_value in [None, -1, shape or (1,)]
|
for fill_value in [None, -1, shape or (1,)]
|
||||||
],
|
],
|
||||||
dtype=all_dtypes,
|
dtype=all_dtypes,
|
||||||
@ -351,17 +349,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
|
return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
|
||||||
for fval, arg in safe_zip(fillvals, result))
|
for fval, arg in safe_zip(fillvals, result))
|
||||||
jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value)
|
jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value)
|
||||||
with jtu.ignore_warning(category=DeprecationWarning,
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||||
message="Calling nonzero on 0d arrays.*"):
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
|
||||||
|
|
||||||
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
|
||||||
def testFlatNonzero(self, shape, dtype):
|
def testFlatNonzero(self, shape, dtype):
|
||||||
rng = jtu.rand_some_zero(self.rng())
|
rng = jtu.rand_some_zero(self.rng())
|
||||||
np_fun = jtu.ignore_warning(
|
np_fun = np.flatnonzero
|
||||||
category=DeprecationWarning,
|
|
||||||
message="Calling nonzero on 0d arrays.*")(np.flatnonzero)
|
|
||||||
jnp_fun = jnp.flatnonzero
|
jnp_fun = jnp.flatnonzero
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||||
@ -371,7 +365,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
shape=nonempty_array_shapes,
|
shape=nonempty_nonscalar_array_shapes,
|
||||||
dtype=all_dtypes,
|
dtype=all_dtypes,
|
||||||
fill_value=[None, -1, 10, (-1,), (10,)],
|
fill_value=[None, -1, 10, (-1,), (10,)],
|
||||||
size=[1, 5, 10],
|
size=[1, 5, 10],
|
||||||
@ -379,7 +373,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
def testFlatNonzeroSize(self, shape, dtype, size, fill_value):
|
def testFlatNonzeroSize(self, shape, dtype, size, fill_value):
|
||||||
rng = jtu.rand_some_zero(self.rng())
|
rng = jtu.rand_some_zero(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
|
|
||||||
def np_fun(x):
|
def np_fun(x):
|
||||||
result = np.flatnonzero(x)
|
result = np.flatnonzero(x)
|
||||||
if size <= len(result):
|
if size <= len(result):
|
||||||
@ -391,24 +384,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
|
||||||
def testArgWhere(self, shape, dtype):
|
def testArgWhere(self, shape, dtype):
|
||||||
rng = jtu.rand_some_zero(self.rng())
|
rng = jtu.rand_some_zero(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
with jtu.ignore_warning(category=DeprecationWarning,
|
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
|
||||||
message="Calling nonzero on 0d arrays.*"):
|
|
||||||
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
|
|
||||||
|
|
||||||
# JIT compilation requires specifying a size statically. Full test of this
|
# JIT compilation requires specifying a size statically. Full test of this
|
||||||
# behavior is in testNonzeroSize().
|
# behavior is in testNonzeroSize().
|
||||||
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
|
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
|
||||||
with jtu.ignore_warning(category=DeprecationWarning,
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
message="Calling nonzero on 0d arrays.*"):
|
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
[dict(shape=shape, fill_value=fill_value)
|
[dict(shape=shape, fill_value=fill_value)
|
||||||
for shape in nonempty_array_shapes
|
for shape in nonempty_nonscalar_array_shapes
|
||||||
for fill_value in [None, -1, shape or (1,)]
|
for fill_value in [None, -1, shape or (1,)]
|
||||||
],
|
],
|
||||||
dtype=all_dtypes,
|
dtype=all_dtypes,
|
||||||
@ -427,10 +416,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
for fval, arg in safe_zip(fillvals, result.T)]).T
|
for fval, arg in safe_zip(fillvals, result.T)]).T
|
||||||
jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value)
|
jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value)
|
||||||
|
|
||||||
with jtu.ignore_warning(category=DeprecationWarning,
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||||
message="Calling nonzero on 0d arrays.*"):
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
[dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name),
|
[dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name),
|
||||||
@ -4490,24 +4477,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
shape=all_shapes,
|
shape=nonzerodim_shapes,
|
||||||
dtype=all_dtypes,
|
dtype=all_dtypes,
|
||||||
)
|
)
|
||||||
def testWhereOneArgument(self, shape, dtype):
|
def testWhereOneArgument(self, shape, dtype):
|
||||||
rng = jtu.rand_some_zero(self.rng())
|
rng = jtu.rand_some_zero(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
|
|
||||||
with jtu.ignore_warning(category=DeprecationWarning,
|
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
|
||||||
message="Calling nonzero on 0d arrays.*"):
|
|
||||||
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
|
|
||||||
|
|
||||||
# JIT compilation requires specifying a size statically. Full test of
|
# JIT compilation requires specifying a size statically. Full test of
|
||||||
# this behavior is in testNonzeroSize().
|
# this behavior is in testNonzeroSize().
|
||||||
jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2)
|
jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2)
|
||||||
|
|
||||||
with jtu.ignore_warning(category=DeprecationWarning,
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
message="Calling nonzero on 0d arrays.*"):
|
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
shapes=filter(_shapes_are_broadcast_compatible,
|
shapes=filter(_shapes_are_broadcast_compatible,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user