[MLIR] Harmonize the behavior of the folding API functions (#88508)

This commit changes `OpBuilder::tryFold` to behave more similarly to
`Operation::fold`. Concretely, this ensures that even an in-place fold
returns `success`.
This is necessary to fix a bug in the dialect conversion that occurred
when an in-place folding made an operation legal. The dialect conversion
infrastructure did not check if the result of an in-place folding
legalized the operation and just went ahead and tried to apply pattern
anyways.

The added test contains a simplified version of a breakage we observed
downstream.
This commit is contained in:
Christian Ulmann 2024-04-23 08:05:55 +02:00 committed by GitHub
parent b28f4d4dd0
commit 4513050f52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 63 additions and 15 deletions

View File

@ -517,7 +517,7 @@ public:
/// Create an operation of specific op type at the current insertion point,
/// and immediately try to fold it. This functions populates 'results' with
/// the results after folding the operation.
/// the results of the operation.
template <typename OpTy, typename... Args>
void createOrFold(SmallVectorImpl<Value> &results, Location location,
Args &&...args) {
@ -530,10 +530,17 @@ public:
if (block)
block->getOperations().insert(insertPoint, op);
// Fold the operation. If successful erase it, otherwise notify.
if (succeeded(tryFold(op, results)))
// Attempt to fold the operation.
if (succeeded(tryFold(op, results)) && !results.empty()) {
// Erase the operation, if the fold removed the need for this operation.
// Note: The fold already populated the results in this case.
op->erase();
else if (block && listener)
return;
}
ResultRange opResults = op->getResults();
results.assign(opResults.begin(), opResults.end());
if (block && listener)
listener->notifyOperationInserted(op, /*previous=*/{});
}
@ -560,7 +567,8 @@ public:
}
/// Attempts to fold the given operation and places new results within
/// 'results'. Returns success if the operation was folded, failure otherwise.
/// `results`. Returns success if the operation was folded, failure otherwise.
/// If the fold was in-place, `results` will not be filled.
/// Note: This function does not erase the operation on a successful fold.
LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);

View File

@ -2831,7 +2831,8 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
/// Folds a cast op that can be chained.
template <typename T>
static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
static OpFoldResult foldChainableCast(T castOp,
typename T::FoldAdaptor adaptor) {
// cast(x : T0, T0) -> x
if (castOp.getArg().getType() == castOp.getType())
return castOp.getArg();

View File

@ -476,16 +476,14 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
return create(state);
}
/// Attempts to fold the given operation and places new results within
/// 'results'. Returns success if the operation was folded, failure otherwise.
/// Note: This function does not erase the operation on a successful fold.
LogicalResult OpBuilder::tryFold(Operation *op,
SmallVectorImpl<Value> &results) {
assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();
results.reserve(opResults.size());
auto cleanupFailure = [&] {
results.assign(opResults.begin(), opResults.end());
results.clear();
return failure();
};
@ -495,20 +493,24 @@ LogicalResult OpBuilder::tryFold(Operation *op,
// Try to fold the operation.
SmallVector<OpFoldResult, 4> foldResults;
if (failed(op->fold(foldResults)) || foldResults.empty())
if (failed(op->fold(foldResults)))
return cleanupFailure();
// An in-place fold does not require generation of any constants.
if (foldResults.empty())
return success();
// A temporary builder used for creating constants during folding.
OpBuilder cstBuilder(context);
SmallVector<Operation *, 1> generatedConstants;
// Populate the results with the folded results.
Dialect *dialect = op->getDialect();
for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
Type expectedType = std::get<1>(it);
for (auto [foldResult, expectedType] :
llvm::zip_equal(foldResults, opResults.getTypes())) {
// Normal values get pushed back directly.
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
if (auto value = llvm::dyn_cast_if_present<Value>(foldResult)) {
results.push_back(value);
continue;
}
@ -518,7 +520,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
return cleanupFailure();
// Ask the dialect to materialize a constant operation for this value.
Attribute attr = std::get<0>(it).get<Attribute>();
Attribute attr = foldResult.get<Attribute>();
auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
op->getLoc());
if (!constOp) {

View File

@ -2072,6 +2072,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);
// Insert a replacement for 'op' with the folded replacement values.
rewriter.replaceOp(op, replacementValues);

View File

@ -427,3 +427,13 @@ func.func @use_of_replaced_bbarg(%arg0: i64) {
}) : (i64) -> (i64)
"test.invalid"(%0) : (i64) -> ()
}
// -----
// CHECK-LABEL: @fold_legalization
func.func @fold_legalization() -> i32 {
// CHECK: op_in_place_self_fold
// CHECK-SAME: folded
%1 = "test.op_in_place_self_fold"() : () -> (i32)
"test.return"(%1) : (i32) -> ()
}

View File

@ -825,6 +825,19 @@ LogicalResult CompareOp::verify() {
return success();
}
//===----------------------------------------------------------------------===//
// TestOpInPlaceSelfFold
//===----------------------------------------------------------------------===//
OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
if (!getFolded()) {
// The folder adds the "folded" if not present.
setFolded(true);
return getResult();
}
return {};
}
//===----------------------------------------------------------------------===//
// TestOpFoldWithFoldAdaptor
//===----------------------------------------------------------------------===//

View File

@ -1351,6 +1351,12 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
let hasFolder = 1;
}
def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
let arguments = (ins UnitAttr:$folded);
let results = (outs I32);
let hasFolder = 1;
}
// Test op that simply returns success.
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
let results = (outs Variadic<I1>);

View File

@ -1168,6 +1168,10 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
[](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
// Create a dynamically legal rule that can only be legalized by folding it.
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;