diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index f6b706f9bc8a..882b2751921b 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -7,6 +7,12 @@ import numpy as np import ctypes +try: + import ml_dtypes +except ModuleNotFoundError: + # The third-party ml_dtypes provides some optional low precision data-types for NumPy. + ml_dtypes = None + class C128(ctypes.Structure): """A ctype representation for MLIR's Double Complex.""" @@ -26,6 +32,12 @@ class F16(ctypes.Structure): _fields_ = [("f16", ctypes.c_int16)] +class BF16(ctypes.Structure): + """A ctype representation for MLIR's BFloat16.""" + + _fields_ = [("bf16", ctypes.c_int16)] + + # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): """Converts dtype to ctype.""" @@ -35,6 +47,8 @@ def as_ctype(dtp): return C64 if dtp == np.dtype(np.float16): return F16 + if ml_dtypes is not None and dtp == ml_dtypes.bfloat16: + return BF16 return np.ctypeslib.as_ctypes_type(dtp) @@ -46,6 +60,11 @@ def to_numpy(array): return array.view("complex64") if array.dtype == F16: return array.view("float16") + assert not ( + array.dtype == BF16 and ml_dtypes is None + ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + if array.dtype == BF16: + return array.view("bfloat16") return array diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index acd6dbb25eda..6ec63e43adf8 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.19.5, <=1.26 pybind11>=2.9.0, <=2.10.3 -PyYAML>=5.3.1, <=6.0.1 \ No newline at end of file +PyYAML>=5.3.1, <=6.0.1 +ml_dtypes # provides several NumPy dtype extensions, including the bf16 \ No newline at end of file diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index e8b47007a890..8125bf3fb8fc 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -5,6 +5,7 @@ from mlir.ir import * from mlir.passmanager import * from mlir.execution_engine import * from mlir.runtime import * +from ml_dtypes import bfloat16 # Log everything to stderr and flush so that we have a unified stream to match @@ -521,6 +522,45 @@ def testComplexUnrankedMemrefAdd(): run(testComplexUnrankedMemrefAdd) +# Test bf16 memrefs +# CHECK-LABEL: TEST: testBF16Memref +def testBF16Memref(): + with Context(): + module = Module.parse( + """ + module { + func.func @main(%arg0: memref<1xbf16>, + %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } { + %0 = arith.constant 0 : index + %1 = memref.load %arg0[%0] : memref<1xbf16> + memref.store %1, %arg1[%0] : memref<1xbf16> + return + } + } """ + ) + + arg1 = np.array([0.5]).astype(bfloat16) + arg2 = np.array([0.0]).astype(bfloat16) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg1)) + ) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg2)) + ) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) + + # test to-numpy utility + # CHECK: [0.5] + npout = ranked_memref_to_numpy(arg2_memref_ptr[0]) + log(npout) + + +run(testBF16Memref) + + # Test addition of two 2d_memref # CHECK-LABEL: TEST: testDynamicMemrefAdd2D def testDynamicMemrefAdd2D():