[mlir][llvmir] expose Type(To/From)LLVMIRTranslator C API (#124864)

This commit is contained in:
Maksim Levental 2025-01-30 11:43:22 -06:00 committed by GitHub
parent 1128343727
commit 7ae964c55b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 123 additions and 3 deletions

View File

@ -52,6 +52,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type,
intptr_t pos);
/// Returns the return type of the function type.
MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type);
/// Returns `true` if the type is an LLVM dialect struct type.
MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type);

View File

@ -16,6 +16,7 @@
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "llvm-c/Core.h"
#include "llvm-c/Support.h"
#ifdef __cplusplus
@ -32,6 +33,48 @@ extern "C" {
MLIR_CAPI_EXPORTED LLVMModuleRef
mlirTranslateModuleToLLVMIR(MlirOperation module, LLVMContextRef context);
struct MlirTypeFromLLVMIRTranslator {
void *ptr;
};
typedef struct MlirTypeFromLLVMIRTranslator MlirTypeFromLLVMIRTranslator;
/// Create an LLVM::TypeFromLLVMIRTranslator and transfer ownership to the
/// caller.
MLIR_CAPI_EXPORTED MlirTypeFromLLVMIRTranslator
mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx);
/// Takes an LLVM::TypeFromLLVMIRTranslator owned by the caller and destroys it.
/// It is the responsibility of the user to only pass an
/// LLVM::TypeFromLLVMIRTranslator class.
MLIR_CAPI_EXPORTED void
mlirTypeFromLLVMIRTranslatorDestroy(MlirTypeFromLLVMIRTranslator translator);
/// Translates the given LLVM IR type to the MLIR LLVM dialect.
MLIR_CAPI_EXPORTED MlirType mlirTypeFromLLVMIRTranslatorTranslateType(
MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType);
struct MlirTypeToLLVMIRTranslator {
void *ptr;
};
typedef struct MlirTypeToLLVMIRTranslator MlirTypeToLLVMIRTranslator;
/// Create an LLVM::TypeToLLVMIRTranslator and transfer ownership to the
/// caller.
MLIR_CAPI_EXPORTED MlirTypeToLLVMIRTranslator
mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx);
/// Takes an LLVM::TypeToLLVMIRTranslator owned by the caller and destroys it.
/// It is the responsibility of the user to only pass an
/// LLVM::TypeToLLVMIRTranslator class.
MLIR_CAPI_EXPORTED void
mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator);
/// Translates the given MLIR LLVM dialect to the LLVM IR type.
MLIR_CAPI_EXPORTED LLVMTypeRef mlirTypeToLLVMIRTranslatorTranslateType(
MlirTypeToLLVMIRTranslator translator, MlirType mlirType);
#ifdef __cplusplus
}
#endif

View File

@ -65,6 +65,10 @@ MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos) {
.getParamType(static_cast<unsigned>(pos)));
}
MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type) {
return wrap(llvm::cast<LLVM::LLVMFunctionType>(unwrap(type)).getReturnType());
}
bool mlirTypeIsALLVMStructType(MlirType type) {
return isa<LLVM::LLVMStructType>(unwrap(type));
}

View File

@ -8,16 +8,15 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/Target/LLVMIR.h"
#include "llvm-c/Support.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include <memory>
#include "llvm/IR/Type.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Target/LLVMIR/TypeFromLLVM.h"
using namespace mlir;
@ -34,3 +33,47 @@ LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module,
return moduleRef;
}
DEFINE_C_API_PTR_METHODS(MlirTypeFromLLVMIRTranslator,
mlir::LLVM::TypeFromLLVMIRTranslator);
MlirTypeFromLLVMIRTranslator
mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx) {
MLIRContext *context = unwrap(ctx);
auto *translator = new LLVM::TypeFromLLVMIRTranslator(*context);
return wrap(translator);
}
void mlirTypeFromLLVMIRTranslatorDestroy(
MlirTypeFromLLVMIRTranslator translator) {
delete static_cast<LLVM::TypeFromLLVMIRTranslator *>(unwrap(translator));
}
MlirType mlirTypeFromLLVMIRTranslatorTranslateType(
MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType) {
LLVM::TypeFromLLVMIRTranslator *translator_ = unwrap(translator);
mlir::Type type = translator_->translateType(llvm::unwrap(llvmType));
return wrap(type);
}
DEFINE_C_API_PTR_METHODS(MlirTypeToLLVMIRTranslator,
mlir::LLVM::TypeToLLVMIRTranslator);
MlirTypeToLLVMIRTranslator
mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx) {
llvm::LLVMContext *context = llvm::unwrap(ctx);
auto *translator = new LLVM::TypeToLLVMIRTranslator(*context);
return wrap(translator);
}
void mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator) {
delete static_cast<LLVM::TypeToLLVMIRTranslator *>(unwrap(translator));
}
LLVMTypeRef
mlirTypeToLLVMIRTranslatorTranslateType(MlirTypeToLLVMIRTranslator translator,
MlirType mlirType) {
LLVM::TypeToLLVMIRTranslator *translator_ = unwrap(translator);
llvm::Type *type = translator_->translateType(unwrap(mlirType));
return llvm::wrap(type);
}

View File

@ -58,11 +58,38 @@ static void testToLLVMIR(MlirContext ctx) {
LLVMContextDispose(llvmCtx);
}
// CHECK-LABEL: testTypeToFromLLVMIRTranslator
static void testTypeToFromLLVMIRTranslator(MlirContext ctx) {
fprintf(stderr, "testTypeToFromLLVMIRTranslator\n");
LLVMContextRef llvmCtx = LLVMContextCreate();
LLVMTypeRef llvmTy = LLVMInt32TypeInContext(llvmCtx);
MlirTypeFromLLVMIRTranslator fromLLVMTranslator =
mlirTypeFromLLVMIRTranslatorCreate(ctx);
MlirType mlirTy =
mlirTypeFromLLVMIRTranslatorTranslateType(fromLLVMTranslator, llvmTy);
// CHECK: i32
mlirTypeDump(mlirTy);
MlirTypeToLLVMIRTranslator toLLVMTranslator =
mlirTypeToLLVMIRTranslatorCreate(llvmCtx);
LLVMTypeRef llvmTy2 =
mlirTypeToLLVMIRTranslatorTranslateType(toLLVMTranslator, mlirTy);
// CHECK: i32
LLVMDumpType(llvmTy2);
fprintf(stderr, "\n");
mlirTypeFromLLVMIRTranslatorDestroy(fromLLVMTranslator);
mlirTypeToLLVMIRTranslatorDestroy(toLLVMTranslator);
LLVMContextDispose(llvmCtx);
}
int main(void) {
MlirContext ctx = mlirContextCreate();
mlirDialectHandleRegisterDialect(mlirGetDialectHandle__llvm__(), ctx);
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("llvm"));
testToLLVMIR(ctx);
testTypeToFromLLVMIRTranslator(ctx);
mlirContextDestroy(ctx);
return 0;
}