Merge pull request #11259 from jakevdp:average-keepdims

PiperOrigin-RevId: 457777845
This commit is contained in:
jax authors 2022-06-28 11:20:38 -07:00
commit 6835dc18e3
2 changed files with 28 additions and 21 deletions

View File

@ -286,16 +286,16 @@ def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
@_wraps(np.average)
def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
returned=False):
return _average(a, _ensure_optional_axes(axis), weights, returned)
returned=False, keepdims=False):
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
@partial(api.jit, static_argnames=('axis', 'returned'), inline=True)
@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True)
def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
returned=False):
returned=False, keepdims=False):
if weights is None: # Treat all weights as 1
_check_arraylike("average", a)
a, = _promote_dtypes_inexact(a)
avg = mean(a, axis=axis)
avg = mean(a, axis=axis, keepdims=keepdims)
if axis is None:
weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype)
else:
@ -324,8 +324,8 @@ def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None
weights = _broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape)
weights = _moveaxis(weights, -1, axis)
weights_sum = sum(weights, axis=axis)
avg = sum(a * weights, axis=axis) / weights_sum
weights_sum = sum(weights, axis=axis, keepdims=keepdims)
avg = sum(a * weights, axis=axis, keepdims=keepdims) / weights_sum
if returned:
if avg.shape != weights_sum.shape:

View File

@ -3829,39 +3829,47 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format(
{"testcase_name": "_shape={}_axis={}_weights={}_returned={}_keepdims={}".format(
jtu.format_shape_dtype_string(shape, dtype),
axis,
(None if weights_shape is None else jtu.format_shape_dtype_string(weights_shape, dtype)),
returned),
returned, keepdims),
"shape": shape, "dtype": dtype, "axis": axis,
"weights_shape": weights_shape, "returned": returned}
"weights_shape": weights_shape, "returned": returned, "keepdims": keepdims}
for shape, dtype in _shape_and_dtypes(nonempty_shapes, number_dtypes)
for axis in list(range(-len(shape), len(shape))) + [None]
# `weights_shape` is either `None`, same as the averaged axis, or same as
# that of the input
for weights_shape in ([None, shape] if axis is None or len(shape) == 1
else [None, (shape[axis],), shape])
for keepdims in ([False, True] if numpy_version >= (1, 23) else [None])
for returned in [False, True]))
def testAverage(self, shape, dtype, axis, weights_shape, returned):
def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims):
rng = jtu.rand_default(self.rng())
kwds = dict(returned=returned)
if keepdims is not None:
kwds['keepdims'] = keepdims
if weights_shape is None:
np_fun = lambda x: np.average(x, axis, returned=returned)
jnp_fun = lambda x: jnp.average(x, axis, returned=returned)
np_fun = lambda x: np.average(x, axis, **kwds)
jnp_fun = lambda x: jnp.average(x, axis, **kwds)
args_maker = lambda: [rng(shape, dtype)]
else:
np_fun = lambda x, weights: np.average(x, axis, weights, returned)
jnp_fun = lambda x, weights: jnp.average(x, axis, weights, returned)
np_fun = lambda x, weights: np.average(x, axis, weights, **kwds)
jnp_fun = lambda x, weights: jnp.average(x, axis, weights, **kwds)
args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)]
np_fun = _promote_like_jnp(np_fun, inexact=True)
tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5,
np.float64: 1e-12, np.complex64: 1e-5}
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE and numpy_version >= (1, 22)
try:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=check_dtypes, tol=tol)
except ZeroDivisionError:
self.skipTest("don't support checking for ZeroDivisionError")
if numpy_version == (1, 23, 0) and keepdims and weights_shape is not None and axis is not None:
# Known failure: https://github.com/numpy/numpy/issues/21850
pass
else:
try:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=check_dtypes, tol=tol)
except ZeroDivisionError:
self.skipTest("don't support checking for ZeroDivisionError")
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
rtol=tol, atol=tol)
@ -6339,7 +6347,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
unsupported_params = {
'asarray': ['like'],
'average': ['keepdims'],
'broadcast_to': ['subok', 'array'],
'clip': ['kwargs'],
'copy': ['subok'],