Address previous FP8-related TODOs in jaxlib/XLA.

The ml_dtype requirement in JAX was updated to version 0.5.0+ (on Mar 20, 2025) - commit 4b7ead4

This update allows us to address previous FP8-related TODOs in jaxlib/XLA.

PiperOrigin-RevId: 744943824
This commit is contained in:
Alex Pivovarov 2025-04-07 20:00:46 -07:00 committed by jax authors
parent 86de4783bb
commit bb515aa74f
5 changed files with 54 additions and 26 deletions

View File

@ -694,16 +694,25 @@ absl::StatusOr<PyArgSignature> PyArgSignatureOfValue(nb::handle arg,
(*p)[dtypes.np_uint16.ptr()] = numpy_array_handler;
(*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;
(*p)[dtypes.np_uint64.ptr()] = np_uint64_handler;
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
// (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler;
// (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler;
// (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler;
// (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler;
// TODO(upwind): Explore if we can remove std::optional for these types
// in xla/python/types.h and xla/python/types.cc
if (dtypes.np_float4_e2m1fn.has_value()) {
(*p)[dtypes.np_float4_e2m1fn->ptr()] = numpy_array_handler;
}
if (dtypes.np_float8_e3m4.has_value()) {
(*p)[dtypes.np_float8_e3m4->ptr()] = numpy_array_handler;
}
if (dtypes.np_float8_e4m3.has_value()) {
(*p)[dtypes.np_float8_e4m3->ptr()] = numpy_array_handler;
}
(*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler;
if (dtypes.np_float8_e8m0fnu.has_value()) {
(*p)[dtypes.np_float8_e8m0fnu->ptr()] = numpy_array_handler;
}
(*p)[dtypes.np_float16.ptr()] = numpy_array_handler;
(*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler;
(*p)[dtypes.np_float32.ptr()] = numpy_array_handler;

View File

@ -208,15 +208,14 @@ NB_MODULE(xla_extension, m) {
.value("U64", U64)
.value("F16", F16)
.value("F4E2M1FN", F4E2M1FN)
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
// .value("F8E3M4", F8E3M4)
// .value("F8E4M3", F8E4M3)
.value("F8E8M0FNU", F8E8M0FNU)
.value("F8E3M4", F8E3M4)
.value("F8E4M3", F8E4M3)
.value("F8E4M3FN", F8E4M3FN)
.value("F8E4M3B11FNUZ", F8E4M3B11FNUZ)
.value("F8E4M3FNUZ", F8E4M3FNUZ)
.value("F8E5M2", F8E5M2)
.value("F8E5M2FNUZ", F8E5M2FNUZ)
.value("F8E8M0FNU", F8E8M0FNU)
.value("BF16", BF16)
.value("F32", F32)
.value("F64", F64)

View File

@ -260,16 +260,15 @@ XLA_ELEMENT_TYPE_TO_DTYPE = {
PrimitiveType.U16: np.dtype('uint16'),
PrimitiveType.U32: np.dtype('uint32'),
PrimitiveType.U64: np.dtype('uint64'),
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
# PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn),
# PrimitiveType.F8E3M4: np.dtype(float8_e3m4),
# PrimitiveType.F8E4M3: np.dtype(float8_e4m3),
# PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu),
PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn),
PrimitiveType.F8E3M4: np.dtype(float8_e3m4),
PrimitiveType.F8E4M3: np.dtype(float8_e4m3),
PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn),
PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz),
PrimitiveType.F8E5M2: np.dtype(float8_e5m2),
PrimitiveType.F8E4M3FNUZ: np.dtype(float8_e4m3fnuz),
PrimitiveType.F8E5M2: np.dtype(float8_e5m2),
PrimitiveType.F8E5M2FNUZ: np.dtype(float8_e5m2fnuz),
PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu),
PrimitiveType.BF16: np.dtype(bfloat16),
PrimitiveType.F16: np.dtype('float16'),
PrimitiveType.F32: np.dtype('float32'),

View File

@ -63,16 +63,15 @@ _ifrt_version: int
mlir_api_version: int
bfloat16: type[numpy.generic]
# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
# float4_e2m1fn: type[numpy.generic]
# float8_e3m4: type[numpy.generic]
# float8_e4m3: type[numpy.generic]
# float8_e8m0fnu: type[numpy.generic]
float4_e2m1fn: type[numpy.generic]
float8_e3m4: type[numpy.generic]
float8_e4m3: type[numpy.generic]
float8_e4m3fn: type[numpy.generic]
float8_e4m3b11fnuz: type[numpy.generic]
float8_e4m3fnuz: type[numpy.generic]
float8_e5m2: type[numpy.generic]
float8_e5m2fnuz: type[numpy.generic]
float8_e8m0fnu: type[numpy.generic]
XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype]
_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]]

View File

@ -48,12 +48,12 @@ bfloat16 = ml_dtypes.bfloat16
float4_e2m1fn = ml_dtypes.float4_e2m1fn
float8_e3m4 = ml_dtypes.float8_e3m4
float8_e4m3 = ml_dtypes.float8_e4m3
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
float8_e4m3fn = ml_dtypes.float8_e4m3fn
float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz
float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz
float8_e5m2 = ml_dtypes.float8_e5m2
float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
ops = xla_client.ops
xla_computation_to_mlir_module = (
xla_client._xla.mlir.xla_computation_to_mlir_module)
@ -178,10 +178,17 @@ def TestFactory(xla_backend,
# TODO(zhangqiaorjc): test fp8 types when XLA support is complete.
# standard_dtypes is only used for BufferProtocolTest so we only test fp8
# round trip tests.
fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2]
fp8_dtypes = [
float8_e3m4,
float8_e4m3,
float8_e4m3fn,
float8_e4m3b11fnuz,
float8_e5m2,
float8_e8m0fnu,
]
standard_dtypes += fp8_dtypes
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
# standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu]
# TODO(upwind): testRoundTrip and testLiveBuffers fail for float4_e2m1fn type
# standard_dtypes += [float4_e2m1fn]
dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes
class ComputationTest(parameterized.TestCase):
@ -1228,9 +1235,19 @@ module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32,
for dtype in standard_dtypes:
if dtype == np.complex128:
continue
# float8_e8m0fnu is not supported on TPU.
if dtype == float8_e8m0fnu and self.backend.platform == "tpu":
continue
# float8_e4m3b11fnuz not supported on some TPU backends.
if (
dtype in [float8_e5m2fnuz, float8_e4m3fnuz, float8_e4m3b11fnuz]
dtype
in [
float8_e3m4,
float8_e4m3,
float8_e4m3fnuz,
float8_e4m3b11fnuz,
float8_e5m2fnuz,
]
and self.backend.platform == "tpu"
):
if self.backend.platform_version.find("TPU") == -1:
@ -2253,6 +2270,11 @@ module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32,
"dtype": dtype,
} for dtype in float_dtypes + fp8_dtypes)
def testNextAfter(self, dtype):
if dtype == float8_e8m0fnu:
# TODO(b/409114865): Test fails with Mismatched elements error.
self.skipTest("b/409114865: Test fails with Mismatched elements error")
if dtype in [float8_e3m4, float8_e4m3] and self.backend.platform == "tpu":
self.skipTest("TPU doesn't support float8_e3m4 or float8_e4m3")
if dtype == np.float64 and self.backend.platform == "tpu":
self.skipTest("TPU doesn't support float64")
if dtype == bfloat16 and self.backend.platform == "tpu":