[mlir][IR] Add RewriterBase::moveBlockBefore and fix bug in moveOpBefore (#79579)

This commit adds a new method to the rewriter API: `moveBlockBefore`.
This op is utilized by `inlineRegionBefore` and covered by dialect
conversion test cases.

Also fixes a bug in `moveOpBefore`, where the previous op location was
not passed correctly. Adds a test case to
`test-strict-pattern-driver.mlir`.
This commit is contained in:
Matthias Springer 2024-01-31 11:25:11 +01:00 committed by GitHub
parent b210cbbd0e
commit da784a2555
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 92 additions and 17 deletions

View File

@ -67,6 +67,10 @@ public:
/// specific block.
void moveBefore(Block *block);
/// Unlink this block from its current region and insert it right before the
/// block that the given iterator points to in the region region.
void moveBefore(Region *region, llvm::iplist<Block>::iterator iterator);
/// Unlink this Block from its parent region and delete it.
void erase();

View File

@ -614,6 +614,13 @@ public:
virtual void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator);
/// Unlink this block and insert it right before `existingBlock`.
void moveBlockBefore(Block *block, Block *anotherBlock);
/// Unlink this block and insert it right before the location that the given
/// iterator points to in the given region.
void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeOpModification` or

View File

@ -52,8 +52,13 @@ void Block::insertAfter(Block *block) {
/// specific block.
void Block::moveBefore(Block *block) {
assert(block->getParent() && "cannot insert before a block without a parent");
block->getParent()->getBlocks().splice(
block->getIterator(), getParent()->getBlocks(), getIterator());
moveBefore(block->getParent(), block->getIterator());
}
/// Unlink this block from its current region and insert it right before the
/// block that the given iterator points to in the region region.
void Block::moveBefore(Region *region, llvm::iplist<Block>::iterator iterator) {
region->getBlocks().splice(iterator, getParent()->getBlocks(), getIterator());
}
/// Unlink this Block from its parent Region and delete it.

View File

@ -350,11 +350,8 @@ void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
}
// Move blocks from the beginning of the region one-by-one.
while (!region.empty()) {
Block *block = &region.front();
parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
listener->notifyBlockInserted(block, &region, region.begin());
}
while (!region.empty())
moveBlockBefore(&region.front(), &parent, before);
}
void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
inlineRegionBefore(region, *before->getParent(), before->getIterator());
@ -378,6 +375,21 @@ void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
moveBlockBefore(block, anotherBlock->getParent(),
anotherBlock->getIterator());
}
void RewriterBase::moveBlockBefore(Block *block, Region *region,
Region::iterator iterator) {
Region *currentRegion = block->getParent();
Region::iterator nextIterator = std::next(block->getIterator());
block->moveBefore(region, iterator);
if (listener)
listener->notifyBlockInserted(block, /*previous=*/currentRegion,
/*previousIt=*/nextIterator);
}
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
}
@ -385,11 +397,11 @@ void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
void RewriterBase::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
Block::iterator currentIterator = op->getIterator();
Block::iterator nextIterator = std::next(op->getIterator());
op->moveBefore(block, iterator);
if (listener)
listener->notifyOperationInserted(
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
op, /*previous=*/InsertPoint(currentBlock, nextIterator));
}
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
@ -398,10 +410,6 @@ void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
void RewriterBase::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
Block::iterator currentIterator = op->getIterator();
op->moveAfter(block, iterator);
if (listener)
listener->notifyOperationInserted(
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
assert(iterator != block->end() && "cannot move after end of block");
moveOpBefore(op, block, std::next(iterator));
}

View File

@ -24,6 +24,7 @@ func.func @test_erase() {
// -----
// CHECK-EN: notifyOperationInserted: test.insert_same_op, was unlinked
// CHECK-EN-LABEL: func @test_insert_same_op
// CHECK-EN-SAME: {pattern_driver_all_erased = false, pattern_driver_changed = true}
// CHECK-EN: "test.insert_same_op"() {skip = true}
@ -35,6 +36,7 @@ func.func @test_insert_same_op() {
// -----
// CHECK-EN: notifyOperationInserted: test.new_op, was unlinked
// CHECK-EN-LABEL: func @test_replace_with_new_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN: %[[n:.*]] = "test.new_op"
@ -49,6 +51,9 @@ func.func @test_replace_with_new_op() {
// -----
// CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
// CHECK-EN: notifyOperationRemoved: test.erase_op
// CHECK-EN-LABEL: func @test_replace_with_erase_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN-NOT: "test.replace_with_new_op"
@ -229,3 +234,18 @@ func.func @test_remove_diamond(%c: i1) {
}) : () -> ()
return
}
// -----
// CHECK-AN: notifyOperationInserted: test.move_before_parent_op, previous = test.dummy_terminator
// CHECK-AN-LABEL: func @test_move_op_before(
// CHECK-AN: test.move_before_parent_op
// CHECK-AN: test.op_with_region
// CHECK-AN: test.dummy_terminator
func.func @test_move_op_before() {
"test.op_with_region"() ({
"test.move_before_parent_op"() : () -> ()
"test.dummy_terminator"() : () ->()
}) : () -> ()
return
}

View File

@ -198,6 +198,21 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
}
};
/// This pattern moves "test.move_before_parent_op" before the parent op.
struct MoveBeforeParentOp : public RewritePattern {
MoveBeforeParentOp(MLIRContext *context)
: RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Do not hoist past functions.
if (isa<FunctionOpInterface>(op->getParentOp()))
return failure();
rewriter.moveOpBefore(op, op->getParentOp());
return success();
}
};
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@ -238,6 +253,20 @@ struct TestPatternDriver
};
struct DumpNotifications : public RewriterBase::Listener {
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
llvm::outs() << "notifyOperationInserted: " << op->getName();
if (!previous.isSet()) {
llvm::outs() << ", was unlinked\n";
} else {
if (previous.getPoint() == previous.getBlock()->end()) {
llvm::outs() << ", was last in block\n";
} else {
llvm::outs() << ", previous = " << previous.getPoint()->getName()
<< "\n";
}
}
}
void notifyOperationRemoved(Operation *op) override {
llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
}
@ -267,14 +296,16 @@ public:
ReplaceWithNewOp,
EraseOp,
ChangeBlockOp,
ImplicitChangeOp
ImplicitChangeOp,
MoveBeforeParentOp
// clang-format on
>(ctx);
SmallVector<Operation *> ops;
getOperation()->walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
opName == "test.replace_with_new_op" || opName == "test.erase_op") {
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
opName == "test.move_before_parent_op") {
ops.push_back(op);
}
});