mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
support e2m1fn
This commit is contained in:
parent
4493889cda
commit
c099e8081d
@ -109,6 +109,12 @@ _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
|
||||
_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2)
|
||||
_float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz)
|
||||
|
||||
# fp4 support
|
||||
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
|
||||
float4_e2m1fn: type[np.generic] | None = None
|
||||
|
||||
_float4_e2m1fn_dtype: np.dtype | None = None
|
||||
|
||||
def supports_inf(dtype: DTypeLike) -> bool:
|
||||
"""Return true if the dtype supports infinity, else return False."""
|
||||
typ = np.dtype(dtype).type
|
||||
@ -144,6 +150,8 @@ _float8_dtypes = [
|
||||
_float8_e5m2fnuz_dtype,
|
||||
]
|
||||
|
||||
_float4_dtypes: list[np.dtype] = []
|
||||
|
||||
# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
|
||||
if hasattr(ml_dtypes, "float8_e4m3"):
|
||||
float8_e4m3 = ml_dtypes.float8_e4m3
|
||||
@ -163,6 +171,12 @@ if hasattr(ml_dtypes, "float8_e8m0fnu"):
|
||||
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
|
||||
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
|
||||
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)
|
||||
if hasattr(ml_dtypes, "float4_e2m1fn"):
|
||||
float4_e2m1fn = ml_dtypes.float4_e2m1fn
|
||||
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
|
||||
_custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type]
|
||||
_custom_float_dtypes.insert(0, _float4_e2m1fn_dtype)
|
||||
_float4_dtypes.insert(0, _float4_e2m1fn_dtype)
|
||||
|
||||
# 2-bit integer support
|
||||
int2: type[np.generic] | None = None
|
||||
|
@ -75,6 +75,7 @@ enum DType: byte {
|
||||
f8_e5m2 = 20,
|
||||
f8_e5m2fnuz = 21,
|
||||
f8_e8m0fnu = 25,
|
||||
f4_e2m1fn = 26,
|
||||
}
|
||||
|
||||
table AbstractValue {
|
||||
|
@ -365,6 +365,8 @@ if dtypes._float8_e4m3_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
|
||||
if dtypes._float8_e8m0fnu_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
|
||||
if dtypes._float4_e2m1fn_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn
|
||||
_dtype_kind_to_dtype = {
|
||||
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
|
||||
}
|
||||
|
@ -62,6 +62,7 @@ class DType(object):
|
||||
f8_e5m2fnuz = 21
|
||||
f0 = 22
|
||||
f8_e8m0fnu = 25
|
||||
f4_e2m1fn = 26
|
||||
|
||||
|
||||
class ShardingKind(object):
|
||||
|
@ -199,6 +199,9 @@ if dtypes.float8_e4m3 is not None:
|
||||
if dtypes.float8_e8m0fnu is not None:
|
||||
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get
|
||||
|
||||
if dtypes.float4_e2m1fn is not None:
|
||||
_dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get
|
||||
|
||||
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
|
||||
if isinstance(dtype, core.bint):
|
||||
# TODO Support different-size underlying dtypes to take advantage of the
|
||||
|
@ -93,6 +93,8 @@ float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
|
||||
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
|
||||
float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz)
|
||||
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
|
||||
if dtypes.float4_e2m1fn is not None:
|
||||
float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn)
|
||||
bfloat16 = _make_scalar_type(dtypes.bfloat16)
|
||||
float16 = _make_scalar_type(np.float16)
|
||||
float32 = single = _make_scalar_type(np.float32)
|
||||
|
@ -100,6 +100,9 @@ if _dtypes.float8_e4m3 is not None:
|
||||
if _dtypes.float8_e8m0fnu is not None:
|
||||
_default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
|
||||
default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
|
||||
if _dtypes.float4_e2m1fn is not None:
|
||||
_default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
|
||||
default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
|
||||
|
||||
def is_python_scalar(val):
|
||||
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
|
||||
@ -124,6 +127,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
|
||||
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
|
||||
if _dtypes.float8_e8m0fnu is not None:
|
||||
custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu)
|
||||
if _dtypes.float4_e2m1fn is not None:
|
||||
custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn)
|
||||
|
||||
def maybe_upcast(x):
|
||||
if x.dtype in custom_float_dtypes:
|
||||
|
@ -1640,6 +1640,8 @@ class _LazyDtypes:
|
||||
float_dtypes += [_dtypes.float8_e4m3]
|
||||
if _dtypes.float8_e8m0fnu is not None:
|
||||
float_dtypes += [_dtypes.float8_e8m0fnu]
|
||||
if _dtypes.float4_e2m1fn is not None:
|
||||
float_dtypes += [_dtypes.float4_e2m1fn]
|
||||
return self.supported(float_dtypes)
|
||||
|
||||
@_cached_property
|
||||
|
@ -310,6 +310,7 @@ try:
|
||||
float8_e3m4 as float8_e3m4,
|
||||
float8_e4m3 as float8_e4m3,
|
||||
float8_e8m0fnu as float8_e8m0fnu,
|
||||
float4_e2m1fn as float4_e2m1fn,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
@ -73,6 +73,12 @@ if dtypes.float8_e8m0fnu is not None:
|
||||
float_dtypes += fp8_dtypes
|
||||
custom_float_dtypes += fp8_dtypes
|
||||
|
||||
fp4_dtypes = []
|
||||
if dtypes.float4_e2m1fn is not None:
|
||||
fp4_dtypes += [np.dtype(dtypes.float4_e2m1fn)]
|
||||
float_dtypes += fp4_dtypes
|
||||
custom_float_dtypes += fp4_dtypes
|
||||
|
||||
complex_dtypes = [np.dtype('complex64'), np.dtype('complex128')]
|
||||
|
||||
|
||||
@ -238,6 +244,8 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
continue
|
||||
if t1 in intn_dtypes:
|
||||
continue
|
||||
if t1 in fp4_dtypes:
|
||||
continue
|
||||
self.assertEqual(np.dtype(np.complex128),
|
||||
dtypes.promote_types(t1, np.complex128))
|
||||
|
||||
@ -247,6 +255,8 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
continue
|
||||
if t2 in intn_dtypes:
|
||||
continue
|
||||
if t2 in fp4_dtypes:
|
||||
continue
|
||||
# Symmetry
|
||||
self.assertEqual(dtypes.promote_types(t1, t2),
|
||||
dtypes.promote_types(t2, t1))
|
||||
@ -261,6 +271,8 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
# TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8.
|
||||
if t in fp8_dtypes:
|
||||
continue
|
||||
if t in fp4_dtypes:
|
||||
continue
|
||||
if t in intn_dtypes or i in intn_dtypes:
|
||||
continue
|
||||
self.assertEqual(t, dtypes.promote_types(t, i))
|
||||
@ -951,10 +963,12 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
self.skipTest("XLA support for int2 and int4 is incomplete.")
|
||||
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
|
||||
self.skipTest("TPU does not support float8_e8m0fnu.")
|
||||
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
|
||||
self.skipTest("TPU does not support float4_e2m1fn.")
|
||||
x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
|
||||
if weak_type:
|
||||
expected = dtypes.canonicalize_dtype(
|
||||
dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes] else x.dtype.kind])
|
||||
dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes, *fp4_dtypes] else x.dtype.kind])
|
||||
else:
|
||||
expected = x.dtype
|
||||
self.assertEqual(dtypes.result_type(x), expected)
|
||||
@ -971,6 +985,17 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
".*8-bit floats do not support implicit promotion"):
|
||||
x + y
|
||||
|
||||
@jax.numpy_dtype_promotion('standard')
|
||||
def testFloat4PromotionError(self):
|
||||
for dtype in fp4_dtypes:
|
||||
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
|
||||
self.skipTest("TPU does not support float4_e2m1fn.")
|
||||
x = jnp.array(1, dtype=dtype)
|
||||
y = jnp.array(1, dtype='float32')
|
||||
with self.assertRaisesRegex(dtypes.TypePromotionError,
|
||||
".*4-bit floats do not support implicit promotion"):
|
||||
x + y
|
||||
|
||||
@jax.numpy_dtype_promotion('standard')
|
||||
@jtu.run_on_devices('tpu')
|
||||
def testInt2PromotionError(self):
|
||||
@ -995,6 +1020,8 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
def testBinaryNonPromotion(self, dtype, weak_type, promotion):
|
||||
if dtype in fp8_dtypes:
|
||||
self.skipTest("XLA support for float8 is incomplete.")
|
||||
if dtype in fp4_dtypes:
|
||||
self.skipTest("XLA support for float4 is incomplete.")
|
||||
if dtype in intn_dtypes:
|
||||
self.skipTest("XLA support for int2 and int4 is incomplete.")
|
||||
# Regression test for https://github.com/jax-ml/jax/issues/6051
|
||||
@ -1027,6 +1054,8 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
self.skipTest('XLA support for int2 is incomplete.')
|
||||
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
|
||||
self.skipTest('TPU does not support float8_e8m0fnu.')
|
||||
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
|
||||
self.skipTest('TPU does not support float4_e2m1fn.')
|
||||
val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
|
||||
rep = repr(val)
|
||||
self.assertStartsWith(rep, 'Array(')
|
||||
|
@ -1014,6 +1014,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.skipTest(f"TODO: serialization not supported for {str(dtype)}")
|
||||
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
|
||||
self.skipTest("TPU does not support float8_e8m0fnu.")
|
||||
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
|
||||
self.skipTest("TPU does not support float4_e2m1fn.")
|
||||
@jax.jit
|
||||
def f_jax(x):
|
||||
return x + x
|
||||
|
Loading…
x
Reference in New Issue
Block a user