Merge pull request #25290 from jakevdp:reduction-where

PiperOrigin-RevId: 703182008
This commit is contained in:
jax authors 2024-12-05 11:17:15 -08:00
commit f73fa7a7ad
3 changed files with 59 additions and 3 deletions

View File

@ -129,5 +129,6 @@ register('jax-numpy-clip-args')
register('jax-numpy-linalg-matrix_rank-tol')
register('jax-numpy-linalg-pinv-rcond')
register('jax-numpy-quantile-interpolation')
register('jax-numpy-reduction-non-boolean-where')
register('jax-numpy-trimzeros-not-1d-array')
register('pallas-gpu-triton')

View File

@ -81,6 +81,20 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
return dtypes.int_
return dtype
def check_where(name: str, where: ArrayLike | None) -> Array | None:
if where is None:
return where
check_arraylike(name, where)
where_arr = lax_internal.asarray(where)
if where_arr.dtype != bool:
# Deprecation added 2024-12-05
deprecations.warn(
'jax-numpy-reduction-non-boolean-where',
f"jnp.{name}: where must be None or a boolean array; got dtype={where_arr.dtype}.",
stacklevel=2)
return where_arr.astype(bool)
return where_arr
ReductionOp = Callable[[Any, Any], Any]
@ -101,6 +115,7 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
check_arraylike(name, a)
where_ = check_where(name, where_)
dtypes.check_user_dtype_supported(dtype, name)
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
@ -730,6 +745,8 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.")
dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce")
check_arraylike("logsumexp", a)
where = check_where("logsumexp", where)
a_arr, = promote_dtypes_inexact(a)
pos_dims, dims = _reduction_dims(a_arr, axis)
amax = max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf)
@ -748,6 +765,8 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.")
dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce")
check_arraylike("logsumexp2", a)
where = check_where("logsumexp2", where)
ln2 = float(np.log(2))
if initial is not None:
initial *= ln2
@ -850,6 +869,7 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
upcast_f16_for_computation: bool = True,
where: ArrayLike | None = None) -> Array:
check_arraylike("mean", a)
where = check_where("mean", where)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
@ -1087,6 +1107,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("var", a)
where = check_where("var", where)
dtypes.check_user_dtype_supported(dtype, "var")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
@ -1224,6 +1245,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("std", a)
where = check_where("std", where)
dtypes.check_user_dtype_supported(dtype, "std")
if dtype is not None and not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}")
@ -1330,13 +1352,15 @@ def count_nonzero(a: ArrayLike, axis: Axis = None,
def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
init_val: ArrayLike, nan_if_all_nan: bool,
axis: Axis = None, keepdims: bool = False, **kwargs) -> Array:
axis: Axis = None, keepdims: bool = False, where: ArrayLike | None = None,
**kwargs) -> Array:
check_arraylike(name, a)
where = check_where(name, where)
if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)
return jnp_reduction(a, axis=axis, keepdims=keepdims, where=where, **kwargs)
out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a),
axis=axis, keepdims=keepdims, **kwargs)
axis=axis, keepdims=keepdims, where=where, **kwargs)
if nan_if_all_nan:
return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims),
_lax_const(a, np.nan), out)
@ -1755,6 +1779,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
Array([[nan, nan, nan, nan]], dtype=float32)
"""
check_arraylike("nanmean", a)
where = check_where("nanmean", where)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer):
@ -1848,6 +1873,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
[4. ]], dtype=float32)
"""
check_arraylike("nanvar", a)
where = check_where("nanvar", where)
dtypes.check_user_dtype_supported(dtype, "nanvar")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
@ -1943,6 +1969,7 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
Array([[0.5, 0.5, 0. , 0. ]], dtype=float32)
"""
check_arraylike("nanstd", a)
where = check_where("nanstd", where)
dtypes.check_user_dtype_supported(dtype, "nanstd")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")

View File

@ -448,6 +448,34 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(rec=JAX_REDUCER_INITIAL_RECORDS)
def testReducerWhereNonBooleanErrorInitial(self, rec):
dtype = rec.dtypes[0]
x = jnp.zeros((10,), dtype)
where = jnp.ones(10, dtype=int)
func = getattr(jnp, rec.name)
def assert_warns_or_errors(msg):
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
func(x, where=where, initial=jnp.array(0, dtype=dtype))
@jtu.sample_product(rec=JAX_REDUCER_WHERE_NO_INITIAL_RECORDS)
def testReducerWhereNonBooleanErrorNoInitial(self, rec):
dtype = rec.dtypes[0]
x = jnp.zeros((10,), dtype)
where = jnp.ones(10, dtype=int)
func = getattr(jnp, rec.name)
def assert_warns_or_errors(msg):
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
func(x, where=where)
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,