diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc index 709f3cb3b..90dd77209 100644 --- a/jaxlib/xla/py_values.cc +++ b/jaxlib/xla/py_values.cc @@ -694,16 +694,25 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; - // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - // (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler; - // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; - // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; - // (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler; + // TODO(upwind): Explore if we can remove std::optional for these types + // in xla/python/types.h and xla/python/types.cc + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = numpy_array_handler; + } (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; - (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = numpy_array_handler; + } (*p)[dtypes.np_float16.ptr()] = numpy_array_handler; (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler; (*p)[dtypes.np_float32.ptr()] = numpy_array_handler; diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index e460a1773..660e62bd8 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -208,15 +208,14 @@ NB_MODULE(xla_extension, m) { .value("U64", U64) .value("F16", F16) .value("F4E2M1FN", F4E2M1FN) - // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - // .value("F8E3M4", F8E3M4) - // .value("F8E4M3", F8E4M3) - .value("F8E8M0FNU", F8E8M0FNU) + .value("F8E3M4", F8E3M4) + .value("F8E4M3", F8E4M3) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) .value("F8E5M2", F8E5M2) .value("F8E5M2FNUZ", F8E5M2FNUZ) + .value("F8E8M0FNU", F8E8M0FNU) .value("BF16", BF16) .value("F32", F32) .value("F64", F64) diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index fa31d1764..637d7d060 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -260,16 +260,15 @@ XLA_ELEMENT_TYPE_TO_DTYPE = { PrimitiveType.U16: np.dtype('uint16'), PrimitiveType.U32: np.dtype('uint32'), PrimitiveType.U64: np.dtype('uint64'), - # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), - # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), - # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), - # PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), + PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), + PrimitiveType.F8E3M4: np.dtype(float8_e3m4), + PrimitiveType.F8E4M3: np.dtype(float8_e4m3), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), - PrimitiveType.F8E5M2: np.dtype(float8_e5m2), PrimitiveType.F8E4M3FNUZ: np.dtype(float8_e4m3fnuz), + PrimitiveType.F8E5M2: np.dtype(float8_e5m2), PrimitiveType.F8E5M2FNUZ: np.dtype(float8_e5m2fnuz), + PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), PrimitiveType.BF16: np.dtype(bfloat16), PrimitiveType.F16: np.dtype('float16'), PrimitiveType.F32: np.dtype('float32'), diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla/xla_client.pyi index b182eb65b..382858d2a 100644 --- a/jaxlib/xla/xla_client.pyi +++ b/jaxlib/xla/xla_client.pyi @@ -63,16 +63,15 @@ _ifrt_version: int mlir_api_version: int bfloat16: type[numpy.generic] -# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. -# float4_e2m1fn: type[numpy.generic] -# float8_e3m4: type[numpy.generic] -# float8_e4m3: type[numpy.generic] -# float8_e8m0fnu: type[numpy.generic] +float4_e2m1fn: type[numpy.generic] +float8_e3m4: type[numpy.generic] +float8_e4m3: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] float8_e5m2: type[numpy.generic] float8_e5m2fnuz: type[numpy.generic] +float8_e8m0fnu: type[numpy.generic] XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype] _NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py index 7de905d9e..9c6625610 100644 --- a/jaxlib/xla/xla_client_test.py +++ b/jaxlib/xla/xla_client_test.py @@ -48,12 +48,12 @@ bfloat16 = ml_dtypes.bfloat16 float4_e2m1fn = ml_dtypes.float4_e2m1fn float8_e3m4 = ml_dtypes.float8_e3m4 float8_e4m3 = ml_dtypes.float8_e4m3 -float8_e8m0fnu = ml_dtypes.float8_e8m0fnu float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e5m2 = ml_dtypes.float8_e5m2 float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +float8_e8m0fnu = ml_dtypes.float8_e8m0fnu ops = xla_client.ops xla_computation_to_mlir_module = ( xla_client._xla.mlir.xla_computation_to_mlir_module) @@ -178,10 +178,17 @@ def TestFactory(xla_backend, # TODO(zhangqiaorjc): test fp8 types when XLA support is complete. # standard_dtypes is only used for BufferProtocolTest so we only test fp8 # round trip tests. - fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] + fp8_dtypes = [ + float8_e3m4, + float8_e4m3, + float8_e4m3fn, + float8_e4m3b11fnuz, + float8_e5m2, + float8_e8m0fnu, + ] standard_dtypes += fp8_dtypes - # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu] + # TODO(upwind): testRoundTrip and testLiveBuffers fail for float4_e2m1fn type + # standard_dtypes += [float4_e2m1fn] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): @@ -1228,9 +1235,19 @@ module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, for dtype in standard_dtypes: if dtype == np.complex128: continue + # float8_e8m0fnu is not supported on TPU. + if dtype == float8_e8m0fnu and self.backend.platform == "tpu": + continue # float8_e4m3b11fnuz not supported on some TPU backends. if ( - dtype in [float8_e5m2fnuz, float8_e4m3fnuz, float8_e4m3b11fnuz] + dtype + in [ + float8_e3m4, + float8_e4m3, + float8_e4m3fnuz, + float8_e4m3b11fnuz, + float8_e5m2fnuz, + ] and self.backend.platform == "tpu" ): if self.backend.platform_version.find("TPU") == -1: @@ -2253,6 +2270,11 @@ module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, "dtype": dtype, } for dtype in float_dtypes + fp8_dtypes) def testNextAfter(self, dtype): + if dtype == float8_e8m0fnu: + # TODO(b/409114865): Test fails with Mismatched elements error. + self.skipTest("b/409114865: Test fails with Mismatched elements error") + if dtype in [float8_e3m4, float8_e4m3] and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float8_e3m4 or float8_e4m3") if dtype == np.float64 and self.backend.platform == "tpu": self.skipTest("TPU doesn't support float64") if dtype == bfloat16 and self.backend.platform == "tpu":