[mlir][IR] Trigger notifyOperationRemoved callback for nested ops (#66771)

When cloning an op, the `notifyOperationInserted` callback is triggered
for all nested ops. Similarly, the `notifyOperationRemoved` callback
should be triggered for all nested ops when removing an op.

Listeners may inspect the IR during a `notifyOperationRemoved` callback.
Therefore, when multiple ops are removed in a single
`RewriterBase::eraseOp` call, the notifications must be triggered in an
order in which the ops could have been removed one-by-one:

* Op removals must be interleaved with `notifyOperationRemoved`
callbacks. A callback is triggered right before the respective op is
removed.
* Ops are removed post-order and in reverse order. Other traversal
orders could delete an op that still has uses. (This is not avoidable in
graph regions and with cyclic block graphs.)

Differential Revision: Imported from https://reviews.llvm.org/D144193.
This commit is contained in:
Matthias Springer 2023-09-20 08:45:46 +02:00 committed by GitHub
parent a317afaf00
commit 695a5a6a66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 251 additions and 25 deletions

View File

@ -43,6 +43,12 @@ public:
/// not implement the RegionKindInterface.
bool mayHaveSSADominance(Region &region);
/// Return "true" if the given region may be a graph region without SSA
/// dominance. This function returns "true" in case the owner op is an
/// unregistered op. It returns "false" if it is a registered op that does not
/// implement the RegionKindInterface.
bool mayBeGraphRegion(Region &region);
} // namespace mlir
#include "mlir/IR/RegionKindInterface.h.inc"

View File

