mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 19:26:06 +00:00
Reapply "[MLIR][Python] add ctype python binding support for bf16" (#101271)
Reapply the PR which was reverted due to built-bots, and now the bots get updated. https://discourse.llvm.org/t/need-a-help-with-the-built-bots/79437 original PR: https://github.com/llvm/llvm-project/pull/92489, reverted in https://github.com/llvm/llvm-project/pull/93771
This commit is contained in:
parent
8300eaad0f
commit
5ef087b705
@ -7,6 +7,12 @@
|
||||
import numpy as np
|
||||
import ctypes
|
||||
|
||||
try:
|
||||
import ml_dtypes
|
||||
except ModuleNotFoundError:
|
||||
# The third-party ml_dtypes provides some optional low precision data-types for NumPy.
|
||||
ml_dtypes = None
|
||||
|
||||
|
||||
class C128(ctypes.Structure):
|
||||
"""A ctype representation for MLIR's Double Complex."""
|
||||
@ -26,6 +32,12 @@ class F16(ctypes.Structure):
|
||||
_fields_ = [("f16", ctypes.c_int16)]
|
||||
|
||||
|
||||
class BF16(ctypes.Structure):
|
||||
"""A ctype representation for MLIR's BFloat16."""
|
||||
|
||||
_fields_ = [("bf16", ctypes.c_int16)]
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
|
||||
def as_ctype(dtp):
|
||||
"""Converts dtype to ctype."""
|
||||
@ -35,6 +47,8 @@ def as_ctype(dtp):
|
||||
return C64
|
||||
if dtp == np.dtype(np.float16):
|
||||
return F16
|
||||
if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
|
||||
return BF16
|
||||
return np.ctypeslib.as_ctypes_type(dtp)
|
||||
|
||||
|
||||
@ -46,6 +60,11 @@ def to_numpy(array):
|
||||
return array.view("complex64")
|
||||
if array.dtype == F16:
|
||||
return array.view("float16")
|
||||
assert not (
|
||||
array.dtype == BF16 and ml_dtypes is None
|
||||
), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
|
||||
if array.dtype == BF16:
|
||||
return array.view("bfloat16")
|
||||
return array
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
numpy>=1.19.5, <=1.26
|
||||
pybind11>=2.9.0, <=2.10.3
|
||||
PyYAML>=5.3.1, <=6.0.1
|
||||
PyYAML>=5.3.1, <=6.0.1
|
||||
ml_dtypes # provides several NumPy dtype extensions, including the bf16
|
@ -5,6 +5,7 @@ from mlir.ir import *
|
||||
from mlir.passmanager import *
|
||||
from mlir.execution_engine import *
|
||||
from mlir.runtime import *
|
||||
from ml_dtypes import bfloat16
|
||||
|
||||
|
||||
# Log everything to stderr and flush so that we have a unified stream to match
|
||||
@ -521,6 +522,45 @@ def testComplexUnrankedMemrefAdd():
|
||||
run(testComplexUnrankedMemrefAdd)
|
||||
|
||||
|
||||
# Test bf16 memrefs
|
||||
# CHECK-LABEL: TEST: testBF16Memref
|
||||
def testBF16Memref():
|
||||
with Context():
|
||||
module = Module.parse(
|
||||
"""
|
||||
module {
|
||||
func.func @main(%arg0: memref<1xbf16>,
|
||||
%arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = memref.load %arg0[%0] : memref<1xbf16>
|
||||
memref.store %1, %arg1[%0] : memref<1xbf16>
|
||||
return
|
||||
}
|
||||
} """
|
||||
)
|
||||
|
||||
arg1 = np.array([0.5]).astype(bfloat16)
|
||||
arg2 = np.array([0.0]).astype(bfloat16)
|
||||
|
||||
arg1_memref_ptr = ctypes.pointer(
|
||||
ctypes.pointer(get_ranked_memref_descriptor(arg1))
|
||||
)
|
||||
arg2_memref_ptr = ctypes.pointer(
|
||||
ctypes.pointer(get_ranked_memref_descriptor(arg2))
|
||||
)
|
||||
|
||||
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
||||
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
|
||||
|
||||
# test to-numpy utility
|
||||
# CHECK: [0.5]
|
||||
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
|
||||
log(npout)
|
||||
|
||||
|
||||
run(testBF16Memref)
|
||||
|
||||
|
||||
# Test addition of two 2d_memref
|
||||
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
|
||||
def testDynamicMemrefAdd2D():
|
||||
|
Loading…
x
Reference in New Issue
Block a user