support e2m1fn

This commit is contained in:
shuw 2025-02-13 20:53:26 +00:00
parent 4493889cda
commit c099e8081d
11 changed files with 63 additions and 1 deletions

View File

@ -109,6 +109,12 @@ _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2)
_float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz)
# fp4 support
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
float4_e2m1fn: type[np.generic] | None = None
_float4_e2m1fn_dtype: np.dtype | None = None
def supports_inf(dtype: DTypeLike) -> bool:
"""Return true if the dtype supports infinity, else return False."""
typ = np.dtype(dtype).type
@ -144,6 +150,8 @@ _float8_dtypes = [
_float8_e5m2fnuz_dtype,
]
_float4_dtypes: list[np.dtype] = []
# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
if hasattr(ml_dtypes, "float8_e4m3"):
float8_e4m3 = ml_dtypes.float8_e4m3
@ -163,6 +171,12 @@ if hasattr(ml_dtypes, "float8_e8m0fnu"):
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)
if hasattr(ml_dtypes, "float4_e2m1fn"):
float4_e2m1fn = ml_dtypes.float4_e2m1fn
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
_custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float4_e2m1fn_dtype)
_float4_dtypes.insert(0, _float4_e2m1fn_dtype)
# 2-bit integer support
int2: type[np.generic] | None = None

View File

@ -75,6 +75,7 @@ enum DType: byte {
f8_e5m2 = 20,
f8_e5m2fnuz = 21,
f8_e8m0fnu = 25,
f4_e2m1fn = 26,
}
table AbstractValue {

View File

@ -365,6 +365,8 @@ if dtypes._float8_e4m3_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
if dtypes._float8_e8m0fnu_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
if dtypes._float4_e2m1fn_dtype is not None:
_dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn
_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
}

View File

@ -62,6 +62,7 @@ class DType(object):
f8_e5m2fnuz = 21
f0 = 22
f8_e8m0fnu = 25
f4_e2m1fn = 26
class ShardingKind(object):

View File

@ -199,6 +199,9 @@ if dtypes.float8_e4m3 is not None:
if dtypes.float8_e8m0fnu is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get
if dtypes.float4_e2m1fn is not None:
_dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):
# TODO Support different-size underlying dtypes to take advantage of the

View File

@ -93,6 +93,8 @@ float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz)
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
if dtypes.float4_e2m1fn is not None:
float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn)
bfloat16 = _make_scalar_type(dtypes.bfloat16)
float16 = _make_scalar_type(np.float16)
float32 = single = _make_scalar_type(np.float32)

View File

@ -100,6 +100,9 @@ if _dtypes.float8_e4m3 is not None:
if _dtypes.float8_e8m0fnu is not None:
_default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
if _dtypes.float4_e2m1fn is not None:
_default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
@ -124,6 +127,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
if _dtypes.float8_e8m0fnu is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu)
if _dtypes.float4_e2m1fn is not None:
custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn)
def maybe_upcast(x):
if x.dtype in custom_float_dtypes:

View File

@ -1640,6 +1640,8 @@ class _LazyDtypes:
float_dtypes += [_dtypes.float8_e4m3]
if _dtypes.float8_e8m0fnu is not None:
float_dtypes += [_dtypes.float8_e8m0fnu]
if _dtypes.float4_e2m1fn is not None:
float_dtypes += [_dtypes.float4_e2m1fn]
return self.supported(float_dtypes)
@_cached_property

View File

@ -310,6 +310,7 @@ try:
float8_e3m4 as float8_e3m4,
float8_e4m3 as float8_e4m3,
float8_e8m0fnu as float8_e8m0fnu,
float4_e2m1fn as float4_e2m1fn,
)
except ImportError:
pass

View File

@ -73,6 +73,12 @@ if dtypes.float8_e8m0fnu is not None:
float_dtypes += fp8_dtypes
custom_float_dtypes += fp8_dtypes
fp4_dtypes = []
if dtypes.float4_e2m1fn is not None:
fp4_dtypes += [np.dtype(dtypes.float4_e2m1fn)]
float_dtypes += fp4_dtypes
custom_float_dtypes += fp4_dtypes
complex_dtypes = [np.dtype('complex64'), np.dtype('complex128')]
@ -238,6 +244,8 @@ class DtypesTest(jtu.JaxTestCase):
continue
if t1 in intn_dtypes:
continue
if t1 in fp4_dtypes:
continue
self.assertEqual(np.dtype(np.complex128),
dtypes.promote_types(t1, np.complex128))
@ -247,6 +255,8 @@ class DtypesTest(jtu.JaxTestCase):
continue
if t2 in intn_dtypes:
continue
if t2 in fp4_dtypes:
continue
# Symmetry
self.assertEqual(dtypes.promote_types(t1, t2),
dtypes.promote_types(t2, t1))
@ -261,6 +271,8 @@ class DtypesTest(jtu.JaxTestCase):
# TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8.
if t in fp8_dtypes:
continue
if t in fp4_dtypes:
continue
if t in intn_dtypes or i in intn_dtypes:
continue
self.assertEqual(t, dtypes.promote_types(t, i))
@ -951,10 +963,12 @@ class TestPromotionTables(jtu.JaxTestCase):
self.skipTest("XLA support for int2 and int4 is incomplete.")
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
self.skipTest("TPU does not support float8_e8m0fnu.")
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
self.skipTest("TPU does not support float4_e2m1fn.")
x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
if weak_type:
expected = dtypes.canonicalize_dtype(
dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes] else x.dtype.kind])
dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes, *fp4_dtypes] else x.dtype.kind])
else:
expected = x.dtype
self.assertEqual(dtypes.result_type(x), expected)
@ -971,6 +985,17 @@ class TestPromotionTables(jtu.JaxTestCase):
".*8-bit floats do not support implicit promotion"):
x + y
@jax.numpy_dtype_promotion('standard')
def testFloat4PromotionError(self):
for dtype in fp4_dtypes:
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
self.skipTest("TPU does not support float4_e2m1fn.")
x = jnp.array(1, dtype=dtype)
y = jnp.array(1, dtype='float32')
with self.assertRaisesRegex(dtypes.TypePromotionError,
".*4-bit floats do not support implicit promotion"):
x + y
@jax.numpy_dtype_promotion('standard')
@jtu.run_on_devices('tpu')
def testInt2PromotionError(self):
@ -995,6 +1020,8 @@ class TestPromotionTables(jtu.JaxTestCase):
def testBinaryNonPromotion(self, dtype, weak_type, promotion):
if dtype in fp8_dtypes:
self.skipTest("XLA support for float8 is incomplete.")
if dtype in fp4_dtypes:
self.skipTest("XLA support for float4 is incomplete.")
if dtype in intn_dtypes:
self.skipTest("XLA support for int2 and int4 is incomplete.")
# Regression test for https://github.com/jax-ml/jax/issues/6051
@ -1027,6 +1054,8 @@ class TestPromotionTables(jtu.JaxTestCase):
self.skipTest('XLA support for int2 is incomplete.')
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
self.skipTest('TPU does not support float8_e8m0fnu.')
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
self.skipTest('TPU does not support float4_e2m1fn.')
val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
rep = repr(val)
self.assertStartsWith(rep, 'Array(')

View File

@ -1014,6 +1014,8 @@ class JaxExportTest(jtu.JaxTestCase):
self.skipTest(f"TODO: serialization not supported for {str(dtype)}")
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
self.skipTest("TPU does not support float8_e8m0fnu.")
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
self.skipTest("TPU does not support float4_e2m1fn.")
@jax.jit
def f_jax(x):
return x + x