@ -394,12 +394,9 @@ public:
protected:
void notifyOperationRemoved(Operation *op) override {
// TODO: Walk can be removed when D144193 has landed.
op->walk([&](Operation *op) {
erasedOps.insert(op);
// Erase if present.
toMemrefOps.erase(op);
});
erasedOps.insert(op);
// Erase if present.
toMemrefOps.erase(op);
}
void notifyOperationInserted(Operation *op) override {

View File

@ -8,6 +8,8 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"
using namespace mlir;
@ -275,7 +277,7 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
for (auto it : llvm::zip(op->getResults(), newValues))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
// Erase the op.
// Erase op and notify listener.
eraseOp(op);
}
@ -295,7 +297,7 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
// Erase the old op.
// Erase op and notify listener.
eraseOp(op);
}
@ -303,9 +305,71 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
/// the given operation *must* be known to be dead.
void RewriterBase::eraseOp(Operation *op) {
assert(op->use_empty() && "expected 'op' to have no uses");
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
// Fast path: If no listener is attached, the op can be dropped in one go.
if (!rewriteListener) {
op->erase();
return;
}
// Helper function that erases a single op.
auto eraseSingleOp = [&](Operation *op) {
#ifndef NDEBUG
// All nested ops should have been erased already.
assert(
llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
"expected empty regions");
// All users should have been erased already if the op is in a region with
// SSA dominance.
if (!op->use_empty() && op->getParentOp())
assert(mayBeGraphRegion(*op->getParentRegion()) &&
"expected that op has no uses");
#endif // NDEBUG
rewriteListener->notifyOperationRemoved(op);
op->erase();
// Explicitly drop all uses in case the op is in a graph region.
op->dropAllUses();
op->erase();
};
// Nested ops must be erased one-by-one, so that listeners have a consistent
// view of the IR every time a notification is triggered. Users must be
// erased before definitions. I.e., post-order, reverse dominance.
std::function<void(Operation *)> eraseTree = [&](Operation *op) {
// Erase nested ops.
for (Region &r : llvm::reverse(op->getRegions())) {
// Erase all blocks in the right order. Successors should be erased
// before predecessors because successor blocks may use values defined
// in predecessor blocks. A post-order traversal of blocks within a
// region visits successors before predecessors. Repeat the traversal
// until the region is empty. (The block graph could be disconnected.)
while (!r.empty()) {
SmallVector<Block *> erasedBlocks;
for (Block *b : llvm::post_order(&r.front())) {
// Visit ops in reverse order.
for (Operation &op :
llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
eraseTree(&op);
// Do not erase the block immediately. This is not supprted by the
// post_order iterator.
erasedBlocks.push_back(b);
}
for (Block *b : erasedBlocks) {
// Explicitly drop all uses in case there is a cycle in the block
// graph.
for (BlockArgument bbArg : b->getArguments())
bbArg.dropAllUses();
b->dropAllUses();
b->erase();
}
}
}
// Then erase the enclosing op.
eraseSingleOp(op);
};
eraseTree(op);
}
void RewriterBase::eraseBlock(Block *block) {

View File

@ -18,9 +18,17 @@ using namespace mlir;
#include "mlir/IR/RegionKindInterface.cpp.inc"
bool mlir::mayHaveSSADominance(Region &region) {
auto regionKindOp =
dyn_cast_if_present<RegionKindInterface>(region.getParentOp());
auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
if (!regionKindOp)
return true;
return regionKindOp.hasSSADominance(region.getRegionNumber());
}
bool mlir::mayBeGraphRegion(Region &region) {
if (!region.getParentOp()->isRegistered())
return true;
auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
if (!regionKindOp)
return false;
return !regionKindOp.hasSSADominance(region.getRegionNumber());
}

View File

@ -421,8 +421,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// If the operation is trivially dead - remove it.
if (isOpTriviallyDead(op)) {
notifyOperationRemoved(op);
op->erase();
eraseOp(op);
changed = true;
LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
@ -567,10 +566,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
config.listener->notifyOperationRemoved(op);
addOperandsToWorklist(op->getOperands());
op->walk([this](Operation *operation) {
worklist.remove(operation);
folder.notifyRemoval(operation);
});
worklist.remove(op);
folder.notifyRemoval(op);
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);

View File

@ -12,9 +12,9 @@
// CHECK-EN-LABEL: func @test_erase
// CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN: test.arg0
// CHECK-EN: test.arg1
// CHECK-EN-NOT: test.erase_op
// CHECK-EN: "test.arg0"
// CHECK-EN: "test.arg1"
// CHECK-EN-NOT: "test.erase_op"
func.func @test_erase() {
%0 = "test.arg0"() : () -> (i32)
%1 = "test.arg1"() : () -> (i32)
@ -51,13 +51,13 @@ func.func @test_replace_with_new_op() {
// CHECK-EN-LABEL: func @test_replace_with_erase_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN-NOT: test.replace_with_new_op
// CHECK-EN-NOT: test.erase_op
// CHECK-EN-NOT: "test.replace_with_new_op"
// CHECK-EN-NOT: "test.erase_op"
// CHECK-EX-LABEL: func @test_replace_with_erase_op
// CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EX-NOT: test.replace_with_new_op
// CHECK-EX: test.erase_op
// CHECK-EX-NOT: "test.replace_with_new_op"
// CHECK-EX: "test.erase_op"
func.func @test_replace_with_erase_op() {
"test.replace_with_new_op"() {create_erase_op} : () -> ()
return
@ -83,3 +83,149 @@ func.func @test_trigger_rewrite_through_block() {
// in turn, replaces the successor with bb3.
"test.implicit_change_op"() [^bb1] : () -> ()
}
// -----
// CHECK-AN: notifyOperationRemoved: test.foo_b
// CHECK-AN: notifyOperationRemoved: test.foo_a
// CHECK-AN: notifyOperationRemoved: test.graph_region
// CHECK-AN: notifyOperationRemoved: test.erase_op
// CHECK-AN-LABEL: func @test_remove_graph_region()
// CHECK-AN-NEXT: return
func.func @test_remove_graph_region() {
"test.erase_op"() ({
test.graph_region {
%0 = "test.foo_a"(%1) : (i1) -> (i1)
%1 = "test.foo_b"(%0) : (i1) -> (i1)
}
}) : () -> ()
return
}
// -----
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.bar
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.foo
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.dummy_op
// CHECK-AN: notifyOperationRemoved: test.erase_op
// CHECK-AN-LABEL: func @test_remove_cyclic_blocks()
// CHECK-AN-NEXT: return
func.func @test_remove_cyclic_blocks() {
"test.erase_op"() ({
%x = "test.dummy_op"() : () -> (i1)
cf.br ^bb1(%x: i1)
^bb1(%arg0: i1):
"test.foo"(%x) : (i1) -> ()
cf.br ^bb2(%arg0: i1)
^bb2(%arg1: i1):
"test.bar"(%x) : (i1) -> ()
cf.br ^bb1(%arg1: i1)
}) : () -> ()
return
}
// -----
// CHECK-AN: notifyOperationRemoved: test.dummy_op
// CHECK-AN: notifyOperationRemoved: test.bar
// CHECK-AN: notifyOperationRemoved: test.qux
// CHECK-AN: notifyOperationRemoved: test.qux_unreachable
// CHECK-AN: notifyOperationRemoved: test.nested_dummy
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.foo
// CHECK-AN: notifyOperationRemoved: test.erase_op
// CHECK-AN-LABEL: func @test_remove_dead_blocks()
// CHECK-AN-NEXT: return
func.func @test_remove_dead_blocks() {
"test.erase_op"() ({
"test.dummy_op"() : () -> (i1)
// The following blocks are not reachable. Still, ^bb2 should be deleted
// befire ^bb1.
^bb1(%arg0: i1):
"test.foo"() : () -> ()
cf.br ^bb2(%arg0: i1)
^bb2(%arg1: i1):
"test.nested_dummy"() ({
"test.qux"() : () -> ()
// The following block is unreachable.
^bb3:
"test.qux_unreachable"() : () -> ()
}) : () -> ()
"test.bar"() : () -> ()
}) : () -> ()
return
}
// -----
// test.nested_* must be deleted before test.foo.
// test.bar must be deleted before test.foo.
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.bar
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.nested_b
// CHECK-AN: notifyOperationRemoved: test.nested_a
// CHECK-AN: notifyOperationRemoved: test.nested_d
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.nested_e
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.nested_c
// CHECK-AN: notifyOperationRemoved: test.foo
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.dummy_op
// CHECK-AN: notifyOperationRemoved: test.erase_op
// CHECK-AN-LABEL: func @test_remove_nested_ops()
// CHECK-AN-NEXT: return
func.func @test_remove_nested_ops() {
"test.erase_op"() ({
%x = "test.dummy_op"() : () -> (i1)
cf.br ^bb1(%x: i1)
^bb1(%arg0: i1):
"test.foo"() ({
"test.nested_a"() : () -> ()
"test.nested_b"() : () -> ()
^dead1:
"test.nested_c"() : () -> ()
cf.br ^dead3
^dead2:
"test.nested_d"() : () -> ()
^dead3:
"test.nested_e"() : () -> ()
cf.br ^dead2
}) : () -> ()
cf.br ^bb2(%arg0: i1)
^bb2(%arg1: i1):
"test.bar"(%x) : (i1) -> ()
cf.br ^bb1(%arg1: i1)
}) : () -> ()
return
}
// -----
// CHECK-AN: notifyOperationRemoved: test.qux
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.foo
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.bar
// CHECK-AN: notifyOperationRemoved: cf.cond_br
// CHECK-AN-LABEL: func @test_remove_diamond(
// CHECK-AN-NEXT: return
func.func @test_remove_diamond(%c: i1) {
"test.erase_op"() ({
cf.cond_br %c, ^bb1, ^bb2
^bb1:
"test.foo"() : () -> ()
cf.br ^bb3
^bb2:
"test.bar"() : () -> ()
cf.br ^bb3
^bb3:
"test.qux"() : () -> ()
}) : () -> ()
return
}

View File

@ -239,6 +239,12 @@ struct TestPatternDriver
llvm::cl::init(GreedyRewriteConfig().maxIterations)};
};
struct DumpNotifications : public RewriterBase::Listener {
void notifyOperationRemoved(Operation *op) override {
llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
}
};
struct TestStrictPatternDriver
: public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
public:
@ -275,7 +281,9 @@ public:
}
});
DumpNotifications dumpNotifications;
GreedyRewriteConfig config;
config.listener = &dumpNotifications;
if (strictMode == "AnyOp") {
config.strictMode = GreedyRewriteStrictness::AnyOp;
} else if (strictMode == "ExistingAndNewOps") {