Merge pull request #15123 from jakevdp:fix-mean-large-dims

PiperOrigin-RevId: 518852476
This commit is contained in:
jax authors 2023-03-23 07:23:20 -07:00
commit 54e8101f00
2 changed files with 12 additions and 5 deletions

View File

@ -327,7 +327,11 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, keepdims: bool = False, *,
where: Optional[ArrayLike] = None) -> Array:
check_arraylike("mean", a)
dtypes.check_user_dtype_supported(dtype, "mean")
if dtype is None:
dtype = dtypes.to_inexact_dtype(dtypes.dtype(a))
else:
dtypes.check_user_dtype_supported(dtype, "mean")
dtype = dtypes.canonicalize_dtype(dtype)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
@ -339,10 +343,6 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
else:
normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims)
if dtype is None:
dtype = dtypes.to_inexact_dtype(dtypes.dtype(a))
dtype = dtypes.canonicalize_dtype(dtype)
return lax.div(
sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
lax.convert_element_type(normalizer, dtype))

View File

@ -5075,6 +5075,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
actual = jnp.fromstring(s, sep=',', dtype=int)
self.assertArraysEqual(expected, actual)
def testMeanLargeArray(self):
# https://github.com/google/jax/issues/15068
raise unittest.SkipTest("test is slow, but it passes!")
x = jnp.ones((16, 32, 1280, 4096), dtype='int8')
self.assertEqual(1.0, jnp.mean(x))
self.assertEqual(1.0, jnp.mean(x, where=True))
# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.