[JAX] Compute FP bias from the min exponent

The min exponent defines the bias, not the number of bits. The number of bits will work for standard FP types but not for FP8 types.

PiperOrigin-RevId: 580949955
This commit is contained in:
David Majnemer 2023-11-09 10:21:39 -08:00 committed by jax authors
parent 340e655ac2
commit 8bac6d7877

View File

@ -515,7 +515,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x2 = lax.convert_element_type(x2, int_type)
mask = (1 << info.nexp) - 1
bias = ((1 << info.nexp) - 1) >> 1
bias = 1 - info.minexp
x, e = _normalize_float(x1)
x2 += e + ((x >> info.nmant) & mask) - bias
@ -555,7 +555,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
dtype = dtypes.dtype(x)
info = dtypes.finfo(dtype)
mask = (1 << info.nexp) - 1
bias = ((1 << info.nexp) - 1) >> 1
bias = 1 - info.minexp
x1, x2 = _normalize_float(x)
x2 += ((x1 >> info.nmant) & mask) - bias + 1