mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Address previous FP8-related TODOs in jaxlib/XLA.
The ml_dtype requirement in JAX was updated to version 0.5.0+ (on Mar 20, 2025) - commit 4b7ead4 This update allows us to address previous FP8-related TODOs in jaxlib/XLA. PiperOrigin-RevId: 744943824
This commit is contained in:
parent
86de4783bb
commit
bb515aa74f
@ -694,16 +694,25 @@ absl::StatusOr<PyArgSignature> PyArgSignatureOfValue(nb::handle arg,
|
||||
(*p)[dtypes.np_uint16.ptr()] = numpy_array_handler;
|
||||
(*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;
|
||||
(*p)[dtypes.np_uint64.ptr()] = np_uint64_handler;
|
||||
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
|
||||
// (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler;
|
||||
// (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler;
|
||||
// (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler;
|
||||
// (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler;
|
||||
// TODO(upwind): Explore if we can remove std::optional for these types
|
||||
// in xla/python/types.h and xla/python/types.cc
|
||||
if (dtypes.np_float4_e2m1fn.has_value()) {
|
||||
(*p)[dtypes.np_float4_e2m1fn->ptr()] = numpy_array_handler;
|
||||
}
|
||||
if (dtypes.np_float8_e3m4.has_value()) {
|
||||
(*p)[dtypes.np_float8_e3m4->ptr()] = numpy_array_handler;
|
||||
}
|
||||
if (dtypes.np_float8_e4m3.has_value()) {
|
||||
(*p)[dtypes.np_float8_e4m3->ptr()] = numpy_array_handler;
|
||||
}
|
||||
(*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler;
|
||||
(*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler;
|
||||
(*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler;
|
||||
(*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler;
|
||||
(*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler;
|
||||
(*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler;
|
||||
if (dtypes.np_float8_e8m0fnu.has_value()) {
|
||||
(*p)[dtypes.np_float8_e8m0fnu->ptr()] = numpy_array_handler;
|
||||
}
|
||||
(*p)[dtypes.np_float16.ptr()] = numpy_array_handler;
|
||||
(*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler;
|
||||
(*p)[dtypes.np_float32.ptr()] = numpy_array_handler;
|
||||
|
@ -208,15 +208,14 @@ NB_MODULE(xla_extension, m) {
|
||||
.value("U64", U64)
|
||||
.value("F16", F16)
|
||||
.value("F4E2M1FN", F4E2M1FN)
|
||||
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
|
||||
// .value("F8E3M4", F8E3M4)
|
||||
// .value("F8E4M3", F8E4M3)
|
||||
.value("F8E8M0FNU", F8E8M0FNU)
|
||||
.value("F8E3M4", F8E3M4)
|
||||
.value("F8E4M3", F8E4M3)
|
||||
.value("F8E4M3FN", F8E4M3FN)
|
||||
.value("F8E4M3B11FNUZ", F8E4M3B11FNUZ)
|
||||
.value("F8E4M3FNUZ", F8E4M3FNUZ)
|
||||
.value("F8E5M2", F8E5M2)
|
||||
.value("F8E5M2FNUZ", F8E5M2FNUZ)
|
||||
.value("F8E8M0FNU", F8E8M0FNU)
|
||||
.value("BF16", BF16)
|
||||
.value("F32", F32)
|
||||
.value("F64", F64)
|
||||
|
@ -260,16 +260,15 @@ XLA_ELEMENT_TYPE_TO_DTYPE = {
|
||||
PrimitiveType.U16: np.dtype('uint16'),
|
||||
PrimitiveType.U32: np.dtype('uint32'),
|
||||
PrimitiveType.U64: np.dtype('uint64'),
|
||||
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
|
||||
# PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn),
|
||||
# PrimitiveType.F8E3M4: np.dtype(float8_e3m4),
|
||||
# PrimitiveType.F8E4M3: np.dtype(float8_e4m3),
|
||||
# PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu),
|
||||
PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn),
|
||||
PrimitiveType.F8E3M4: np.dtype(float8_e3m4),
|
||||
PrimitiveType.F8E4M3: np.dtype(float8_e4m3),
|
||||
PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn),
|
||||
PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz),
|
||||
PrimitiveType.F8E5M2: np.dtype(float8_e5m2),
|
||||
PrimitiveType.F8E4M3FNUZ: np.dtype(float8_e4m3fnuz),
|
||||
PrimitiveType.F8E5M2: np.dtype(float8_e5m2),
|
||||
PrimitiveType.F8E5M2FNUZ: np.dtype(float8_e5m2fnuz),
|
||||
PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu),
|
||||
PrimitiveType.BF16: np.dtype(bfloat16),
|
||||
PrimitiveType.F16: np.dtype('float16'),
|
||||
PrimitiveType.F32: np.dtype('float32'),
|
||||
|
@ -63,16 +63,15 @@ _ifrt_version: int
|
||||
mlir_api_version: int
|
||||
|
||||
bfloat16: type[numpy.generic]
|
||||
# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
|
||||
# float4_e2m1fn: type[numpy.generic]
|
||||
# float8_e3m4: type[numpy.generic]
|
||||
# float8_e4m3: type[numpy.generic]
|
||||
# float8_e8m0fnu: type[numpy.generic]
|
||||
float4_e2m1fn: type[numpy.generic]
|
||||
float8_e3m4: type[numpy.generic]
|
||||
float8_e4m3: type[numpy.generic]
|
||||
float8_e4m3fn: type[numpy.generic]
|
||||
float8_e4m3b11fnuz: type[numpy.generic]
|
||||
float8_e4m3fnuz: type[numpy.generic]
|
||||
float8_e5m2: type[numpy.generic]
|
||||
float8_e5m2fnuz: type[numpy.generic]
|
||||
float8_e8m0fnu: type[numpy.generic]
|
||||
XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype]
|
||||
|
||||
_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]]
|
||||
|
@ -48,12 +48,12 @@ bfloat16 = ml_dtypes.bfloat16
|
||||
float4_e2m1fn = ml_dtypes.float4_e2m1fn
|
||||
float8_e3m4 = ml_dtypes.float8_e3m4
|
||||
float8_e4m3 = ml_dtypes.float8_e4m3
|
||||
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
|
||||
float8_e4m3fn = ml_dtypes.float8_e4m3fn
|
||||
float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz
|
||||
float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz
|
||||
float8_e5m2 = ml_dtypes.float8_e5m2
|
||||
float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz
|
||||
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
|
||||
ops = xla_client.ops
|
||||
xla_computation_to_mlir_module = (
|
||||
xla_client._xla.mlir.xla_computation_to_mlir_module)
|
||||
@ -178,10 +178,17 @@ def TestFactory(xla_backend,
|
||||
# TODO(zhangqiaorjc): test fp8 types when XLA support is complete.
|
||||
# standard_dtypes is only used for BufferProtocolTest so we only test fp8
|
||||
# round trip tests.
|
||||
fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2]
|
||||
fp8_dtypes = [
|
||||
float8_e3m4,
|
||||
float8_e4m3,
|
||||
float8_e4m3fn,
|
||||
float8_e4m3b11fnuz,
|
||||
float8_e5m2,
|
||||
float8_e8m0fnu,
|
||||
]
|
||||
standard_dtypes += fp8_dtypes
|
||||
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
|
||||
# standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu]
|
||||
# TODO(upwind): testRoundTrip and testLiveBuffers fail for float4_e2m1fn type
|
||||
# standard_dtypes += [float4_e2m1fn]
|
||||
dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes
|
||||
|
||||
class ComputationTest(parameterized.TestCase):
|
||||
@ -1228,9 +1235,19 @@ module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32,
|
||||
for dtype in standard_dtypes:
|
||||
if dtype == np.complex128:
|
||||
continue
|
||||
# float8_e8m0fnu is not supported on TPU.
|
||||
if dtype == float8_e8m0fnu and self.backend.platform == "tpu":
|
||||
continue
|
||||
# float8_e4m3b11fnuz not supported on some TPU backends.
|
||||
if (
|
||||
dtype in [float8_e5m2fnuz, float8_e4m3fnuz, float8_e4m3b11fnuz]
|
||||
dtype
|
||||
in [
|
||||
float8_e3m4,
|
||||
float8_e4m3,
|
||||
float8_e4m3fnuz,
|
||||
float8_e4m3b11fnuz,
|
||||
float8_e5m2fnuz,
|
||||
]
|
||||
and self.backend.platform == "tpu"
|
||||
):
|
||||
if self.backend.platform_version.find("TPU") == -1:
|
||||
@ -2253,6 +2270,11 @@ module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32,
|
||||
"dtype": dtype,
|
||||
} for dtype in float_dtypes + fp8_dtypes)
|
||||
def testNextAfter(self, dtype):
|
||||
if dtype == float8_e8m0fnu:
|
||||
# TODO(b/409114865): Test fails with Mismatched elements error.
|
||||
self.skipTest("b/409114865: Test fails with Mismatched elements error")
|
||||
if dtype in [float8_e3m4, float8_e4m3] and self.backend.platform == "tpu":
|
||||
self.skipTest("TPU doesn't support float8_e3m4 or float8_e4m3")
|
||||
if dtype == np.float64 and self.backend.platform == "tpu":
|
||||
self.skipTest("TPU doesn't support float64")
|
||||
if dtype == bfloat16 and self.backend.platform == "tpu":
|
||||
|
Loading…
x
Reference in New Issue
Block a user