mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-19 10:06:54 +00:00

This is a companion to #118583, although it can be landed independently because since #117922 dialects do not have to use the same Python binding framework as the Python core code. This PR ports all of the in-tree dialect and pass extensions to nanobind, with the exception of those that remain for testing pybind11 support. This PR also: * removes CollectDiagnosticsToStringScope from NanobindAdaptors.h. This was overlooked in a previous PR and it is duplicated in Diagnostics.h. --------- Co-authored-by: Jacques Pienaar <jpienaar@google.com>
177 lines
7.7 KiB
C++
177 lines
7.7 KiB
C++
//===- Pass.cpp - Pass Management -----------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "Pass.h"
|
|
|
|
#include "IRModule.h"
|
|
#include "mlir-c/Pass.h"
|
|
#include "mlir/Bindings/Python/Nanobind.h"
|
|
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
|
|
|
|
namespace nb = nanobind;
|
|
using namespace nb::literals;
|
|
using namespace mlir;
|
|
using namespace mlir::python;
|
|
|
|
namespace {
|
|
|
|
/// Owning Wrapper around a PassManager.
|
|
class PyPassManager {
|
|
public:
|
|
PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
|
|
PyPassManager(PyPassManager &&other) noexcept
|
|
: passManager(other.passManager) {
|
|
other.passManager.ptr = nullptr;
|
|
}
|
|
~PyPassManager() {
|
|
if (!mlirPassManagerIsNull(passManager))
|
|
mlirPassManagerDestroy(passManager);
|
|
}
|
|
MlirPassManager get() { return passManager; }
|
|
|
|
void release() { passManager.ptr = nullptr; }
|
|
nb::object getCapsule() {
|
|
return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
|
|
}
|
|
|
|
static nb::object createFromCapsule(nb::object capsule) {
|
|
MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
|
|
if (mlirPassManagerIsNull(rawPm))
|
|
throw nb::python_error();
|
|
return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
|
|
}
|
|
|
|
private:
|
|
MlirPassManager passManager;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
/// Create the `mlir.passmanager` here.
|
|
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
|
|
//----------------------------------------------------------------------------
|
|
// Mapping of the top-level PassManager
|
|
//----------------------------------------------------------------------------
|
|
nb::class_<PyPassManager>(m, "PassManager")
|
|
.def(
|
|
"__init__",
|
|
[](PyPassManager &self, const std::string &anchorOp,
|
|
DefaultingPyMlirContext context) {
|
|
MlirPassManager passManager = mlirPassManagerCreateOnOperation(
|
|
context->get(),
|
|
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
|
|
new (&self) PyPassManager(passManager);
|
|
},
|
|
"anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(),
|
|
"Create a new PassManager for the current (or provided) Context.")
|
|
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
|
|
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
|
|
.def("_testing_release", &PyPassManager::release,
|
|
"Releases (leaks) the backing pass manager (testing)")
|
|
.def(
|
|
"enable_ir_printing",
|
|
[](PyPassManager &passManager, bool printBeforeAll,
|
|
bool printAfterAll, bool printModuleScope, bool printAfterChange,
|
|
bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
|
|
bool enableDebugInfo, bool printGenericOpForm,
|
|
std::optional<std::string> optionalTreePrintingPath) {
|
|
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
|
|
if (largeElementsLimit)
|
|
mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
|
|
*largeElementsLimit);
|
|
if (enableDebugInfo)
|
|
mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
|
|
/*prettyForm=*/false);
|
|
if (printGenericOpForm)
|
|
mlirOpPrintingFlagsPrintGenericOpForm(flags);
|
|
std::string treePrintingPath = "";
|
|
if (optionalTreePrintingPath.has_value())
|
|
treePrintingPath = optionalTreePrintingPath.value();
|
|
mlirPassManagerEnableIRPrinting(
|
|
passManager.get(), printBeforeAll, printAfterAll,
|
|
printModuleScope, printAfterChange, printAfterFailure, flags,
|
|
mlirStringRefCreate(treePrintingPath.data(),
|
|
treePrintingPath.size()));
|
|
mlirOpPrintingFlagsDestroy(flags);
|
|
},
|
|
"print_before_all"_a = false, "print_after_all"_a = true,
|
|
"print_module_scope"_a = false, "print_after_change"_a = false,
|
|
"print_after_failure"_a = false,
|
|
"large_elements_limit"_a.none() = nb::none(),
|
|
"enable_debug_info"_a = false, "print_generic_op_form"_a = false,
|
|
"tree_printing_dir_path"_a.none() = nb::none(),
|
|
"Enable IR printing, default as mlir-print-ir-after-all.")
|
|
.def(
|
|
"enable_verifier",
|
|
[](PyPassManager &passManager, bool enable) {
|
|
mlirPassManagerEnableVerifier(passManager.get(), enable);
|
|
},
|
|
"enable"_a, "Enable / disable verify-each.")
|
|
.def_static(
|
|
"parse",
|
|
[](const std::string &pipeline, DefaultingPyMlirContext context) {
|
|
MlirPassManager passManager = mlirPassManagerCreate(context->get());
|
|
PyPrintAccumulator errorMsg;
|
|
MlirLogicalResult status = mlirParsePassPipeline(
|
|
mlirPassManagerGetAsOpPassManager(passManager),
|
|
mlirStringRefCreate(pipeline.data(), pipeline.size()),
|
|
errorMsg.getCallback(), errorMsg.getUserData());
|
|
if (mlirLogicalResultIsFailure(status))
|
|
throw nb::value_error(errorMsg.join().c_str());
|
|
return new PyPassManager(passManager);
|
|
},
|
|
"pipeline"_a, "context"_a.none() = nb::none(),
|
|
"Parse a textual pass-pipeline and return a top-level PassManager "
|
|
"that can be applied on a Module. Throw a ValueError if the pipeline "
|
|
"can't be parsed")
|
|
.def(
|
|
"add",
|
|
[](PyPassManager &passManager, const std::string &pipeline) {
|
|
PyPrintAccumulator errorMsg;
|
|
MlirLogicalResult status = mlirOpPassManagerAddPipeline(
|
|
mlirPassManagerGetAsOpPassManager(passManager.get()),
|
|
mlirStringRefCreate(pipeline.data(), pipeline.size()),
|
|
errorMsg.getCallback(), errorMsg.getUserData());
|
|
if (mlirLogicalResultIsFailure(status))
|
|
throw nb::value_error(errorMsg.join().c_str());
|
|
},
|
|
"pipeline"_a,
|
|
"Add textual pipeline elements to the pass manager. Throws a "
|
|
"ValueError if the pipeline can't be parsed.")
|
|
.def(
|
|
"run",
|
|
[](PyPassManager &passManager, PyOperationBase &op,
|
|
bool invalidateOps) {
|
|
if (invalidateOps) {
|
|
op.getOperation().getContext()->clearOperationsInside(op);
|
|
}
|
|
// Actually run the pass manager.
|
|
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
|
|
MlirLogicalResult status = mlirPassManagerRunOnOp(
|
|
passManager.get(), op.getOperation().get());
|
|
if (mlirLogicalResultIsFailure(status))
|
|
throw MLIRError("Failure while executing pass pipeline",
|
|
errors.take());
|
|
},
|
|
"operation"_a, "invalidate_ops"_a = true,
|
|
"Run the pass manager on the provided operation, raising an "
|
|
"MLIRError on failure.")
|
|
.def(
|
|
"__str__",
|
|
[](PyPassManager &self) {
|
|
MlirPassManager passManager = self.get();
|
|
PyPrintAccumulator printAccum;
|
|
mlirPrintPassPipeline(
|
|
mlirPassManagerGetAsOpPassManager(passManager),
|
|
printAccum.getCallback(), printAccum.getUserData());
|
|
return printAccum.join();
|
|
},
|
|
"Print the textual representation for this PassManager, suitable to "
|
|
"be passed to `parse` for round-tripping.");
|
|
}
|