Merge pull request #21262 from vfdev-5:depr-change-ddof-to-correction-21088

PiperOrigin-RevId: 636949170
This commit is contained in:
jax authors 2024-05-24 09:47:27 -07:00
commit bab7f40dec
5 changed files with 40 additions and 22 deletions

View File

@ -433,13 +433,17 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
@implements(np.var, skip_params=['out'])
def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
if correction is None:
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.")
return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("var", a)
dtypes.check_user_dtype_supported(dtype, "var")
@ -465,7 +469,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
else:
normalizer = sum(_broadcast_to(where, np.shape(a)), axis,
dtype=computation_dtype, keepdims=keepdims)
normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype))
normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype))
result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where)
return _where(normalizer > 0, lax.div(result, normalizer).astype(dtype), np.nan)
@ -494,13 +498,17 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy
@implements(np.std, skip_params=['out'])
def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
if correction is None:
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.")
return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("std", a)
dtypes.check_user_dtype_supported(dtype, "std")
@ -508,7 +516,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where))
@implements(np.ptp, skip_params=['out'])

View File

@ -14,13 +14,12 @@
import jax
# TODO(micky774): Remove after deprecating ddof-->correction in jnp.std and
# jnp.var
def std(x, /, *, axis=None, correction=0.0, keepdims=False):
"""Calculates the standard deviation of the input array x."""
return jax.numpy.std(x, axis=axis, ddof=correction, keepdims=keepdims)
return jax.numpy.std(x, axis=axis, correction=correction, keepdims=keepdims)
def var(x, /, *, axis=None, correction=0.0, keepdims=False):
"""Calculates the variance of the input array x."""
return jax.numpy.var(x, axis=axis, ddof=correction, keepdims=keepdims)
return jax.numpy.var(x, axis=axis, correction=correction, keepdims=keepdims)

View File

@ -783,7 +783,7 @@ def stack(
) -> Array: ...
def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
where: Optional[ArrayLike] = ...) -> Array: ...
where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ...
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def sum(
a: ArrayLike,
@ -894,7 +894,7 @@ def vander(
) -> Array: ...
def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
where: Optional[ArrayLike] = ...) -> Array: ...
where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ...
def vdot(
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ...,
preferred_element_type: Optional[DTypeLike] = ...) -> Array: ...

View File

@ -540,31 +540,42 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
rtol=tol, atol=tol)
@jtu.sample_product(
test_fns=[(np.var, jnp.var), (np.std, jnp.std)],
shape=[(5,), (10, 5)],
dtype=all_dtypes,
out_dtype=inexact_dtypes,
axis=[None, 0, -1],
ddof=[0, 1, 2],
ddof_correction=[(0, None), (1, None), (1, 0), (0, 0), (0, 1), (0, 2)],
keepdims=[False, True],
)
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
def testStdOrVar(self, test_fns, shape, dtype, out_dtype, axis, ddof_correction, keepdims):
np_fn, jnp_fn = test_fns
ddof, correction = ddof_correction
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.")
@jtu.ignore_warning(category=NumpyComplexWarning)
def np_fun(x):
# setup ddof and correction kwargs excluding case when correction is not specified
ddof_correction_kwargs = {"ddof": ddof}
if correction is not None:
key = "correction" if numpy_version >= (2, 0) else "ddof"
ddof_correction_kwargs[key] = correction
# Numpy fails with bfloat16 inputs
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
out = np_fn(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype,
axis=axis, ddof=ddof, keepdims=keepdims)
axis=axis, keepdims=keepdims, **ddof_correction_kwargs)
return out.astype(out_dtype)
jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
jnp_fun = partial(jnp_fn, dtype=out_dtype, axis=axis, ddof=ddof, correction=correction,
keepdims=keepdims)
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
np.float64: 1e-3, np.complex128: 1e-6})
if (jnp.issubdtype(dtype, jnp.complexfloating) and
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
self.assertRaises(ValueError, jnp_fun, *args_maker())
elif (correction is not None and ddof != 0):
self.assertRaises(ValueError, jnp_fun, *args_maker())
else:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)

View File

@ -5960,9 +5960,9 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'reshape': ['shape', 'copy'],
'row_stack': ['casting'],
'stack': ['casting'],
'std': ['correction', 'mean'],
'std': ['mean'],
'tri': ['like'],
'var': ['correction', 'mean'],
'var': ['mean'],
'vstack': ['casting'],
'zeros_like': ['subok', 'order']
}