mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #21262 from vfdev-5:depr-change-ddof-to-correction-21088
PiperOrigin-RevId: 636949170
This commit is contained in:
commit
bab7f40dec
@ -433,13 +433,17 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
|
||||
@implements(np.var, skip_params=['out'])
|
||||
def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, ddof: int = 0, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
|
||||
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
|
||||
if correction is None:
|
||||
correction = ddof
|
||||
elif not isinstance(ddof, int) or ddof != 0:
|
||||
raise ValueError("ddof and correction can't be provided simultaneously.")
|
||||
return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
|
||||
where=where)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, ddof: int = 0, keepdims: bool = False, *,
|
||||
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
check_arraylike("var", a)
|
||||
dtypes.check_user_dtype_supported(dtype, "var")
|
||||
@ -465,7 +469,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
else:
|
||||
normalizer = sum(_broadcast_to(where, np.shape(a)), axis,
|
||||
dtype=computation_dtype, keepdims=keepdims)
|
||||
normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype))
|
||||
normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype))
|
||||
result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where)
|
||||
return _where(normalizer > 0, lax.div(result, normalizer).astype(dtype), np.nan)
|
||||
|
||||
@ -494,13 +498,17 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy
|
||||
@implements(np.std, skip_params=['out'])
|
||||
def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, ddof: int = 0, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
|
||||
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
|
||||
if correction is None:
|
||||
correction = ddof
|
||||
elif not isinstance(ddof, int) or ddof != 0:
|
||||
raise ValueError("ddof and correction can't be provided simultaneously.")
|
||||
return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
|
||||
where=where)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
|
||||
def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, ddof: int = 0, keepdims: bool = False, *,
|
||||
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
check_arraylike("std", a)
|
||||
dtypes.check_user_dtype_supported(dtype, "std")
|
||||
@ -508,7 +516,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}")
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
|
||||
return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
|
||||
return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where))
|
||||
|
||||
|
||||
@implements(np.ptp, skip_params=['out'])
|
||||
|
@ -14,13 +14,12 @@
|
||||
|
||||
import jax
|
||||
|
||||
# TODO(micky774): Remove after deprecating ddof-->correction in jnp.std and
|
||||
# jnp.var
|
||||
|
||||
def std(x, /, *, axis=None, correction=0.0, keepdims=False):
|
||||
"""Calculates the standard deviation of the input array x."""
|
||||
return jax.numpy.std(x, axis=axis, ddof=correction, keepdims=keepdims)
|
||||
return jax.numpy.std(x, axis=axis, correction=correction, keepdims=keepdims)
|
||||
|
||||
|
||||
def var(x, /, *, axis=None, correction=0.0, keepdims=False):
|
||||
"""Calculates the variance of the input array x."""
|
||||
return jax.numpy.var(x, axis=axis, ddof=correction, keepdims=keepdims)
|
||||
return jax.numpy.var(x, axis=axis, correction=correction, keepdims=keepdims)
|
||||
|
@ -783,7 +783,7 @@ def stack(
|
||||
) -> Array: ...
|
||||
def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
|
||||
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
|
||||
where: Optional[ArrayLike] = ...) -> Array: ...
|
||||
where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ...
|
||||
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def sum(
|
||||
a: ArrayLike,
|
||||
@ -894,7 +894,7 @@ def vander(
|
||||
) -> Array: ...
|
||||
def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
|
||||
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
|
||||
where: Optional[ArrayLike] = ...) -> Array: ...
|
||||
where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ...
|
||||
def vdot(
|
||||
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ...,
|
||||
preferred_element_type: Optional[DTypeLike] = ...) -> Array: ...
|
||||
|
@ -540,31 +540,42 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
test_fns=[(np.var, jnp.var), (np.std, jnp.std)],
|
||||
shape=[(5,), (10, 5)],
|
||||
dtype=all_dtypes,
|
||||
out_dtype=inexact_dtypes,
|
||||
axis=[None, 0, -1],
|
||||
ddof=[0, 1, 2],
|
||||
ddof_correction=[(0, None), (1, None), (1, 0), (0, 0), (0, 1), (0, 2)],
|
||||
keepdims=[False, True],
|
||||
)
|
||||
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
|
||||
def testStdOrVar(self, test_fns, shape, dtype, out_dtype, axis, ddof_correction, keepdims):
|
||||
np_fn, jnp_fn = test_fns
|
||||
ddof, correction = ddof_correction
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.")
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
def np_fun(x):
|
||||
# setup ddof and correction kwargs excluding case when correction is not specified
|
||||
ddof_correction_kwargs = {"ddof": ddof}
|
||||
if correction is not None:
|
||||
key = "correction" if numpy_version >= (2, 0) else "ddof"
|
||||
ddof_correction_kwargs[key] = correction
|
||||
# Numpy fails with bfloat16 inputs
|
||||
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
||||
out = np_fn(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
||||
dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype,
|
||||
axis=axis, ddof=ddof, keepdims=keepdims)
|
||||
axis=axis, keepdims=keepdims, **ddof_correction_kwargs)
|
||||
return out.astype(out_dtype)
|
||||
jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
|
||||
jnp_fun = partial(jnp_fn, dtype=out_dtype, axis=axis, ddof=ddof, correction=correction,
|
||||
keepdims=keepdims)
|
||||
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
|
||||
np.float64: 1e-3, np.complex128: 1e-6})
|
||||
if (jnp.issubdtype(dtype, jnp.complexfloating) and
|
||||
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
|
||||
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
|
||||
self.assertRaises(ValueError, jnp_fun, *args_maker())
|
||||
elif (correction is not None and ddof != 0):
|
||||
self.assertRaises(ValueError, jnp_fun, *args_maker())
|
||||
else:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
tol=tol)
|
||||
|
@ -5960,9 +5960,9 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'reshape': ['shape', 'copy'],
|
||||
'row_stack': ['casting'],
|
||||
'stack': ['casting'],
|
||||
'std': ['correction', 'mean'],
|
||||
'std': ['mean'],
|
||||
'tri': ['like'],
|
||||
'var': ['correction', 'mean'],
|
||||
'var': ['mean'],
|
||||
'vstack': ['casting'],
|
||||
'zeros_like': ['subok', 'order']
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user