mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #17792 from jakevdp:mean-cast-f16
PiperOrigin-RevId: 569019549
This commit is contained in:
commit
59360794c1
@ -320,16 +320,23 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
|
||||
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
|
||||
out: None = None, keepdims: bool = False, *,
|
||||
upcast_f16_for_computation: bool = True,
|
||||
where: Optional[ArrayLike] = None) -> Array:
|
||||
check_arraylike("mean", a)
|
||||
if dtype is None:
|
||||
dtype = dtypes.to_inexact_dtype(dtypes.dtype(a))
|
||||
else:
|
||||
dtypes.check_user_dtype_supported(dtype, "mean")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
|
||||
|
||||
if dtype is None:
|
||||
result_dtype = dtypes.to_inexact_dtype(dtypes.dtype(a, canonicalize=True))
|
||||
else:
|
||||
dtypes.check_user_dtype_supported(dtype, "mean")
|
||||
result_dtype = dtypes.canonicalize_dtype(dtype)
|
||||
|
||||
if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact):
|
||||
computation_dtype = _upcast_f16(result_dtype)
|
||||
else:
|
||||
computation_dtype = result_dtype
|
||||
|
||||
if where is None:
|
||||
if axis is None:
|
||||
normalizer = core.dimension_as_value(np.size(a))
|
||||
@ -339,8 +346,9 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
|
||||
normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims)
|
||||
|
||||
return lax.div(
|
||||
sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
|
||||
lax.convert_element_type(normalizer, dtype))
|
||||
sum(a, axis, dtype=computation_dtype, keepdims=keepdims, where=where),
|
||||
lax.convert_element_type(normalizer, computation_dtype)
|
||||
).astype(result_dtype)
|
||||
|
||||
@overload
|
||||
def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None,
|
||||
|
@ -758,6 +758,15 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self.assertEqual(0.0, jnp.std(x))
|
||||
self.assertEqual(0.0, jnp.std(x, where=True))
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=[np.dtype(np.float16), np.dtype(dtypes.bfloat16)],
|
||||
)
|
||||
def test_f16_mean(self, dtype):
|
||||
x = np.full(100_000, 1E-5, dtype=dtype)
|
||||
expected = np.mean(x.astype('float64')).astype(dtype)
|
||||
actual = jnp.mean(x)
|
||||
self.assertAllClose(expected, actual, atol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user