mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Use ml_dtypes definition for jnp.finfo
This commit is contained in:
parent
40d730be49
commit
59e6ed213e
@ -195,177 +195,12 @@ def coerce_to_array(x: Any, dtype: Optional[DTypeLike] = None) -> np.ndarray:
|
||||
return np.asarray(x, dtype)
|
||||
|
||||
iinfo = np.iinfo
|
||||
|
||||
class _Bfloat16MachArLike:
|
||||
def __init__(self):
|
||||
smallest_normal = float.fromhex("0x1p-126")
|
||||
self.smallest_normal = bfloat16(smallest_normal)
|
||||
smallest_subnormal = float.fromhex("0x1p-133")
|
||||
self.smallest_subnormal = bfloat16(smallest_subnormal)
|
||||
|
||||
class _Float8E4m3FnMachArLike:
|
||||
def __init__(self):
|
||||
smallest_normal = float.fromhex("0x1p-6")
|
||||
self.smallest_normal = float8_e4m3fn(smallest_normal)
|
||||
smallest_subnormal = float.fromhex("0x1p-9")
|
||||
self.smallest_subnormal = float8_e4m3fn(smallest_subnormal)
|
||||
|
||||
class _Float8E5m2MachArLike:
|
||||
def __init__(self):
|
||||
smallest_normal = float.fromhex("0x1p-14")
|
||||
self.smallest_normal = float8_e5m2(smallest_normal)
|
||||
smallest_subnormal = float.fromhex("0x1p-16")
|
||||
self.smallest_subnormal = float8_e5m2(smallest_subnormal)
|
||||
|
||||
class finfo(np.finfo):
|
||||
__doc__ = np.finfo.__doc__
|
||||
_finfo_cache: Dict[np.dtype, np.finfo] = {}
|
||||
@staticmethod
|
||||
def _bfloat16_finfo():
|
||||
def float_to_str(f):
|
||||
return "%12.4e" % float(f)
|
||||
|
||||
bfloat16 = _bfloat16_dtype.type
|
||||
tiny = float.fromhex("0x1p-126")
|
||||
resolution = 0.01
|
||||
eps = float.fromhex("0x1p-7")
|
||||
epsneg = float.fromhex("0x1p-8")
|
||||
max = float.fromhex("0x1.FEp127")
|
||||
|
||||
obj = object.__new__(np.finfo)
|
||||
obj.dtype = _bfloat16_dtype
|
||||
obj.bits = 16
|
||||
obj.eps = bfloat16(eps)
|
||||
obj.epsneg = bfloat16(epsneg)
|
||||
obj.machep = -7
|
||||
obj.negep = -8
|
||||
obj.max = bfloat16(max)
|
||||
obj.min = bfloat16(-max)
|
||||
obj.nexp = 8
|
||||
obj.nmant = 7
|
||||
obj.iexp = obj.nexp
|
||||
obj.maxexp = 128
|
||||
obj.minexp = -126
|
||||
obj.precision = 2
|
||||
obj.resolution = bfloat16(resolution)
|
||||
obj._machar = _Bfloat16MachArLike()
|
||||
if not hasattr(obj, "tiny"):
|
||||
obj.tiny = bfloat16(tiny)
|
||||
if not hasattr(obj, "smallest_normal"):
|
||||
obj.smallest_normal = obj._machar.smallest_normal
|
||||
obj.smallest_subnormal = obj._machar.smallest_subnormal
|
||||
|
||||
obj._str_tiny = float_to_str(tiny)
|
||||
obj._str_smallest_normal = float_to_str(tiny)
|
||||
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
|
||||
obj._str_max = float_to_str(max)
|
||||
obj._str_epsneg = float_to_str(epsneg)
|
||||
obj._str_eps = float_to_str(eps)
|
||||
obj._str_resolution = float_to_str(resolution)
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def _float8_e4m3fn_finfo():
|
||||
def float_to_str(f):
|
||||
return "%6.2e" % float(f)
|
||||
|
||||
float8_e4m3fn = _float8_e4m3fn_dtype.type
|
||||
tiny = float.fromhex("0x1p-6")
|
||||
resolution = 0.1
|
||||
eps = float.fromhex("0x1p-3")
|
||||
epsneg = float.fromhex("0x1p-4")
|
||||
max = float.fromhex("0x1.Cp8")
|
||||
|
||||
obj = object.__new__(np.finfo)
|
||||
obj.dtype = _float8_e4m3fn_dtype
|
||||
obj.bits = 8
|
||||
obj.eps = float8_e4m3fn(eps)
|
||||
obj.epsneg = float8_e4m3fn(epsneg)
|
||||
obj.machep = -3
|
||||
obj.negep = -4
|
||||
obj.max = float8_e4m3fn(max)
|
||||
obj.min = float8_e4m3fn(-max)
|
||||
obj.nexp = 4
|
||||
obj.nmant = 3
|
||||
obj.iexp = obj.nexp
|
||||
obj.maxexp = 9
|
||||
obj.minexp = -6
|
||||
obj.precision = 1
|
||||
obj.resolution = float8_e4m3fn(resolution)
|
||||
obj._machar = _Float8E4m3FnMachArLike()
|
||||
if not hasattr(obj, "tiny"):
|
||||
obj.tiny = float8_e4m3fn(tiny)
|
||||
if not hasattr(obj, "smallest_normal"):
|
||||
obj.smallest_normal = obj._machar.smallest_normal
|
||||
obj.smallest_subnormal = obj._machar.smallest_subnormal
|
||||
|
||||
obj._str_tiny = float_to_str(tiny)
|
||||
obj._str_smallest_normal = float_to_str(tiny)
|
||||
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
|
||||
obj._str_max = float_to_str(max)
|
||||
obj._str_epsneg = float_to_str(epsneg)
|
||||
obj._str_eps = float_to_str(eps)
|
||||
obj._str_resolution = float_to_str(resolution)
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def _float8_e5m2_finfo():
|
||||
def float_to_str(f):
|
||||
return "%6.2e" % float(f)
|
||||
|
||||
float8_e5m2 = _float8_e5m2_dtype.type
|
||||
tiny = float.fromhex("0x1p-14")
|
||||
resolution = 0.1
|
||||
eps = float.fromhex("0x1p-2")
|
||||
epsneg = float.fromhex("0x1p-3")
|
||||
max = float.fromhex("0x1.Cp15")
|
||||
|
||||
obj = object.__new__(np.finfo)
|
||||
obj.dtype = _float8_e5m2_dtype
|
||||
obj.bits = 8
|
||||
obj.eps = float8_e5m2(eps)
|
||||
obj.epsneg = float8_e5m2(epsneg)
|
||||
obj.machep = -2
|
||||
obj.negep = -3
|
||||
obj.max = float8_e5m2(max)
|
||||
obj.min = float8_e5m2(-max)
|
||||
obj.nexp = 5
|
||||
obj.nmant = 2
|
||||
obj.iexp = obj.nexp
|
||||
obj.maxexp = 16
|
||||
obj.minexp = -14
|
||||
obj.precision = 1
|
||||
obj.resolution = float8_e5m2(resolution)
|
||||
obj._machar = _Float8E5m2MachArLike()
|
||||
if not hasattr(obj, "tiny"):
|
||||
obj.tiny = float8_e5m2(tiny)
|
||||
if not hasattr(obj, "smallest_normal"):
|
||||
obj.smallest_normal = obj._machar.smallest_normal
|
||||
obj.smallest_subnormal = obj._machar.smallest_subnormal
|
||||
|
||||
obj._str_tiny = float_to_str(tiny)
|
||||
obj._str_smallest_normal = float_to_str(tiny)
|
||||
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
|
||||
obj._str_max = float_to_str(max)
|
||||
obj._str_epsneg = float_to_str(epsneg)
|
||||
obj._str_eps = float_to_str(eps)
|
||||
obj._str_resolution = float_to_str(resolution)
|
||||
return obj
|
||||
|
||||
def __new__(cls, dtype):
|
||||
if isinstance(dtype, str) and dtype == 'bfloat16' or dtype == _bfloat16_dtype:
|
||||
if _bfloat16_dtype not in cls._finfo_cache:
|
||||
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
|
||||
return cls._finfo_cache[_bfloat16_dtype]
|
||||
if isinstance(dtype, str) and dtype == 'float8_e4m3fn' or dtype == _float8_e4m3fn_dtype:
|
||||
if _float8_e4m3fn_dtype not in cls._finfo_cache:
|
||||
cls._finfo_cache[_float8_e4m3fn_dtype] = cls._float8_e4m3fn_finfo()
|
||||
return cls._finfo_cache[_float8_e4m3fn_dtype]
|
||||
if isinstance(dtype, str) and dtype == 'float8_e5m2' or dtype == _float8_e5m2_dtype:
|
||||
if _float8_e5m2_dtype not in cls._finfo_cache:
|
||||
cls._finfo_cache[_float8_e5m2_dtype] = cls._float8_e5m2_finfo()
|
||||
return cls._finfo_cache[_float8_e5m2_dtype]
|
||||
return super().__new__(cls, dtype)
|
||||
try:
|
||||
finfo = ml_dtypes.finfo
|
||||
except AttributeError as err:
|
||||
_ml_dtypes_version = getattr(ml_dtypes, "__version__", "<unknown>")
|
||||
raise ImportError("JAX requires package ml_dtypes>=0.1.0. "
|
||||
f"Installed version is {_ml_dtypes_version}.") from err
|
||||
|
||||
def _issubclass(a: Any, b: Any) -> bool:
|
||||
"""Determines if ``a`` is a subclass of ``b``.
|
||||
|
@ -46,7 +46,7 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=['jaxlib', 'jaxlib.xla_extension'],
|
||||
python_requires='>=3.8',
|
||||
install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.0.3'],
|
||||
install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.1.0'],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
classifiers=[
|
||||
|
Loading…
x
Reference in New Issue
Block a user