mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #11259 from jakevdp:average-keepdims
PiperOrigin-RevId: 457777845
This commit is contained in:
commit
6835dc18e3
@ -286,16 +286,16 @@ def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
|
||||
@_wraps(np.average)
|
||||
def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
|
||||
returned=False):
|
||||
return _average(a, _ensure_optional_axes(axis), weights, returned)
|
||||
returned=False, keepdims=False):
|
||||
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'returned'), inline=True)
|
||||
@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True)
|
||||
def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
|
||||
returned=False):
|
||||
returned=False, keepdims=False):
|
||||
if weights is None: # Treat all weights as 1
|
||||
_check_arraylike("average", a)
|
||||
a, = _promote_dtypes_inexact(a)
|
||||
avg = mean(a, axis=axis)
|
||||
avg = mean(a, axis=axis, keepdims=keepdims)
|
||||
if axis is None:
|
||||
weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype)
|
||||
else:
|
||||
@ -324,8 +324,8 @@ def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None
|
||||
weights = _broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape)
|
||||
weights = _moveaxis(weights, -1, axis)
|
||||
|
||||
weights_sum = sum(weights, axis=axis)
|
||||
avg = sum(a * weights, axis=axis) / weights_sum
|
||||
weights_sum = sum(weights, axis=axis, keepdims=keepdims)
|
||||
avg = sum(a * weights, axis=axis, keepdims=keepdims) / weights_sum
|
||||
|
||||
if returned:
|
||||
if avg.shape != weights_sum.shape:
|
||||
|
@ -3829,39 +3829,47 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format(
|
||||
{"testcase_name": "_shape={}_axis={}_weights={}_returned={}_keepdims={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
axis,
|
||||
(None if weights_shape is None else jtu.format_shape_dtype_string(weights_shape, dtype)),
|
||||
returned),
|
||||
returned, keepdims),
|
||||
"shape": shape, "dtype": dtype, "axis": axis,
|
||||
"weights_shape": weights_shape, "returned": returned}
|
||||
"weights_shape": weights_shape, "returned": returned, "keepdims": keepdims}
|
||||
for shape, dtype in _shape_and_dtypes(nonempty_shapes, number_dtypes)
|
||||
for axis in list(range(-len(shape), len(shape))) + [None]
|
||||
# `weights_shape` is either `None`, same as the averaged axis, or same as
|
||||
# that of the input
|
||||
for weights_shape in ([None, shape] if axis is None or len(shape) == 1
|
||||
else [None, (shape[axis],), shape])
|
||||
for keepdims in ([False, True] if numpy_version >= (1, 23) else [None])
|
||||
for returned in [False, True]))
|
||||
def testAverage(self, shape, dtype, axis, weights_shape, returned):
|
||||
def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
kwds = dict(returned=returned)
|
||||
if keepdims is not None:
|
||||
kwds['keepdims'] = keepdims
|
||||
if weights_shape is None:
|
||||
np_fun = lambda x: np.average(x, axis, returned=returned)
|
||||
jnp_fun = lambda x: jnp.average(x, axis, returned=returned)
|
||||
np_fun = lambda x: np.average(x, axis, **kwds)
|
||||
jnp_fun = lambda x: jnp.average(x, axis, **kwds)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
else:
|
||||
np_fun = lambda x, weights: np.average(x, axis, weights, returned)
|
||||
jnp_fun = lambda x, weights: jnp.average(x, axis, weights, returned)
|
||||
np_fun = lambda x, weights: np.average(x, axis, weights, **kwds)
|
||||
jnp_fun = lambda x, weights: jnp.average(x, axis, weights, **kwds)
|
||||
args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)]
|
||||
np_fun = _promote_like_jnp(np_fun, inexact=True)
|
||||
tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5,
|
||||
np.float64: 1e-12, np.complex64: 1e-5}
|
||||
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE and numpy_version >= (1, 22)
|
||||
try:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
check_dtypes=check_dtypes, tol=tol)
|
||||
except ZeroDivisionError:
|
||||
self.skipTest("don't support checking for ZeroDivisionError")
|
||||
if numpy_version == (1, 23, 0) and keepdims and weights_shape is not None and axis is not None:
|
||||
# Known failure: https://github.com/numpy/numpy/issues/21850
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
check_dtypes=check_dtypes, tol=tol)
|
||||
except ZeroDivisionError:
|
||||
self.skipTest("don't support checking for ZeroDivisionError")
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
@ -6339,7 +6347,6 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
|
||||
unsupported_params = {
|
||||
'asarray': ['like'],
|
||||
'average': ['keepdims'],
|
||||
'broadcast_to': ['subok', 'array'],
|
||||
'clip': ['kwargs'],
|
||||
'copy': ['subok'],
|
||||
|
Loading…
x
Reference in New Issue
Block a user