mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-16 11:36:46 +00:00

This PR implements `verify-diagnostics=only-expected` which is a "best effort" verification - i.e., `unexpected`s and `near-misses` will not be considered failures. The purpose is to enable narrowly scoped checking of verification remarks (just as we have for lit where only a subset of lines get `CHECK`ed).
412 lines
17 KiB
C++
412 lines
17 KiB
C++
//===- mlir-transform-opt.cpp -----------------------------------*- C++ -*-===//
|
|
//
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/Utils.h"
|
|
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/DialectRegistry.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/InitAllDialects.h"
|
|
#include "mlir/InitAllExtensions.h"
|
|
#include "mlir/InitAllPasses.h"
|
|
#include "mlir/Parser/Parser.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/InitLLVM.h"
|
|
#include "llvm/Support/SourceMgr.h"
|
|
#include "llvm/Support/ToolOutputFile.h"
|
|
#include <cstdlib>
|
|
|
|
namespace {
|
|
|
|
using namespace llvm;
|
|
|
|
/// Structure containing command line options for the tool, these will get
|
|
/// initialized when an instance is created.
|
|
struct MlirTransformOptCLOptions {
|
|
cl::opt<bool> allowUnregisteredDialects{
|
|
"allow-unregistered-dialect",
|
|
cl::desc("Allow operations coming from an unregistered dialect"),
|
|
cl::init(false)};
|
|
|
|
cl::opt<mlir::SourceMgrDiagnosticVerifierHandler::Level> verifyDiagnostics{
|
|
"verify-diagnostics", llvm::cl::ValueOptional,
|
|
cl::desc("Check that emitted diagnostics match expected-* lines on the "
|
|
"corresponding line"),
|
|
cl::values(
|
|
clEnumValN(
|
|
mlir::SourceMgrDiagnosticVerifierHandler::Level::All, "all",
|
|
"Check all diagnostics (expected, unexpected, near-misses)"),
|
|
// Implicit value: when passed with no arguments, e.g.
|
|
// `--verify-diagnostics` or `--verify-diagnostics=`.
|
|
clEnumValN(
|
|
mlir::SourceMgrDiagnosticVerifierHandler::Level::All, "",
|
|
"Check all diagnostics (expected, unexpected, near-misses)"),
|
|
clEnumValN(
|
|
mlir::SourceMgrDiagnosticVerifierHandler::Level::OnlyExpected,
|
|
"only-expected", "Check only expected diagnostics"))};
|
|
|
|
cl::opt<std::string> payloadFilename{cl::Positional, cl::desc("<input file>"),
|
|
cl::init("-")};
|
|
|
|
cl::opt<std::string> outputFilename{"o", cl::desc("Output filename"),
|
|
cl::value_desc("filename"),
|
|
cl::init("-")};
|
|
|
|
cl::opt<std::string> transformMainFilename{
|
|
"transform",
|
|
cl::desc("File containing entry point of the transform script, if "
|
|
"different from the input file"),
|
|
cl::value_desc("filename"), cl::init("")};
|
|
|
|
cl::list<std::string> transformLibraryFilenames{
|
|
"transform-library", cl::desc("File(s) containing definitions of "
|
|
"additional transform script symbols")};
|
|
|
|
cl::opt<std::string> transformEntryPoint{
|
|
"transform-entry-point",
|
|
cl::desc("Name of the entry point transform symbol"),
|
|
cl::init(mlir::transform::TransformDialect::kTransformEntryPointSymbolName
|
|
.str())};
|
|
|
|
cl::opt<bool> disableExpensiveChecks{
|
|
"disable-expensive-checks",
|
|
cl::desc("Disables potentially expensive checks in the transform "
|
|
"interpreter, providing more speed at the expense of "
|
|
"potential memory problems and silent corruptions"),
|
|
cl::init(false)};
|
|
|
|
cl::opt<bool> dumpLibraryModule{
|
|
"dump-library-module",
|
|
cl::desc("Prints the combined library module before the output"),
|
|
cl::init(false)};
|
|
};
|
|
} // namespace
|
|
|
|
/// "Managed" static instance of the command-line options structure. This makes
|
|
/// them locally-scoped and explicitly initialized/deinitialized. While this is
|
|
/// not strictly necessary in the tool source file that is not being used as a
|
|
/// library (where the options would pollute the global list of options), it is
|
|
/// good practice to follow this.
|
|
static llvm::ManagedStatic<MlirTransformOptCLOptions> clOptions;
|
|
|
|
/// Explicitly registers command-line options.
|
|
static void registerCLOptions() { *clOptions; }
|
|
|
|
namespace {
|
|
/// A wrapper class for source managers diagnostic. This provides both unique
|
|
/// ownership and virtual function-like overload for a pair of
|
|
/// inheritance-related classes that do not use virtual functions.
|
|
class DiagnosticHandlerWrapper {
|
|
public:
|
|
/// Kind of the diagnostic handler to use.
|
|
enum class Kind { EmitDiagnostics, VerifyDiagnostics };
|
|
|
|
/// Constructs the diagnostic handler of the specified kind of the given
|
|
/// source manager and context.
|
|
DiagnosticHandlerWrapper(
|
|
Kind kind, llvm::SourceMgr &mgr, mlir::MLIRContext *context,
|
|
std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level> level =
|
|
{}) {
|
|
if (kind == Kind::EmitDiagnostics) {
|
|
handler = new mlir::SourceMgrDiagnosticHandler(mgr, context);
|
|
} else {
|
|
assert(level.has_value() && "expected level");
|
|
handler =
|
|
new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context, *level);
|
|
}
|
|
}
|
|
|
|
/// This object is non-copyable but movable.
|
|
DiagnosticHandlerWrapper(const DiagnosticHandlerWrapper &) = delete;
|
|
DiagnosticHandlerWrapper(DiagnosticHandlerWrapper &&other) = default;
|
|
DiagnosticHandlerWrapper &
|
|
operator=(const DiagnosticHandlerWrapper &) = delete;
|
|
DiagnosticHandlerWrapper &operator=(DiagnosticHandlerWrapper &&) = default;
|
|
|
|
/// Verifies the captured "expected-*" diagnostics if required.
|
|
llvm::LogicalResult verify() const {
|
|
if (auto *ptr =
|
|
handler.dyn_cast<mlir::SourceMgrDiagnosticVerifierHandler *>()) {
|
|
return ptr->verify();
|
|
}
|
|
return mlir::success();
|
|
}
|
|
|
|
/// Destructs the object of the same type as allocated.
|
|
~DiagnosticHandlerWrapper() {
|
|
if (auto *ptr = handler.dyn_cast<mlir::SourceMgrDiagnosticHandler *>()) {
|
|
delete ptr;
|
|
} else {
|
|
delete cast<mlir::SourceMgrDiagnosticVerifierHandler *>(handler);
|
|
}
|
|
}
|
|
|
|
private:
|
|
/// Internal storage is a type-safe union.
|
|
llvm::PointerUnion<mlir::SourceMgrDiagnosticHandler *,
|
|
mlir::SourceMgrDiagnosticVerifierHandler *>
|
|
handler;
|
|
};
|
|
|
|
/// MLIR has deeply rooted expectations that the LLVM source manager contains
|
|
/// exactly one buffer, until at least the lexer level. This class wraps
|
|
/// multiple LLVM source managers each managing a buffer to match MLIR's
|
|
/// expectations while still providing a centralized handling mechanism.
|
|
class TransformSourceMgr {
|
|
public:
|
|
/// Constructs the source manager indicating whether diagnostic messages will
|
|
/// be verified later on.
|
|
explicit TransformSourceMgr(
|
|
std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level>
|
|
verifyDiagnostics)
|
|
: verifyDiagnostics(verifyDiagnostics) {}
|
|
|
|
/// Deconstructs the source manager. Note that `checkResults` must have been
|
|
/// called on this instance before deconstructing it.
|
|
~TransformSourceMgr() {
|
|
assert(resultChecked && "must check the result of diagnostic handlers by "
|
|
"running TransformSourceMgr::checkResult");
|
|
}
|
|
|
|
/// Parses the given buffer and creates the top-level operation of the kind
|
|
/// specified as template argument in the given context. Additional parsing
|
|
/// options may be provided.
|
|
template <typename OpTy = mlir::Operation *>
|
|
mlir::OwningOpRef<OpTy> parseBuffer(std::unique_ptr<MemoryBuffer> buffer,
|
|
mlir::MLIRContext &context,
|
|
const mlir::ParserConfig &config) {
|
|
// Create a single-buffer LLVM source manager. Note that `unique_ptr` allows
|
|
// the code below to capture a reference to the source manager in such a way
|
|
// that it is not invalidated when the vector contents is eventually
|
|
// reallocated.
|
|
llvm::SourceMgr &mgr =
|
|
*sourceMgrs.emplace_back(std::make_unique<llvm::SourceMgr>());
|
|
mgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
|
|
|
// Choose the type of diagnostic handler depending on whether diagnostic
|
|
// verification needs to happen and store it.
|
|
if (verifyDiagnostics) {
|
|
diagHandlers.emplace_back(
|
|
DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, mgr, &context,
|
|
verifyDiagnostics);
|
|
} else {
|
|
diagHandlers.emplace_back(DiagnosticHandlerWrapper::Kind::EmitDiagnostics,
|
|
mgr, &context);
|
|
}
|
|
|
|
// Defer to MLIR's parser.
|
|
return mlir::parseSourceFile<OpTy>(mgr, config);
|
|
}
|
|
|
|
/// If diagnostic message verification has been requested upon construction of
|
|
/// this source manager, performs the verification, reports errors and returns
|
|
/// the result of the verification. Otherwise passes through the given value.
|
|
llvm::LogicalResult checkResult(llvm::LogicalResult result) {
|
|
resultChecked = true;
|
|
if (!verifyDiagnostics)
|
|
return result;
|
|
|
|
return mlir::failure(llvm::any_of(diagHandlers, [](const auto &handler) {
|
|
return mlir::failed(handler.verify());
|
|
}));
|
|
}
|
|
|
|
private:
|
|
/// Indicates whether diagnostic message verification is requested.
|
|
const std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level>
|
|
verifyDiagnostics;
|
|
|
|
/// Indicates that diagnostic message verification has taken place, and the
|
|
/// deconstruction is therefore safe.
|
|
bool resultChecked = false;
|
|
|
|
/// Storage for per-buffer source managers and diagnostic handlers. These are
|
|
/// wrapped into unique pointers in order to make it safe to capture
|
|
/// references to these objects: if the vector is reallocated, the unique
|
|
/// pointer objects are moved by the pointer addresses won't change. Also, for
|
|
/// handlers, this allows to store the pointer to the base class.
|
|
SmallVector<std::unique_ptr<llvm::SourceMgr>> sourceMgrs;
|
|
SmallVector<DiagnosticHandlerWrapper> diagHandlers;
|
|
};
|
|
} // namespace
|
|
|
|
/// Trivial wrapper around `applyTransforms` that doesn't support extra mapping
|
|
/// and doesn't enforce the entry point transform ops being top-level.
|
|
static llvm::LogicalResult
|
|
applyTransforms(mlir::Operation *payloadRoot,
|
|
mlir::transform::TransformOpInterface transformRoot,
|
|
const mlir::transform::TransformOptions &options) {
|
|
return applyTransforms(payloadRoot, transformRoot, {}, options,
|
|
/*enforceToplevelTransformOp=*/false);
|
|
}
|
|
|
|
/// Applies transforms indicated in the transform dialect script to the input
|
|
/// buffer. The transform script may be embedded in the input buffer or as a
|
|
/// separate buffer. The transform script may have external symbols, the
|
|
/// definitions of which must be provided in transform library buffers. If the
|
|
/// application is successful, prints the transformed input buffer into the
|
|
/// given output stream. Additional configuration options are derived from
|
|
/// command-line options.
|
|
static llvm::LogicalResult processPayloadBuffer(
|
|
raw_ostream &os, std::unique_ptr<MemoryBuffer> inputBuffer,
|
|
std::unique_ptr<llvm::MemoryBuffer> transformBuffer,
|
|
MutableArrayRef<std::unique_ptr<MemoryBuffer>> transformLibraries,
|
|
mlir::DialectRegistry ®istry) {
|
|
|
|
// Initialize the MLIR context, and various configurations.
|
|
mlir::MLIRContext context(registry, mlir::MLIRContext::Threading::DISABLED);
|
|
context.allowUnregisteredDialects(clOptions->allowUnregisteredDialects);
|
|
mlir::ParserConfig config(&context);
|
|
TransformSourceMgr sourceMgr(
|
|
/*verifyDiagnostics=*/clOptions->verifyDiagnostics.getNumOccurrences()
|
|
? std::optional{clOptions->verifyDiagnostics.getValue()}
|
|
: std::nullopt);
|
|
|
|
// Parse the input buffer that will be used as transform payload.
|
|
mlir::OwningOpRef<mlir::Operation *> payloadRoot =
|
|
sourceMgr.parseBuffer(std::move(inputBuffer), context, config);
|
|
if (!payloadRoot)
|
|
return sourceMgr.checkResult(mlir::failure());
|
|
|
|
// Identify the module containing the transform script entry point. This may
|
|
// be the same module as the input or a separate module. In the former case,
|
|
// make a copy of the module so it can be modified freely. Modification may
|
|
// happen in the script itself (at which point it could be rewriting itself
|
|
// during interpretation, leading to tricky memory errors) or by embedding
|
|
// library modules in the script.
|
|
mlir::OwningOpRef<mlir::ModuleOp> transformRoot;
|
|
if (transformBuffer) {
|
|
transformRoot = sourceMgr.parseBuffer<mlir::ModuleOp>(
|
|
std::move(transformBuffer), context, config);
|
|
if (!transformRoot)
|
|
return sourceMgr.checkResult(mlir::failure());
|
|
} else {
|
|
transformRoot = cast<mlir::ModuleOp>(payloadRoot->clone());
|
|
}
|
|
|
|
// Parse and merge the libraries into the main transform module.
|
|
for (auto &&transformLibrary : transformLibraries) {
|
|
mlir::OwningOpRef<mlir::ModuleOp> libraryModule =
|
|
sourceMgr.parseBuffer<mlir::ModuleOp>(std::move(transformLibrary),
|
|
context, config);
|
|
|
|
if (!libraryModule ||
|
|
mlir::failed(mlir::transform::detail::mergeSymbolsInto(
|
|
*transformRoot, std::move(libraryModule))))
|
|
return sourceMgr.checkResult(mlir::failure());
|
|
}
|
|
|
|
// If requested, dump the combined transform module.
|
|
if (clOptions->dumpLibraryModule)
|
|
transformRoot->dump();
|
|
|
|
// Find the entry point symbol. Even if it had originally been in the payload
|
|
// module, it was cloned into the transform module so only look there.
|
|
mlir::transform::TransformOpInterface entryPoint =
|
|
mlir::transform::detail::findTransformEntryPoint(
|
|
*transformRoot, mlir::ModuleOp(), clOptions->transformEntryPoint);
|
|
if (!entryPoint)
|
|
return sourceMgr.checkResult(mlir::failure());
|
|
|
|
// Apply the requested transformations.
|
|
mlir::transform::TransformOptions transformOptions;
|
|
transformOptions.enableExpensiveChecks(!clOptions->disableExpensiveChecks);
|
|
if (mlir::failed(applyTransforms(*payloadRoot, entryPoint, transformOptions)))
|
|
return sourceMgr.checkResult(mlir::failure());
|
|
|
|
// Print the transformed result and check the captured diagnostics if
|
|
// requested.
|
|
payloadRoot->print(os);
|
|
return sourceMgr.checkResult(mlir::success());
|
|
}
|
|
|
|
/// Tool entry point.
|
|
static llvm::LogicalResult runMain(int argc, char **argv) {
|
|
// Register all upstream dialects and extensions. Specific uses are advised
|
|
// not to register all dialects indiscriminately but rather hand-pick what is
|
|
// necessary for their use case.
|
|
mlir::DialectRegistry registry;
|
|
mlir::registerAllDialects(registry);
|
|
mlir::registerAllExtensions(registry);
|
|
mlir::registerAllPasses();
|
|
|
|
// Explicitly register the transform dialect. This is not strictly necessary
|
|
// since it has been already registered as part of the upstream dialect list,
|
|
// but useful for example purposes for cases when dialects to register are
|
|
// hand-picked. The transform dialect must be registered.
|
|
registry.insert<mlir::transform::TransformDialect>();
|
|
|
|
// Register various command-line options. Note that the LLVM initializer
|
|
// object is a RAII that ensures correct deconstruction of command-line option
|
|
// objects inside ManagedStatic.
|
|
llvm::InitLLVM y(argc, argv);
|
|
mlir::registerAsmPrinterCLOptions();
|
|
mlir::registerMLIRContextCLOptions();
|
|
registerCLOptions();
|
|
llvm::cl::ParseCommandLineOptions(argc, argv,
|
|
"Minimal Transform dialect driver\n");
|
|
|
|
// Try opening the main input file.
|
|
std::string errorMessage;
|
|
std::unique_ptr<llvm::MemoryBuffer> payloadFile =
|
|
mlir::openInputFile(clOptions->payloadFilename, &errorMessage);
|
|
if (!payloadFile) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return mlir::failure();
|
|
}
|
|
|
|
// Try opening the output file.
|
|
std::unique_ptr<llvm::ToolOutputFile> outputFile =
|
|
mlir::openOutputFile(clOptions->outputFilename, &errorMessage);
|
|
if (!outputFile) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return mlir::failure();
|
|
}
|
|
|
|
// Try opening the main transform file if provided.
|
|
std::unique_ptr<llvm::MemoryBuffer> transformRootFile;
|
|
if (!clOptions->transformMainFilename.empty()) {
|
|
if (clOptions->transformMainFilename == clOptions->payloadFilename) {
|
|
llvm::errs() << "warning: " << clOptions->payloadFilename
|
|
<< " is provided as both payload and transform file\n";
|
|
} else {
|
|
transformRootFile =
|
|
mlir::openInputFile(clOptions->transformMainFilename, &errorMessage);
|
|
if (!transformRootFile) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
}
|
|
|
|
// Try opening transform library files if provided.
|
|
SmallVector<std::unique_ptr<llvm::MemoryBuffer>> transformLibraries;
|
|
transformLibraries.reserve(clOptions->transformLibraryFilenames.size());
|
|
for (llvm::StringRef filename : clOptions->transformLibraryFilenames) {
|
|
transformLibraries.emplace_back(
|
|
mlir::openInputFile(filename, &errorMessage));
|
|
if (!transformLibraries.back()) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
|
|
return processPayloadBuffer(outputFile->os(), std::move(payloadFile),
|
|
std::move(transformRootFile), transformLibraries,
|
|
registry);
|
|
}
|
|
|
|
int main(int argc, char **argv) {
|
|
return mlir::asMainReturnCode(runMain(argc, argv));
|
|
}
|