mirror of
https://github.com/llvm/llvm-project.git
synced 2025-05-03 00:06:06 +00:00
[mlir][IR] Trigger notifyOperationReplaced
on replaceAllOpUsesWith
(#84721)
Before this change: `notifyOperationReplaced` was triggered when calling `RewriteBase::replaceOp`. After this change: `notifyOperationReplaced` is triggered when `RewriterBase::replaceAllOpUsesWith` or `RewriterBase::replaceOp` is called. Until now, every `notifyOperationReplaced` was always sent together with a `notifyOperationErased`, which made that `notifyOperationErased` callback irrelevant. More importantly, when a user called `RewriterBase::replaceAllOpUsesWith`+`RewriterBase::eraseOp` instead of `RewriterBase::replaceOp`, no `notifyOperationReplaced` callback was sent, even though the two notations are semantically equivalent. As an example, this can be a problem when applying patterns with the transform dialect because the `TrackingListener` will only see the `notifyOperationErased` callback and the payload op is dropped from the mappings. Note: It is still possible to write semantically equivalent code that does not trigger a `notifyOperationReplaced` (e.g., when op results are replaced one-by-one), but this commit already improves the situation a lot.
This commit is contained in:
parent
d7a43a00fe
commit
38113a0832
@ -409,9 +409,9 @@ public:
|
||||
/// Notify the listener that the specified operation was modified in-place.
|
||||
virtual void notifyOperationModified(Operation *op) {}
|
||||
|
||||
/// Notify the listener that the specified operation is about to be replaced
|
||||
/// with another operation. This is called before the uses of the old
|
||||
/// operation have been changed.
|
||||
/// Notify the listener that all uses of the specified operation's results
|
||||
/// are about to be replaced with the results of another operation. This is
|
||||
/// called before the uses of the old operation have been changed.
|
||||
///
|
||||
/// By default, this function calls the "operation replaced with values"
|
||||
/// notification.
|
||||
@ -420,9 +420,10 @@ public:
|
||||
notifyOperationReplaced(op, replacement->getResults());
|
||||
}
|
||||
|
||||
/// Notify the listener that the specified operation is about to be replaced
|
||||
/// with the a range of values, potentially produced by other operations.
|
||||
/// This is called before the uses of the operation have been changed.
|
||||
/// Notify the listener that all uses of the specified operation's results
|
||||
/// are about to be replaced with the a range of values, potentially
|
||||
/// produced by other operations. This is called before the uses of the
|
||||
/// operation have been changed.
|
||||
virtual void notifyOperationReplaced(Operation *op,
|
||||
ValueRange replacement) {}
|
||||
|
||||
@ -648,12 +649,16 @@ public:
|
||||
for (auto it : llvm::zip(from, to))
|
||||
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
|
||||
}
|
||||
// Note: This function cannot be called `replaceAllUsesWith` because the
|
||||
// overload resolution, when called with an op that can be implicitly
|
||||
// converted to a Value, would be ambiguous.
|
||||
void replaceAllOpUsesWith(Operation *from, ValueRange to) {
|
||||
replaceAllUsesWith(from->getResults(), to);
|
||||
}
|
||||
|
||||
/// Find uses of `from` and replace them with `to`. Also notify the listener
|
||||
/// about every in-place op modification (for every use that was replaced)
|
||||
/// and that the `from` operation is about to be replaced.
|
||||
///
|
||||
/// Note: This function cannot be called `replaceAllUsesWith` because the
|
||||
/// overload resolution, when called with an op that can be implicitly
|
||||
/// converted to a Value, would be ambiguous.
|
||||
void replaceAllOpUsesWith(Operation *from, ValueRange to);
|
||||
void replaceAllOpUsesWith(Operation *from, Operation *to);
|
||||
|
||||
/// Find uses of `from` and replace them with `to` if the `functor` returns
|
||||
/// true. Also notify the listener about every in-place op modification (for
|
||||
|
@ -110,6 +110,22 @@ RewriterBase::~RewriterBase() {
|
||||
// Out of line to provide a vtable anchor for the class.
|
||||
}
|
||||
|
||||
void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
|
||||
// Notify the listener that we're about to replace this op.
|
||||
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
|
||||
rewriteListener->notifyOperationReplaced(from, to);
|
||||
|
||||
replaceAllUsesWith(from->getResults(), to);
|
||||
}
|
||||
|
||||
void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
|
||||
// Notify the listener that we're about to replace this op.
|
||||
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
|
||||
rewriteListener->notifyOperationReplaced(from, to);
|
||||
|
||||
replaceAllUsesWith(from->getResults(), to->getResults());
|
||||
}
|
||||
|
||||
/// This method replaces the results of the operation with the specified list of
|
||||
/// values. The number of provided values must match the number of results of
|
||||
/// the operation. The replaced op is erased.
|
||||
@ -117,10 +133,6 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
|
||||
assert(op->getNumResults() == newValues.size() &&
|
||||
"incorrect # of replacement values");
|
||||
|
||||
// Notify the listener that we're about to replace this op.
|
||||
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
|
||||
rewriteListener->notifyOperationReplaced(op, newValues);
|
||||
|
||||
// Replace all result uses. Also notifies the listener of modifications.
|
||||
replaceAllOpUsesWith(op, newValues);
|
||||
|
||||
@ -136,10 +148,6 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
|
||||
assert(op->getNumResults() == newOp->getNumResults() &&
|
||||
"ops have different number of results");
|
||||
|
||||
// Notify the listener that we're about to replace this op.
|
||||
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
|
||||
rewriteListener->notifyOperationReplaced(op, newOp);
|
||||
|
||||
// Replace all result uses. Also notifies the listener of modifications.
|
||||
replaceAllOpUsesWith(op, newOp->getResults());
|
||||
|
||||
|
@ -489,7 +489,10 @@ private:
|
||||
OperationName("test.new_op", op->getContext()).getIdentifier(),
|
||||
op->getOperands(), op->getResultTypes());
|
||||
}
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
// "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
|
||||
// A "notifyOperationReplaced" callback is triggered in either case.
|
||||
rewriter.replaceAllOpUsesWith(op, newOp->getResults());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user