mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-18 09:06:43 +00:00
[mlir][fix] Add callback functions for ModuleToObject (#116916)
Here is the [merged MR](https://github.com/llvm/llvm-project/pull/116007) which caused a failure and [was reverted](https://github.com/llvm/llvm-project/pull/116811). Thanks to @joker-eph for the help, I fix it (miss constructing `ModuleObject` with callback functions in `mlir/lib/Target/LLVM/NVVM/Target.cpp`) and split unit tests from origin test which don't need `ptxas` to make the test runs more widely.
This commit is contained in:
parent
33fcd6acc7
commit
08e7609692
@ -14,6 +14,7 @@
|
||||
#define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
|
||||
namespace llvm {
|
||||
class IRBuilderBase;
|
||||
@ -52,7 +53,11 @@ public:
|
||||
StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
|
||||
StringRef cmdOptions = {},
|
||||
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
|
||||
function_ref<SymbolTable *()> getSymbolTableCallback = {});
|
||||
function_ref<SymbolTable *()> getSymbolTableCallback = {},
|
||||
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
|
||||
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
|
||||
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
|
||||
function_ref<void(StringRef)> isaCallback = {});
|
||||
|
||||
/// Returns the typeID.
|
||||
TypeID getTypeID() const;
|
||||
@ -80,6 +85,22 @@ public:
|
||||
/// table.
|
||||
SymbolTable *getSymbolTable() const;
|
||||
|
||||
/// Returns the callback invoked with the initial LLVM IR for the device
|
||||
/// module.
|
||||
function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
|
||||
|
||||
/// Returns the callback invoked with LLVM IR for the device module
|
||||
/// after linking the device libraries.
|
||||
function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
|
||||
|
||||
/// Returns the callback invoked with LLVM IR for the device module after
|
||||
/// LLVM optimizations but before codegen.
|
||||
function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
|
||||
|
||||
/// Returns the callback invoked with the target ISA for the device,
|
||||
/// for example PTX assembly.
|
||||
function_ref<void(StringRef)> getISACallback() const;
|
||||
|
||||
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
|
||||
static CompilationTarget getDefaultCompilationTarget();
|
||||
|
||||
@ -90,7 +111,11 @@ protected:
|
||||
TypeID typeID, StringRef toolkitPath = {},
|
||||
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
|
||||
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
|
||||
function_ref<SymbolTable *()> getSymbolTableCallback = {});
|
||||
function_ref<SymbolTable *()> getSymbolTableCallback = {},
|
||||
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
|
||||
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
|
||||
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
|
||||
function_ref<void(StringRef)> isaCallback = {});
|
||||
|
||||
/// Path to the target toolkit.
|
||||
std::string toolkitPath;
|
||||
@ -109,6 +134,21 @@ protected:
|
||||
/// being serialized.
|
||||
function_ref<SymbolTable *()> getSymbolTableCallback;
|
||||
|
||||
/// Callback invoked with the initial LLVM IR for the device module.
|
||||
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
|
||||
|
||||
/// Callback invoked with LLVM IR for the device module after
|
||||
/// linking the device libraries.
|
||||
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
|
||||
|
||||
/// Callback invoked with LLVM IR for the device module after
|
||||
/// LLVM optimizations but before codegen.
|
||||
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
|
||||
|
||||
/// Callback invoked with the target ISA for the device,
|
||||
/// for example PTX assembly.
|
||||
function_ref<void(StringRef)> isaCallback;
|
||||
|
||||
private:
|
||||
TypeID typeID;
|
||||
};
|
||||
|
@ -29,8 +29,13 @@ class ModuleTranslation;
|
||||
/// operations being transformed must be translatable into LLVM IR.
|
||||
class ModuleToObject {
|
||||
public:
|
||||
ModuleToObject(Operation &module, StringRef triple, StringRef chip,
|
||||
StringRef features = {}, int optLevel = 3);
|
||||
ModuleToObject(
|
||||
Operation &module, StringRef triple, StringRef chip,
|
||||
StringRef features = {}, int optLevel = 3,
|
||||
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
|
||||
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
|
||||
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
|
||||
function_ref<void(StringRef)> isaCallback = {});
|
||||
virtual ~ModuleToObject();
|
||||
|
||||
/// Returns the operation being serialized.
|
||||
@ -114,6 +119,21 @@ protected:
|
||||
/// Optimization level.
|
||||
int optLevel;
|
||||
|
||||
/// Callback invoked with the initial LLVM IR for the device module.
|
||||
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
|
||||
|
||||
/// Callback invoked with LLVM IR for the device module after
|
||||
/// linking the device libraries.
|
||||
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
|
||||
|
||||
/// Callback invoked with LLVM IR for the device module after
|
||||
/// LLVM optimizations but before codegen.
|
||||
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
|
||||
|
||||
/// Callback invoked with the target ISA for the device,
|
||||
/// for example PTX assembly.
|
||||
function_ref<void(StringRef)> isaCallback;
|
||||
|
||||
private:
|
||||
/// The TargetMachine created for the given Triple, if available.
|
||||
/// Accessible through `getOrCreateTargetMachine()`.
|
||||
|
@ -2302,17 +2302,31 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
|
||||
TargetOptions::TargetOptions(
|
||||
StringRef toolkitPath, ArrayRef<std::string> linkFiles,
|
||||
StringRef cmdOptions, CompilationTarget compilationTarget,
|
||||
function_ref<SymbolTable *()> getSymbolTableCallback)
|
||||
function_ref<SymbolTable *()> getSymbolTableCallback,
|
||||
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
|
||||
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
|
||||
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
|
||||
function_ref<void(StringRef)> isaCallback)
|
||||
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
|
||||
cmdOptions, compilationTarget, getSymbolTableCallback) {}
|
||||
cmdOptions, compilationTarget, getSymbolTableCallback,
|
||||
initialLlvmIRCallback, linkedLlvmIRCallback,
|
||||
optimizedLlvmIRCallback, isaCallback) {}
|
||||
|
||||
TargetOptions::TargetOptions(
|
||||
TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
|
||||
StringRef cmdOptions, CompilationTarget compilationTarget,
|
||||
function_ref<SymbolTable *()> getSymbolTableCallback)
|
||||
function_ref<SymbolTable *()> getSymbolTableCallback,
|
||||
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
|
||||
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
|
||||
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
|
||||
function_ref<void(StringRef)> isaCallback)
|
||||
: toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
|
||||
cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
|
||||
getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
|
||||
getSymbolTableCallback(getSymbolTableCallback),
|
||||
initialLlvmIRCallback(initialLlvmIRCallback),
|
||||
linkedLlvmIRCallback(linkedLlvmIRCallback),
|
||||
optimizedLlvmIRCallback(optimizedLlvmIRCallback),
|
||||
isaCallback(isaCallback), typeID(typeID) {}
|
||||
|
||||
TypeID TargetOptions::getTypeID() const { return typeID; }
|
||||
|
||||
@ -2326,6 +2340,25 @@ SymbolTable *TargetOptions::getSymbolTable() const {
|
||||
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
|
||||
}
|
||||
|
||||
function_ref<void(llvm::Module &)>
|
||||
TargetOptions::getInitialLlvmIRCallback() const {
|
||||
return initialLlvmIRCallback;
|
||||
}
|
||||
|
||||
function_ref<void(llvm::Module &)>
|
||||
TargetOptions::getLinkedLlvmIRCallback() const {
|
||||
return linkedLlvmIRCallback;
|
||||
}
|
||||
|
||||
function_ref<void(llvm::Module &)>
|
||||
TargetOptions::getOptimizedLlvmIRCallback() const {
|
||||
return optimizedLlvmIRCallback;
|
||||
}
|
||||
|
||||
function_ref<void(StringRef)> TargetOptions::getISACallback() const {
|
||||
return isaCallback;
|
||||
}
|
||||
|
||||
CompilationTarget TargetOptions::getCompilationTarget() const {
|
||||
return compilationTarget;
|
||||
}
|
||||
|
@ -34,10 +34,17 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::LLVM;
|
||||
|
||||
ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
|
||||
StringRef chip, StringRef features, int optLevel)
|
||||
ModuleToObject::ModuleToObject(
|
||||
Operation &module, StringRef triple, StringRef chip, StringRef features,
|
||||
int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
|
||||
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
|
||||
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
|
||||
function_ref<void(StringRef)> isaCallback)
|
||||
: module(module), triple(triple), chip(chip), features(features),
|
||||
optLevel(optLevel) {}
|
||||
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
|
||||
linkedLlvmIRCallback(linkedLlvmIRCallback),
|
||||
optimizedLlvmIRCallback(optimizedLlvmIRCallback),
|
||||
isaCallback(isaCallback) {}
|
||||
|
||||
ModuleToObject::~ModuleToObject() = default;
|
||||
|
||||
@ -215,6 +222,9 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
|
||||
}
|
||||
setDataLayoutAndTriple(*llvmModule);
|
||||
|
||||
if (initialLlvmIRCallback)
|
||||
initialLlvmIRCallback(*llvmModule);
|
||||
|
||||
// Link bitcode files.
|
||||
handleModulePreLink(*llvmModule);
|
||||
{
|
||||
@ -227,10 +237,16 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
|
||||
handleModulePostLink(*llvmModule);
|
||||
}
|
||||
|
||||
if (linkedLlvmIRCallback)
|
||||
linkedLlvmIRCallback(*llvmModule);
|
||||
|
||||
// Optimize the module.
|
||||
if (failed(optimizeModule(*llvmModule, optLevel)))
|
||||
return std::nullopt;
|
||||
|
||||
if (optimizedLlvmIRCallback)
|
||||
optimizedLlvmIRCallback(*llvmModule);
|
||||
|
||||
// Return the serialized object.
|
||||
return moduleToObject(*llvmModule);
|
||||
}
|
||||
|
@ -86,7 +86,11 @@ SerializeGPUModuleBase::SerializeGPUModuleBase(
|
||||
Operation &module, NVVMTargetAttr target,
|
||||
const gpu::TargetOptions &targetOptions)
|
||||
: ModuleToObject(module, target.getTriple(), target.getChip(),
|
||||
target.getFeatures(), target.getO()),
|
||||
target.getFeatures(), target.getO(),
|
||||
targetOptions.getInitialLlvmIRCallback(),
|
||||
targetOptions.getLinkedLlvmIRCallback(),
|
||||
targetOptions.getOptimizedLlvmIRCallback(),
|
||||
targetOptions.getISACallback()),
|
||||
target(target), toolkitPath(targetOptions.getToolkitPath()),
|
||||
fileList(targetOptions.getLinkFiles()) {
|
||||
|
||||
@ -572,6 +576,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
|
||||
getOperation().emitError() << "Failed translating the module to ISA.";
|
||||
return std::nullopt;
|
||||
}
|
||||
if (isaCallback)
|
||||
isaCallback(serializedISA.value());
|
||||
|
||||
#define DEBUG_TYPE "serialize-to-isa"
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";
|
||||
|
@ -156,3 +156,62 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToBinary)) {
|
||||
ASSERT_TRUE(!object->empty());
|
||||
}
|
||||
}
|
||||
|
||||
// Test callback functions invoked with LLVM IR and ISA.
|
||||
TEST_F(MLIRTargetLLVMNVVM,
|
||||
SKIP_WITHOUT_NVPTX(CallbackInvokedWithLLVMIRAndISA)) {
|
||||
MLIRContext context(registry);
|
||||
|
||||
OwningOpRef<ModuleOp> module =
|
||||
parseSourceString<ModuleOp>(moduleStr, &context);
|
||||
ASSERT_TRUE(!!module);
|
||||
|
||||
NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
|
||||
|
||||
auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
|
||||
ASSERT_TRUE(!!serializer);
|
||||
|
||||
std::string initialLLVMIR;
|
||||
auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
|
||||
llvm::raw_string_ostream ros(initialLLVMIR);
|
||||
module.print(ros, nullptr);
|
||||
};
|
||||
|
||||
std::string linkedLLVMIR;
|
||||
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
|
||||
llvm::raw_string_ostream ros(linkedLLVMIR);
|
||||
module.print(ros, nullptr);
|
||||
};
|
||||
|
||||
std::string optimizedLLVMIR;
|
||||
auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
|
||||
llvm::raw_string_ostream ros(optimizedLLVMIR);
|
||||
module.print(ros, nullptr);
|
||||
};
|
||||
|
||||
std::string isaResult;
|
||||
auto isaCallback = [&isaResult](llvm::StringRef isa) {
|
||||
isaResult = isa.str();
|
||||
};
|
||||
|
||||
gpu::TargetOptions options({}, {}, {}, gpu::CompilationTarget::Assembly, {},
|
||||
initialCallback, linkedCallback, optimizedCallback,
|
||||
isaCallback);
|
||||
|
||||
for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
|
||||
std::optional<SmallVector<char, 0>> object =
|
||||
serializer.serializeToObject(gpuModule, options);
|
||||
|
||||
ASSERT_TRUE(object != std::nullopt);
|
||||
ASSERT_TRUE(!object->empty());
|
||||
ASSERT_TRUE(!initialLLVMIR.empty());
|
||||
ASSERT_TRUE(!linkedLLVMIR.empty());
|
||||
ASSERT_TRUE(!optimizedLLVMIR.empty());
|
||||
ASSERT_TRUE(!isaResult.empty());
|
||||
|
||||
initialLLVMIR.clear();
|
||||
linkedLLVMIR.clear();
|
||||
optimizedLLVMIR.clear();
|
||||
isaResult.clear();
|
||||
}
|
||||
}
|
||||
|
@ -105,7 +105,9 @@ TargetAttrImpl::serializeToObject(Attribute attribute, Operation *module,
|
||||
// Set a dummy attr to be retrieved by `createObject`.
|
||||
module->setAttr("serialize_attr", UnitAttr::get(module->getContext()));
|
||||
std::string targetTriple = llvm::sys::getProcessTriple();
|
||||
LLVM::ModuleToObject serializer(*module, targetTriple, "", "");
|
||||
LLVM::ModuleToObject serializer(
|
||||
*module, targetTriple, "", "", 3, options.getInitialLlvmIRCallback(),
|
||||
options.getLinkedLlvmIRCallback(), options.getOptimizedLlvmIRCallback());
|
||||
return serializer.run();
|
||||
}
|
||||
|
||||
@ -153,3 +155,88 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(TargetAttrAPI)) {
|
||||
// `serializeToObject`.
|
||||
ASSERT_TRUE(properties.contains("serialize_attr"));
|
||||
}
|
||||
|
||||
// Test callback function invoked with initial LLVM IR
|
||||
TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) {
|
||||
MLIRContext context(registry);
|
||||
|
||||
OwningOpRef<ModuleOp> module =
|
||||
parseSourceString<ModuleOp>(moduleStr, &context);
|
||||
ASSERT_TRUE(!!module);
|
||||
Builder builder(&context);
|
||||
IntegerAttr target = builder.getI32IntegerAttr(0);
|
||||
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
|
||||
|
||||
std::string initialLLVMIR;
|
||||
auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
|
||||
llvm::raw_string_ostream ros(initialLLVMIR);
|
||||
module.print(ros, nullptr);
|
||||
};
|
||||
|
||||
gpu::TargetOptions opts(
|
||||
{}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {},
|
||||
initialCallback);
|
||||
std::optional<SmallVector<char, 0>> serializedBinary =
|
||||
targetAttr.serializeToObject(*module, opts);
|
||||
|
||||
ASSERT_TRUE(serializedBinary != std::nullopt);
|
||||
ASSERT_TRUE(!serializedBinary->empty());
|
||||
ASSERT_TRUE(!initialLLVMIR.empty());
|
||||
}
|
||||
|
||||
// Test callback function invoked with linked LLVM IR
|
||||
TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) {
|
||||
MLIRContext context(registry);
|
||||
|
||||
OwningOpRef<ModuleOp> module =
|
||||
parseSourceString<ModuleOp>(moduleStr, &context);
|
||||
ASSERT_TRUE(!!module);
|
||||
Builder builder(&context);
|
||||
IntegerAttr target = builder.getI32IntegerAttr(0);
|
||||
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
|
||||
|
||||
std::string linkedLLVMIR;
|
||||
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
|
||||
llvm::raw_string_ostream ros(linkedLLVMIR);
|
||||
module.print(ros, nullptr);
|
||||
};
|
||||
|
||||
gpu::TargetOptions opts(
|
||||
{}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {},
|
||||
{}, linkedCallback);
|
||||
std::optional<SmallVector<char, 0>> serializedBinary =
|
||||
targetAttr.serializeToObject(*module, opts);
|
||||
|
||||
ASSERT_TRUE(serializedBinary != std::nullopt);
|
||||
ASSERT_TRUE(!serializedBinary->empty());
|
||||
ASSERT_TRUE(!linkedLLVMIR.empty());
|
||||
}
|
||||
|
||||
// Test callback function invoked with optimized LLVM IR
|
||||
TEST_F(MLIRTargetLLVM,
|
||||
SKIP_WITHOUT_NATIVE(CallbackInvokedWithOptimizedLLVMIR)) {
|
||||
MLIRContext context(registry);
|
||||
|
||||
OwningOpRef<ModuleOp> module =
|
||||
parseSourceString<ModuleOp>(moduleStr, &context);
|
||||
ASSERT_TRUE(!!module);
|
||||
Builder builder(&context);
|
||||
IntegerAttr target = builder.getI32IntegerAttr(0);
|
||||
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
|
||||
|
||||
std::string optimizedLLVMIR;
|
||||
auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
|
||||
llvm::raw_string_ostream ros(optimizedLLVMIR);
|
||||
module.print(ros, nullptr);
|
||||
};
|
||||
|
||||
gpu::TargetOptions opts(
|
||||
{}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {},
|
||||
{}, {}, optimizedCallback);
|
||||
std::optional<SmallVector<char, 0>> serializedBinary =
|
||||
targetAttr.serializeToObject(*module, opts);
|
||||
|
||||
ASSERT_TRUE(serializedBinary != std::nullopt);
|
||||
ASSERT_TRUE(!serializedBinary->empty());
|
||||
ASSERT_TRUE(!optimizedLLVMIR.empty());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user