mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 23:26:04 +00:00
[mlir][IR] Add rewriter API for moving operations (#78988)
The pattern rewriter documentation states that "*all* IR mutations [...] are required to be performed via the `PatternRewriter`." This commit adds two functions that were missing from the rewriter API: `moveOpBefore` and `moveOpAfter`. After an operation was moved, the `notifyOperationInserted` callback is triggered. This allows listeners such as the greedy pattern rewrite driver to react to IR changes. This commit narrows the discrepancy between the kind of IR modification that can be performed and the kind of IR modifications that can be listened to.
This commit is contained in:
parent
45fec0c110
commit
5cc0f76d34
@ -490,7 +490,11 @@ public:
|
||||
LLVM_DUMP_METHOD void dumpFunc();
|
||||
|
||||
/// FirOpBuilder hook for creating new operation.
|
||||
void notifyOperationInserted(mlir::Operation *op) override {
|
||||
void notifyOperationInserted(mlir::Operation *op,
|
||||
mlir::OpBuilder::InsertPoint previous) override {
|
||||
// We only care about newly created operations.
|
||||
if (previous.isSet())
|
||||
return;
|
||||
setCommonAttributes(op);
|
||||
}
|
||||
|
||||
|
@ -730,9 +730,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
|
||||
HLFIRListener(fir::FirOpBuilder &builder,
|
||||
mlir::ConversionPatternRewriter &rewriter)
|
||||
: builder{builder}, rewriter{rewriter} {}
|
||||
void notifyOperationInserted(mlir::Operation *op) override {
|
||||
builder.notifyOperationInserted(op);
|
||||
rewriter.notifyOperationInserted(op);
|
||||
void notifyOperationInserted(mlir::Operation *op,
|
||||
mlir::OpBuilder::InsertPoint previous) override {
|
||||
builder.notifyOperationInserted(op, previous);
|
||||
rewriter.notifyOperationInserted(op, previous);
|
||||
}
|
||||
virtual void notifyBlockCreated(mlir::Block *block) override {
|
||||
builder.notifyBlockCreated(block);
|
||||
|
@ -205,6 +205,7 @@ protected:
|
||||
/// automatically inserted at an insertion point. The builder is copyable.
|
||||
class OpBuilder : public Builder {
|
||||
public:
|
||||
class InsertPoint;
|
||||
struct Listener;
|
||||
|
||||
/// Create a builder with the given context.
|
||||
@ -285,12 +286,17 @@ public:
|
||||
|
||||
virtual ~Listener() = default;
|
||||
|
||||
/// Notification handler for when an operation is inserted into the builder.
|
||||
/// `op` is the operation that was inserted.
|
||||
virtual void notifyOperationInserted(Operation *op) {}
|
||||
/// Notify the listener that the specified operation was inserted.
|
||||
///
|
||||
/// * If the operation was moved, then `previous` is the previous location
|
||||
/// of the op.
|
||||
/// * If the operation was unlinked before it was inserted, then `previous`
|
||||
/// is empty.
|
||||
///
|
||||
/// Note: Creating an (unlinked) op does not trigger this notification.
|
||||
virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
|
||||
|
||||
/// Notification handler for when a block is created using the builder.
|
||||
/// `block` is the block that was created.
|
||||
/// Notify the listener that the specified block was inserted.
|
||||
virtual void notifyBlockCreated(Block *block) {}
|
||||
|
||||
protected:
|
||||
@ -517,7 +523,7 @@ public:
|
||||
if (succeeded(tryFold(op, results)))
|
||||
op->erase();
|
||||
else if (listener)
|
||||
listener->notifyOperationInserted(op);
|
||||
listener->notifyOperationInserted(op, /*previous=*/{});
|
||||
}
|
||||
|
||||
/// Overload to create or fold a single result operation.
|
||||
|
@ -428,6 +428,8 @@ public:
|
||||
|
||||
/// Notify the listener that the specified operation is about to be erased.
|
||||
/// At this point, the operation has zero uses.
|
||||
///
|
||||
/// Note: This notification is not triggered when unlinking an operation.
|
||||
virtual void notifyOperationRemoved(Operation *op) {}
|
||||
|
||||
/// Notify the listener that the pattern failed to match the given
|
||||
@ -450,8 +452,8 @@ public:
|
||||
struct ForwardingListener : public RewriterBase::Listener {
|
||||
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
|
||||
|
||||
void notifyOperationInserted(Operation *op) override {
|
||||
listener->notifyOperationInserted(op);
|
||||
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
|
||||
listener->notifyOperationInserted(op, previous);
|
||||
}
|
||||
void notifyBlockCreated(Block *block) override {
|
||||
listener->notifyBlockCreated(block);
|
||||
@ -591,6 +593,26 @@ public:
|
||||
/// block into a new block, and return it.
|
||||
virtual Block *splitBlock(Block *block, Block::iterator before);
|
||||
|
||||
/// Unlink this operation from its current block and insert it right before
|
||||
/// `existingOp` which may be in the same or another block in the same
|
||||
/// function.
|
||||
void moveOpBefore(Operation *op, Operation *existingOp);
|
||||
|
||||
/// Unlink this operation from its current block and insert it right before
|
||||
/// `iterator` in the specified block.
|
||||
virtual void moveOpBefore(Operation *op, Block *block,
|
||||
Block::iterator iterator);
|
||||
|
||||
/// Unlink this operation from its current block and insert it right after
|
||||
/// `existingOp` which may be in the same or another block in the same
|
||||
/// function.
|
||||
void moveOpAfter(Operation *op, Operation *existingOp);
|
||||
|
||||
/// Unlink this operation from its current block and insert it right after
|
||||
/// `iterator` in the specified block.
|
||||
virtual void moveOpAfter(Operation *op, Block *block,
|
||||
Block::iterator iterator);
|
||||
|
||||
/// This method is used to notify the rewriter that an in-place operation
|
||||
/// modification is about to happen. A call to this function *must* be
|
||||
/// followed by a call to either `finalizeOpModification` or
|
||||
|
@ -737,7 +737,7 @@ public:
|
||||
using PatternRewriter::cloneRegionBefore;
|
||||
|
||||
/// PatternRewriter hook for inserting a new operation.
|
||||
void notifyOperationInserted(Operation *op) override;
|
||||
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
|
||||
|
||||
/// PatternRewriter hook for updating the given operation in-place.
|
||||
/// Note: These methods only track updates to the given operation itself,
|
||||
@ -761,9 +761,15 @@ public:
|
||||
detail::ConversionPatternRewriterImpl &getImpl();
|
||||
|
||||
private:
|
||||
// Hide unsupported pattern rewriter API.
|
||||
using OpBuilder::getListener;
|
||||
using OpBuilder::setListener;
|
||||
|
||||
void moveOpBefore(Operation *op, Block *block,
|
||||
Block::iterator iterator) override;
|
||||
void moveOpAfter(Operation *op, Block *block,
|
||||
Block::iterator iterator) override;
|
||||
|
||||
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
|
||||
};
|
||||
|
||||
|
@ -1206,7 +1206,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
|
||||
if (failed(applyOp->fold(constOperands, foldResults)) ||
|
||||
foldResults.empty()) {
|
||||
if (OpBuilder::Listener *listener = b.getListener())
|
||||
listener->notifyOperationInserted(applyOp);
|
||||
listener->notifyOperationInserted(applyOp, /*previous=*/{});
|
||||
return applyOp.getResult();
|
||||
}
|
||||
|
||||
@ -1274,7 +1274,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
|
||||
if (failed(minMaxOp->fold(constOperands, foldResults)) ||
|
||||
foldResults.empty()) {
|
||||
if (OpBuilder::Listener *listener = b.getListener())
|
||||
listener->notifyOperationInserted(minMaxOp);
|
||||
listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
|
||||
return minMaxOp.getResult();
|
||||
}
|
||||
|
||||
|
@ -273,7 +273,7 @@ static ParallelComputeFunction createParallelComputeFunction(
|
||||
// Insert function into the module symbol table and assign it unique name.
|
||||
SymbolTable symbolTable(module);
|
||||
symbolTable.insert(func);
|
||||
rewriter.getListener()->notifyOperationInserted(func);
|
||||
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
|
||||
|
||||
// Create function entry block.
|
||||
Block *block =
|
||||
@ -489,7 +489,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
|
||||
// Insert function into the module symbol table and assign it unique name.
|
||||
SymbolTable symbolTable(module);
|
||||
symbolTable.insert(func);
|
||||
rewriter.getListener()->notifyOperationInserted(func);
|
||||
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
|
||||
|
||||
// Create function entry block.
|
||||
Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
|
||||
|
@ -371,7 +371,11 @@ protected:
|
||||
toMemrefOps.erase(op);
|
||||
}
|
||||
|
||||
void notifyOperationInserted(Operation *op) override {
|
||||
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
|
||||
// We only care about newly created ops.
|
||||
if (previous.isSet())
|
||||
return;
|
||||
|
||||
erasedOps.erase(op);
|
||||
|
||||
// Gather statistics about allocs.
|
||||
|
@ -214,8 +214,12 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
void notifyOperationInserted(Operation *op) override {
|
||||
ForwardingListener::notifyOperationInserted(op);
|
||||
void notifyOperationInserted(Operation *op,
|
||||
OpBuilder::InsertPoint previous) override {
|
||||
ForwardingListener::notifyOperationInserted(op, previous);
|
||||
// We only care about newly created ops.
|
||||
if (previous.isSet())
|
||||
return;
|
||||
auto inserted = newOps.insert(op);
|
||||
(void)inserted;
|
||||
assert(inserted.second && "expected newly created op");
|
||||
|
@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
|
||||
|
||||
// Inline for-loop body operations into 'after' region.
|
||||
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
|
||||
arg.moveBefore(afterBlock, afterBlock->end());
|
||||
rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
|
||||
|
||||
// Add incremented IV to yield operations
|
||||
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
|
||||
|
@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
|
||||
for (Operation *user : srcBuffer->getUsers()) {
|
||||
if (hasEffect<MemoryEffects::Free>(user)) {
|
||||
if (user->getBlock() == parallelCombiningParent->getBlock())
|
||||
user->moveBefore(user->getBlock()->getTerminator());
|
||||
rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -412,7 +412,7 @@ Operation *OpBuilder::insert(Operation *op) {
|
||||
block->getOperations().insert(insertPoint, op);
|
||||
|
||||
if (listener)
|
||||
listener->notifyOperationInserted(op);
|
||||
listener->notifyOperationInserted(op, /*previous=*/{});
|
||||
return op;
|
||||
}
|
||||
|
||||
@ -530,7 +530,7 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
|
||||
// about any ops that got inserted inside those regions as part of cloning.
|
||||
if (listener) {
|
||||
auto walkFn = [&](Operation *walkedOp) {
|
||||
listener->notifyOperationInserted(walkedOp);
|
||||
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
|
||||
};
|
||||
for (Region ®ion : newOp->getRegions())
|
||||
region.walk<WalkOrder::PreOrder>(walkFn);
|
||||
|
@ -366,3 +366,31 @@ void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent,
|
||||
void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
|
||||
cloneRegionBefore(region, *before->getParent(), before->getIterator());
|
||||
}
|
||||
|
||||
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
|
||||
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
|
||||
}
|
||||
|
||||
void RewriterBase::moveOpBefore(Operation *op, Block *block,
|
||||
Block::iterator iterator) {
|
||||
Block *currentBlock = op->getBlock();
|
||||
Block::iterator currentIterator = op->getIterator();
|
||||
op->moveBefore(block, iterator);
|
||||
if (listener)
|
||||
listener->notifyOperationInserted(
|
||||
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
|
||||
}
|
||||
|
||||
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
|
||||
moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
|
||||
}
|
||||
|
||||
void RewriterBase::moveOpAfter(Operation *op, Block *block,
|
||||
Block::iterator iterator) {
|
||||
Block *currentBlock = op->getBlock();
|
||||
Block::iterator currentIterator = op->getIterator();
|
||||
op->moveAfter(block, iterator);
|
||||
if (listener)
|
||||
listener->notifyOperationInserted(
|
||||
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
|
||||
}
|
||||
|
@ -1602,11 +1602,13 @@ void ConversionPatternRewriter::cloneRegionBefore(Region ®ion,
|
||||
Block *cloned = mapping.lookup(&b);
|
||||
impl->notifyCreatedBlock(cloned);
|
||||
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
|
||||
[&](Operation *op) { notifyOperationInserted(op); });
|
||||
[&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
|
||||
}
|
||||
}
|
||||
|
||||
void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
|
||||
void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
|
||||
InsertPoint previous) {
|
||||
assert(!previous.isSet() && "expected newly created op");
|
||||
LLVM_DEBUG({
|
||||
impl->logger.startLine()
|
||||
<< "** Insert : '" << op->getName() << "'(" << op << ")\n";
|
||||
@ -1651,6 +1653,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
|
||||
return impl->notifyMatchFailure(loc, reasonCallback);
|
||||
}
|
||||
|
||||
void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
|
||||
Block::iterator iterator) {
|
||||
llvm_unreachable(
|
||||
"moving single ops is not supported in a dialect conversion");
|
||||
}
|
||||
|
||||
void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
|
||||
Block::iterator iterator) {
|
||||
llvm_unreachable(
|
||||
"moving single ops is not supported in a dialect conversion");
|
||||
}
|
||||
|
||||
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
|
||||
return *impl;
|
||||
}
|
||||
|
@ -133,9 +133,16 @@ protected:
|
||||
}
|
||||
}
|
||||
|
||||
void notifyOperationInserted(Operation *op) override {
|
||||
RewriterBase::ForwardingListener::notifyOperationInserted(op);
|
||||
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
|
||||
RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
|
||||
// Invalidate the finger print of the op that owns the block into which the
|
||||
// op was inserted into.
|
||||
invalidateFingerPrint(op->getParentOp());
|
||||
|
||||
// Also invalidate the finger print of the op that owns the block from which
|
||||
// the op was moved from. (Only applicable if the op was moved.)
|
||||
if (previous.isSet())
|
||||
invalidateFingerPrint(previous.getBlock()->getParentOp());
|
||||
}
|
||||
|
||||
void notifyOperationModified(Operation *op) override {
|
||||
@ -331,7 +338,7 @@ protected:
|
||||
/// Notify the driver that the specified operation was inserted. Update the
|
||||
/// worklist as needed: The operation is enqueued depending on scope and
|
||||
/// strict mode.
|
||||
void notifyOperationInserted(Operation *op) override;
|
||||
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
|
||||
|
||||
/// Notify the driver that the specified operation was removed. Update the
|
||||
/// worklist as needed: The operation and its children are removed from the
|
||||
@ -641,13 +648,14 @@ void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
|
||||
config.listener->notifyBlockRemoved(block);
|
||||
}
|
||||
|
||||
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
|
||||
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
|
||||
InsertPoint previous) {
|
||||
LLVM_DEBUG({
|
||||
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
|
||||
<< ")\n";
|
||||
});
|
||||
if (config.listener)
|
||||
config.listener->notifyOperationInserted(op);
|
||||
config.listener->notifyOperationInserted(op, previous);
|
||||
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
|
||||
strictModeFilteredOps.insert(op);
|
||||
addToWorklist(op);
|
||||
|
@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
|
||||
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
|
||||
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
|
||||
OpResult newLoopResult = loopLike.getLoopResults()->back();
|
||||
extractionOp->moveBefore(loopLike);
|
||||
insertionOp->moveAfter(loopLike);
|
||||
rewriter.moveOpBefore(extractionOp, loopLike);
|
||||
rewriter.moveOpAfter(insertionOp, loopLike);
|
||||
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
|
||||
insertionOp.getDestinationOperand().get());
|
||||
extractionOp.getSourceOperand().set(
|
||||
|
@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
|
||||
auto ifOp =
|
||||
rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
|
||||
// True branch.
|
||||
op->moveBefore(&ifOp.getThenRegion().front(),
|
||||
ifOp.getThenRegion().front().begin());
|
||||
rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
|
||||
ifOp.getThenRegion().front().begin());
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
if (op->getNumResults() > 0)
|
||||
rewriter.create<scf::YieldOp>(loc, op->getResults());
|
||||
|
@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
|
||||
return failure();
|
||||
if (!toBeHoisted->hasAttr("eligible"))
|
||||
return failure();
|
||||
// Hoisting means removing an op from the enclosing op. I.e., the enclosing
|
||||
// op is modified.
|
||||
rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
|
||||
rewriter.moveOpBefore(toBeHoisted, op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -15,7 +15,8 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct DumpNotifications : public OpBuilder::Listener {
|
||||
void notifyOperationInserted(Operation *op) override {
|
||||
void notifyOperationInserted(Operation *op,
|
||||
OpBuilder::InsertPoint previous) override {
|
||||
llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n";
|
||||
}
|
||||
};
|
||||
|
@ -27,7 +27,8 @@ struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
|
||||
void foldOperation(Operation *op, OperationFolder &helper);
|
||||
void runOnOperation() override;
|
||||
|
||||
void notifyOperationInserted(Operation *op) override {
|
||||
void notifyOperationInserted(Operation *op,
|
||||
OpBuilder::InsertPoint previous) override {
|
||||
existingConstants.push_back(op);
|
||||
}
|
||||
void notifyOperationRemoved(Operation *op) override {
|
||||
|
Loading…
x
Reference in New Issue
Block a user