From 638c6ae046c46c17072a3b7e9cd941497a0b6bf7 Mon Sep 17 00:00:00 2001 From: wenscarl Date: Wed, 22 Jan 2025 21:57:43 +0000 Subject: [PATCH] Add e8m0fnu support by conditional dtype. --- jax/_src/dtypes.py | 8 ++++++++ jax/_src/export/serialization.fbs | 1 + jax/_src/export/serialization.py | 3 ++- jax/_src/export/serialization_generated.py | 1 + jax/_src/interpreters/mlir.py | 2 ++ jax/_src/lax/lax.py | 4 ++++ jax/_src/numpy/lax_numpy.py | 2 ++ jax/_src/public_test_util.py | 5 +++++ jax/_src/test_util.py | 2 ++ jax/numpy/__init__.py | 1 + tests/dtypes_test.py | 2 ++ 11 files changed, 30 insertions(+), 1 deletion(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 04b07843a..24c0f1a62 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -93,6 +93,7 @@ class ExtendedDType(StrictABC): # TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 float8_e3m4: type[np.generic] | None = None float8_e4m3: type[np.generic] | None = None +float8_e8m0fnu: type[np.generic] | None = None float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz @@ -101,6 +102,7 @@ float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz _float8_e3m4_dtype: np.dtype | None = None _float8_e4m3_dtype: np.dtype | None = None +_float8_e8m0fnu_dtype: np.dtype | None = None _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -155,6 +157,12 @@ if hasattr(ml_dtypes, "float8_e3m4"): _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] _custom_float_dtypes.insert(0, _float8_e3m4_dtype) _float8_dtypes.insert(0, _float8_e3m4_dtype) +if hasattr(ml_dtypes, "float8_e8m0fnu"): + float8_e8m0fnu = ml_dtypes.float8_e8m0fnu + _float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu) + _custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype) + _float8_dtypes.insert(0, _float8_e8m0fnu_dtype) # 2-bit integer support int2: type[np.generic] | None = None diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 2c2eb0f69..dd0ae3edc 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -74,6 +74,7 @@ enum DType: byte { f8_e4m3fnuz = 19, f8_e5m2 = 20, f8_e5m2fnuz = 21, + f8_e8m0fnu = 25, } table AbstractValue { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 0d9ce961b..7707670f1 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -363,7 +363,8 @@ if dtypes._float8_e3m4_dtype is not None: _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 if dtypes._float8_e4m3_dtype is not None: _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 - +if dtypes._float8_e8m0fnu_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index 70d298020..69092cd7e 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -61,6 +61,7 @@ class DType(object): f8_e5m2 = 20 f8_e5m2fnuz = 21 f0 = 22 + f8_e8m0fnu = 25 class ShardingKind(object): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index fba5cc8bb..0bda64f50 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -193,6 +193,8 @@ if dtypes.float8_e3m4 is not None: _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get if dtypes.float8_e4m3 is not None: _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get +if dtypes.float8_e8m0fnu is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 10dc7ca0d..d0e27285c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1205,6 +1205,8 @@ class DotAlgorithmPreset(enum.Enum): fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] if dtypes.float8_e4m3 is not None: fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] + if dtypes.float8_e8m0fnu is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -3965,6 +3967,8 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, fp8_dtypes += (dtypes.float8_e3m4,) if dtypes.float8_e4m3 is not None: fp8_dtypes += (dtypes.float8_e4m3,) + if dtypes.float8_e8m0fnu is not None: + fp8_dtypes += (dtypes.float8_e8m0fnu,) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes del preferred_element_type # Implied by the output aval lhs_aval, rhs_aval = ctx.avals_in diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index dc689b619..be5475476 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -225,6 +225,8 @@ if dtypes.float8_e3m4 is not None: float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) if dtypes.float8_e4m3 is not None: float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) +if dtypes.float8_e8m0fnu is not None: + float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 6bbcdd084..93a6c29c2 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -97,6 +97,9 @@ if _dtypes.float8_e3m4 is not None: if _dtypes.float8_e4m3 is not None: _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 +if _dtypes.float8_e8m0fnu is not None: + _default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 + default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -119,6 +122,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): custom_float_dtypes.insert(0, _dtypes.float8_e4m3) if _dtypes.float8_e3m4 is not None: custom_float_dtypes.insert(0, _dtypes.float8_e3m4) + if _dtypes.float8_e8m0fnu is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu) def maybe_upcast(x): if x.dtype in custom_float_dtypes: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 3017890e6..d8bf467f4 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1616,6 +1616,8 @@ class _LazyDtypes: float_dtypes += [_dtypes.float8_e3m4] if _dtypes.float8_e4m3 is not None: float_dtypes += [_dtypes.float8_e4m3] + if _dtypes.float8_e8m0fnu is not None: + float_dtypes += [_dtypes.float8_e8m0fnu] return self.supported(float_dtypes) @_cached_property diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index c447b0844..6873e5b7c 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -281,6 +281,7 @@ try: from jax._src.numpy.lax_numpy import ( float8_e3m4 as float8_e3m4, float8_e4m3 as float8_e4m3, + float8_e8m0fnu as float8_e8m0fnu, ) except ImportError: pass diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index f0b8f5367..5c9f4f140 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -68,6 +68,8 @@ if dtypes.float8_e3m4 is not None: fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] if dtypes.float8_e4m3 is not None: fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] +if dtypes.float8_e8m0fnu is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes