mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #25290 from jakevdp:reduction-where
PiperOrigin-RevId: 703182008
This commit is contained in:
commit
f73fa7a7ad
@ -129,5 +129,6 @@ register('jax-numpy-clip-args')
|
||||
register('jax-numpy-linalg-matrix_rank-tol')
|
||||
register('jax-numpy-linalg-pinv-rcond')
|
||||
register('jax-numpy-quantile-interpolation')
|
||||
register('jax-numpy-reduction-non-boolean-where')
|
||||
register('jax-numpy-trimzeros-not-1d-array')
|
||||
register('pallas-gpu-triton')
|
||||
|
@ -81,6 +81,20 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
|
||||
return dtypes.int_
|
||||
return dtype
|
||||
|
||||
def check_where(name: str, where: ArrayLike | None) -> Array | None:
|
||||
if where is None:
|
||||
return where
|
||||
check_arraylike(name, where)
|
||||
where_arr = lax_internal.asarray(where)
|
||||
if where_arr.dtype != bool:
|
||||
# Deprecation added 2024-12-05
|
||||
deprecations.warn(
|
||||
'jax-numpy-reduction-non-boolean-where',
|
||||
f"jnp.{name}: where must be None or a boolean array; got dtype={where_arr.dtype}.",
|
||||
stacklevel=2)
|
||||
return where_arr.astype(bool)
|
||||
return where_arr
|
||||
|
||||
|
||||
ReductionOp = Callable[[Any, Any], Any]
|
||||
|
||||
@ -101,6 +115,7 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
|
||||
check_arraylike(name, a)
|
||||
where_ = check_where(name, where_)
|
||||
dtypes.check_user_dtype_supported(dtype, name)
|
||||
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
|
||||
|
||||
@ -730,6 +745,8 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.")
|
||||
dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce")
|
||||
check_arraylike("logsumexp", a)
|
||||
where = check_where("logsumexp", where)
|
||||
a_arr, = promote_dtypes_inexact(a)
|
||||
pos_dims, dims = _reduction_dims(a_arr, axis)
|
||||
amax = max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf)
|
||||
@ -748,6 +765,8 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.")
|
||||
dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce")
|
||||
check_arraylike("logsumexp2", a)
|
||||
where = check_where("logsumexp2", where)
|
||||
ln2 = float(np.log(2))
|
||||
if initial is not None:
|
||||
initial *= ln2
|
||||
@ -850,6 +869,7 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
upcast_f16_for_computation: bool = True,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
check_arraylike("mean", a)
|
||||
where = check_where("mean", where)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
|
||||
|
||||
@ -1087,6 +1107,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
check_arraylike("var", a)
|
||||
where = check_where("var", where)
|
||||
dtypes.check_user_dtype_supported(dtype, "var")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
|
||||
@ -1224,6 +1245,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
check_arraylike("std", a)
|
||||
where = check_where("std", where)
|
||||
dtypes.check_user_dtype_supported(dtype, "std")
|
||||
if dtype is not None and not dtypes.issubdtype(dtype, np.inexact):
|
||||
raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}")
|
||||
@ -1330,13 +1352,15 @@ def count_nonzero(a: ArrayLike, axis: Axis = None,
|
||||
|
||||
def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
|
||||
init_val: ArrayLike, nan_if_all_nan: bool,
|
||||
axis: Axis = None, keepdims: bool = False, **kwargs) -> Array:
|
||||
axis: Axis = None, keepdims: bool = False, where: ArrayLike | None = None,
|
||||
**kwargs) -> Array:
|
||||
check_arraylike(name, a)
|
||||
where = check_where(name, where)
|
||||
if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
|
||||
return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)
|
||||
return jnp_reduction(a, axis=axis, keepdims=keepdims, where=where, **kwargs)
|
||||
|
||||
out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a),
|
||||
axis=axis, keepdims=keepdims, **kwargs)
|
||||
axis=axis, keepdims=keepdims, where=where, **kwargs)
|
||||
if nan_if_all_nan:
|
||||
return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims),
|
||||
_lax_const(a, np.nan), out)
|
||||
@ -1755,6 +1779,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
|
||||
Array([[nan, nan, nan, nan]], dtype=float32)
|
||||
"""
|
||||
check_arraylike("nanmean", a)
|
||||
where = check_where("nanmean", where)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
|
||||
if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer):
|
||||
@ -1848,6 +1873,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
|
||||
[4. ]], dtype=float32)
|
||||
"""
|
||||
check_arraylike("nanvar", a)
|
||||
where = check_where("nanvar", where)
|
||||
dtypes.check_user_dtype_supported(dtype, "nanvar")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
|
||||
@ -1943,6 +1969,7 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
|
||||
Array([[0.5, 0.5, 0. , 0. ]], dtype=float32)
|
||||
"""
|
||||
check_arraylike("nanstd", a)
|
||||
where = check_where("nanstd", where)
|
||||
dtypes.check_user_dtype_supported(dtype, "nanstd")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")
|
||||
|
@ -448,6 +448,34 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(rec=JAX_REDUCER_INITIAL_RECORDS)
|
||||
def testReducerWhereNonBooleanErrorInitial(self, rec):
|
||||
dtype = rec.dtypes[0]
|
||||
x = jnp.zeros((10,), dtype)
|
||||
where = jnp.ones(10, dtype=int)
|
||||
func = getattr(jnp, rec.name)
|
||||
def assert_warns_or_errors(msg):
|
||||
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
|
||||
return self.assertRaisesRegex(ValueError, msg)
|
||||
else:
|
||||
return self.assertWarnsRegex(DeprecationWarning, msg)
|
||||
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
|
||||
func(x, where=where, initial=jnp.array(0, dtype=dtype))
|
||||
|
||||
@jtu.sample_product(rec=JAX_REDUCER_WHERE_NO_INITIAL_RECORDS)
|
||||
def testReducerWhereNonBooleanErrorNoInitial(self, rec):
|
||||
dtype = rec.dtypes[0]
|
||||
x = jnp.zeros((10,), dtype)
|
||||
where = jnp.ones(10, dtype=int)
|
||||
func = getattr(jnp, rec.name)
|
||||
def assert_warns_or_errors(msg):
|
||||
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
|
||||
return self.assertRaisesRegex(ValueError, msg)
|
||||
else:
|
||||
return self.assertWarnsRegex(DeprecationWarning, msg)
|
||||
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
|
||||
func(x, where=where)
|
||||
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,
|
||||
|
Loading…
x
Reference in New Issue
Block a user