mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-24 04:16:08 +00:00
[MLIR] Harmonize the behavior of the folding API functions (#88508)
This commit changes `OpBuilder::tryFold` to behave more similarly to `Operation::fold`. Concretely, this ensures that even an in-place fold returns `success`. This is necessary to fix a bug in the dialect conversion that occurred when an in-place folding made an operation legal. The dialect conversion infrastructure did not check if the result of an in-place folding legalized the operation and just went ahead and tried to apply pattern anyways. The added test contains a simplified version of a breakage we observed downstream.
This commit is contained in:
parent
b28f4d4dd0
commit
4513050f52
@ -517,7 +517,7 @@ public:
|
||||
|
||||
/// Create an operation of specific op type at the current insertion point,
|
||||
/// and immediately try to fold it. This functions populates 'results' with
|
||||
/// the results after folding the operation.
|
||||
/// the results of the operation.
|
||||
template <typename OpTy, typename... Args>
|
||||
void createOrFold(SmallVectorImpl<Value> &results, Location location,
|
||||
Args &&...args) {
|
||||
@ -530,10 +530,17 @@ public:
|
||||
if (block)
|
||||
block->getOperations().insert(insertPoint, op);
|
||||
|
||||
// Fold the operation. If successful erase it, otherwise notify.
|
||||
if (succeeded(tryFold(op, results)))
|
||||
// Attempt to fold the operation.
|
||||
if (succeeded(tryFold(op, results)) && !results.empty()) {
|
||||
// Erase the operation, if the fold removed the need for this operation.
|
||||
// Note: The fold already populated the results in this case.
|
||||
op->erase();
|
||||
else if (block && listener)
|
||||
return;
|
||||
}
|
||||
|
||||
ResultRange opResults = op->getResults();
|
||||
results.assign(opResults.begin(), opResults.end());
|
||||
if (block && listener)
|
||||
listener->notifyOperationInserted(op, /*previous=*/{});
|
||||
}
|
||||
|
||||
@ -560,7 +567,8 @@ public:
|
||||
}
|
||||
|
||||
/// Attempts to fold the given operation and places new results within
|
||||
/// 'results'. Returns success if the operation was folded, failure otherwise.
|
||||
/// `results`. Returns success if the operation was folded, failure otherwise.
|
||||
/// If the fold was in-place, `results` will not be filled.
|
||||
/// Note: This function does not erase the operation on a successful fold.
|
||||
LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
|
||||
|
||||
|
@ -2831,7 +2831,8 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
|
||||
|
||||
/// Folds a cast op that can be chained.
|
||||
template <typename T>
|
||||
static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
|
||||
static OpFoldResult foldChainableCast(T castOp,
|
||||
typename T::FoldAdaptor adaptor) {
|
||||
// cast(x : T0, T0) -> x
|
||||
if (castOp.getArg().getType() == castOp.getType())
|
||||
return castOp.getArg();
|
||||
|
@ -476,16 +476,14 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
|
||||
return create(state);
|
||||
}
|
||||
|
||||
/// Attempts to fold the given operation and places new results within
|
||||
/// 'results'. Returns success if the operation was folded, failure otherwise.
|
||||
/// Note: This function does not erase the operation on a successful fold.
|
||||
LogicalResult OpBuilder::tryFold(Operation *op,
|
||||
SmallVectorImpl<Value> &results) {
|
||||
assert(results.empty() && "expected empty results");
|
||||
ResultRange opResults = op->getResults();
|
||||
|
||||
results.reserve(opResults.size());
|
||||
auto cleanupFailure = [&] {
|
||||
results.assign(opResults.begin(), opResults.end());
|
||||
results.clear();
|
||||
return failure();
|
||||
};
|
||||
|
||||
@ -495,20 +493,24 @@ LogicalResult OpBuilder::tryFold(Operation *op,
|
||||
|
||||
// Try to fold the operation.
|
||||
SmallVector<OpFoldResult, 4> foldResults;
|
||||
if (failed(op->fold(foldResults)) || foldResults.empty())
|
||||
if (failed(op->fold(foldResults)))
|
||||
return cleanupFailure();
|
||||
|
||||
// An in-place fold does not require generation of any constants.
|
||||
if (foldResults.empty())
|
||||
return success();
|
||||
|
||||
// A temporary builder used for creating constants during folding.
|
||||
OpBuilder cstBuilder(context);
|
||||
SmallVector<Operation *, 1> generatedConstants;
|
||||
|
||||
// Populate the results with the folded results.
|
||||
Dialect *dialect = op->getDialect();
|
||||
for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
|
||||
Type expectedType = std::get<1>(it);
|
||||
for (auto [foldResult, expectedType] :
|
||||
llvm::zip_equal(foldResults, opResults.getTypes())) {
|
||||
|
||||
// Normal values get pushed back directly.
|
||||
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
|
||||
if (auto value = llvm::dyn_cast_if_present<Value>(foldResult)) {
|
||||
results.push_back(value);
|
||||
continue;
|
||||
}
|
||||
@ -518,7 +520,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
|
||||
return cleanupFailure();
|
||||
|
||||
// Ask the dialect to materialize a constant operation for this value.
|
||||
Attribute attr = std::get<0>(it).get<Attribute>();
|
||||
Attribute attr = foldResult.get<Attribute>();
|
||||
auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
|
||||
op->getLoc());
|
||||
if (!constOp) {
|
||||
|
@ -2072,6 +2072,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
|
||||
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
|
||||
return failure();
|
||||
}
|
||||
// An empty list of replacement values indicates that the fold was in-place.
|
||||
// As the operation changed, a new legalization needs to be attempted.
|
||||
if (replacementValues.empty())
|
||||
return legalize(op, rewriter);
|
||||
|
||||
// Insert a replacement for 'op' with the folded replacement values.
|
||||
rewriter.replaceOp(op, replacementValues);
|
||||
|
@ -427,3 +427,13 @@ func.func @use_of_replaced_bbarg(%arg0: i64) {
|
||||
}) : (i64) -> (i64)
|
||||
"test.invalid"(%0) : (i64) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_legalization
|
||||
func.func @fold_legalization() -> i32 {
|
||||
// CHECK: op_in_place_self_fold
|
||||
// CHECK-SAME: folded
|
||||
%1 = "test.op_in_place_self_fold"() : () -> (i32)
|
||||
"test.return"(%1) : (i32) -> ()
|
||||
}
|
||||
|
@ -825,6 +825,19 @@ LogicalResult CompareOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestOpInPlaceSelfFold
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
|
||||
if (!getFolded()) {
|
||||
// The folder adds the "folded" if not present.
|
||||
setFolded(true);
|
||||
return getResult();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestOpFoldWithFoldAdaptor
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1351,6 +1351,12 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
|
||||
let arguments = (ins UnitAttr:$folded);
|
||||
let results = (outs I32);
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// Test op that simply returns success.
|
||||
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
|
||||
let results = (outs Variadic<I1>);
|
||||
|
@ -1168,6 +1168,10 @@ struct TestLegalizePatternDriver
|
||||
target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
|
||||
[](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
|
||||
|
||||
// Create a dynamically legal rule that can only be legalized by folding it.
|
||||
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
|
||||
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });
|
||||
|
||||
// Handle a partial conversion.
|
||||
if (mode == ConversionMode::Partial) {
|
||||
DenseSet<Operation *> unlegalizedOps;
|
||||
|
Loading…
x
Reference in New Issue
Block a user