[mlir][IR] Trigger notifyOperationReplaced on replaceAllOpUsesWith (#84721)

Before this change: `notifyOperationReplaced` was triggered when calling
`RewriteBase::replaceOp`.
After this change: `notifyOperationReplaced` is triggered when
`RewriterBase::replaceAllOpUsesWith` or `RewriterBase::replaceOp` is
called.

Until now, every `notifyOperationReplaced` was always sent together with
a `notifyOperationErased`, which made that `notifyOperationErased`
callback irrelevant. More importantly, when a user called
`RewriterBase::replaceAllOpUsesWith`+`RewriterBase::eraseOp` instead of
`RewriterBase::replaceOp`, no `notifyOperationReplaced` callback was
sent, even though the two notations are semantically equivalent. As an
example, this can be a problem when applying patterns with the transform
dialect because the `TrackingListener` will only see the
`notifyOperationErased` callback and the payload op is dropped from the
mappings.

Note: It is still possible to write semantically equivalent code that
does not trigger a `notifyOperationReplaced` (e.g., when op results are
replaced one-by-one), but this commit already improves the situation a
lot.
This commit is contained in:
Matthias Springer 2024-04-02 10:53:57 +09:00 committed by GitHub
parent d7a43a00fe
commit 38113a0832
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 21 deletions

View File

@ -409,9 +409,9 @@ public:
/// Notify the listener that the specified operation was modified in-place. /// Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationModified(Operation *op) {} virtual void notifyOperationModified(Operation *op) {}
/// Notify the listener that the specified operation is about to be replaced /// Notify the listener that all uses of the specified operation's results
/// with another operation. This is called before the uses of the old /// are about to be replaced with the results of another operation. This is
/// operation have been changed. /// called before the uses of the old operation have been changed.
/// ///
/// By default, this function calls the "operation replaced with values" /// By default, this function calls the "operation replaced with values"
/// notification. /// notification.
@ -420,9 +420,10 @@ public:
notifyOperationReplaced(op, replacement->getResults()); notifyOperationReplaced(op, replacement->getResults());
} }
/// Notify the listener that the specified operation is about to be replaced /// Notify the listener that all uses of the specified operation's results
/// with the a range of values, potentially produced by other operations. /// are about to be replaced with the a range of values, potentially
/// This is called before the uses of the operation have been changed. /// produced by other operations. This is called before the uses of the
/// operation have been changed.
virtual void notifyOperationReplaced(Operation *op, virtual void notifyOperationReplaced(Operation *op,
ValueRange replacement) {} ValueRange replacement) {}
@ -648,12 +649,16 @@ public:
for (auto it : llvm::zip(from, to)) for (auto it : llvm::zip(from, to))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
} }
// Note: This function cannot be called `replaceAllUsesWith` because the
// overload resolution, when called with an op that can be implicitly /// Find uses of `from` and replace them with `to`. Also notify the listener
// converted to a Value, would be ambiguous. /// about every in-place op modification (for every use that was replaced)
void replaceAllOpUsesWith(Operation *from, ValueRange to) { /// and that the `from` operation is about to be replaced.
replaceAllUsesWith(from->getResults(), to); ///
} /// Note: This function cannot be called `replaceAllUsesWith` because the
/// overload resolution, when called with an op that can be implicitly
/// converted to a Value, would be ambiguous.
void replaceAllOpUsesWith(Operation *from, ValueRange to);
void replaceAllOpUsesWith(Operation *from, Operation *to);
/// Find uses of `from` and replace them with `to` if the `functor` returns /// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. Also notify the listener about every in-place op modification (for /// true. Also notify the listener about every in-place op modification (for

View File

@ -110,6 +110,22 @@ RewriterBase::~RewriterBase() {
// Out of line to provide a vtable anchor for the class. // Out of line to provide a vtable anchor for the class.
} }
void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
// Notify the listener that we're about to replace this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(from, to);
replaceAllUsesWith(from->getResults(), to);
}
void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
// Notify the listener that we're about to replace this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(from, to);
replaceAllUsesWith(from->getResults(), to->getResults());
}
/// This method replaces the results of the operation with the specified list of /// This method replaces the results of the operation with the specified list of
/// values. The number of provided values must match the number of results of /// values. The number of provided values must match the number of results of
/// the operation. The replaced op is erased. /// the operation. The replaced op is erased.
@ -117,10 +133,6 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() && assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values"); "incorrect # of replacement values");
// Notify the listener that we're about to replace this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newValues);
// Replace all result uses. Also notifies the listener of modifications. // Replace all result uses. Also notifies the listener of modifications.
replaceAllOpUsesWith(op, newValues); replaceAllOpUsesWith(op, newValues);
@ -136,10 +148,6 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
assert(op->getNumResults() == newOp->getNumResults() && assert(op->getNumResults() == newOp->getNumResults() &&
"ops have different number of results"); "ops have different number of results");
// Notify the listener that we're about to replace this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newOp);
// Replace all result uses. Also notifies the listener of modifications. // Replace all result uses. Also notifies the listener of modifications.
replaceAllOpUsesWith(op, newOp->getResults()); replaceAllOpUsesWith(op, newOp->getResults());

View File

@ -489,7 +489,10 @@ private:
OperationName("test.new_op", op->getContext()).getIdentifier(), OperationName("test.new_op", op->getContext()).getIdentifier(),
op->getOperands(), op->getResultTypes()); op->getOperands(), op->getResultTypes());
} }
rewriter.replaceOp(op, newOp->getResults()); // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
// A "notifyOperationReplaced" callback is triggered in either case.
rewriter.replaceAllOpUsesWith(op, newOp->getResults());
rewriter.eraseOp(op);
return success(); return success();
} }
}; };