diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index ac0418932..c9710c287 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -90,12 +90,17 @@ class ExtendedDType(StrictABC): # fp8 support +# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 +float8_e3m4: type[np.generic] | None = None +float8_e4m3: type[np.generic] | None = None float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz +_float8_e3m4_dtype: np.dtype | None = None +_float8_e4m3_dtype: np.dtype | None = None _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -137,6 +142,20 @@ _float8_dtypes = [ _float8_e5m2fnuz_dtype, ] +# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 +if hasattr(ml_dtypes, "float8_e4m3"): + float8_e4m3 = ml_dtypes.float8_e4m3 + _float8_e4m3_dtype = np.dtype(float8_e4m3) + _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e4m3_dtype) + _float8_dtypes.insert(0, _float8_e4m3_dtype) +if hasattr(ml_dtypes, "float8_e3m4"): + float8_e3m4 = ml_dtypes.float8_e3m4 + _float8_e3m4_dtype = np.dtype(float8_e3m4) + _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e3m4_dtype) + _float8_dtypes.insert(0, _float8_e3m4_dtype) + # 2-bit integer support int2: type[np.generic] | None = None uint2: type[np.generic] | None = None diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 3198f83aa..b71b377d8 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -67,6 +67,8 @@ enum DType: byte { i4 = 15, ui4 = 16, + f8_e3m4 = 24, + f8_e4m3 = 23, f8_e4m3b11fnuz = 17, f8_e4m3fn = 18, f8_e4m3fnuz = 19, diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index e392289da..0d9ce961b 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -359,6 +359,10 @@ _dtype_to_dtype_kind = { dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, } +if dtypes._float8_e3m4_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 +if dtypes._float8_e4m3_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index 18dd2c3cb..70d298020 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -53,6 +53,8 @@ class DType(object): bf16 = 14 i4 = 15 ui4 = 16 + f8_e3m4 = 24 + f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2c0e26019..54a85f92c 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -184,13 +184,13 @@ _dtype_to_ir_type : dict[np.dtype, Callable[[], ir.Type]] = { if dtypes.int2 is not None: assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial( - ir.IntegerType.get_signless, 2 - ) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial( - ir.IntegerType.get_unsigned, 2 - ) + _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2) + _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2) +if dtypes.float8_e3m4 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get +if dtypes.float8_e4m3 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8b6a517a5..7fa2dd4ac 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -937,11 +937,15 @@ class DotAlgorithmPreset(enum.Enum): DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): - fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz), + fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)) + np.dtype(dtypes.float8_e5m2fnuz)] + if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] + if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -3625,6 +3629,10 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) + if dtypes.float8_e3m4 is not None: + fp8_dtypes += (dtypes.float8_e3m4,) + if dtypes.float8_e4m3 is not None: + fp8_dtypes += (dtypes.float8_e4m3,) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes del preferred_element_type # Implied by the output aval lhs_aval, rhs_aval = ctx.avals_in diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d2e898339..c419c083f 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -217,6 +217,10 @@ int8 = _make_scalar_type(np.int8) int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) +if dtypes.float8_e3m4 is not None: + float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +if dtypes.float8_e4m3 is not None: + float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 9859eb64c..6bbcdd084 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -90,6 +90,14 @@ default_gradient_tolerance = { np.dtype(np.complex128): 1e-5, } +# TODO: make this unconditional when ml_dtypes>=0.5.0 is required +if _dtypes.float8_e3m4 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 +if _dtypes.float8_e4m3 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -106,6 +114,12 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.float8_e5m2fnuz, _dtypes.bfloat16, ] + + if _dtypes.float8_e4m3 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e4m3) + if _dtypes.float8_e3m4 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e3m4) + def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index bb81c979b..e7707f58f 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1431,10 +1431,19 @@ class _LazyDtypes: @_cached_property def custom_floats(self): - return [np.dtype(t) for t in [ - _dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz, - _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, - _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]] + float_dtypes = [ + _dtypes.bfloat16, + _dtypes.float8_e4m3b11fnuz, + _dtypes.float8_e4m3fn, + _dtypes.float8_e4m3fnuz, + _dtypes.float8_e5m2, + _dtypes.float8_e5m2fnuz, + ] + if _dtypes.float8_e3m4 is not None: + float_dtypes += [_dtypes.float8_e3m4] + if _dtypes.float8_e4m3 is not None: + float_dtypes += [_dtypes.float8_e4m3] + return [np.dtype(t) for t in float_dtypes] @_cached_property def floating(self): diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 9be73e96a..9a643bf49 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -273,6 +273,15 @@ try: except ImportError: pass +# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 +try: + from jax._src.numpy.lax_numpy import ( + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, + ) +except ImportError: + pass + from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 89d70871a..6c7e9e3ab 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -64,6 +64,10 @@ custom_float_dtypes = [np.dtype(dtypes.bfloat16)] fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz)] +if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] +if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes