jax.numpy reductions: avoid upcast of f16 when dtype is specified by user

This commit is contained in:
Jake VanderPlas 2025-02-12 11:49:39 -08:00
parent 5b697728c7
commit b5e7b60d6a
2 changed files with 56 additions and 5 deletions

View File

@ -231,7 +231,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
bool_op=lax.bitwise_or, upcast_f16_for_computation=(dtype is None),
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.psum,
promote_integers=promote_integers)
@ -319,7 +319,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric,
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
bool_op=lax.bitwise_and, upcast_f16_for_computation=(dtype is None),
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, promote_integers=promote_integers)
@ -865,9 +865,10 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
[6. ]], dtype=float32)
"""
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
where=where)
where=where, upcast_f16_for_computation=(dtype is None))
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'upcast_f16_for_computation'),
inline=True)
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, *,
upcast_f16_for_computation: bool = True,

View File

@ -231,7 +231,12 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
np.uint32: 3e-7, np.float32: 1e-3, np.complex64: 1e-3,
np.float64: 1e-5, np.complex128: 1e-5}
tol = jtu.tolerance(dtype, tol_spec)
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
if out_dtype in [np.float16, dtypes.bfloat16]:
# For 16-bit out_type, NumPy will accumulate in float32, while JAX
# accumulates in 16-bit, so we need a larger tolerance.
tol = 1e-1
else:
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=jnp.bfloat16 not in (dtype, out_dtype),
tol=tol)
@ -930,5 +935,50 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
op=['sum', 'prod'],
dtype=['float16', 'bfloat16'],
)
def testReducerF16Casts(self, op, dtype):
rng = jtu.rand_default(self.rng())
x = jnp.asarray(rng((10,), dtype))
func = getattr(jnp, op)
reduce_p = getattr(jax.lax, f"reduce_{op}_p")
conv_elem_p = jax.lax.convert_element_type_p
# Without dtype specified, the reduction is sandwiched between two casts.
jaxpr1 = jax.make_jaxpr(func)(x)
self.assertEqual(
[eqn.primitive for eqn in jaxpr1.eqns],
[conv_elem_p, reduce_p, conv_elem_p])
# With dtype specified, the reduction happens without a cast.
jaxpr2 = jax.make_jaxpr(partial(func, dtype=dtype))(x)
self.assertEqual([eqn.primitive for eqn in jaxpr2.eqns], [reduce_p])
@jtu.sample_product(
dtype=['float16', 'bfloat16'],
)
def testMeanF16Casts(self, dtype):
rng = jtu.rand_default(self.rng())
x = jnp.asarray(rng((10,), dtype))
reduce_sum_p = jax.lax.reduce_sum_p
div_p = jax.lax.div_p
conv_elem_p = jax.lax.convert_element_type_p
# Without dtype specified, the reduction is sandwiched between two casts.
jaxpr1 = jax.make_jaxpr(jnp.mean)(x)
self.assertEqual(
[eqn.primitive for eqn in jaxpr1.eqns],
[conv_elem_p, reduce_sum_p, div_p, conv_elem_p])
# With dtype specified, the reduction happens without a cast.
jaxpr2 = jax.make_jaxpr(partial(jnp.mean, dtype=dtype))(x)
self.assertEqual(
[eqn.primitive for eqn in jaxpr2.eqns],
[reduce_sum_p, div_p])
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())