diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 808d129ba..853fb5d1c 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -109,6 +109,12 @@ _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) _float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2) _float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz) +# fp4 support +# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 +float4_e2m1fn: type[np.generic] | None = None + +_float4_e2m1fn_dtype: np.dtype | None = None + def supports_inf(dtype: DTypeLike) -> bool: """Return true if the dtype supports infinity, else return False.""" typ = np.dtype(dtype).type @@ -144,6 +150,8 @@ _float8_dtypes = [ _float8_e5m2fnuz_dtype, ] +_float4_dtypes: list[np.dtype] = [] + # TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 if hasattr(ml_dtypes, "float8_e4m3"): float8_e4m3 = ml_dtypes.float8_e4m3 @@ -163,6 +171,12 @@ if hasattr(ml_dtypes, "float8_e8m0fnu"): _custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type] _custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype) _float8_dtypes.insert(0, _float8_e8m0fnu_dtype) +if hasattr(ml_dtypes, "float4_e2m1fn"): + float4_e2m1fn = ml_dtypes.float4_e2m1fn + _float4_e2m1fn_dtype = np.dtype(float4_e2m1fn) + _custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float4_e2m1fn_dtype) + _float4_dtypes.insert(0, _float4_e2m1fn_dtype) # 2-bit integer support int2: type[np.generic] | None = None diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index dd0ae3edc..7d3e342f1 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -75,6 +75,7 @@ enum DType: byte { f8_e5m2 = 20, f8_e5m2fnuz = 21, f8_e8m0fnu = 25, + f4_e2m1fn = 26, } table AbstractValue { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 7707670f1..ac97c11d1 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -365,6 +365,8 @@ if dtypes._float8_e4m3_dtype is not None: _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 if dtypes._float8_e8m0fnu_dtype is not None: _dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu +if dtypes._float4_e2m1fn_dtype is not None: + _dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index 69092cd7e..b1fc13333 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -62,6 +62,7 @@ class DType(object): f8_e5m2fnuz = 21 f0 = 22 f8_e8m0fnu = 25 + f4_e2m1fn = 26 class ShardingKind(object): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 7c10c7b8d..c20fa34d4 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -199,6 +199,9 @@ if dtypes.float8_e4m3 is not None: if dtypes.float8_e8m0fnu is not None: _dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get +if dtypes.float4_e2m1fn is not None: + _dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get + def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): # TODO Support different-size underlying dtypes to take advantage of the diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py index 5d20b73af..585a5484a 100644 --- a/jax/_src/numpy/scalar_types.py +++ b/jax/_src/numpy/scalar_types.py @@ -93,6 +93,8 @@ float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz) float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz) +if dtypes.float4_e2m1fn is not None: + float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) bfloat16 = _make_scalar_type(dtypes.bfloat16) float16 = _make_scalar_type(np.float16) float32 = single = _make_scalar_type(np.float32) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 220342ce5..455a3b98c 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -100,6 +100,9 @@ if _dtypes.float8_e4m3 is not None: if _dtypes.float8_e8m0fnu is not None: _default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 +if _dtypes.float4_e2m1fn is not None: + _default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 + default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -124,6 +127,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): custom_float_dtypes.insert(0, _dtypes.float8_e3m4) if _dtypes.float8_e8m0fnu is not None: custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu) + if _dtypes.float4_e2m1fn is not None: + custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn) def maybe_upcast(x): if x.dtype in custom_float_dtypes: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 9f2bab2b4..18f7efa16 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1640,6 +1640,8 @@ class _LazyDtypes: float_dtypes += [_dtypes.float8_e4m3] if _dtypes.float8_e8m0fnu is not None: float_dtypes += [_dtypes.float8_e8m0fnu] + if _dtypes.float4_e2m1fn is not None: + float_dtypes += [_dtypes.float4_e2m1fn] return self.supported(float_dtypes) @_cached_property diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index ad71b9f74..cb291bdca 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -310,6 +310,7 @@ try: float8_e3m4 as float8_e3m4, float8_e4m3 as float8_e4m3, float8_e8m0fnu as float8_e8m0fnu, + float4_e2m1fn as float4_e2m1fn, ) except ImportError: pass diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 8127aed7a..fca3f4320 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -73,6 +73,12 @@ if dtypes.float8_e8m0fnu is not None: float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes +fp4_dtypes = [] +if dtypes.float4_e2m1fn is not None: + fp4_dtypes += [np.dtype(dtypes.float4_e2m1fn)] +float_dtypes += fp4_dtypes +custom_float_dtypes += fp4_dtypes + complex_dtypes = [np.dtype('complex64'), np.dtype('complex128')] @@ -238,6 +244,8 @@ class DtypesTest(jtu.JaxTestCase): continue if t1 in intn_dtypes: continue + if t1 in fp4_dtypes: + continue self.assertEqual(np.dtype(np.complex128), dtypes.promote_types(t1, np.complex128)) @@ -247,6 +255,8 @@ class DtypesTest(jtu.JaxTestCase): continue if t2 in intn_dtypes: continue + if t2 in fp4_dtypes: + continue # Symmetry self.assertEqual(dtypes.promote_types(t1, t2), dtypes.promote_types(t2, t1)) @@ -261,6 +271,8 @@ class DtypesTest(jtu.JaxTestCase): # TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8. if t in fp8_dtypes: continue + if t in fp4_dtypes: + continue if t in intn_dtypes or i in intn_dtypes: continue self.assertEqual(t, dtypes.promote_types(t, i)) @@ -951,10 +963,12 @@ class TestPromotionTables(jtu.JaxTestCase): self.skipTest("XLA support for int2 and int4 is incomplete.") if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']): self.skipTest("TPU does not support float8_e8m0fnu.") + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + self.skipTest("TPU does not support float4_e2m1fn.") x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) if weak_type: expected = dtypes.canonicalize_dtype( - dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes] else x.dtype.kind]) + dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes, *fp4_dtypes] else x.dtype.kind]) else: expected = x.dtype self.assertEqual(dtypes.result_type(x), expected) @@ -971,6 +985,17 @@ class TestPromotionTables(jtu.JaxTestCase): ".*8-bit floats do not support implicit promotion"): x + y + @jax.numpy_dtype_promotion('standard') + def testFloat4PromotionError(self): + for dtype in fp4_dtypes: + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + self.skipTest("TPU does not support float4_e2m1fn.") + x = jnp.array(1, dtype=dtype) + y = jnp.array(1, dtype='float32') + with self.assertRaisesRegex(dtypes.TypePromotionError, + ".*4-bit floats do not support implicit promotion"): + x + y + @jax.numpy_dtype_promotion('standard') @jtu.run_on_devices('tpu') def testInt2PromotionError(self): @@ -995,6 +1020,8 @@ class TestPromotionTables(jtu.JaxTestCase): def testBinaryNonPromotion(self, dtype, weak_type, promotion): if dtype in fp8_dtypes: self.skipTest("XLA support for float8 is incomplete.") + if dtype in fp4_dtypes: + self.skipTest("XLA support for float4 is incomplete.") if dtype in intn_dtypes: self.skipTest("XLA support for int2 and int4 is incomplete.") # Regression test for https://github.com/jax-ml/jax/issues/6051 @@ -1027,6 +1054,8 @@ class TestPromotionTables(jtu.JaxTestCase): self.skipTest('XLA support for int2 is incomplete.') if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']): self.skipTest('TPU does not support float8_e8m0fnu.') + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + self.skipTest('TPU does not support float4_e2m1fn.') val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) rep = repr(val) self.assertStartsWith(rep, 'Array(') diff --git a/tests/export_test.py b/tests/export_test.py index 60c96fca4..6baecebe1 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1014,6 +1014,8 @@ class JaxExportTest(jtu.JaxTestCase): self.skipTest(f"TODO: serialization not supported for {str(dtype)}") if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']): self.skipTest("TPU does not support float8_e8m0fnu.") + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + self.skipTest("TPU does not support float4_e2m1fn.") @jax.jit def f_jax(x): return x + x