Use ml_dtypes definition for jnp.finfo

This commit is contained in:
Jake VanderPlas 2023-05-04 10:40:44 -07:00
parent 40d730be49
commit 59e6ed213e
3 changed files with 8 additions and 173 deletions

View File

@ -195,177 +195,12 @@ def coerce_to_array(x: Any, dtype: Optional[DTypeLike] = None) -> np.ndarray:
return np.asarray(x, dtype)
iinfo = np.iinfo
class _Bfloat16MachArLike:
def __init__(self):
smallest_normal = float.fromhex("0x1p-126")
self.smallest_normal = bfloat16(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-133")
self.smallest_subnormal = bfloat16(smallest_subnormal)
class _Float8E4m3FnMachArLike:
def __init__(self):
smallest_normal = float.fromhex("0x1p-6")
self.smallest_normal = float8_e4m3fn(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-9")
self.smallest_subnormal = float8_e4m3fn(smallest_subnormal)
class _Float8E5m2MachArLike:
def __init__(self):
smallest_normal = float.fromhex("0x1p-14")
self.smallest_normal = float8_e5m2(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-16")
self.smallest_subnormal = float8_e5m2(smallest_subnormal)
class finfo(np.finfo):
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[np.dtype, np.finfo] = {}
@staticmethod
def _bfloat16_finfo():
def float_to_str(f):
return "%12.4e" % float(f)
bfloat16 = _bfloat16_dtype.type
tiny = float.fromhex("0x1p-126")
resolution = 0.01
eps = float.fromhex("0x1p-7")
epsneg = float.fromhex("0x1p-8")
max = float.fromhex("0x1.FEp127")
obj = object.__new__(np.finfo)
obj.dtype = _bfloat16_dtype
obj.bits = 16
obj.eps = bfloat16(eps)
obj.epsneg = bfloat16(epsneg)
obj.machep = -7
obj.negep = -8
obj.max = bfloat16(max)
obj.min = bfloat16(-max)
obj.nexp = 8
obj.nmant = 7
obj.iexp = obj.nexp
obj.maxexp = 128
obj.minexp = -126
obj.precision = 2
obj.resolution = bfloat16(resolution)
obj._machar = _Bfloat16MachArLike()
if not hasattr(obj, "tiny"):
obj.tiny = bfloat16(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(max)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)
obj._str_resolution = float_to_str(resolution)
return obj
@staticmethod
def _float8_e4m3fn_finfo():
def float_to_str(f):
return "%6.2e" % float(f)
float8_e4m3fn = _float8_e4m3fn_dtype.type
tiny = float.fromhex("0x1p-6")
resolution = 0.1
eps = float.fromhex("0x1p-3")
epsneg = float.fromhex("0x1p-4")
max = float.fromhex("0x1.Cp8")
obj = object.__new__(np.finfo)
obj.dtype = _float8_e4m3fn_dtype
obj.bits = 8
obj.eps = float8_e4m3fn(eps)
obj.epsneg = float8_e4m3fn(epsneg)
obj.machep = -3
obj.negep = -4
obj.max = float8_e4m3fn(max)
obj.min = float8_e4m3fn(-max)
obj.nexp = 4
obj.nmant = 3
obj.iexp = obj.nexp
obj.maxexp = 9
obj.minexp = -6
obj.precision = 1
obj.resolution = float8_e4m3fn(resolution)
obj._machar = _Float8E4m3FnMachArLike()
if not hasattr(obj, "tiny"):
obj.tiny = float8_e4m3fn(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(max)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)
obj._str_resolution = float_to_str(resolution)
return obj
@staticmethod
def _float8_e5m2_finfo():
def float_to_str(f):
return "%6.2e" % float(f)
float8_e5m2 = _float8_e5m2_dtype.type
tiny = float.fromhex("0x1p-14")
resolution = 0.1
eps = float.fromhex("0x1p-2")
epsneg = float.fromhex("0x1p-3")
max = float.fromhex("0x1.Cp15")
obj = object.__new__(np.finfo)
obj.dtype = _float8_e5m2_dtype
obj.bits = 8
obj.eps = float8_e5m2(eps)
obj.epsneg = float8_e5m2(epsneg)
obj.machep = -2
obj.negep = -3
obj.max = float8_e5m2(max)
obj.min = float8_e5m2(-max)
obj.nexp = 5
obj.nmant = 2
obj.iexp = obj.nexp
obj.maxexp = 16
obj.minexp = -14
obj.precision = 1
obj.resolution = float8_e5m2(resolution)
obj._machar = _Float8E5m2MachArLike()
if not hasattr(obj, "tiny"):
obj.tiny = float8_e5m2(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(max)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)
obj._str_resolution = float_to_str(resolution)
return obj
def __new__(cls, dtype):
if isinstance(dtype, str) and dtype == 'bfloat16' or dtype == _bfloat16_dtype:
if _bfloat16_dtype not in cls._finfo_cache:
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
return cls._finfo_cache[_bfloat16_dtype]
if isinstance(dtype, str) and dtype == 'float8_e4m3fn' or dtype == _float8_e4m3fn_dtype:
if _float8_e4m3fn_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3fn_dtype] = cls._float8_e4m3fn_finfo()
return cls._finfo_cache[_float8_e4m3fn_dtype]
if isinstance(dtype, str) and dtype == 'float8_e5m2' or dtype == _float8_e5m2_dtype:
if _float8_e5m2_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e5m2_dtype] = cls._float8_e5m2_finfo()
return cls._finfo_cache[_float8_e5m2_dtype]
return super().__new__(cls, dtype)
try:
finfo = ml_dtypes.finfo
except AttributeError as err:
_ml_dtypes_version = getattr(ml_dtypes, "__version__", "<unknown>")
raise ImportError("JAX requires package ml_dtypes>=0.1.0. "
f"Installed version is {_ml_dtypes_version}.") from err
def _issubclass(a: Any, b: Any) -> bool:
"""Determines if ``a`` is a subclass of ``b``.

View File

@ -46,7 +46,7 @@ setup(
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension'],
python_requires='>=3.8',
install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.0.3'],
install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.1.0'],
url='https://github.com/google/jax',
license='Apache-2.0',
classifiers=[

View File

@ -63,7 +63,7 @@ setup(
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
python_requires='>=3.8',
install_requires=[
'ml_dtypes>=0.0.3',
'ml_dtypes>=0.1.0',
'numpy>=1.21',
'opt_einsum',
'scipy>=1.7',