mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-24 17:06:07 +00:00
[mlir][NFC] Migrate rest of the dialects to the new fold API
This commit is contained in:
parent
7039bd2509
commit
7df761217c
@ -24,6 +24,7 @@ def Affine_Dialect : Dialect {
|
||||
let cppNamespace = "mlir";
|
||||
let hasConstantMaterializer = 1;
|
||||
let dependentDialects = ["arith::ArithDialect"];
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
// Base class for Affine dialect ops.
|
||||
|
@ -69,6 +69,7 @@ def Bufferization_Dialect : Dialect {
|
||||
kEscapeAttrName = "bufferization.escape";
|
||||
}];
|
||||
let hasOperationAttrVerify = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
#endif // BUFFERIZATION_BASE
|
||||
|
@ -22,6 +22,7 @@ def Complex_Dialect : Dialect {
|
||||
let dependentDialects = ["arith::ArithDialect"];
|
||||
let hasConstantMaterializer = 1;
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
#endif // COMPLEX_BASE
|
||||
|
@ -31,6 +31,7 @@ def EmitC_Dialect : Dialect {
|
||||
let hasConstantMaterializer = 1;
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_EMITC_IR_EMITCBASE
|
||||
|
@ -23,6 +23,7 @@ def Func_Dialect : Dialect {
|
||||
let cppNamespace = "::mlir::func";
|
||||
let dependentDialects = ["cf::ControlFlowDialect"];
|
||||
let hasConstantMaterializer = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
// Base class for Func dialect ops.
|
||||
|
@ -56,6 +56,7 @@ def GPU_Dialect : Dialect {
|
||||
let dependentDialects = ["arith::ArithDialect"];
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
def GPU_AsyncToken : DialectType<
|
||||
|
@ -31,6 +31,7 @@ def LLVM_Dialect : Dialect {
|
||||
let hasRegionArgAttrVerify = 1;
|
||||
let hasRegionResultAttrVerify = 1;
|
||||
let hasOperationAttrVerify = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Name of the data layout attributes.
|
||||
|
@ -46,6 +46,7 @@ def Linalg_Dialect : Dialect {
|
||||
let hasCanonicalizer = 1;
|
||||
let hasOperationAttrVerify = 1;
|
||||
let hasConstantMaterializer = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
let extraClassDeclaration = [{
|
||||
/// Attribute name used to to memoize indexing maps for named ops.
|
||||
constexpr const static ::llvm::StringLiteral
|
||||
|
@ -20,6 +20,7 @@ def Quantization_Dialect : Dialect {
|
||||
let cppNamespace = "::mlir::quant";
|
||||
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -25,6 +25,7 @@ def SCF_Dialect : Dialect {
|
||||
let name = "scf";
|
||||
let cppNamespace = "::mlir::scf";
|
||||
let dependentDialects = ["arith::ArithDialect"];
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
// Base class for SCF dialect ops.
|
||||
|
@ -83,6 +83,7 @@ def SparseTensor_Dialect : Dialect {
|
||||
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
#endif // SPARSETENSOR_BASE
|
||||
|
@ -23,6 +23,8 @@ def Transform_Dialect : Dialect {
|
||||
"::mlir::pdl_interp::PDLInterpDialect",
|
||||
];
|
||||
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the named PDL constraint functions available in the dialect
|
||||
/// as a map from their name to the function.
|
||||
|
@ -34,6 +34,8 @@ def Builtin_Dialect : Dialect {
|
||||
|
||||
public:
|
||||
}];
|
||||
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
#endif // BUILTIN_BASE
|
||||
|
@ -562,7 +562,7 @@ bool AffineApplyOp::isValidSymbol(Region *region) {
|
||||
});
|
||||
}
|
||||
|
||||
OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
|
||||
auto map = getAffineMap();
|
||||
|
||||
// Fold dims and symbols to existing values.
|
||||
@ -574,7 +574,7 @@ OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
// Otherwise, default to folding the map.
|
||||
SmallVector<Attribute, 1> result;
|
||||
if (failed(map.constantFold(operands, result)))
|
||||
if (failed(map.constantFold(adaptor.getMapOperands(), result)))
|
||||
return {};
|
||||
return result[0];
|
||||
}
|
||||
@ -2135,7 +2135,7 @@ static bool hasTrivialZeroTripCount(AffineForOp op) {
|
||||
return tripCount && *tripCount == 0;
|
||||
}
|
||||
|
||||
LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
|
||||
LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
bool folded = succeeded(foldLoopBounds(*this));
|
||||
folded |= succeeded(canonicalizeLoopBounds(*this));
|
||||
@ -2723,7 +2723,7 @@ static void composeSetAndOperands(IntegerSet &set,
|
||||
}
|
||||
|
||||
/// Canonicalize an affine if op's conditional (integer set + operands).
|
||||
LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
|
||||
LogicalResult AffineIfOp::fold(FoldAdaptor,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
auto set = getIntegerSet();
|
||||
SmallVector<Value, 4> operands(getOperands());
|
||||
@ -2858,7 +2858,7 @@ void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
results.add<SimplifyAffineOp<AffineLoadOp>>(context);
|
||||
}
|
||||
|
||||
OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
|
||||
OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
|
||||
/// load(memrefcast) -> load
|
||||
if (succeeded(memref::foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
@ -2975,7 +2975,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
results.add<SimplifyAffineOp<AffineStoreOp>>(context);
|
||||
}
|
||||
|
||||
LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
|
||||
LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
/// store(memrefcast) -> store
|
||||
return memref::foldMemRefCast(*this, getValueToStore());
|
||||
@ -3282,8 +3282,8 @@ struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
|
||||
// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
|
||||
//
|
||||
|
||||
OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
|
||||
return foldMinMaxOp(*this, operands);
|
||||
OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
|
||||
return foldMinMaxOp(*this, adaptor.getOperands());
|
||||
}
|
||||
|
||||
void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
@ -3310,8 +3310,8 @@ void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
|
||||
// %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
|
||||
//
|
||||
|
||||
OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
|
||||
return foldMinMaxOp(*this, operands);
|
||||
OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
|
||||
return foldMinMaxOp(*this, adaptor.getOperands());
|
||||
}
|
||||
|
||||
void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
@ -3431,7 +3431,7 @@ void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
|
||||
}
|
||||
|
||||
LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
|
||||
LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
/// prefetch(memrefcast) -> prefetch
|
||||
return memref::foldMemRefCast(*this);
|
||||
@ -3705,7 +3705,7 @@ static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
|
||||
LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
return canonicalizeLoopBounds(*this);
|
||||
}
|
||||
|
@ -458,7 +458,7 @@ void CloneOp::getEffects(
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
|
||||
return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
|
||||
}
|
||||
|
||||
@ -560,7 +560,7 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
|
||||
// ToTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
|
||||
OpFoldResult ToTensorOp::fold(FoldAdaptor) {
|
||||
if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
|
||||
// Approximate alias analysis by conservatively folding only when no there
|
||||
// is no interleaved operation.
|
||||
@ -596,7 +596,7 @@ void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
// ToMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
|
||||
OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
|
||||
if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
|
||||
if (memrefToTensor.getMemref().getType() == getType())
|
||||
return memrefToTensor.getMemref();
|
||||
|
@ -17,8 +17,7 @@ using namespace mlir::complex;
|
||||
// ConstantOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
|
||||
return getValue();
|
||||
}
|
||||
|
||||
@ -68,8 +67,7 @@ LogicalResult ConstantOp::verify() {
|
||||
// CreateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "binary op takes two operands");
|
||||
OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
|
||||
// Fold complex.create(complex.re(op), complex.im(op)).
|
||||
if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
|
||||
if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
|
||||
@ -85,9 +83,8 @@ OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
|
||||
// ImOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1 && "unary op takes 1 operand");
|
||||
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
|
||||
OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
|
||||
ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
|
||||
if (arrayAttr && arrayAttr.size() == 2)
|
||||
return arrayAttr[1];
|
||||
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
|
||||
@ -99,9 +96,8 @@ OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
|
||||
// ReOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1 && "unary op takes 1 operand");
|
||||
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
|
||||
OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
|
||||
ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
|
||||
if (arrayAttr && arrayAttr.size() == 2)
|
||||
return arrayAttr[0];
|
||||
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
|
||||
@ -113,9 +109,7 @@ OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
|
||||
// AddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "binary op takes 2 operands");
|
||||
|
||||
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
|
||||
// complex.add(complex.sub(a, b), b) -> a
|
||||
if (auto sub = getLhs().getDefiningOp<SubOp>())
|
||||
if (getRhs() == sub.getRhs())
|
||||
@ -142,9 +136,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
|
||||
// SubOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "binary op takes 2 operands");
|
||||
|
||||
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
|
||||
// complex.sub(complex.add(a, b), b) -> a
|
||||
if (auto add = getLhs().getDefiningOp<AddOp>())
|
||||
if (getRhs() == add.getRhs())
|
||||
@ -166,9 +158,7 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
|
||||
// NegOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1 && "unary op takes 1 operand");
|
||||
|
||||
OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
|
||||
// complex.neg(complex.neg(a)) -> a
|
||||
if (auto negOp = getOperand().getDefiningOp<NegOp>())
|
||||
return negOp.getOperand();
|
||||
@ -180,9 +170,7 @@ OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
|
||||
// LogOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1 && "unary op takes 1 operand");
|
||||
|
||||
OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
|
||||
// complex.log(complex.exp(a)) -> a
|
||||
if (auto expOp = getOperand().getDefiningOp<ExpOp>())
|
||||
return expOp.getOperand();
|
||||
@ -194,9 +182,7 @@ OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
|
||||
// ExpOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ExpOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1 && "unary op takes 1 operand");
|
||||
|
||||
OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
|
||||
// complex.exp(complex.log(a)) -> a
|
||||
if (auto logOp = getOperand().getDefiningOp<LogOp>())
|
||||
return logOp.getOperand();
|
||||
@ -208,9 +194,7 @@ OpFoldResult ExpOp::fold(ArrayRef<Attribute> operands) {
|
||||
// ConjOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ConjOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1 && "unary op takes 1 operand");
|
||||
|
||||
OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
|
||||
// complex.conj(complex.conj(a)) -> a
|
||||
if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
|
||||
return conjOp.getOperand();
|
||||
|
@ -129,8 +129,7 @@ LogicalResult emitc::ConstantOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult emitc::ConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) {
|
||||
return getValue();
|
||||
}
|
||||
|
||||
|
@ -201,8 +201,7 @@ LogicalResult ConstantOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
|
@ -1286,12 +1286,12 @@ LogicalResult SubgroupMmaComputeOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
|
||||
LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<::mlir::OpFoldResult> &results) {
|
||||
return memref::foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
|
||||
LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<::mlir::OpFoldResult> &results) {
|
||||
return memref::foldMemRefCast(*this);
|
||||
}
|
||||
|
@ -1441,7 +1441,7 @@ static Type getInsertExtractValueElementType(Type llvmType,
|
||||
return llvmType;
|
||||
}
|
||||
|
||||
OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
|
||||
auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
|
||||
OpFoldResult result = {};
|
||||
while (insertValueOp) {
|
||||
@ -2275,7 +2275,7 @@ LogicalResult LLVM::ConstantOp::verify() {
|
||||
}
|
||||
|
||||
// Constant op constant-folds to its value.
|
||||
OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); }
|
||||
OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility functions for parsing atomic ops
|
||||
@ -2513,7 +2513,7 @@ LogicalResult FenceOp::verify() {
|
||||
// Folder for LLVM::BitcastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
|
||||
// bitcast(x : T0, T0) -> x
|
||||
if (getArg().getType() == getType())
|
||||
return getArg();
|
||||
@ -2528,7 +2528,7 @@ OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Folder for LLVM::AddrSpaceCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
|
||||
// addrcast(x : T0, T0) -> x
|
||||
if (getArg().getType() == getType())
|
||||
return getArg();
|
||||
@ -2543,9 +2543,9 @@ OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Folder for LLVM::GEPOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
|
||||
GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
|
||||
operands.drop_front());
|
||||
adaptor.getDynamicIndices());
|
||||
|
||||
// gep %x:T, 0 -> %x
|
||||
if (getBase().getType() == getType() && indices.size() == 1)
|
||||
|
@ -980,8 +980,7 @@ void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
results.add<EraseIdentityGenericOp>(context);
|
||||
}
|
||||
|
||||
LogicalResult GenericOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
||||
return memref::foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
|
@ -36,7 +36,7 @@ void QuantizationDialect::initialize() {
|
||||
addBytecodeInterface(this);
|
||||
}
|
||||
|
||||
OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
|
||||
// Matches x -> [scast -> scast] -> y, replacing the second scast with the
|
||||
// value of x if the casts invert each other.
|
||||
auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();
|
||||
|
@ -1598,7 +1598,7 @@ void IfOp::getSuccessorRegions(std::optional<unsigned> index,
|
||||
regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion));
|
||||
}
|
||||
|
||||
LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
|
||||
LogicalResult IfOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
// if (!c) then A() else B() -> if c then B() else A()
|
||||
if (getElseRegion().empty())
|
||||
|
@ -467,7 +467,7 @@ LogicalResult ConvertOp::verify() {
|
||||
return emitError("unexpected type in convert");
|
||||
}
|
||||
|
||||
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
|
||||
Type dstType = getType();
|
||||
// Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse
|
||||
// convert for codegen to remove. This is because we use trivial
|
||||
@ -531,7 +531,7 @@ static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
|
||||
return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
|
||||
}
|
||||
|
||||
OpFoldResult GetStorageSpecifierOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
|
||||
StorageSpecifierKind kind = getSpecifierKind();
|
||||
std::optional<APInt> dim = getDim();
|
||||
for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
|
||||
|
@ -463,7 +463,7 @@ void transform::MergeHandlesOp::getEffects(
|
||||
// manipulation.
|
||||
}
|
||||
|
||||
OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
|
||||
if (getDeduplicate() || getHandles().size() != 1)
|
||||
return {};
|
||||
|
||||
|
@ -190,7 +190,7 @@ LogicalResult ModuleOp::verify() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
|
||||
UnrealizedConversionCastOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &foldResults) {
|
||||
OperandRange operands = getInputs();
|
||||
ResultRange results = getOutputs();
|
||||
|
@ -1099,32 +1099,31 @@ void TestOpWithRegionPattern::getCanonicalizationPatterns(
|
||||
results.add<TestRemoveOpWithInnerOps>(context);
|
||||
}
|
||||
|
||||
OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
|
||||
return getOperand();
|
||||
}
|
||||
|
||||
OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) {
|
||||
return getValue();
|
||||
}
|
||||
|
||||
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
|
||||
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
|
||||
FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
|
||||
for (Value input : this->getOperands()) {
|
||||
results.push_back(input);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1);
|
||||
if (operands.front()) {
|
||||
(*this)->setAttr("attr", operands.front());
|
||||
OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
|
||||
if (adaptor.getOp()) {
|
||||
(*this)->setAttr("attr", adaptor.getOp());
|
||||
return getResult();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) {
|
||||
return getOperand();
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@ def Test_Dialect : Dialect {
|
||||
let hasNonDefaultDestructor = 1;
|
||||
let useDefaultTypePrinterParser = 0;
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
let isExtensible = 1;
|
||||
let dependentDialects = ["::mlir::DLTIDialect"];
|
||||
|
||||
|
@ -1290,7 +1290,7 @@ def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
|
||||
let results = (outs Variadic<I1>);
|
||||
let hasFolder = 1;
|
||||
let extraClassDefinition = [{
|
||||
::mlir::LogicalResult $cppClass::fold(ArrayRef<Attribute> operands,
|
||||
::mlir::LogicalResult $cppClass::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
return success();
|
||||
}
|
||||
@ -1315,11 +1315,7 @@ def TestOpFoldWithFoldAdaptor
|
||||
$op `,` `[` $variadic `]` `,` `{` $var_of_var `}` $body attr-dict-with-keyword
|
||||
}];
|
||||
|
||||
let hasFolder = 0;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::OpFoldResult fold(FoldAdaptor adaptor);
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// An op that always fold itself.
|
||||
|
@ -18,13 +18,13 @@ using namespace test;
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold(
|
||||
ArrayRef<Attribute> operands) {
|
||||
FoldAdaptor adaptor) {
|
||||
// This failure should cause the trait fold to run instead.
|
||||
return {};
|
||||
}
|
||||
|
||||
OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
|
||||
ArrayRef<Attribute> operands) {
|
||||
FoldAdaptor adaptor) {
|
||||
auto argumentOp = getOperand();
|
||||
// The success case should cause the trait fold to be supressed.
|
||||
return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{};
|
||||
|
@ -654,7 +654,7 @@ ArrayAttr {0}::getIndexingMaps() {{
|
||||
// Parameters:
|
||||
// {0}: Class name
|
||||
const char structuredOpFoldersFormat[] = R"FMT(
|
||||
LogicalResult {0}::fold(ArrayRef<Attribute>,
|
||||
LogicalResult {0}::fold(FoldAdaptor,
|
||||
SmallVectorImpl<OpFoldResult> &) {{
|
||||
return memref::foldMemRefCast(*this);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user