[mlir][IR] Add rewriter API for moving operations (#78988)

The pattern rewriter documentation states that "*all* IR mutations [...]
are required to be performed via the `PatternRewriter`." This commit
adds two functions that were missing from the rewriter API:
`moveOpBefore` and `moveOpAfter`.

After an operation was moved, the `notifyOperationInserted` callback is
triggered. This allows listeners such as the greedy pattern rewrite
driver to react to IR changes.

This commit narrows the discrepancy between the kind of IR modification
that can be performed and the kind of IR modifications that can be
listened to.
This commit is contained in:
Matthias Springer 2024-01-25 11:01:28 +01:00 committed by GitHub
parent 45fec0c110
commit 5cc0f76d34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 137 additions and 40 deletions

View File

@ -490,7 +490,11 @@ public:
LLVM_DUMP_METHOD void dumpFunc();
/// FirOpBuilder hook for creating new operation.
void notifyOperationInserted(mlir::Operation *op) override {
void notifyOperationInserted(mlir::Operation *op,
mlir::OpBuilder::InsertPoint previous) override {
// We only care about newly created operations.
if (previous.isSet())
return;
setCommonAttributes(op);
}

View File

@ -730,9 +730,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
HLFIRListener(fir::FirOpBuilder &builder,
mlir::ConversionPatternRewriter &rewriter)
: builder{builder}, rewriter{rewriter} {}
void notifyOperationInserted(mlir::Operation *op) override {
builder.notifyOperationInserted(op);
rewriter.notifyOperationInserted(op);
void notifyOperationInserted(mlir::Operation *op,
mlir::OpBuilder::InsertPoint previous) override {
builder.notifyOperationInserted(op, previous);
rewriter.notifyOperationInserted(op, previous);
}
virtual void notifyBlockCreated(mlir::Block *block) override {
builder.notifyBlockCreated(block);

View File

@ -205,6 +205,7 @@ protected:
/// automatically inserted at an insertion point. The builder is copyable.
class OpBuilder : public Builder {
public:
class InsertPoint;
struct Listener;
/// Create a builder with the given context.
@ -285,12 +286,17 @@ public:
virtual ~Listener() = default;
/// Notification handler for when an operation is inserted into the builder.
/// `op` is the operation that was inserted.
virtual void notifyOperationInserted(Operation *op) {}
/// Notify the listener that the specified operation was inserted.
///
/// * If the operation was moved, then `previous` is the previous location
/// of the op.
/// * If the operation was unlinked before it was inserted, then `previous`
/// is empty.
///
/// Note: Creating an (unlinked) op does not trigger this notification.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
/// Notification handler for when a block is created using the builder.
/// `block` is the block that was created.
/// Notify the listener that the specified block was inserted.
virtual void notifyBlockCreated(Block *block) {}
protected:
@ -517,7 +523,7 @@ public:
if (succeeded(tryFold(op, results)))
op->erase();
else if (listener)
listener->notifyOperationInserted(op);
listener->notifyOperationInserted(op, /*previous=*/{});
}
/// Overload to create or fold a single result operation.

View File

@ -428,6 +428,8 @@ public:
/// Notify the listener that the specified operation is about to be erased.
/// At this point, the operation has zero uses.
///
/// Note: This notification is not triggered when unlinking an operation.
virtual void notifyOperationRemoved(Operation *op) {}
/// Notify the listener that the pattern failed to match the given
@ -450,8 +452,8 @@ public:
struct ForwardingListener : public RewriterBase::Listener {
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
void notifyOperationInserted(Operation *op) override {
listener->notifyOperationInserted(op);
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
listener->notifyOperationInserted(op, previous);
}
void notifyBlockCreated(Block *block) override {
listener->notifyBlockCreated(block);
@ -591,6 +593,26 @@ public:
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before);
/// Unlink this operation from its current block and insert it right before
/// `existingOp` which may be in the same or another block in the same
/// function.
void moveOpBefore(Operation *op, Operation *existingOp);
/// Unlink this operation from its current block and insert it right before
/// `iterator` in the specified block.
virtual void moveOpBefore(Operation *op, Block *block,
Block::iterator iterator);
/// Unlink this operation from its current block and insert it right after
/// `existingOp` which may be in the same or another block in the same
/// function.
void moveOpAfter(Operation *op, Operation *existingOp);
/// Unlink this operation from its current block and insert it right after
/// `iterator` in the specified block.
virtual void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator);
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeOpModification` or

View File

@ -737,7 +737,7 @@ public:
using PatternRewriter::cloneRegionBefore;
/// PatternRewriter hook for inserting a new operation.
void notifyOperationInserted(Operation *op) override;
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
@ -761,9 +761,15 @@ public:
detail::ConversionPatternRewriterImpl &getImpl();
private:
// Hide unsupported pattern rewriter API.
using OpBuilder::getListener;
using OpBuilder::setListener;
void moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) override;
void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) override;
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};

View File

@ -1206,7 +1206,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
if (failed(applyOp->fold(constOperands, foldResults)) ||
foldResults.empty()) {
if (OpBuilder::Listener *listener = b.getListener())
listener->notifyOperationInserted(applyOp);
listener->notifyOperationInserted(applyOp, /*previous=*/{});
return applyOp.getResult();
}
@ -1274,7 +1274,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
if (failed(minMaxOp->fold(constOperands, foldResults)) ||
foldResults.empty()) {
if (OpBuilder::Listener *listener = b.getListener())
listener->notifyOperationInserted(minMaxOp);
listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
return minMaxOp.getResult();
}

View File

@ -273,7 +273,7 @@ static ParallelComputeFunction createParallelComputeFunction(
// Insert function into the module symbol table and assign it unique name.
SymbolTable symbolTable(module);
symbolTable.insert(func);
rewriter.getListener()->notifyOperationInserted(func);
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
// Create function entry block.
Block *block =
@ -489,7 +489,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
// Insert function into the module symbol table and assign it unique name.
SymbolTable symbolTable(module);
symbolTable.insert(func);
rewriter.getListener()->notifyOperationInserted(func);
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
// Create function entry block.
Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),

View File

@ -371,7 +371,11 @@ protected:
toMemrefOps.erase(op);
}
void notifyOperationInserted(Operation *op) override {
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
// We only care about newly created ops.
if (previous.isSet())
return;
erasedOps.erase(op);
// Gather statistics about allocs.

View File

@ -214,8 +214,12 @@ public:
}
private:
void notifyOperationInserted(Operation *op) override {
ForwardingListener::notifyOperationInserted(op);
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
ForwardingListener::notifyOperationInserted(op, previous);
// We only care about newly created ops.
if (previous.isSet())
return;
auto inserted = newOps.insert(op);
(void)inserted;
assert(inserted.second && "expected newly created op");

View File

@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
// Inline for-loop body operations into 'after' region.
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
arg.moveBefore(afterBlock, afterBlock->end());
rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
// Add incremented IV to yield operations
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {

View File

@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
for (Operation *user : srcBuffer->getUsers()) {
if (hasEffect<MemoryEffects::Free>(user)) {
if (user->getBlock() == parallelCombiningParent->getBlock())
user->moveBefore(user->getBlock()->getTerminator());
rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
break;
}
}

View File

@ -412,7 +412,7 @@ Operation *OpBuilder::insert(Operation *op) {
block->getOperations().insert(insertPoint, op);
if (listener)
listener->notifyOperationInserted(op);
listener->notifyOperationInserted(op, /*previous=*/{});
return op;
}
@ -530,7 +530,7 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp);
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
};
for (Region &region : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);

View File

@ -366,3 +366,31 @@ void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
}
void RewriterBase::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
Block::iterator currentIterator = op->getIterator();
op->moveBefore(block, iterator);
if (listener)
listener->notifyOperationInserted(
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
}
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
}
void RewriterBase::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
Block::iterator currentIterator = op->getIterator();
op->moveAfter(block, iterator);
if (listener)
listener->notifyOperationInserted(
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
}

View File

@ -1602,11 +1602,13 @@ void ConversionPatternRewriter::cloneRegionBefore(Region &region,
Block *cloned = mapping.lookup(&b);
impl->notifyCreatedBlock(cloned);
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
[&](Operation *op) { notifyOperationInserted(op); });
[&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
}
}
void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
InsertPoint previous) {
assert(!previous.isSet() && "expected newly created op");
LLVM_DEBUG({
impl->logger.startLine()
<< "** Insert : '" << op->getName() << "'(" << op << ")\n";
@ -1651,6 +1653,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
return impl->notifyMatchFailure(loc, reasonCallback);
}
void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
"moving single ops is not supported in a dialect conversion");
}
void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
"moving single ops is not supported in a dialect conversion");
}
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}

View File

@ -133,9 +133,16 @@ protected:
}
}
void notifyOperationInserted(Operation *op) override {
RewriterBase::ForwardingListener::notifyOperationInserted(op);
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
// Invalidate the finger print of the op that owns the block into which the
// op was inserted into.
invalidateFingerPrint(op->getParentOp());
// Also invalidate the finger print of the op that owns the block from which
// the op was moved from. (Only applicable if the op was moved.)
if (previous.isSet())
invalidateFingerPrint(previous.getBlock()->getParentOp());
}
void notifyOperationModified(Operation *op) override {
@ -331,7 +338,7 @@ protected:
/// Notify the driver that the specified operation was inserted. Update the
/// worklist as needed: The operation is enqueued depending on scope and
/// strict mode.
void notifyOperationInserted(Operation *op) override;
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
/// Notify the driver that the specified operation was removed. Update the
/// worklist as needed: The operation and its children are removed from the
@ -641,13 +648,14 @@ void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
config.listener->notifyBlockRemoved(block);
}
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
InsertPoint previous) {
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
if (config.listener)
config.listener->notifyOperationInserted(op);
config.listener->notifyOperationInserted(op, previous);
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
strictModeFilteredOps.insert(op);
addToWorklist(op);

View File

@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
OpResult newLoopResult = loopLike.getLoopResults()->back();
extractionOp->moveBefore(loopLike);
insertionOp->moveAfter(loopLike);
rewriter.moveOpBefore(extractionOp, loopLike);
rewriter.moveOpAfter(insertionOp, loopLike);
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
insertionOp.getDestinationOperand().get());
extractionOp.getSourceOperand().set(

View File

@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
auto ifOp =
rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
// True branch.
op->moveBefore(&ifOp.getThenRegion().front(),
ifOp.getThenRegion().front().begin());
rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
ifOp.getThenRegion().front().begin());
rewriter.setInsertionPointAfter(op);
if (op->getNumResults() > 0)
rewriter.create<scf::YieldOp>(loc, op->getResults());

View File

@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
return failure();
if (!toBeHoisted->hasAttr("eligible"))
return failure();
// Hoisting means removing an op from the enclosing op. I.e., the enclosing
// op is modified.
rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
rewriter.moveOpBefore(toBeHoisted, op);
return success();
}
};

View File

@ -15,7 +15,8 @@ using namespace mlir;
namespace {
struct DumpNotifications : public OpBuilder::Listener {
void notifyOperationInserted(Operation *op) override {
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n";
}
};

View File

@ -27,7 +27,8 @@ struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
void foldOperation(Operation *op, OperationFolder &helper);
void runOnOperation() override;
void notifyOperationInserted(Operation *op) override {
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
existingConstants.push_back(op);
}
void notifyOperationRemoved(Operation *op) override {