Add float8_e4m3 and float8_e3m4 types support

This commit is contained in:
Sergei Lebedev 2024-10-07 15:33:24 -07:00 committed by Alexander Pivovarov
parent 2b55bd5a24
commit 78da9fa432
11 changed files with 87 additions and 12 deletions

View File

@ -90,12 +90,17 @@ class ExtendedDType(StrictABC):
# fp8 support
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
float8_e3m4: type[np.generic] | None = None
float8_e4m3: type[np.generic] | None = None
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz
_float8_e3m4_dtype: np.dtype | None = None
_float8_e4m3_dtype: np.dtype | None = None
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
@ -137,6 +142,20 @@ _float8_dtypes = [
_float8_e5m2fnuz_dtype,
]
# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
if hasattr(ml_dtypes, "float8_e4m3"):
float8_e4m3 = ml_dtypes.float8_e4m3
_float8_e4m3_dtype = np.dtype(float8_e4m3)
_custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e4m3_dtype)
_float8_dtypes.insert(0, _float8_e4m3_dtype)
if hasattr(ml_dtypes, "float8_e3m4"):
float8_e3m4 = ml_dtypes.float8_e3m4
_float8_e3m4_dtype = np.dtype(float8_e3m4)
_custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e3m4_dtype)
_float8_dtypes.insert(0, _float8_e3m4_dtype)
# 2-bit integer support
int2: type[np.generic] | None = None
uint2: type[np.generic] | None = None

View File

@ -67,6 +67,8 @@ enum DType: byte {
i4 = 15,
ui4 = 16,
f8_e3m4 = 24,
f8_e4m3 = 23,
f8_e4m3b11fnuz = 17,
f8_e4m3fn = 18,
f8_e4m3fnuz = 19,

View File

@ -359,6 +359,10 @@ _dtype_to_dtype_kind = {
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
}
if dtypes._float8_e3m4_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
if dtypes._float8_e4m3_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()

View File

@ -53,6 +53,8 @@ class DType(object):
bf16 = 14
i4 = 15
ui4 = 16
f8_e3m4 = 24
f8_e4m3 = 23
f8_e4m3b11fnuz = 17
f8_e4m3fn = 18
f8_e4m3fnuz = 19

View File

@ -184,13 +184,13 @@ _dtype_to_ir_type : dict[np.dtype, Callable[[], ir.Type]] = {
if dtypes.int2 is not None:
assert dtypes.uint2 is not None
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(
ir.IntegerType.get_signless, 2
)
_dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(
ir.IntegerType.get_unsigned, 2
)
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2)
_dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2)
if dtypes.float8_e3m4 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get
if dtypes.float8_e4m3 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):

View File

@ -937,11 +937,15 @@ class DotAlgorithmPreset(enum.Enum):
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
):
fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz),
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz),
np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e4m3fnuz),
np.dtype(dtypes.float8_e5m2),
np.dtype(dtypes.float8_e5m2fnuz))
np.dtype(dtypes.float8_e5m2fnuz)]
if dtypes.float8_e3m4 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
raise ValueError(
f"The dot algorithm '{self}' requires both inputs to have float8 "
@ -3625,6 +3629,10 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
if dtypes.float8_e3m4 is not None:
fp8_dtypes += (dtypes.float8_e3m4,)
if dtypes.float8_e4m3 is not None:
fp8_dtypes += (dtypes.float8_e4m3,)
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
del preferred_element_type # Implied by the output aval
lhs_aval, rhs_aval = ctx.avals_in

View File

@ -217,6 +217,10 @@ int8 = _make_scalar_type(np.int8)
int16 = _make_scalar_type(np.int16)
int32 = _make_scalar_type(np.int32)
int64 = _make_scalar_type(np.int64)
if dtypes.float8_e3m4 is not None:
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
if dtypes.float8_e4m3 is not None:
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)

View File

@ -90,6 +90,14 @@ default_gradient_tolerance = {
np.dtype(np.complex128): 1e-5,
}
# TODO: make this unconditional when ml_dtypes>=0.5.0 is required
if _dtypes.float8_e3m4 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
if _dtypes.float8_e4m3 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
@ -106,6 +114,12 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
_dtypes.float8_e5m2fnuz,
_dtypes.bfloat16,
]
if _dtypes.float8_e4m3 is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e4m3)
if _dtypes.float8_e3m4 is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
def maybe_upcast(x):
if x.dtype in custom_float_dtypes:
return x.astype(np.float32)

View File

@ -1431,10 +1431,19 @@ class _LazyDtypes:
@_cached_property
def custom_floats(self):
return [np.dtype(t) for t in [
_dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz,
_dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz,
_dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]]
float_dtypes = [
_dtypes.bfloat16,
_dtypes.float8_e4m3b11fnuz,
_dtypes.float8_e4m3fn,
_dtypes.float8_e4m3fnuz,
_dtypes.float8_e5m2,
_dtypes.float8_e5m2fnuz,
]
if _dtypes.float8_e3m4 is not None:
float_dtypes += [_dtypes.float8_e3m4]
if _dtypes.float8_e4m3 is not None:
float_dtypes += [_dtypes.float8_e4m3]
return [np.dtype(t) for t in float_dtypes]
@_cached_property
def floating(self):

View File

@ -273,6 +273,15 @@ try:
except ImportError:
pass
# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0
try:
from jax._src.numpy.lax_numpy import (
float8_e3m4 as float8_e3m4,
float8_e4m3 as float8_e4m3,
)
except ImportError:
pass
from jax._src.numpy.array_api_metadata import (
__array_api_version__ as __array_api_version__,
__array_namespace_info__ as __array_namespace_info__,

View File

@ -64,6 +64,10 @@ custom_float_dtypes = [np.dtype(dtypes.bfloat16)]
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2),
np.dtype(dtypes.float8_e5m2fnuz)]
if dtypes.float8_e3m4 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
float_dtypes += fp8_dtypes
custom_float_dtypes += fp8_dtypes