//===- Ops.cpp - Standard MLIR Operations ---------------------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; //===----------------------------------------------------------------------===// // StandardOpsDialect Interfaces //===----------------------------------------------------------------------===// namespace { struct StdOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; /// Get a special name to use when printing the given operation. The desired /// name should be streamed into 'os'. void getOpResultName(Operation *op, raw_ostream &os) const final { if (ConstantOp constant = dyn_cast(op)) return getConstantOpResultName(constant, os); } /// Get a special name to use when printing the given constant. static void getConstantOpResultName(ConstantOp op, raw_ostream &os) { Type type = op.getType(); Attribute value = op.getValue(); if (auto intCst = value.dyn_cast()) { if (type.isIndex()) { os << 'c' << intCst.getInt(); } else if (type.cast().isInteger(1)) { // i1 constants get special names. os << (intCst.getInt() ? "true" : "false"); } else { os << 'c' << intCst.getInt() << '_' << type; } } else if (type.isa()) { os << 'f'; } else { os << "cst"; } } }; /// This class defines the interface for handling inlining with standard /// operations. struct StdInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks //===--------------------------------------------------------------------===// /// All operations within standard ops can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { return true; } //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, Block *newDest) const final { // Only "std.return" needs to be handled here. auto returnOp = dyn_cast(op); if (!returnOp) return; // Replace the return with a branch to the dest. OpBuilder builder(op); builder.create(op->getLoc(), newDest, llvm::to_vector<4>(returnOp.getOperands())); op->erase(); } /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, ArrayRef valuesToRepl) const final { // Only "std.return" needs to be handled here. auto returnOp = cast(op); // Replace the values directly with the return operands. assert(returnOp.getNumOperands() == valuesToRepl.size()); for (const auto &it : llvm::enumerate(returnOp.getOperands())) valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // StandardOpsDialect //===----------------------------------------------------------------------===// /// A custom unary operation printer that omits the "std." prefix from the /// operation names. static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) { assert(op->getNumOperands() == 1 && "unary op should have one operand"); assert(op->getNumResults() == 1 && "unary op should have one result"); int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' << *op->getOperand(0); p.printOptionalAttrDict(op->getAttrs()); p << " : " << op->getOperand(0)->getType(); } /// A custom binary operation printer that omits the "std." prefix from the /// operation names. static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) { assert(op->getNumOperands() == 2 && "binary op should have two operands"); assert(op->getNumResults() == 1 && "binary op should have one result"); // If not all the operand and result types are the same, just use the // generic assembly form to avoid omitting information in printing. auto resultType = op->getResult(0)->getType(); if (op->getOperand(0)->getType() != resultType || op->getOperand(1)->getType() != resultType) { p.printGenericOp(op); return; } int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' << *op->getOperand(0) << ", " << *op->getOperand(1); p.printOptionalAttrDict(op->getAttrs()); // Now we can output only one type for all operands and the result. p << " : " << op->getResult(0)->getType(); } /// A custom cast operation printer that omits the "std." prefix from the /// operation names. static void printStandardCastOp(Operation *op, OpAsmPrinter &p) { int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType(); } /// A custom cast operation verifier. template static LogicalResult verifyCastOp(T op) { auto opType = op.getOperand()->getType(); auto resType = op.getType(); if (!T::areCastCompatible(opType, resType)) return op.emitError("operand type ") << opType << " and result type " << resType << " are cast incompatible"; return success(); } StandardOpsDialect::StandardOpsDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations(); addInterfaces(); } void mlir::printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &p) { p << '('; p.printOperands(begin, begin + numDims); p << ')'; if (begin + numDims != end) { p << '['; p.printOperands(begin + numDims, end); p << ']'; } } // Parses dimension and symbol list, and sets 'numDims' to the number of // dimension operands parsed. // Returns 'false' on success and 'true' on error. ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser, SmallVector &operands, unsigned &numDims) { SmallVector opInfos; if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) return failure(); // Store number of dimensions for validation by caller. numDims = opInfos.size(); // Parse the optional symbol operands. auto indexTy = parser.getBuilder().getIndexType(); if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::OptionalSquare) || parser.resolveOperands(opInfos, indexTy, operands)) return failure(); return success(); } /// Matches a ConstantIndexOp. /// TODO: This should probably just be a general matcher that uses m_Constant /// and checks the operation for an index type. static detail::op_matcher m_ConstantIndex() { return detail::op_matcher(); } //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// namespace { /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref_cast /// into the root operation directly. struct MemRefCastFolder : public RewritePattern { /// The rootOpName is the name of the root operation to match against. MemRefCastFolder(StringRef rootOpName, MLIRContext *context) : RewritePattern(rootOpName, 1, context) {} PatternMatchResult match(Operation *op) const override { for (auto *operand : op->getOperands()) if (matchPattern(operand, m_Op())) return matchSuccess(); return matchFailure(); } void rewrite(Operation *op, PatternRewriter &rewriter) const override { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) if (auto *memref = op->getOperand(i)->getDefiningOp()) if (auto cast = dyn_cast(memref)) op->setOperand(i, cast.getOperand()); rewriter.updatedRootInPlace(op); } }; /// Performs const folding `calculate` with element-wise behavior on the two /// attributes in `operands` and returns the result if possible. template > Attribute constFoldBinaryOp(ArrayRef operands, const CalculationT &calculate) { assert(operands.size() == 2 && "binary op takes two operands"); if (auto lhs = operands[0].dyn_cast_or_null()) { auto rhs = operands[1].dyn_cast_or_null(); if (!rhs || lhs.getType() != rhs.getType()) return {}; return AttrElementT::get(lhs.getType(), calculate(lhs.getValue(), rhs.getValue())); } else if (auto lhs = operands[0].dyn_cast_or_null()) { auto rhs = operands[1].dyn_cast_or_null(); if (!rhs || lhs.getType() != rhs.getType()) return {}; auto elementResult = constFoldBinaryOp( {lhs.getSplatValue(), rhs.getSplatValue()}, calculate); if (!elementResult) return {}; return DenseElementsAttr::get(lhs.getType(), elementResult); } return {}; } } // end anonymous namespace. //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// OpFoldResult AddFOp::fold(ArrayRef operands) { return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a + b; }); } //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// OpFoldResult AddIOp::fold(ArrayRef operands) { /// addi(x, 0) -> x if (matchPattern(rhs(), m_Zero())) return lhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a + b; }); } //===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, AllocOp op) { p << "alloc"; // Print dynamic dimension operands. MemRefType type = op.getType(); printDimAndSymbolList(op.operand_begin(), op.operand_end(), type.getNumDynamicDims(), p); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); p << " : " << type; } static ParseResult parseAllocOp(OpAsmParser &parser, OperationState &result) { MemRefType type; // Parse the dimension operands and optional symbol operands, followed by a // memref type. unsigned numDimOperands; if (parseDimAndSymbolList(parser, result.operands, numDimOperands) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type)) return failure(); // Check numDynamicDims against number of question marks in memref type. // Note: this check remains here (instead of in verify()), because the // partition between dim operands and symbol operands is lost after parsing. // Verification still checks that the total number of operands matches // the number of symbols in the affine map, plus the number of dynamic // dimensions in the memref. if (numDimOperands != type.getNumDynamicDims()) return parser.emitError(parser.getNameLoc()) << "dimension operand count does not equal memref dynamic dimension " "count"; result.types.push_back(type); return success(); } static LogicalResult verify(AllocOp op) { auto memRefType = op.getResult()->getType().dyn_cast(); if (!memRefType) return op.emitOpError("result must be a memref"); unsigned numSymbols = 0; if (!memRefType.getAffineMaps().empty()) { // Store number of symbols used in affine map (used in subsequent check). AffineMap affineMap = memRefType.getAffineMaps()[0]; numSymbols = affineMap.getNumSymbols(); } // Check that the total number of operands matches the number of symbols in // the affine map, plus the number of dynamic dimensions specified in the // memref type. unsigned numDynamicDims = memRefType.getNumDynamicDims(); if (op.getNumOperands() != numDynamicDims + numSymbols) return op.emitOpError( "operand count does not equal dimension plus symbol operand count"); // Verify that all operands are of type Index. for (auto operandType : op.getOperandTypes()) if (!operandType.isIndex()) return op.emitOpError("requires operands to be of type Index"); return success(); } namespace { /// Fold constant dimensions into an alloc operation. struct SimplifyAllocConst : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(AllocOp alloc, PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. if (llvm::none_of(alloc.getOperands(), [](Value *operand) { return matchPattern(operand, m_ConstantIndex()); })) return matchFailure(); auto memrefType = alloc.getType(); // Ok, we have one or more constant operands. Collect the non-constant ones // and keep track of the resultant memref type to build. SmallVector newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); SmallVector newOperands; SmallVector droppedOperands; unsigned dynamicDimPos = 0; for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. if (dimSize != -1) { newShapeConstants.push_back(dimSize); continue; } auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); // Record to check for zero uses later below. droppedOperands.push_back(constantIndexOp); } else { // Dynamic shape dimension not folded; copy operand from old memref. newShapeConstants.push_back(-1); newOperands.push_back(alloc.getOperand(dynamicDimPos)); } dynamicDimPos++; } // Create new memref type (which will have fewer dynamic dimensions). auto newMemRefType = MemRefType::get( newShapeConstants, memrefType.getElementType(), memrefType.getAffineMaps(), memrefType.getMemorySpace()); assert(static_cast(newOperands.size()) == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. auto newAlloc = rewriter.create(alloc.getLoc(), newMemRefType, newOperands); // Insert a cast so we have the same type as the old alloc. auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, alloc.getType()); rewriter.replaceOp(alloc, {resultCast}, droppedOperands); return matchSuccess(); } }; /// Fold alloc operations with no uses. Alloc has side effects on the heap, /// but can still be deleted if it has zero uses. struct SimplifyDeadAlloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(AllocOp alloc, PatternRewriter &rewriter) const override { if (alloc.use_empty()) { rewriter.eraseOp(alloc); return matchSuccess(); } return matchFailure(); } }; } // end anonymous namespace. void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// namespace { /// Simplify a branch to a block that has a single predecessor. This effectively /// merges the two blocks. struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(BranchOp op, PatternRewriter &rewriter) const override { // Check that the successor block has a single predecessor. Block *succ = op.getDest(); Block *opParent = op.getOperation()->getBlock(); if (succ == opParent || !has_single_element(succ->getPredecessors())) return matchFailure(); // Merge the successor into the current block and erase the branch. rewriter.mergeBlocks(succ, opParent, llvm::to_vector<1>(op.getOperands())); rewriter.eraseOp(op); return matchSuccess(); } }; } // end anonymous namespace. static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { Block *dest; SmallVector destOperands; if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); result.addSuccessor(dest, destOperands); return success(); } static void print(OpAsmPrinter &p, BranchOp op) { p << "br "; p.printSuccessorAndUseList(op.getOperation(), 0); } Block *BranchOp::getDest() { return getSuccessor(0); } void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); } void BranchOp::eraseOperand(unsigned index) { getOperation()->eraseSuccessorOperand(0, index); } void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { SymbolRefAttr calleeAttr; FunctionType calleeType; SmallVector operands; auto calleeLoc = parser.getNameLoc(); if (parser.parseAttribute(calleeAttr, "callee", result.attributes) || parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(calleeType) || parser.addTypesToList(calleeType.getResults(), result.types) || parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, result.operands)) return failure(); return success(); } static void print(OpAsmPrinter &p, CallOp op) { p << "call " << op.getAttr("callee") << '('; p.printOperands(op.getOperands()); p << ')'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); p << " : "; p.printType(op.getCalleeType()); } static LogicalResult verify(CallOp op) { // Check that the callee attribute was specified. auto fnAttr = op.getAttrOfType("callee"); if (!fnAttr) return op.emitOpError("requires a 'callee' symbol reference attribute"); auto fn = op.getParentOfType().lookupSymbol(fnAttr.getValue()); if (!fn) return op.emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; // Verify that the operand and result types match the callee. auto fnType = fn.getType(); if (fnType.getNumInputs() != op.getNumOperands()) return op.emitOpError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) if (op.getOperand(i)->getType() != fnType.getInput(i)) return op.emitOpError("operand type mismatch"); if (fnType.getNumResults() != op.getNumResults()) return op.emitOpError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) if (op.getResult(i)->getType() != fnType.getResult(i)) return op.emitOpError("result type mismatch"); return success(); } FunctionType CallOp::getCalleeType() { SmallVector resultTypes(getResultTypes()); SmallVector argTypes(getOperandTypes()); return FunctionType::get(argTypes, resultTypes, getContext()); } //===----------------------------------------------------------------------===// // CallIndirectOp //===----------------------------------------------------------------------===// namespace { /// Fold indirect calls that have a constant function as the callee operand. struct SimplifyIndirectCallWithKnownCallee : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall, PatternRewriter &rewriter) const override { // Check that the callee is a constant callee. SymbolRefAttr calledFn; if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) return matchFailure(); // Replace with a direct call. SmallVector callResults(indirectCall.getResultTypes()); SmallVector callOperands(indirectCall.getArgOperands()); rewriter.replaceOpWithNewOp(indirectCall, calledFn.getValue(), callResults, callOperands); return matchSuccess(); } }; } // end anonymous namespace. static ParseResult parseCallIndirectOp(OpAsmParser &parser, OperationState &result) { FunctionType calleeType; OpAsmParser::OperandType callee; llvm::SMLoc operandsLoc; SmallVector operands; return failure( parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(calleeType) || parser.resolveOperand(callee, calleeType, result.operands) || parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, result.operands) || parser.addTypesToList(calleeType.getResults(), result.types)); } static void print(OpAsmPrinter &p, CallIndirectOp op) { p << "call_indirect "; p.printOperand(op.getCallee()); p << '('; p.printOperands(op.getArgOperands()); p << ')'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); p << " : " << op.getCallee()->getType(); } static LogicalResult verify(CallIndirectOp op) { // The callee must be a function. auto fnType = op.getCallee()->getType().dyn_cast(); if (!fnType) return op.emitOpError("callee must have function type"); // Verify that the operand and result types match the callee. if (fnType.getNumInputs() != op.getNumOperands() - 1) return op.emitOpError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) if (op.getOperand(i + 1)->getType() != fnType.getInput(i)) return op.emitOpError("operand type mismatch"); if (fnType.getNumResults() != op.getNumResults()) return op.emitOpError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) if (op.getResult(i)->getType() != fnType.getResult(i)) return op.emitOpError("result type mismatch"); return success(); } void CallIndirectOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // General helpers for comparison ops //===----------------------------------------------------------------------===// // Return the type of the same shape (scalar, vector or tensor) containing i1. static Type getCheckedI1SameShape(Builder *build, Type type) { auto i1Type = build->getI1Type(); if (type.isIntOrIndexOrFloat()) return i1Type; if (auto tensorType = type.dyn_cast()) return RankedTensorType::get(tensorType.getShape(), i1Type); if (type.isa()) return UnrankedTensorType::get(i1Type); if (auto vectorType = type.dyn_cast()) return VectorType::get(vectorType.getShape(), i1Type); return Type(); } static Type getI1SameShape(Builder *build, Type type) { Type res = getCheckedI1SameShape(build, type); assert(res && "expected type with valid i1 shape"); return res; } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// // Returns an array of mnemonics for CmpIPredicates indexed by values thereof. static inline const char *const *getCmpIPredicateNames() { static const char *predicateNames[]{ /*EQ*/ "eq", /*NE*/ "ne", /*SLT*/ "slt", /*SLE*/ "sle", /*SGT*/ "sgt", /*SGE*/ "sge", /*ULT*/ "ult", /*ULE*/ "ule", /*UGT*/ "ugt", /*UGE*/ "uge", }; static_assert(std::extent::value == (size_t)CmpIPredicate::NumPredicates, "wrong number of predicate names"); return predicateNames; } // Returns a value of the predicate corresponding to the given mnemonic. // Returns NumPredicates (one-past-end) if there is no such mnemonic. CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { return llvm::StringSwitch(name) .Case("eq", CmpIPredicate::EQ) .Case("ne", CmpIPredicate::NE) .Case("slt", CmpIPredicate::SLT) .Case("sle", CmpIPredicate::SLE) .Case("sgt", CmpIPredicate::SGT) .Case("sge", CmpIPredicate::SGE) .Case("ult", CmpIPredicate::ULT) .Case("ule", CmpIPredicate::ULE) .Case("ugt", CmpIPredicate::UGT) .Case("uge", CmpIPredicate::UGE) .Default(CmpIPredicate::NumPredicates); } static void buildCmpIOp(Builder *build, OperationState &result, CmpIPredicate predicate, Value *lhs, Value *rhs) { result.addOperands({lhs, rhs}); result.types.push_back(getI1SameShape(build, lhs->getType())); result.addAttribute( CmpIOp::getPredicateAttrName(), build->getI64IntegerAttr(static_cast(predicate))); } static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) { SmallVector ops; SmallVector attrs; Attribute predicateNameAttr; Type type; if (parser.parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(), attrs) || parser.parseComma() || parser.parseOperandList(ops, 2) || parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || parser.resolveOperands(ops, type, result.operands)) return failure(); if (!predicateNameAttr.isa()) return parser.emitError(parser.getNameLoc(), "expected string comparison predicate attribute"); // Rewrite string attribute to an enum value. StringRef predicateName = predicateNameAttr.cast().getValue(); auto predicate = CmpIOp::getPredicateByName(predicateName); if (predicate == CmpIPredicate::NumPredicates) return parser.emitError(parser.getNameLoc()) << "unknown comparison predicate \"" << predicateName << "\""; auto builder = parser.getBuilder(); Type i1Type = getCheckedI1SameShape(&builder, type); if (!i1Type) return parser.emitError(parser.getNameLoc(), "expected type with valid i1 shape"); attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); result.attributes = attrs; result.addTypes({i1Type}); return success(); } static void print(OpAsmPrinter &p, CmpIOp op) { p << "cmpi "; auto predicateValue = op.getAttrOfType(CmpIOp::getPredicateAttrName()).getInt(); assert(predicateValue >= static_cast(CmpIPredicate::FirstValidValue) && predicateValue < static_cast(CmpIPredicate::NumPredicates) && "unknown predicate index"); Builder b(op.getContext()); auto predicateStringAttr = b.getStringAttr(getCmpIPredicateNames()[predicateValue]); p.printAttribute(predicateStringAttr); p << ", "; p.printOperand(op.lhs()); p << ", "; p.printOperand(op.rhs()); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); p << " : " << op.lhs()->getType(); } static LogicalResult verify(CmpIOp op) { auto predicateAttr = op.getAttrOfType(CmpIOp::getPredicateAttrName()); if (!predicateAttr) return op.emitOpError("requires an integer attribute named 'predicate'"); auto predicate = predicateAttr.getInt(); if (predicate < (int64_t)CmpIPredicate::FirstValidValue || predicate >= (int64_t)CmpIPredicate::NumPredicates) return op.emitOpError("'predicate' attribute value out of range"); return success(); } // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer // comparison predicates. static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, const APInt &rhs) { switch (predicate) { case CmpIPredicate::EQ: return lhs.eq(rhs); case CmpIPredicate::NE: return lhs.ne(rhs); case CmpIPredicate::SLT: return lhs.slt(rhs); case CmpIPredicate::SLE: return lhs.sle(rhs); case CmpIPredicate::SGT: return lhs.sgt(rhs); case CmpIPredicate::SGE: return lhs.sge(rhs); case CmpIPredicate::ULT: return lhs.ult(rhs); case CmpIPredicate::ULE: return lhs.ule(rhs); case CmpIPredicate::UGT: return lhs.ugt(rhs); case CmpIPredicate::UGE: return lhs.uge(rhs); default: llvm_unreachable("unknown comparison predicate"); } } // Constant folding hook for comparisons. OpFoldResult CmpIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpi takes two arguments"); auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); if (!lhs || !rhs) return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); } //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// // Returns an array of mnemonics for CmpFPredicates indexed by values thereof. static inline const char *const *getCmpFPredicateNames() { static const char *predicateNames[] = { /*AlwaysFalse*/ "false", /*OEQ*/ "oeq", /*OGT*/ "ogt", /*OGE*/ "oge", /*OLT*/ "olt", /*OLE*/ "ole", /*ONE*/ "one", /*ORD*/ "ord", /*UEQ*/ "ueq", /*UGT*/ "ugt", /*UGE*/ "uge", /*ULT*/ "ult", /*ULE*/ "ule", /*UNE*/ "une", /*UNO*/ "uno", /*AlwaysTrue*/ "true", }; static_assert(std::extent::value == (size_t)CmpFPredicate::NumPredicates, "wrong number of predicate names"); return predicateNames; } // Returns a value of the predicate corresponding to the given mnemonic. // Returns NumPredicates (one-past-end) if there is no such mnemonic. CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { return llvm::StringSwitch(name) .Case("false", CmpFPredicate::AlwaysFalse) .Case("oeq", CmpFPredicate::OEQ) .Case("ogt", CmpFPredicate::OGT) .Case("oge", CmpFPredicate::OGE) .Case("olt", CmpFPredicate::OLT) .Case("ole", CmpFPredicate::OLE) .Case("one", CmpFPredicate::ONE) .Case("ord", CmpFPredicate::ORD) .Case("ueq", CmpFPredicate::UEQ) .Case("ugt", CmpFPredicate::UGT) .Case("uge", CmpFPredicate::UGE) .Case("ult", CmpFPredicate::ULT) .Case("ule", CmpFPredicate::ULE) .Case("une", CmpFPredicate::UNE) .Case("uno", CmpFPredicate::UNO) .Case("true", CmpFPredicate::AlwaysTrue) .Default(CmpFPredicate::NumPredicates); } static void buildCmpFOp(Builder *build, OperationState &result, CmpFPredicate predicate, Value *lhs, Value *rhs) { result.addOperands({lhs, rhs}); result.types.push_back(getI1SameShape(build, lhs->getType())); result.addAttribute( CmpFOp::getPredicateAttrName(), build->getI64IntegerAttr(static_cast(predicate))); } static ParseResult parseCmpFOp(OpAsmParser &parser, OperationState &result) { SmallVector ops; SmallVector attrs; Attribute predicateNameAttr; Type type; if (parser.parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(), attrs) || parser.parseComma() || parser.parseOperandList(ops, 2) || parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || parser.resolveOperands(ops, type, result.operands)) return failure(); if (!predicateNameAttr.isa()) return parser.emitError(parser.getNameLoc(), "expected string comparison predicate attribute"); // Rewrite string attribute to an enum value. StringRef predicateName = predicateNameAttr.cast().getValue(); auto predicate = CmpFOp::getPredicateByName(predicateName); if (predicate == CmpFPredicate::NumPredicates) return parser.emitError(parser.getNameLoc(), "unknown comparison predicate \"" + predicateName + "\""); auto builder = parser.getBuilder(); Type i1Type = getCheckedI1SameShape(&builder, type); if (!i1Type) return parser.emitError(parser.getNameLoc(), "expected type with valid i1 shape"); attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); result.attributes = attrs; result.addTypes({i1Type}); return success(); } static void print(OpAsmPrinter &p, CmpFOp op) { p << "cmpf "; auto predicateValue = op.getAttrOfType(CmpFOp::getPredicateAttrName()).getInt(); assert(predicateValue >= static_cast(CmpFPredicate::FirstValidValue) && predicateValue < static_cast(CmpFPredicate::NumPredicates) && "unknown predicate index"); Builder b(op.getContext()); auto predicateStringAttr = b.getStringAttr(getCmpFPredicateNames()[predicateValue]); p.printAttribute(predicateStringAttr); p << ", "; p.printOperand(op.lhs()); p << ", "; p.printOperand(op.rhs()); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); p << " : " << op.lhs()->getType(); } static LogicalResult verify(CmpFOp op) { auto predicateAttr = op.getAttrOfType(CmpFOp::getPredicateAttrName()); if (!predicateAttr) return op.emitOpError("requires an integer attribute named 'predicate'"); auto predicate = predicateAttr.getInt(); if (predicate < (int64_t)CmpFPredicate::FirstValidValue || predicate >= (int64_t)CmpFPredicate::NumPredicates) return op.emitOpError("'predicate' attribute value out of range"); return success(); } // Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point // comparison predicates. static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs) { auto cmpResult = lhs.compare(rhs); switch (predicate) { case CmpFPredicate::AlwaysFalse: return false; case CmpFPredicate::OEQ: return cmpResult == APFloat::cmpEqual; case CmpFPredicate::OGT: return cmpResult == APFloat::cmpGreaterThan; case CmpFPredicate::OGE: return cmpResult == APFloat::cmpGreaterThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::OLT: return cmpResult == APFloat::cmpLessThan; case CmpFPredicate::OLE: return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::ONE: return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; case CmpFPredicate::ORD: return cmpResult != APFloat::cmpUnordered; case CmpFPredicate::UEQ: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; case CmpFPredicate::UGT: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpGreaterThan; case CmpFPredicate::UGE: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpGreaterThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::ULT: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpLessThan; case CmpFPredicate::ULE: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::UNE: return cmpResult != APFloat::cmpEqual; case CmpFPredicate::UNO: return cmpResult == APFloat::cmpUnordered; case CmpFPredicate::AlwaysTrue: return true; default: llvm_unreachable("unknown comparison predicate"); } } // Constant folding hook for comparisons. OpFoldResult CmpFOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpf takes two arguments"); auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); // TODO(gcmn) We could actually do some intelligent things if we know only one // of the operands, but it's inf or nan. if (!lhs || !rhs) return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); } //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// namespace { /// cond_br true, ^bb1, ^bb2 -> br ^bb1 /// cond_br false, ^bb1, ^bb2 -> br ^bb2 /// struct SimplifyConstCondBranchPred : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { if (matchPattern(condbr.getCondition(), m_NonZero())) { // True branch taken. rewriter.replaceOpWithNewOp( condbr, condbr.getTrueDest(), llvm::to_vector<4>(condbr.getTrueOperands())); return matchSuccess(); } else if (matchPattern(condbr.getCondition(), m_Zero())) { // False branch taken. rewriter.replaceOpWithNewOp( condbr, condbr.getFalseDest(), llvm::to_vector<4>(condbr.getFalseOperands())); return matchSuccess(); } return matchFailure(); } }; } // end anonymous namespace. static ParseResult parseCondBranchOp(OpAsmParser &parser, OperationState &result) { SmallVector destOperands; Block *dest; OpAsmParser::OperandType condInfo; // Parse the condition. Type int1Ty = parser.getBuilder().getI1Type(); if (parser.parseOperand(condInfo) || parser.parseComma() || parser.resolveOperand(condInfo, int1Ty, result.operands)) { return parser.emitError(parser.getNameLoc(), "expected condition type was boolean (i1)"); } // Parse the true successor. if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); result.addSuccessor(dest, destOperands); // Parse the false successor. destOperands.clear(); if (parser.parseComma() || parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); result.addSuccessor(dest, destOperands); return success(); } static void print(OpAsmPrinter &p, CondBranchOp op) { p << "cond_br "; p.printOperand(op.getCondition()); p << ", "; p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); p << ", "; p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); } void CondBranchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // Constant*Op //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ConstantOp &op) { p << "constant "; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); if (op.getAttrs().size() > 1) p << ' '; p.printAttribute(op.getValue()); // If the value is a symbol reference, print a trailing type. if (op.getValue().isa()) p << " : " << op.getType(); } static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &result) { Attribute valueAttr; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseAttribute(valueAttr, "value", result.attributes)) return failure(); // If the attribute is a symbol reference, then we expect a trailing type. Type type; if (!valueAttr.isa()) type = valueAttr.getType(); else if (parser.parseColonType(type)) return failure(); // Add the attribute type to the list. return parser.addTypeToList(type, result.types); } /// The constant op requires an attribute, and furthermore requires that it /// matches the return type. static LogicalResult verify(ConstantOp &op) { auto value = op.getValue(); if (!value) return op.emitOpError("requires a 'value' attribute"); auto type = op.getType(); if (!value.getType().isa() && type != value.getType()) return op.emitOpError() << "requires attribute's type (" << value.getType() << ") to match op's return type (" << type << ")"; if (type.isa() || value.isa()) return success(); if (auto intAttr = value.dyn_cast()) { // If the type has a known bitwidth we verify that the value can be // represented with the given bitwidth. auto bitwidth = type.cast().getWidth(); auto intVal = intAttr.getValue(); if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) return op.emitOpError("requires 'value' to be an integer within the " "range of the integer result type"); return success(); } if (type.isa()) { if (!value.isa()) return op.emitOpError("requires 'value' to be a floating point constant"); return success(); } if (type.isa()) { if (!value.isa()) return op.emitOpError("requires 'value' to be a shaped constant"); return success(); } if (type.isa()) { auto fnAttr = value.dyn_cast(); if (!fnAttr) return op.emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. auto fn = op.getParentOfType().lookupSymbol(fnAttr.getValue()); if (!fn) return op.emitOpError("reference to undefined function 'bar'"); // Check that the referenced function has the correct type. if (fn.getType() != type) return op.emitOpError("reference to function with mismatched type"); return success(); } if (type.isa() && value.isa()) return success(); return op.emitOpError("unsupported 'value' attribute: ") << value; } OpFoldResult ConstantOp::fold(ArrayRef operands) { assert(operands.empty() && "constant has no operands"); return getValue(); } /// Returns true if a constant operation can be built with the given value and /// result type. bool ConstantOp::isBuildableWith(Attribute value, Type type) { // SymbolRefAttr can only be used with a function type. if (value.isa()) return type.isa(); // Otherwise, the attribute must have the same type as 'type'. if (value.getType() != type) return false; // Finally, check that the attribute kind is handled. return value.isa() || value.isa() || value.isa() || value.isa() || value.isa(); } void ConstantFloatOp::build(Builder *builder, OperationState &result, const APFloat &value, FloatType type) { ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value)); } bool ConstantFloatOp::classof(Operation *op) { return ConstantOp::classof(op) && op->getResult(0)->getType().isa(); } /// ConstantIntOp only matches values whose result type is an IntegerType. bool ConstantIntOp::classof(Operation *op) { return ConstantOp::classof(op) && op->getResult(0)->getType().isa(); } void ConstantIntOp::build(Builder *builder, OperationState &result, int64_t value, unsigned width) { Type type = builder->getIntegerType(width); ConstantOp::build(builder, result, type, builder->getIntegerAttr(type, value)); } /// Build a constant int op producing an integer with the specified type, /// which must be an integer type. void ConstantIntOp::build(Builder *builder, OperationState &result, int64_t value, Type type) { assert(type.isa() && "ConstantIntOp can only have integer type"); ConstantOp::build(builder, result, type, builder->getIntegerAttr(type, value)); } /// ConstantIndexOp only matches values whose result type is Index. bool ConstantIndexOp::classof(Operation *op) { return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex(); } void ConstantIndexOp::build(Builder *builder, OperationState &result, int64_t value) { Type type = builder->getIndexType(); ConstantOp::build(builder, result, type, builder->getIntegerAttr(type, value)); } //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// namespace { /// Fold Dealloc operations that are deallocating an AllocOp that is only used /// by other Dealloc operations. struct SimplifyDeadDealloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(DeallocOp dealloc, PatternRewriter &rewriter) const override { // Check that the memref operand's defining operation is an AllocOp. Value *memref = dealloc.memref(); if (!isa_and_nonnull(memref->getDefiningOp())) return matchFailure(); // Check that all of the uses of the AllocOp are other DeallocOps. for (auto *user : memref->getUsers()) if (!isa(user)) return matchFailure(); // Erase the dealloc operation. rewriter.eraseOp(dealloc); return matchSuccess(); } }; } // end anonymous namespace. static void print(OpAsmPrinter &p, DeallocOp op) { p << "dealloc " << *op.memref() << " : " << op.memref()->getType(); } static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType memrefInfo; MemRefType type; return failure(parser.parseOperand(memrefInfo) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands)); } static LogicalResult verify(DeallocOp op) { if (!op.memref()->getType().isa()) return op.emitOpError("operand must be a memref"); return success(); } void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dealloc(memrefcast) -> dealloc results.insert(getOperationName(), context); results.insert(context); } //===----------------------------------------------------------------------===// // DimOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, DimOp op) { p << "dim " << *op.getOperand() << ", " << op.getIndex(); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); p << " : " << op.getOperand()->getType(); } static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType operandInfo; IntegerAttr indexAttr; Type type; Type indexType = parser.getBuilder().getIndexType(); return failure( parser.parseOperand(operandInfo) || parser.parseComma() || parser.parseAttribute(indexAttr, indexType, "index", result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(operandInfo, type, result.operands) || parser.addTypeToList(indexType, result.types)); } static LogicalResult verify(DimOp op) { // Check that we have an integer index operand. auto indexAttr = op.getAttrOfType("index"); if (!indexAttr) return op.emitOpError("requires an integer attribute named 'index'"); int64_t index = indexAttr.getValue().getSExtValue(); auto type = op.getOperand()->getType(); if (auto tensorType = type.dyn_cast()) { if (index >= tensorType.getRank()) return op.emitOpError("index is out of range"); } else if (auto memrefType = type.dyn_cast()) { if (index >= memrefType.getRank()) return op.emitOpError("index is out of range"); } else if (type.isa()) { // ok, assumed to be in-range. } else { return op.emitOpError("requires an operand with tensor or memref type"); } return success(); } OpFoldResult DimOp::fold(ArrayRef operands) { // Constant fold dim when the size along the index referred to is a constant. auto opType = getOperand()->getType(); int64_t indexSize = -1; if (auto tensorType = opType.dyn_cast()) indexSize = tensorType.getShape()[getIndex()]; else if (auto memrefType = opType.dyn_cast()) indexSize = memrefType.getShape()[getIndex()]; if (indexSize >= 0) return IntegerAttr::get(IndexType::get(getContext()), indexSize); return {}; } void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dim(memrefcast) -> dim results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// // DivISOp //===----------------------------------------------------------------------===// OpFoldResult DivISOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "binary operation takes two operands"); auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); if (!lhs || !rhs) return {}; // Don't fold if it requires division by zero. if (rhs.getValue().isNullValue()) return {}; // Don't fold if it would overflow. bool overflow; auto result = lhs.getValue().sdiv_ov(rhs.getValue(), overflow); return overflow ? IntegerAttr() : IntegerAttr::get(lhs.getType(), result); } //===----------------------------------------------------------------------===// // DivIUOp //===----------------------------------------------------------------------===// OpFoldResult DivIUOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "binary operation takes two operands"); auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); if (!lhs || !rhs) return {}; // Don't fold if it requires division by zero. auto rhsValue = rhs.getValue(); if (rhsValue.isNullValue()) return {}; return IntegerAttr::get(lhs.getType(), lhs.getValue().udiv(rhsValue)); } // --------------------------------------------------------------------------- // DmaStartOp // --------------------------------------------------------------------------- void DmaStartOp::build(Builder *builder, OperationState &result, Value *srcMemRef, ArrayRef srcIndices, Value *destMemRef, ArrayRef destIndices, Value *numElements, Value *tagMemRef, ArrayRef tagIndices, Value *stride, Value *elementsPerStride) { result.addOperands(srcMemRef); result.addOperands(srcIndices); result.addOperands(destMemRef); result.addOperands(destIndices); result.addOperands({numElements, tagMemRef}); result.addOperands(tagIndices); if (stride) result.addOperands({stride, elementsPerStride}); } void DmaStartOp::print(OpAsmPrinter &p) { p << "dma_start " << *getSrcMemRef() << '['; p.printOperands(getSrcIndices()); p << "], " << *getDstMemRef() << '['; p.printOperands(getDstIndices()); p << "], " << *getNumElements(); p << ", " << *getTagMemRef() << '['; p.printOperands(getTagIndices()); p << ']'; if (isStrided()) { p << ", " << *getStride(); p << ", " << *getNumElementsPerStride(); } p.printOptionalAttrDict(getAttrs()); p << " : " << getSrcMemRef()->getType(); p << ", " << getDstMemRef()->getType(); p << ", " << getTagMemRef()->getType(); } // Parse DmaStartOp. // Ex: // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, // %tag[%index], %stride, %num_elt_per_stride : // : memref<3076 x f32, 0>, // memref<1024 x f32, 2>, // memref<1 x i32> // ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcMemRefInfo; SmallVector srcIndexInfos; OpAsmParser::OperandType dstMemRefInfo; SmallVector dstIndexInfos; OpAsmParser::OperandType numElementsInfo; OpAsmParser::OperandType tagMemrefInfo; SmallVector tagIndexInfos; SmallVector strideInfo; SmallVector types; auto indexType = parser.getBuilder().getIndexType(); // Parse and resolve the following list of operands: // *) source memref followed by its indices (in square brackets). // *) destination memref followed by its indices (in square brackets). // *) dma size in KiB. if (parser.parseOperand(srcMemRefInfo) || parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(dstMemRefInfo) || parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(numElementsInfo) || parser.parseComma() || parser.parseOperand(tagMemrefInfo) || parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) return failure(); // Parse optional stride and elements per stride. if (parser.parseTrailingOperandList(strideInfo)) return failure(); bool isStrided = strideInfo.size() == 2; if (!strideInfo.empty() && !isStrided) { return parser.emitError(parser.getNameLoc(), "expected two stride related operands"); } if (parser.parseColonTypeList(types)) return failure(); if (types.size() != 3) return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || parser.resolveOperands(srcIndexInfos, indexType, result.operands) || parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || parser.resolveOperands(dstIndexInfos, indexType, result.operands) || // size should be an index. parser.resolveOperand(numElementsInfo, indexType, result.operands) || parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || // tag indices should be index. parser.resolveOperands(tagIndexInfos, indexType, result.operands)) return failure(); auto memrefType0 = types[0].dyn_cast(); if (!memrefType0) return parser.emitError(parser.getNameLoc(), "expected source to be of memref type"); auto memrefType1 = types[1].dyn_cast(); if (!memrefType1) return parser.emitError(parser.getNameLoc(), "expected destination to be of memref type"); auto memrefType2 = types[2].dyn_cast(); if (!memrefType2) return parser.emitError(parser.getNameLoc(), "expected tag to be of memref type"); if (isStrided) { if (parser.resolveOperands(strideInfo, indexType, result.operands)) return failure(); } // Check that source/destination index list size matches associated rank. if (static_cast(srcIndexInfos.size()) != memrefType0.getRank() || static_cast(dstIndexInfos.size()) != memrefType1.getRank()) return parser.emitError(parser.getNameLoc(), "memref rank not equal to indices count"); if (static_cast(tagIndexInfos.size()) != memrefType2.getRank()) return parser.emitError(parser.getNameLoc(), "tag memref rank not equal to indices count"); return success(); } LogicalResult DmaStartOp::verify() { // DMAs from different memory spaces supported. if (getSrcMemorySpace() == getDstMemorySpace()) return emitOpError("DMA should be between different memory spaces"); if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + getDstMemRefRank() + 3 + 1 && getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + getDstMemRefRank() + 3 + 1 + 2) { return emitOpError("incorrect number of operands"); } return success(); } void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dma_start(memrefcast) -> dma_start results.insert(getOperationName(), context); } // --------------------------------------------------------------------------- // DmaWaitOp // --------------------------------------------------------------------------- void DmaWaitOp::build(Builder *builder, OperationState &result, Value *tagMemRef, ArrayRef tagIndices, Value *numElements) { result.addOperands(tagMemRef); result.addOperands(tagIndices); result.addOperands(numElements); } void DmaWaitOp::print(OpAsmPrinter &p) { p << "dma_wait "; p.printOperand(getTagMemRef()); p << '['; p.printOperands(getTagIndices()); p << "], "; p.printOperand(getNumElements()); p.printOptionalAttrDict(getAttrs()); p << " : " << getTagMemRef()->getType(); } // Parse DmaWaitOp. // Eg: // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> // ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType tagMemrefInfo; SmallVector tagIndexInfos; Type type; auto indexType = parser.getBuilder().getIndexType(); OpAsmParser::OperandType numElementsInfo; // Parse tag memref, its indices, and dma size. if (parser.parseOperand(tagMemrefInfo) || parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(numElementsInfo) || parser.parseColonType(type) || parser.resolveOperand(tagMemrefInfo, type, result.operands) || parser.resolveOperands(tagIndexInfos, indexType, result.operands) || parser.resolveOperand(numElementsInfo, indexType, result.operands)) return failure(); auto memrefType = type.dyn_cast(); if (!memrefType) return parser.emitError(parser.getNameLoc(), "expected tag to be of memref type"); if (static_cast(tagIndexInfos.size()) != memrefType.getRank()) return parser.emitError(parser.getNameLoc(), "tag memref rank not equal to indices count"); return success(); } void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// dma_wait(memrefcast) -> dma_wait results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ExtractElementOp op) { p << "extract_element " << *op.getAggregate() << '['; p.printOperands(op.getIndices()); p << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getAggregate()->getType(); } static ParseResult parseExtractElementOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType aggregateInfo; SmallVector indexInfo; ShapedType type; auto indexTy = parser.getBuilder().getIndexType(); return failure( parser.parseOperand(aggregateInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(aggregateInfo, type, result.operands) || parser.resolveOperands(indexInfo, indexTy, result.operands) || parser.addTypeToList(type.getElementType(), result.types)); } static LogicalResult verify(ExtractElementOp op) { auto aggregateType = op.getAggregate()->getType().cast(); // This should be possible with tablegen type constraints if (op.getType() != aggregateType.getElementType()) return op.emitOpError("result type must match element type of aggregate"); // Verify the # indices match if we have a ranked type. if (aggregateType.hasRank() && aggregateType.getRank() != op.getNumOperands() - 1) return op.emitOpError("incorrect number of indices for extract_element"); return success(); } OpFoldResult ExtractElementOp::fold(ArrayRef operands) { assert(!operands.empty() && "extract_element takes at least one operand"); // The aggregate operand must be a known constant. Attribute aggregate = operands.front(); if (!aggregate) return {}; // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. if (auto splatAggregate = aggregate.dyn_cast()) return splatAggregate.getSplatValue(); // Otherwise, collect the constant indices into the aggregate. SmallVector indices; for (Attribute indice : llvm::drop_begin(operands, 1)) { if (!indice || !indice.isa()) return {}; indices.push_back(indice.cast().getInt()); } // If this is an elements attribute, query the value at the given indices. auto elementsAttr = aggregate.dyn_cast(); if (elementsAttr && elementsAttr.isValidIndex(indices)) return elementsAttr.getValue(indices); return {}; } //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// // Index cast is applicable from index to integer and backwards. bool IndexCastOp::areCastCompatible(Type a, Type b) { return (a.isIndex() && b.isa()) || (a.isa() && b.isIndex()); } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, LoadOp op) { p << "load " << *op.getMemRef() << '['; p.printOperands(op.getIndices()); p << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getMemRefType(); } static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; MemRefType type; auto indexTy = parser.getBuilder().getIndexType(); return failure( parser.parseOperand(memrefInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperands(indexInfo, indexTy, result.operands) || parser.addTypeToList(type.getElementType(), result.types)); } static LogicalResult verify(LoadOp op) { if (op.getType() != op.getMemRefType().getElementType()) return op.emitOpError("result type must match element type of memref"); if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) return op.emitOpError("incorrect number of indices for load"); return success(); } void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// load(memrefcast) -> load results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// // MemRefCastOp //===----------------------------------------------------------------------===// bool MemRefCastOp::areCastCompatible(Type a, Type b) { auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); if (!aT || !bT) return false; if (aT.getElementType() != bT.getElementType()) return false; if (aT.getAffineMaps() != bT.getAffineMaps()) return false; if (aT.getMemorySpace() != bT.getMemorySpace()) return false; // They must have the same rank, and any specified dimensions must match. if (aT.getRank() != bT.getRank()) return false; for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); if (aDim != -1 && bDim != -1 && aDim != bDim) return false; } return true; } OpFoldResult MemRefCastOp::fold(ArrayRef operands) { return impl::foldCastOp(*this); } //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// OpFoldResult MulFOp::fold(ArrayRef operands) { return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a * b; }); } //===----------------------------------------------------------------------===// // MulIOp //===----------------------------------------------------------------------===// OpFoldResult MulIOp::fold(ArrayRef operands) { /// muli(x, 0) -> 0 if (matchPattern(rhs(), m_Zero())) return rhs(); /// muli(x, 1) -> x if (matchPattern(rhs(), m_One())) return getOperand(0); // TODO: Handle the overflow case. return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a * b; }); } //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, RankOp op) { p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType(); } static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType operandInfo; Type type; Type indexType = parser.getBuilder().getIndexType(); return failure(parser.parseOperand(operandInfo) || parser.parseColonType(type) || parser.resolveOperand(operandInfo, type, result.operands) || parser.addTypeToList(indexType, result.types)); } OpFoldResult RankOp::fold(ArrayRef operands) { // Constant fold rank when the rank of the tensor is known. auto type = getOperand()->getType(); if (auto tensorType = type.dyn_cast()) return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank()); return IntegerAttr(); } //===----------------------------------------------------------------------===// // RemISOp //===----------------------------------------------------------------------===// OpFoldResult RemISOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "remis takes two operands"); auto rhs = operands.back().dyn_cast_or_null(); if (!rhs) return {}; auto rhsValue = rhs.getValue(); // x % 1 = 0 if (rhsValue.isOneValue()) return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); // Don't fold if it requires division by zero. if (rhsValue.isNullValue()) return {}; auto lhs = operands.front().dyn_cast_or_null(); if (!lhs) return {}; return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); } //===----------------------------------------------------------------------===// // RemIUOp //===----------------------------------------------------------------------===// OpFoldResult RemIUOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "remiu takes two operands"); auto rhs = operands.back().dyn_cast_or_null(); if (!rhs) return {}; auto rhsValue = rhs.getValue(); // x % 1 = 0 if (rhsValue.isOneValue()) return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); // Don't fold if it requires division by zero. if (rhsValue.isNullValue()) return {}; auto lhs = operands.front().dyn_cast_or_null(); if (!lhs) return {}; return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { SmallVector opInfo; SmallVector types; llvm::SMLoc loc = parser.getCurrentLocation(); return failure(parser.parseOperandList(opInfo) || (!opInfo.empty() && parser.parseColonTypeList(types)) || parser.resolveOperands(opInfo, types, loc, result.operands)); } static void print(OpAsmPrinter &p, ReturnOp op) { p << "return"; if (op.getNumOperands() != 0) { p << ' '; p.printOperands(op.getOperands()); p << " : "; interleaveComma(op.getOperandTypes(), p); } } static LogicalResult verify(ReturnOp op) { auto function = cast(op.getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); if (op.getNumOperands() != results.size()) return op.emitOpError("has ") << op.getNumOperands() << " operands, but enclosing function returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) if (op.getOperand(i)->getType() != results[i]) return op.emitError() << "type of return operand " << i << " (" << op.getOperand(i)->getType() << ") doesn't match function result type (" << results[i] << ")"; return success(); } //===----------------------------------------------------------------------===// // SIToFPOp //===----------------------------------------------------------------------===// // sitofp is applicable from integer types to float types. bool SIToFPOp::areCastCompatible(Type a, Type b) { return a.isa() && b.isa(); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { SmallVector ops; SmallVector attrs; Type type; if (parser.parseOperandList(ops, 3) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type)) return failure(); auto i1Type = getCheckedI1SameShape(&parser.getBuilder(), type); if (!i1Type) return parser.emitError(parser.getNameLoc(), "expected type with valid i1 shape"); SmallVector types = {i1Type, type, type}; return failure(parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands) || parser.addTypeToList(type, result.types)); } static void print(OpAsmPrinter &p, SelectOp op) { p << "select "; p.printOperands(op.getOperands()); p << " : " << op.getTrueValue()->getType(); p.printOptionalAttrDict(op.getAttrs()); } static LogicalResult verify(SelectOp op) { auto trueType = op.getTrueValue()->getType(); auto falseType = op.getFalseValue()->getType(); if (trueType != falseType) return op.emitOpError( "requires 'true' and 'false' arguments to be of the same type"); return success(); } OpFoldResult SelectOp::fold(ArrayRef operands) { auto *condition = getCondition(); // select true, %0, %1 => %0 if (matchPattern(condition, m_One())) return getTrueValue(); // select false, %0, %1 => %1 if (matchPattern(condition, m_Zero())) return getFalseValue(); return nullptr; } //===----------------------------------------------------------------------===// // SignExtendIOp //===----------------------------------------------------------------------===// static LogicalResult verify(SignExtendIOp op) { // Get the scalar type (which is either directly the type of the operand // or the vector's/tensor's element type. auto srcType = getElementTypeOrSelf(op.getOperand()->getType()); auto dstType = getElementTypeOrSelf(op.getType()); // For now, index is forbidden for the source and the destination type. if (srcType.isa()) return op.emitError() << srcType << " is not a valid operand type"; if (dstType.isa()) return op.emitError() << dstType << " is not a valid result type"; if (srcType.cast().getWidth() >= dstType.cast().getWidth()) return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; return success(); } //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, SplatOp op) { p << "splat " << *op.getOperand(); p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getType(); } static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType splatValueInfo; ShapedType shapedType; return failure(parser.parseOperand(splatValueInfo) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(shapedType) || parser.resolveOperand(splatValueInfo, shapedType.getElementType(), result.operands) || parser.addTypeToList(shapedType, result.types)); } static LogicalResult verify(SplatOp op) { // TODO: we could replace this by a trait. if (op.getOperand()->getType() != op.getType().cast().getElementType()) return op.emitError("operand should be of elemental type of result type"); return success(); } // Constant folding hook for SplatOp. OpFoldResult SplatOp::fold(ArrayRef operands) { assert(operands.size() == 1 && "splat takes one operand"); auto constOperand = operands.front(); if (!constOperand || (!constOperand.isa() && !constOperand.isa())) return {}; auto shapedType = getType().cast(); assert(shapedType.getElementType() == constOperand.getType() && "incorrect input attribute type for folding"); // SplatElementsAttr::get treats single value for second arg as being a splat. return SplatElementsAttr::get(shapedType, {constOperand}); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, StoreOp op) { p << "store " << *op.getValueToStore(); p << ", " << *op.getMemRef() << '['; p.printOperands(op.getIndices()); p << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getMemRefType(); } static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; MemRefType memrefType; auto indexTy = parser.getBuilder().getIndexType(); return failure( parser.parseOperand(storeValueInfo) || parser.parseComma() || parser.parseOperand(memrefInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(memrefType) || parser.resolveOperand(storeValueInfo, memrefType.getElementType(), result.operands) || parser.resolveOperand(memrefInfo, memrefType, result.operands) || parser.resolveOperands(indexInfo, indexTy, result.operands)); } static LogicalResult verify(StoreOp op) { // First operand must have same type as memref element type. if (op.getValueToStore()->getType() != op.getMemRefType().getElementType()) return op.emitOpError( "first operand must have same type memref element type"); if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) return op.emitOpError("store index operand count not equal to memref rank"); return success(); } void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { /// store(memrefcast) -> store results.insert(getOperationName(), context); } //===----------------------------------------------------------------------===// // SubFOp //===----------------------------------------------------------------------===// OpFoldResult SubFOp::fold(ArrayRef operands) { return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a - b; }); } //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// OpFoldResult SubIOp::fold(ArrayRef operands) { // subi(x,x) -> 0 if (getOperand(0) == getOperand(1)) return Builder(getContext()).getZeroAttr(getType()); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a - b; }); } //===----------------------------------------------------------------------===// // AndOp //===----------------------------------------------------------------------===// OpFoldResult AndOp::fold(ArrayRef operands) { /// and(x, 0) -> 0 if (matchPattern(rhs(), m_Zero())) return rhs(); /// and(x,x) -> x if (lhs() == rhs()) return rhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a & b; }); } //===----------------------------------------------------------------------===// // OrOp //===----------------------------------------------------------------------===// OpFoldResult OrOp::fold(ArrayRef operands) { /// or(x, 0) -> x if (matchPattern(rhs(), m_Zero())) return lhs(); /// or(x,x) -> x if (lhs() == rhs()) return rhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a | b; }); } //===----------------------------------------------------------------------===// // XOrOp //===----------------------------------------------------------------------===// OpFoldResult XOrOp::fold(ArrayRef operands) { /// xor(x, 0) -> x if (matchPattern(rhs(), m_Zero())) return lhs(); /// xor(x,x) -> 0 if (lhs() == rhs()) return Builder(getContext()).getZeroAttr(getType()); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a ^ b; }); } //===----------------------------------------------------------------------===// // TensorCastOp //===----------------------------------------------------------------------===// bool TensorCastOp::areCastCompatible(Type a, Type b) { auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); if (!aT || !bT) return false; if (aT.getElementType() != bT.getElementType()) return false; return succeeded(verifyCompatibleShape(aT, bT)); } OpFoldResult TensorCastOp::fold(ArrayRef operands) { return impl::foldCastOp(*this); } //===----------------------------------------------------------------------===// // Helpers for Tensor[Load|Store]Op //===----------------------------------------------------------------------===// static Type getTensorTypeFromMemRefType(Builder &b, Type type) { if (auto memref = type.dyn_cast()) return RankedTensorType::get(memref.getShape(), memref.getElementType()); return b.getNoneType(); } //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, TensorLoadOp op) { p << "tensor_load " << *op.getOperand(); p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getOperand()->getType(); } static ParseResult parseTensorLoadOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType op; Type type; return failure(parser.parseOperand(op) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(op, type, result.operands) || parser.addTypeToList( getTensorTypeFromMemRefType(parser.getBuilder(), type), result.types)); } //===----------------------------------------------------------------------===// // TensorStoreOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, TensorStoreOp op) { p << "tensor_store " << *op.tensor() << ", " << *op.memref(); p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.memref()->getType(); } static ParseResult parseTensorStoreOp(OpAsmParser &parser, OperationState &result) { SmallVector ops; Type type; llvm::SMLoc loc = parser.getCurrentLocation(); return failure( parser.parseOperandList(ops, /*requiredOperandCount=*/2) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperands( ops, {getTensorTypeFromMemRefType(parser.getBuilder(), type), type}, loc, result.operands)); } //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// static LogicalResult verify(TruncateIOp op) { auto srcType = getElementTypeOrSelf(op.getOperand()->getType()); auto dstType = getElementTypeOrSelf(op.getType()); if (srcType.isa()) return op.emitError() << srcType << " is not a valid operand type"; if (dstType.isa()) return op.emitError() << dstType << " is not a valid result type"; if (srcType.cast().getWidth() <= dstType.cast().getWidth()) return op.emitError("operand type ") << srcType << " must be wider than result type " << dstType; return success(); } //===----------------------------------------------------------------------===// // ViewOp //===----------------------------------------------------------------------===// static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcInfo; SmallVector offsetInfo; SmallVector sizesInfo; auto indexType = parser.getBuilder().getIndexType(); Type srcType, dstType; return failure( parser.parseOperand(srcInfo) || parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || parser.resolveOperand(srcInfo, srcType, result.operands) || parser.resolveOperands(sizesInfo, indexType, result.operands) || parser.resolveOperands(offsetInfo, indexType, result.operands) || parser.parseKeywordType("to", dstType) || parser.addTypeToList(dstType, result.types)); } static void print(OpAsmPrinter &p, ViewOp op) { p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; p.printOperands(op.getDynamicSizes()); p << "]["; auto *dynamicOffset = op.getDynamicOffset(); if (dynamicOffset != nullptr) p.printOperand(dynamicOffset); p << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); } static LogicalResult verify(ViewOp op) { auto baseType = op.getOperand(0)->getType().cast(); auto viewType = op.getResult()->getType().cast(); // The base memref should have identity layout map (or none). if (baseType.getAffineMaps().size() > 1 || (baseType.getAffineMaps().size() == 1 && !baseType.getAffineMaps()[0].isIdentity())) return op.emitError("unsupported map for base memref type ") << baseType; // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != viewType.getMemorySpace()) return op.emitError("different memory spaces specified for base memref " "type ") << baseType << " and view memref type " << viewType; // Verify that the result memref type has a strided layout map. is strided int64_t offset; llvm::SmallVector strides; if (failed(getStridesAndOffset(viewType, strides, offset))) return op.emitError("result type ") << viewType << " is not strided"; // Verify that we have the correct number of operands for the result type. unsigned memrefOperandCount = 1; unsigned numDynamicDims = viewType.getNumDynamicDims(); unsigned dynamicOffsetCount = offset == MemRefType::getDynamicStrideOrOffset() ? 1 : 0; if (op.getNumOperands() != memrefOperandCount + numDynamicDims + dynamicOffsetCount) return op.emitError("incorrect number of operands for type ") << viewType; // Verify dynamic strides symbols were added to correct dimensions based // on dynamic sizes. ArrayRef viewShape = viewType.getShape(); unsigned viewRank = viewType.getRank(); assert(viewRank == strides.size()); bool dynamicStrides = false; for (int i = viewRank - 2; i >= 0; --i) { // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. if (ShapedType::isDynamic(viewShape[i + 1])) dynamicStrides = true; // If stride at dim 'i' is not dynamic, return error. if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) return op.emitError("incorrect dynamic strides in view memref type ") << viewType; } return success(); } //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// static LogicalResult verify(ZeroExtendIOp op) { auto srcType = getElementTypeOrSelf(op.getOperand()->getType()); auto dstType = getElementTypeOrSelf(op.getType()); if (srcType.isa()) return op.emitError() << srcType << " is not a valid operand type"; if (dstType.isa()) return op.emitError() << dstType << " is not a valid result type"; if (srcType.cast().getWidth() >= dstType.cast().getWidth()) return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; return success(); } //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// bool FPExtOp::areCastCompatible(Type a, Type b) { if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() < fb.getWidth(); return false; } //===----------------------------------------------------------------------===// // FPTruncOp //===----------------------------------------------------------------------===// bool FPTruncOp::areCastCompatible(Type a, Type b) { if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() > fb.getWidth(); return false; } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/StandardOps/Ops.cpp.inc"