Add e8m0fnu support by conditional dtype.

This commit is contained in:
wenscarl 2025-01-22 21:57:43 +00:00
parent fc9356085e
commit 638c6ae046
11 changed files with 30 additions and 1 deletions

View File

@ -93,6 +93,7 @@ class ExtendedDType(StrictABC):
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
float8_e3m4: type[np.generic] | None = None
float8_e4m3: type[np.generic] | None = None
float8_e8m0fnu: type[np.generic] | None = None
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
@ -101,6 +102,7 @@ float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz
_float8_e3m4_dtype: np.dtype | None = None
_float8_e4m3_dtype: np.dtype | None = None
_float8_e8m0fnu_dtype: np.dtype | None = None
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
@ -155,6 +157,12 @@ if hasattr(ml_dtypes, "float8_e3m4"):
_custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e3m4_dtype)
_float8_dtypes.insert(0, _float8_e3m4_dtype)
if hasattr(ml_dtypes, "float8_e8m0fnu"):
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
_float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu)
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)
# 2-bit integer support
int2: type[np.generic] | None = None

View File

@ -74,6 +74,7 @@ enum DType: byte {
f8_e4m3fnuz = 19,
f8_e5m2 = 20,
f8_e5m2fnuz = 21,
f8_e8m0fnu = 25,
}
table AbstractValue {

View File

@ -363,7 +363,8 @@ if dtypes._float8_e3m4_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
if dtypes._float8_e4m3_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
if dtypes._float8_e8m0fnu_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
}

View File

@ -61,6 +61,7 @@ class DType(object):
f8_e5m2 = 20
f8_e5m2fnuz = 21
f0 = 22
f8_e8m0fnu = 25
class ShardingKind(object):

View File

@ -193,6 +193,8 @@ if dtypes.float8_e3m4 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get
if dtypes.float8_e4m3 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get
if dtypes.float8_e8m0fnu is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):

View File

@ -1205,6 +1205,8 @@ class DotAlgorithmPreset(enum.Enum):
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
raise ValueError(
f"The dot algorithm '{self}' requires both inputs to have float8 "
@ -3965,6 +3967,8 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
fp8_dtypes += (dtypes.float8_e3m4,)
if dtypes.float8_e4m3 is not None:
fp8_dtypes += (dtypes.float8_e4m3,)
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += (dtypes.float8_e8m0fnu,)
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
del preferred_element_type # Implied by the output aval
lhs_aval, rhs_aval = ctx.avals_in

View File

@ -225,6 +225,8 @@ if dtypes.float8_e3m4 is not None:
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
if dtypes.float8_e4m3 is not None:
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
if dtypes.float8_e8m0fnu is not None:
float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu)
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)

View File

@ -97,6 +97,9 @@ if _dtypes.float8_e3m4 is not None:
if _dtypes.float8_e4m3 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
if _dtypes.float8_e8m0fnu is not None:
_default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
@ -119,6 +122,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
custom_float_dtypes.insert(0, _dtypes.float8_e4m3)
if _dtypes.float8_e3m4 is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
if _dtypes.float8_e8m0fnu is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu)
def maybe_upcast(x):
if x.dtype in custom_float_dtypes:

View File

@ -1616,6 +1616,8 @@ class _LazyDtypes:
float_dtypes += [_dtypes.float8_e3m4]
if _dtypes.float8_e4m3 is not None:
float_dtypes += [_dtypes.float8_e4m3]
if _dtypes.float8_e8m0fnu is not None:
float_dtypes += [_dtypes.float8_e8m0fnu]
return self.supported(float_dtypes)
@_cached_property

View File

@ -281,6 +281,7 @@ try:
from jax._src.numpy.lax_numpy import (
float8_e3m4 as float8_e3m4,
float8_e4m3 as float8_e4m3,
float8_e8m0fnu as float8_e8m0fnu,
)
except ImportError:
pass

View File

@ -68,6 +68,8 @@ if dtypes.float8_e3m4 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
float_dtypes += fp8_dtypes
custom_float_dtypes += fp8_dtypes