mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[JAX] Compute FP bias from the min exponent
The min exponent defines the bias, not the number of bits. The number of bits will work for standard FP types but not for FP8 types. PiperOrigin-RevId: 580949955
This commit is contained in:
parent
340e655ac2
commit
8bac6d7877
@ -515,7 +515,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x2 = lax.convert_element_type(x2, int_type)
|
||||
|
||||
mask = (1 << info.nexp) - 1
|
||||
bias = ((1 << info.nexp) - 1) >> 1
|
||||
bias = 1 - info.minexp
|
||||
x, e = _normalize_float(x1)
|
||||
x2 += e + ((x >> info.nmant) & mask) - bias
|
||||
|
||||
@ -555,7 +555,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
|
||||
dtype = dtypes.dtype(x)
|
||||
info = dtypes.finfo(dtype)
|
||||
mask = (1 << info.nexp) - 1
|
||||
bias = ((1 << info.nexp) - 1) >> 1
|
||||
bias = 1 - info.minexp
|
||||
|
||||
x1, x2 = _normalize_float(x)
|
||||
x2 += ((x1 >> info.nmant) & mask) - bias + 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user