Merge pull request #17792 from jakevdp:mean-cast-f16

PiperOrigin-RevId: 569019549
This commit is contained in:
jax authors 2023-09-27 18:30:06 -07:00
commit 59360794c1
2 changed files with 24 additions and 7 deletions

View File

@ -320,16 +320,23 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
out: None = None, keepdims: bool = False, *,
upcast_f16_for_computation: bool = True,
where: Optional[ArrayLike] = None) -> Array:
check_arraylike("mean", a)
if dtype is None:
dtype = dtypes.to_inexact_dtype(dtypes.dtype(a))
else:
dtypes.check_user_dtype_supported(dtype, "mean")
dtype = dtypes.canonicalize_dtype(dtype)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
if dtype is None:
result_dtype = dtypes.to_inexact_dtype(dtypes.dtype(a, canonicalize=True))
else:
dtypes.check_user_dtype_supported(dtype, "mean")
result_dtype = dtypes.canonicalize_dtype(dtype)
if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact):
computation_dtype = _upcast_f16(result_dtype)
else:
computation_dtype = result_dtype
if where is None:
if axis is None:
normalizer = core.dimension_as_value(np.size(a))
@ -339,8 +346,9 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims)
return lax.div(
sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
lax.convert_element_type(normalizer, dtype))
sum(a, axis, dtype=computation_dtype, keepdims=keepdims, where=where),
lax.convert_element_type(normalizer, computation_dtype)
).astype(result_dtype)
@overload
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,

View File

@ -758,6 +758,15 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self.assertEqual(0.0, jnp.std(x))
self.assertEqual(0.0, jnp.std(x, where=True))
@jtu.sample_product(
dtype=[np.dtype(np.float16), np.dtype(dtypes.bfloat16)],
)
def test_f16_mean(self, dtype):
x = np.full(100_000, 1E-5, dtype=dtype)
expected = np.mean(x.astype('float64')).astype(dtype)
actual = jnp.mean(x)
self.assertAllClose(expected, actual, atol=0)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())