mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-24 04:16:08 +00:00
[mlir][IR] Add RewriterBase::moveBlockBefore
and fix bug in moveOpBefore
(#79579)
This commit adds a new method to the rewriter API: `moveBlockBefore`. This op is utilized by `inlineRegionBefore` and covered by dialect conversion test cases. Also fixes a bug in `moveOpBefore`, where the previous op location was not passed correctly. Adds a test case to `test-strict-pattern-driver.mlir`.
This commit is contained in:
parent
b210cbbd0e
commit
da784a2555
@ -67,6 +67,10 @@ public:
|
||||
/// specific block.
|
||||
void moveBefore(Block *block);
|
||||
|
||||
/// Unlink this block from its current region and insert it right before the
|
||||
/// block that the given iterator points to in the region region.
|
||||
void moveBefore(Region *region, llvm::iplist<Block>::iterator iterator);
|
||||
|
||||
/// Unlink this Block from its parent region and delete it.
|
||||
void erase();
|
||||
|
||||
|
@ -614,6 +614,13 @@ public:
|
||||
virtual void moveOpAfter(Operation *op, Block *block,
|
||||
Block::iterator iterator);
|
||||
|
||||
/// Unlink this block and insert it right before `existingBlock`.
|
||||
void moveBlockBefore(Block *block, Block *anotherBlock);
|
||||
|
||||
/// Unlink this block and insert it right before the location that the given
|
||||
/// iterator points to in the given region.
|
||||
void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
|
||||
|
||||
/// This method is used to notify the rewriter that an in-place operation
|
||||
/// modification is about to happen. A call to this function *must* be
|
||||
/// followed by a call to either `finalizeOpModification` or
|
||||
|
@ -52,8 +52,13 @@ void Block::insertAfter(Block *block) {
|
||||
/// specific block.
|
||||
void Block::moveBefore(Block *block) {
|
||||
assert(block->getParent() && "cannot insert before a block without a parent");
|
||||
block->getParent()->getBlocks().splice(
|
||||
block->getIterator(), getParent()->getBlocks(), getIterator());
|
||||
moveBefore(block->getParent(), block->getIterator());
|
||||
}
|
||||
|
||||
/// Unlink this block from its current region and insert it right before the
|
||||
/// block that the given iterator points to in the region region.
|
||||
void Block::moveBefore(Region *region, llvm::iplist<Block>::iterator iterator) {
|
||||
region->getBlocks().splice(iterator, getParent()->getBlocks(), getIterator());
|
||||
}
|
||||
|
||||
/// Unlink this Block from its parent Region and delete it.
|
||||
|
@ -350,11 +350,8 @@ void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent,
|
||||
}
|
||||
|
||||
// Move blocks from the beginning of the region one-by-one.
|
||||
while (!region.empty()) {
|
||||
Block *block = ®ion.front();
|
||||
parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
|
||||
listener->notifyBlockInserted(block, ®ion, region.begin());
|
||||
}
|
||||
while (!region.empty())
|
||||
moveBlockBefore(®ion.front(), &parent, before);
|
||||
}
|
||||
void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) {
|
||||
inlineRegionBefore(region, *before->getParent(), before->getIterator());
|
||||
@ -378,6 +375,21 @@ void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
|
||||
cloneRegionBefore(region, *before->getParent(), before->getIterator());
|
||||
}
|
||||
|
||||
void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
|
||||
moveBlockBefore(block, anotherBlock->getParent(),
|
||||
anotherBlock->getIterator());
|
||||
}
|
||||
|
||||
void RewriterBase::moveBlockBefore(Block *block, Region *region,
|
||||
Region::iterator iterator) {
|
||||
Region *currentRegion = block->getParent();
|
||||
Region::iterator nextIterator = std::next(block->getIterator());
|
||||
block->moveBefore(region, iterator);
|
||||
if (listener)
|
||||
listener->notifyBlockInserted(block, /*previous=*/currentRegion,
|
||||
/*previousIt=*/nextIterator);
|
||||
}
|
||||
|
||||
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
|
||||
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
|
||||
}
|
||||
@ -385,11 +397,11 @@ void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
|
||||
void RewriterBase::moveOpBefore(Operation *op, Block *block,
|
||||
Block::iterator iterator) {
|
||||
Block *currentBlock = op->getBlock();
|
||||
Block::iterator currentIterator = op->getIterator();
|
||||
Block::iterator nextIterator = std::next(op->getIterator());
|
||||
op->moveBefore(block, iterator);
|
||||
if (listener)
|
||||
listener->notifyOperationInserted(
|
||||
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
|
||||
op, /*previous=*/InsertPoint(currentBlock, nextIterator));
|
||||
}
|
||||
|
||||
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
|
||||
@ -398,10 +410,6 @@ void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
|
||||
|
||||
void RewriterBase::moveOpAfter(Operation *op, Block *block,
|
||||
Block::iterator iterator) {
|
||||
Block *currentBlock = op->getBlock();
|
||||
Block::iterator currentIterator = op->getIterator();
|
||||
op->moveAfter(block, iterator);
|
||||
if (listener)
|
||||
listener->notifyOperationInserted(
|
||||
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
|
||||
assert(iterator != block->end() && "cannot move after end of block");
|
||||
moveOpBefore(op, block, std::next(iterator));
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ func.func @test_erase() {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-EN: notifyOperationInserted: test.insert_same_op, was unlinked
|
||||
// CHECK-EN-LABEL: func @test_insert_same_op
|
||||
// CHECK-EN-SAME: {pattern_driver_all_erased = false, pattern_driver_changed = true}
|
||||
// CHECK-EN: "test.insert_same_op"() {skip = true}
|
||||
@ -35,6 +36,7 @@ func.func @test_insert_same_op() {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-EN: notifyOperationInserted: test.new_op, was unlinked
|
||||
// CHECK-EN-LABEL: func @test_replace_with_new_op
|
||||
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
|
||||
// CHECK-EN: %[[n:.*]] = "test.new_op"
|
||||
@ -49,6 +51,9 @@ func.func @test_replace_with_new_op() {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
|
||||
// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
|
||||
// CHECK-EN: notifyOperationRemoved: test.erase_op
|
||||
// CHECK-EN-LABEL: func @test_replace_with_erase_op
|
||||
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
|
||||
// CHECK-EN-NOT: "test.replace_with_new_op"
|
||||
@ -229,3 +234,18 @@ func.func @test_remove_diamond(%c: i1) {
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-AN: notifyOperationInserted: test.move_before_parent_op, previous = test.dummy_terminator
|
||||
// CHECK-AN-LABEL: func @test_move_op_before(
|
||||
// CHECK-AN: test.move_before_parent_op
|
||||
// CHECK-AN: test.op_with_region
|
||||
// CHECK-AN: test.dummy_terminator
|
||||
func.func @test_move_op_before() {
|
||||
"test.op_with_region"() ({
|
||||
"test.move_before_parent_op"() : () -> ()
|
||||
"test.dummy_terminator"() : () ->()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
@ -198,6 +198,21 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
|
||||
}
|
||||
};
|
||||
|
||||
/// This pattern moves "test.move_before_parent_op" before the parent op.
|
||||
struct MoveBeforeParentOp : public RewritePattern {
|
||||
MoveBeforeParentOp(MLIRContext *context)
|
||||
: RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Do not hoist past functions.
|
||||
if (isa<FunctionOpInterface>(op->getParentOp()))
|
||||
return failure();
|
||||
rewriter.moveOpBefore(op, op->getParentOp());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TestPatternDriver
|
||||
: public PassWrapper<TestPatternDriver, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
|
||||
@ -238,6 +253,20 @@ struct TestPatternDriver
|
||||
};
|
||||
|
||||
struct DumpNotifications : public RewriterBase::Listener {
|
||||
void notifyOperationInserted(Operation *op,
|
||||
OpBuilder::InsertPoint previous) override {
|
||||
llvm::outs() << "notifyOperationInserted: " << op->getName();
|
||||
if (!previous.isSet()) {
|
||||
llvm::outs() << ", was unlinked\n";
|
||||
} else {
|
||||
if (previous.getPoint() == previous.getBlock()->end()) {
|
||||
llvm::outs() << ", was last in block\n";
|
||||
} else {
|
||||
llvm::outs() << ", previous = " << previous.getPoint()->getName()
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
void notifyOperationRemoved(Operation *op) override {
|
||||
llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
|
||||
}
|
||||
@ -267,14 +296,16 @@ public:
|
||||
ReplaceWithNewOp,
|
||||
EraseOp,
|
||||
ChangeBlockOp,
|
||||
ImplicitChangeOp
|
||||
ImplicitChangeOp,
|
||||
MoveBeforeParentOp
|
||||
// clang-format on
|
||||
>(ctx);
|
||||
SmallVector<Operation *> ops;
|
||||
getOperation()->walk([&](Operation *op) {
|
||||
StringRef opName = op->getName().getStringRef();
|
||||
if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
|
||||
opName == "test.replace_with_new_op" || opName == "test.erase_op") {
|
||||
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
|
||||
opName == "test.move_before_parent_op") {
|
||||
ops.push_back(op);
|
||||
}
|
||||
});
|
||||
|
Loading…
x
Reference in New Issue
Block a user