diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 53047c5a9..985b296bc 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -231,7 +231,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric, - bool_op=lax.bitwise_or, upcast_f16_for_computation=True, + bool_op=lax.bitwise_or, upcast_f16_for_computation=(dtype is None), axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.psum, promote_integers=promote_integers) @@ -319,7 +319,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric, - bool_op=lax.bitwise_and, upcast_f16_for_computation=True, + bool_op=lax.bitwise_and, upcast_f16_for_computation=(dtype is None), axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, promote_integers=promote_integers) @@ -865,9 +865,10 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, [6. ]], dtype=float32) """ return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims, - where=where) + where=where, upcast_f16_for_computation=(dtype is None)) -@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True) +@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'upcast_f16_for_computation'), + inline=True) def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, upcast_f16_for_computation: bool = True, diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 027eac86f..0c3f1d147 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -231,7 +231,12 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): np.uint32: 3e-7, np.float32: 1e-3, np.complex64: 1e-3, np.float64: 1e-5, np.complex128: 1e-5} tol = jtu.tolerance(dtype, tol_spec) - tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol + if out_dtype in [np.float16, dtypes.bfloat16]: + # For 16-bit out_type, NumPy will accumulate in float32, while JAX + # accumulates in 16-bit, so we need a larger tolerance. + tol = 1e-1 + else: + tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=jnp.bfloat16 not in (dtype, out_dtype), tol=tol) @@ -930,5 +935,50 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + op=['sum', 'prod'], + dtype=['float16', 'bfloat16'], + ) + def testReducerF16Casts(self, op, dtype): + rng = jtu.rand_default(self.rng()) + x = jnp.asarray(rng((10,), dtype)) + + func = getattr(jnp, op) + reduce_p = getattr(jax.lax, f"reduce_{op}_p") + conv_elem_p = jax.lax.convert_element_type_p + + # Without dtype specified, the reduction is sandwiched between two casts. + jaxpr1 = jax.make_jaxpr(func)(x) + self.assertEqual( + [eqn.primitive for eqn in jaxpr1.eqns], + [conv_elem_p, reduce_p, conv_elem_p]) + + # With dtype specified, the reduction happens without a cast. + jaxpr2 = jax.make_jaxpr(partial(func, dtype=dtype))(x) + self.assertEqual([eqn.primitive for eqn in jaxpr2.eqns], [reduce_p]) + + @jtu.sample_product( + dtype=['float16', 'bfloat16'], + ) + def testMeanF16Casts(self, dtype): + rng = jtu.rand_default(self.rng()) + x = jnp.asarray(rng((10,), dtype)) + + reduce_sum_p = jax.lax.reduce_sum_p + div_p = jax.lax.div_p + conv_elem_p = jax.lax.convert_element_type_p + + # Without dtype specified, the reduction is sandwiched between two casts. + jaxpr1 = jax.make_jaxpr(jnp.mean)(x) + self.assertEqual( + [eqn.primitive for eqn in jaxpr1.eqns], + [conv_elem_p, reduce_sum_p, div_p, conv_elem_p]) + + # With dtype specified, the reduction happens without a cast. + jaxpr2 = jax.make_jaxpr(partial(jnp.mean, dtype=dtype))(x) + self.assertEqual( + [eqn.primitive for eqn in jaxpr2.eqns], + [reduce_sum_p, div_p]) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())