mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
jax.numpy reductions: avoid upcast of f16 when dtype is specified by user
This commit is contained in:
parent
5b697728c7
commit
b5e7b60d6a
@ -231,7 +231,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None,
|
||||
promote_integers: bool = True) -> Array:
|
||||
return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
|
||||
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
|
||||
bool_op=lax.bitwise_or, upcast_f16_for_computation=(dtype is None),
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.psum,
|
||||
promote_integers=promote_integers)
|
||||
@ -319,7 +319,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None,
|
||||
promote_integers: bool = True) -> Array:
|
||||
return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric,
|
||||
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
|
||||
bool_op=lax.bitwise_and, upcast_f16_for_computation=(dtype is None),
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, promote_integers=promote_integers)
|
||||
|
||||
@ -865,9 +865,10 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
[6. ]], dtype=float32)
|
||||
"""
|
||||
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
|
||||
where=where)
|
||||
where=where, upcast_f16_for_computation=(dtype is None))
|
||||
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
|
||||
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'upcast_f16_for_computation'),
|
||||
inline=True)
|
||||
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False, *,
|
||||
upcast_f16_for_computation: bool = True,
|
||||
|
@ -231,7 +231,12 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
np.uint32: 3e-7, np.float32: 1e-3, np.complex64: 1e-3,
|
||||
np.float64: 1e-5, np.complex128: 1e-5}
|
||||
tol = jtu.tolerance(dtype, tol_spec)
|
||||
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
|
||||
if out_dtype in [np.float16, dtypes.bfloat16]:
|
||||
# For 16-bit out_type, NumPy will accumulate in float32, while JAX
|
||||
# accumulates in 16-bit, so we need a larger tolerance.
|
||||
tol = 1e-1
|
||||
else:
|
||||
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
check_dtypes=jnp.bfloat16 not in (dtype, out_dtype),
|
||||
tol=tol)
|
||||
@ -930,5 +935,50 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
op=['sum', 'prod'],
|
||||
dtype=['float16', 'bfloat16'],
|
||||
)
|
||||
def testReducerF16Casts(self, op, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = jnp.asarray(rng((10,), dtype))
|
||||
|
||||
func = getattr(jnp, op)
|
||||
reduce_p = getattr(jax.lax, f"reduce_{op}_p")
|
||||
conv_elem_p = jax.lax.convert_element_type_p
|
||||
|
||||
# Without dtype specified, the reduction is sandwiched between two casts.
|
||||
jaxpr1 = jax.make_jaxpr(func)(x)
|
||||
self.assertEqual(
|
||||
[eqn.primitive for eqn in jaxpr1.eqns],
|
||||
[conv_elem_p, reduce_p, conv_elem_p])
|
||||
|
||||
# With dtype specified, the reduction happens without a cast.
|
||||
jaxpr2 = jax.make_jaxpr(partial(func, dtype=dtype))(x)
|
||||
self.assertEqual([eqn.primitive for eqn in jaxpr2.eqns], [reduce_p])
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=['float16', 'bfloat16'],
|
||||
)
|
||||
def testMeanF16Casts(self, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = jnp.asarray(rng((10,), dtype))
|
||||
|
||||
reduce_sum_p = jax.lax.reduce_sum_p
|
||||
div_p = jax.lax.div_p
|
||||
conv_elem_p = jax.lax.convert_element_type_p
|
||||
|
||||
# Without dtype specified, the reduction is sandwiched between two casts.
|
||||
jaxpr1 = jax.make_jaxpr(jnp.mean)(x)
|
||||
self.assertEqual(
|
||||
[eqn.primitive for eqn in jaxpr1.eqns],
|
||||
[conv_elem_p, reduce_sum_p, div_p, conv_elem_p])
|
||||
|
||||
# With dtype specified, the reduction happens without a cast.
|
||||
jaxpr2 = jax.make_jaxpr(partial(jnp.mean, dtype=dtype))(x)
|
||||
self.assertEqual(
|
||||
[eqn.primitive for eqn in jaxpr2.eqns],
|
||||
[reduce_sum_p, div_p])
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user