mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15123 from jakevdp:fix-mean-large-dims
PiperOrigin-RevId: 518852476
This commit is contained in:
commit
54e8101f00
@ -327,7 +327,11 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
|
||||
out: None = None, keepdims: bool = False, *,
|
||||
where: Optional[ArrayLike] = None) -> Array:
|
||||
check_arraylike("mean", a)
|
||||
dtypes.check_user_dtype_supported(dtype, "mean")
|
||||
if dtype is None:
|
||||
dtype = dtypes.to_inexact_dtype(dtypes.dtype(a))
|
||||
else:
|
||||
dtypes.check_user_dtype_supported(dtype, "mean")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
|
||||
|
||||
@ -339,10 +343,6 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
|
||||
else:
|
||||
normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims)
|
||||
|
||||
if dtype is None:
|
||||
dtype = dtypes.to_inexact_dtype(dtypes.dtype(a))
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
|
||||
return lax.div(
|
||||
sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
|
||||
lax.convert_element_type(normalizer, dtype))
|
||||
|
@ -5075,6 +5075,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
actual = jnp.fromstring(s, sep=',', dtype=int)
|
||||
self.assertArraysEqual(expected, actual)
|
||||
|
||||
def testMeanLargeArray(self):
|
||||
# https://github.com/google/jax/issues/15068
|
||||
raise unittest.SkipTest("test is slow, but it passes!")
|
||||
x = jnp.ones((16, 32, 1280, 4096), dtype='int8')
|
||||
self.assertEqual(1.0, jnp.mean(x))
|
||||
self.assertEqual(1.0, jnp.mean(x, where=True))
|
||||
|
||||
|
||||
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
||||
# as needed for e.g. particular compound ops of interest.
|
||||
|
Loading…
x
Reference in New Issue
Block a user