mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add float8_e4m3 and float8_e3m4 types support
This commit is contained in:
parent
2b55bd5a24
commit
78da9fa432
@ -90,12 +90,17 @@ class ExtendedDType(StrictABC):
|
||||
|
||||
|
||||
# fp8 support
|
||||
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
|
||||
float8_e3m4: type[np.generic] | None = None
|
||||
float8_e4m3: type[np.generic] | None = None
|
||||
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
|
||||
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
|
||||
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
|
||||
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
|
||||
float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz
|
||||
|
||||
_float8_e3m4_dtype: np.dtype | None = None
|
||||
_float8_e4m3_dtype: np.dtype | None = None
|
||||
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
|
||||
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
|
||||
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
|
||||
@ -137,6 +142,20 @@ _float8_dtypes = [
|
||||
_float8_e5m2fnuz_dtype,
|
||||
]
|
||||
|
||||
# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
|
||||
if hasattr(ml_dtypes, "float8_e4m3"):
|
||||
float8_e4m3 = ml_dtypes.float8_e4m3
|
||||
_float8_e4m3_dtype = np.dtype(float8_e4m3)
|
||||
_custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type]
|
||||
_custom_float_dtypes.insert(0, _float8_e4m3_dtype)
|
||||
_float8_dtypes.insert(0, _float8_e4m3_dtype)
|
||||
if hasattr(ml_dtypes, "float8_e3m4"):
|
||||
float8_e3m4 = ml_dtypes.float8_e3m4
|
||||
_float8_e3m4_dtype = np.dtype(float8_e3m4)
|
||||
_custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type]
|
||||
_custom_float_dtypes.insert(0, _float8_e3m4_dtype)
|
||||
_float8_dtypes.insert(0, _float8_e3m4_dtype)
|
||||
|
||||
# 2-bit integer support
|
||||
int2: type[np.generic] | None = None
|
||||
uint2: type[np.generic] | None = None
|
||||
|
@ -67,6 +67,8 @@ enum DType: byte {
|
||||
i4 = 15,
|
||||
ui4 = 16,
|
||||
|
||||
f8_e3m4 = 24,
|
||||
f8_e4m3 = 23,
|
||||
f8_e4m3b11fnuz = 17,
|
||||
f8_e4m3fn = 18,
|
||||
f8_e4m3fnuz = 19,
|
||||
|
@ -359,6 +359,10 @@ _dtype_to_dtype_kind = {
|
||||
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
|
||||
}
|
||||
|
||||
if dtypes._float8_e3m4_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
|
||||
if dtypes._float8_e4m3_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
|
||||
|
||||
_dtype_kind_to_dtype = {
|
||||
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
|
||||
|
@ -53,6 +53,8 @@ class DType(object):
|
||||
bf16 = 14
|
||||
i4 = 15
|
||||
ui4 = 16
|
||||
f8_e3m4 = 24
|
||||
f8_e4m3 = 23
|
||||
f8_e4m3b11fnuz = 17
|
||||
f8_e4m3fn = 18
|
||||
f8_e4m3fnuz = 19
|
||||
|
@ -184,13 +184,13 @@ _dtype_to_ir_type : dict[np.dtype, Callable[[], ir.Type]] = {
|
||||
|
||||
if dtypes.int2 is not None:
|
||||
assert dtypes.uint2 is not None
|
||||
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(
|
||||
ir.IntegerType.get_signless, 2
|
||||
)
|
||||
_dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(
|
||||
ir.IntegerType.get_unsigned, 2
|
||||
)
|
||||
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2)
|
||||
_dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2)
|
||||
|
||||
if dtypes.float8_e3m4 is not None:
|
||||
_dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get
|
||||
if dtypes.float8_e4m3 is not None:
|
||||
_dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get
|
||||
|
||||
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
|
||||
if isinstance(dtype, core.bint):
|
||||
|
@ -937,11 +937,15 @@ class DotAlgorithmPreset(enum.Enum):
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
|
||||
):
|
||||
fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz),
|
||||
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz),
|
||||
np.dtype(dtypes.float8_e4m3fn),
|
||||
np.dtype(dtypes.float8_e4m3fnuz),
|
||||
np.dtype(dtypes.float8_e5m2),
|
||||
np.dtype(dtypes.float8_e5m2fnuz))
|
||||
np.dtype(dtypes.float8_e5m2fnuz)]
|
||||
if dtypes.float8_e3m4 is not None:
|
||||
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
|
||||
if dtypes.float8_e4m3 is not None:
|
||||
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
|
||||
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
|
||||
raise ValueError(
|
||||
f"The dot algorithm '{self}' requires both inputs to have float8 "
|
||||
@ -3625,6 +3629,10 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
|
||||
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
|
||||
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
|
||||
if dtypes.float8_e3m4 is not None:
|
||||
fp8_dtypes += (dtypes.float8_e3m4,)
|
||||
if dtypes.float8_e4m3 is not None:
|
||||
fp8_dtypes += (dtypes.float8_e4m3,)
|
||||
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
|
||||
del preferred_element_type # Implied by the output aval
|
||||
lhs_aval, rhs_aval = ctx.avals_in
|
||||
|
@ -217,6 +217,10 @@ int8 = _make_scalar_type(np.int8)
|
||||
int16 = _make_scalar_type(np.int16)
|
||||
int32 = _make_scalar_type(np.int32)
|
||||
int64 = _make_scalar_type(np.int64)
|
||||
if dtypes.float8_e3m4 is not None:
|
||||
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
|
||||
if dtypes.float8_e4m3 is not None:
|
||||
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
|
||||
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
|
||||
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
|
||||
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
|
||||
|
@ -90,6 +90,14 @@ default_gradient_tolerance = {
|
||||
np.dtype(np.complex128): 1e-5,
|
||||
}
|
||||
|
||||
# TODO: make this unconditional when ml_dtypes>=0.5.0 is required
|
||||
if _dtypes.float8_e3m4 is not None:
|
||||
_default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
|
||||
default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
|
||||
if _dtypes.float8_e4m3 is not None:
|
||||
_default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
|
||||
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
|
||||
|
||||
def is_python_scalar(val):
|
||||
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
|
||||
|
||||
@ -106,6 +114,12 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
|
||||
_dtypes.float8_e5m2fnuz,
|
||||
_dtypes.bfloat16,
|
||||
]
|
||||
|
||||
if _dtypes.float8_e4m3 is not None:
|
||||
custom_float_dtypes.insert(0, _dtypes.float8_e4m3)
|
||||
if _dtypes.float8_e3m4 is not None:
|
||||
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
|
||||
|
||||
def maybe_upcast(x):
|
||||
if x.dtype in custom_float_dtypes:
|
||||
return x.astype(np.float32)
|
||||
|
@ -1431,10 +1431,19 @@ class _LazyDtypes:
|
||||
|
||||
@_cached_property
|
||||
def custom_floats(self):
|
||||
return [np.dtype(t) for t in [
|
||||
_dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz,
|
||||
_dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz,
|
||||
_dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]]
|
||||
float_dtypes = [
|
||||
_dtypes.bfloat16,
|
||||
_dtypes.float8_e4m3b11fnuz,
|
||||
_dtypes.float8_e4m3fn,
|
||||
_dtypes.float8_e4m3fnuz,
|
||||
_dtypes.float8_e5m2,
|
||||
_dtypes.float8_e5m2fnuz,
|
||||
]
|
||||
if _dtypes.float8_e3m4 is not None:
|
||||
float_dtypes += [_dtypes.float8_e3m4]
|
||||
if _dtypes.float8_e4m3 is not None:
|
||||
float_dtypes += [_dtypes.float8_e4m3]
|
||||
return [np.dtype(t) for t in float_dtypes]
|
||||
|
||||
@_cached_property
|
||||
def floating(self):
|
||||
|
@ -273,6 +273,15 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0
|
||||
try:
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
float8_e3m4 as float8_e3m4,
|
||||
float8_e4m3 as float8_e4m3,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from jax._src.numpy.array_api_metadata import (
|
||||
__array_api_version__ as __array_api_version__,
|
||||
__array_namespace_info__ as __array_namespace_info__,
|
||||
|
@ -64,6 +64,10 @@ custom_float_dtypes = [np.dtype(dtypes.bfloat16)]
|
||||
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn),
|
||||
np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2),
|
||||
np.dtype(dtypes.float8_e5m2fnuz)]
|
||||
if dtypes.float8_e3m4 is not None:
|
||||
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
|
||||
if dtypes.float8_e4m3 is not None:
|
||||
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
|
||||
float_dtypes += fp8_dtypes
|
||||
custom_float_dtypes += fp8_dtypes
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user