mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-24 04:26:07 +00:00
[mlir][IR] Trigger notifyOperationRemoved
callback for nested ops (#66771)
When cloning an op, the `notifyOperationInserted` callback is triggered for all nested ops. Similarly, the `notifyOperationRemoved` callback should be triggered for all nested ops when removing an op. Listeners may inspect the IR during a `notifyOperationRemoved` callback. Therefore, when multiple ops are removed in a single `RewriterBase::eraseOp` call, the notifications must be triggered in an order in which the ops could have been removed one-by-one: * Op removals must be interleaved with `notifyOperationRemoved` callbacks. A callback is triggered right before the respective op is removed. * Ops are removed post-order and in reverse order. Other traversal orders could delete an op that still has uses. (This is not avoidable in graph regions and with cyclic block graphs.) Differential Revision: Imported from https://reviews.llvm.org/D144193.
This commit is contained in:
parent
a317afaf00
commit
695a5a6a66
@ -43,6 +43,12 @@ public:
|
||||
/// not implement the RegionKindInterface.
|
||||
bool mayHaveSSADominance(Region ®ion);
|
||||
|
||||
/// Return "true" if the given region may be a graph region without SSA
|
||||
/// dominance. This function returns "true" in case the owner op is an
|
||||
/// unregistered op. It returns "false" if it is a registered op that does not
|
||||
/// implement the RegionKindInterface.
|
||||
bool mayBeGraphRegion(Region ®ion);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#include "mlir/IR/RegionKindInterface.h.inc"
|
||||
|
@ -394,12 +394,9 @@ public:
|
||||
|
||||
protected:
|
||||
void notifyOperationRemoved(Operation *op) override {
|
||||
// TODO: Walk can be removed when D144193 has landed.
|
||||
op->walk([&](Operation *op) {
|
||||
erasedOps.insert(op);
|
||||
// Erase if present.
|
||||
toMemrefOps.erase(op);
|
||||
});
|
||||
erasedOps.insert(op);
|
||||
// Erase if present.
|
||||
toMemrefOps.erase(op);
|
||||
}
|
||||
|
||||
void notifyOperationInserted(Operation *op) override {
|
||||
|
@ -8,6 +8,8 @@
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Iterators.h"
|
||||
#include "mlir/IR/RegionKindInterface.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@ -275,7 +277,7 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
|
||||
for (auto it : llvm::zip(op->getResults(), newValues))
|
||||
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
|
||||
|
||||
// Erase the op.
|
||||
// Erase op and notify listener.
|
||||
eraseOp(op);
|
||||
}
|
||||
|
||||
@ -295,7 +297,7 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
|
||||
for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
|
||||
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
|
||||
|
||||
// Erase the old op.
|
||||
// Erase op and notify listener.
|
||||
eraseOp(op);
|
||||
}
|
||||
|
||||
@ -303,9 +305,71 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
|
||||
/// the given operation *must* be known to be dead.
|
||||
void RewriterBase::eraseOp(Operation *op) {
|
||||
assert(op->use_empty() && "expected 'op' to have no uses");
|
||||
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
|
||||
auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
|
||||
|
||||
// Fast path: If no listener is attached, the op can be dropped in one go.
|
||||
if (!rewriteListener) {
|
||||
op->erase();
|
||||
return;
|
||||
}
|
||||
|
||||
// Helper function that erases a single op.
|
||||
auto eraseSingleOp = [&](Operation *op) {
|
||||
#ifndef NDEBUG
|
||||
// All nested ops should have been erased already.
|
||||
assert(
|
||||
llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
|
||||
"expected empty regions");
|
||||
// All users should have been erased already if the op is in a region with
|
||||
// SSA dominance.
|
||||
if (!op->use_empty() && op->getParentOp())
|
||||
assert(mayBeGraphRegion(*op->getParentRegion()) &&
|
||||
"expected that op has no uses");
|
||||
#endif // NDEBUG
|
||||
rewriteListener->notifyOperationRemoved(op);
|
||||
op->erase();
|
||||
|
||||
// Explicitly drop all uses in case the op is in a graph region.
|
||||
op->dropAllUses();
|
||||
op->erase();
|
||||
};
|
||||
|
||||
// Nested ops must be erased one-by-one, so that listeners have a consistent
|
||||
// view of the IR every time a notification is triggered. Users must be
|
||||
// erased before definitions. I.e., post-order, reverse dominance.
|
||||
std::function<void(Operation *)> eraseTree = [&](Operation *op) {
|
||||
// Erase nested ops.
|
||||
for (Region &r : llvm::reverse(op->getRegions())) {
|
||||
// Erase all blocks in the right order. Successors should be erased
|
||||
// before predecessors because successor blocks may use values defined
|
||||
// in predecessor blocks. A post-order traversal of blocks within a
|
||||
// region visits successors before predecessors. Repeat the traversal
|
||||
// until the region is empty. (The block graph could be disconnected.)
|
||||
while (!r.empty()) {
|
||||
SmallVector<Block *> erasedBlocks;
|
||||
for (Block *b : llvm::post_order(&r.front())) {
|
||||
// Visit ops in reverse order.
|
||||
for (Operation &op :
|
||||
llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
|
||||
eraseTree(&op);
|
||||
// Do not erase the block immediately. This is not supprted by the
|
||||
// post_order iterator.
|
||||
erasedBlocks.push_back(b);
|
||||
}
|
||||
for (Block *b : erasedBlocks) {
|
||||
// Explicitly drop all uses in case there is a cycle in the block
|
||||
// graph.
|
||||
for (BlockArgument bbArg : b->getArguments())
|
||||
bbArg.dropAllUses();
|
||||
b->dropAllUses();
|
||||
b->erase();
|
||||
}
|
||||
}
|
||||
}
|
||||
// Then erase the enclosing op.
|
||||
eraseSingleOp(op);
|
||||
};
|
||||
|
||||
eraseTree(op);
|
||||
}
|
||||
|
||||
void RewriterBase::eraseBlock(Block *block) {
|
||||
|
@ -18,9 +18,17 @@ using namespace mlir;
|
||||
#include "mlir/IR/RegionKindInterface.cpp.inc"
|
||||
|
||||
bool mlir::mayHaveSSADominance(Region ®ion) {
|
||||
auto regionKindOp =
|
||||
dyn_cast_if_present<RegionKindInterface>(region.getParentOp());
|
||||
auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
|
||||
if (!regionKindOp)
|
||||
return true;
|
||||
return regionKindOp.hasSSADominance(region.getRegionNumber());
|
||||
}
|
||||
|
||||
bool mlir::mayBeGraphRegion(Region ®ion) {
|
||||
if (!region.getParentOp()->isRegistered())
|
||||
return true;
|
||||
auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
|
||||
if (!regionKindOp)
|
||||
return false;
|
||||
return !regionKindOp.hasSSADominance(region.getRegionNumber());
|
||||
}
|
||||
|
@ -421,8 +421,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
|
||||
|
||||
// If the operation is trivially dead - remove it.
|
||||
if (isOpTriviallyDead(op)) {
|
||||
notifyOperationRemoved(op);
|
||||
op->erase();
|
||||
eraseOp(op);
|
||||
changed = true;
|
||||
|
||||
LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
|
||||
@ -567,10 +566,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
|
||||
config.listener->notifyOperationRemoved(op);
|
||||
|
||||
addOperandsToWorklist(op->getOperands());
|
||||
op->walk([this](Operation *operation) {
|
||||
worklist.remove(operation);
|
||||
folder.notifyRemoval(operation);
|
||||
});
|
||||
worklist.remove(op);
|
||||
folder.notifyRemoval(op);
|
||||
|
||||
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
|
||||
strictModeFilteredOps.erase(op);
|
||||
|
@ -12,9 +12,9 @@
|
||||
|
||||
// CHECK-EN-LABEL: func @test_erase
|
||||
// CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true}
|
||||
// CHECK-EN: test.arg0
|
||||
// CHECK-EN: test.arg1
|
||||
// CHECK-EN-NOT: test.erase_op
|
||||
// CHECK-EN: "test.arg0"
|
||||
// CHECK-EN: "test.arg1"
|
||||
// CHECK-EN-NOT: "test.erase_op"
|
||||
func.func @test_erase() {
|
||||
%0 = "test.arg0"() : () -> (i32)
|
||||
%1 = "test.arg1"() : () -> (i32)
|
||||
@ -51,13 +51,13 @@ func.func @test_replace_with_new_op() {
|
||||
|
||||
// CHECK-EN-LABEL: func @test_replace_with_erase_op
|
||||
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
|
||||
// CHECK-EN-NOT: test.replace_with_new_op
|
||||
// CHECK-EN-NOT: test.erase_op
|
||||
// CHECK-EN-NOT: "test.replace_with_new_op"
|
||||
// CHECK-EN-NOT: "test.erase_op"
|
||||
|
||||
// CHECK-EX-LABEL: func @test_replace_with_erase_op
|
||||
// CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
|
||||
// CHECK-EX-NOT: test.replace_with_new_op
|
||||
// CHECK-EX: test.erase_op
|
||||
// CHECK-EX-NOT: "test.replace_with_new_op"
|
||||
// CHECK-EX: "test.erase_op"
|
||||
func.func @test_replace_with_erase_op() {
|
||||
"test.replace_with_new_op"() {create_erase_op} : () -> ()
|
||||
return
|
||||
@ -83,3 +83,149 @@ func.func @test_trigger_rewrite_through_block() {
|
||||
// in turn, replaces the successor with bb3.
|
||||
"test.implicit_change_op"() [^bb1] : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-AN: notifyOperationRemoved: test.foo_b
|
||||
// CHECK-AN: notifyOperationRemoved: test.foo_a
|
||||
// CHECK-AN: notifyOperationRemoved: test.graph_region
|
||||
// CHECK-AN: notifyOperationRemoved: test.erase_op
|
||||
// CHECK-AN-LABEL: func @test_remove_graph_region()
|
||||
// CHECK-AN-NEXT: return
|
||||
func.func @test_remove_graph_region() {
|
||||
"test.erase_op"() ({
|
||||
test.graph_region {
|
||||
%0 = "test.foo_a"(%1) : (i1) -> (i1)
|
||||
%1 = "test.foo_b"(%0) : (i1) -> (i1)
|
||||
}
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-AN: notifyOperationRemoved: cf.br
|
||||
// CHECK-AN: notifyOperationRemoved: test.bar
|
||||
// CHECK-AN: notifyOperationRemoved: cf.br
|
||||
// CHECK-AN: notifyOperationRemoved: test.foo
|
||||
// CHECK-AN: notifyOperationRemoved: cf.br
|
||||
// CHECK-AN: notifyOperationRemoved: test.dummy_op
|
||||
// CHECK-AN: notifyOperationRemoved: test.erase_op
|
||||
// CHECK-AN-LABEL: func @test_remove_cyclic_blocks()
|
||||
// CHECK-AN-NEXT: return
|
||||
func.func @test_remove_cyclic_blocks() {
|
||||
"test.erase_op"() ({
|
||||
%x = "test.dummy_op"() : () -> (i1)
|
||||
cf.br ^bb1(%x: i1)
|
||||
^bb1(%arg0: i1):
|
||||
"test.foo"(%x) : (i1) -> ()
|
||||
cf.br ^bb2(%arg0: i1)
|
||||
^bb2(%arg1: i1):
|
||||
"test.bar"(%x) : (i1) -> ()
|
||||
cf.br ^bb1(%arg1: i1)
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-AN: notifyOperationRemoved: test.dummy_op
|
||||
// CHECK-AN: notifyOperationRemoved: test.bar
|
||||
// CHECK-AN: notifyOperationRemoved: test.qux
|
||||
// CHECK-AN: notifyOperationRemoved: test.qux_unreachable
|
||||
// CHECK-AN: notifyOperationRemoved: test.nested_dummy
|
||||
// CHECK-AN: notifyOperationRemoved: cf.br
|
||||
// CHECK-AN: notifyOperationRemoved: test.foo
|
||||
// CHECK-AN: notifyOperationRemoved: test.erase_op
|
||||
// CHECK-AN-LABEL: func @test_remove_dead_blocks()
|
||||
// CHECK-AN-NEXT: return
|
||||
func.func @test_remove_dead_blocks() {
|
||||
"test.erase_op"() ({
|
||||
"test.dummy_op"() : () -> (i1)
|
||||
// The following blocks are not reachable. Still, ^bb2 should be deleted
|
||||
// befire ^bb1.
|
||||
^bb1(%arg0: i1):
|
||||
"test.foo"() : () -> ()
|
||||
cf.br ^bb2(%arg0: i1)
|
||||
^bb2(%arg1: i1):
|
||||
"test.nested_dummy"() ({
|
||||
"test.qux"() : () -> ()
|
||||
// The following block is unreachable.
|
||||
^bb3:
|
||||
"test.qux_unreachable"() : () -> ()
|
||||
}) : () -> ()
|
||||
"test.bar"() : () -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// test.nested_* must be deleted before test.foo.
|
||||
// test.bar must be deleted before test.foo.
|
||||
|
||||
// CHECK-AN: notifyOperationRemoved: cf.br
|
||||
// CHECK-AN: notifyOperationRemoved: test.bar
|
||||
// CHECK-AN: notifyOperationRemoved: cf.br
|
||||
// CHECK-AN: notifyOperationRemoved: test.nested_b
|
||||
// CHECK-AN: notifyOperationRemoved: test.nested_a
|
||||
// CHECK-AN: notifyOperationRemoved: test.nested_d
|
||||
// CHECK-AN: notifyOperationRemoved: cf.br
|
||||
// CHECK-AN: notifyOperationRemoved: test.nested_e
|
||||
// CHECK-AN: notifyOperationRemoved: cf.br
|
||||
// CHECK-AN: notifyOperationRemoved: test.nested_c
|
||||
// CHECK-AN: notifyOperationRemoved: test.foo
|
||||
// CHECK-AN: notifyOperationRemoved: cf.br
|
||||
// CHECK-AN: notifyOperationRemoved: test.dummy_op
|
||||
// CHECK-AN: notifyOperationRemoved: test.erase_op
|
||||
// CHECK-AN-LABEL: func @test_remove_nested_ops()
|
||||
// CHECK-AN-NEXT: return
|
||||
func.func @test_remove_nested_ops() {
|
||||
"test.erase_op"() ({
|
||||
%x = "test.dummy_op"() : () -> (i1)
|
||||
cf.br ^bb1(%x: i1)
|
||||
^bb1(%arg0: i1):
|
||||
"test.foo"() ({
|
||||
"test.nested_a"() : () -> ()
|
||||
"test.nested_b"() : () -> ()
|
||||
^dead1:
|
||||
"test.nested_c"() : () -> ()
|
||||
cf.br ^dead3
|
||||
^dead2:
|
||||
"test.nested_d"() : () -> ()
|
||||
^dead3:
|
||||
"test.nested_e"() : () -> ()
|
||||
cf.br ^dead2
|
||||
}) : () -> ()
|
||||
cf.br ^bb2(%arg0: i1)
|
||||
^bb2(%arg1: i1):
|
||||
"test.bar"(%x) : (i1) -> ()
|
||||
cf.br ^bb1(%arg1: i1)
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-AN: notifyOperationRemoved: test.qux
|
||||
// CHECK-AN: notifyOperationRemoved: cf.br
|
||||
// CHECK-AN: notifyOperationRemoved: test.foo
|
||||
// CHECK-AN: notifyOperationRemoved: cf.br
|
||||
// CHECK-AN: notifyOperationRemoved: test.bar
|
||||
// CHECK-AN: notifyOperationRemoved: cf.cond_br
|
||||
// CHECK-AN-LABEL: func @test_remove_diamond(
|
||||
// CHECK-AN-NEXT: return
|
||||
func.func @test_remove_diamond(%c: i1) {
|
||||
"test.erase_op"() ({
|
||||
cf.cond_br %c, ^bb1, ^bb2
|
||||
^bb1:
|
||||
"test.foo"() : () -> ()
|
||||
cf.br ^bb3
|
||||
^bb2:
|
||||
"test.bar"() : () -> ()
|
||||
cf.br ^bb3
|
||||
^bb3:
|
||||
"test.qux"() : () -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
@ -239,6 +239,12 @@ struct TestPatternDriver
|
||||
llvm::cl::init(GreedyRewriteConfig().maxIterations)};
|
||||
};
|
||||
|
||||
struct DumpNotifications : public RewriterBase::Listener {
|
||||
void notifyOperationRemoved(Operation *op) override {
|
||||
llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
struct TestStrictPatternDriver
|
||||
: public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
|
||||
public:
|
||||
@ -275,7 +281,9 @@ public:
|
||||
}
|
||||
});
|
||||
|
||||
DumpNotifications dumpNotifications;
|
||||
GreedyRewriteConfig config;
|
||||
config.listener = &dumpNotifications;
|
||||
if (strictMode == "AnyOp") {
|
||||
config.strictMode = GreedyRewriteStrictness::AnyOp;
|
||||
} else if (strictMode == "ExistingAndNewOps") {
|
||||
|
Loading…
x
Reference in New Issue
Block a user