//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" using namespace mlir; using namespace mlir::memref; namespace { /// Idiomatic saturated operations on offsets, sizes and strides. namespace saturated_arith { struct Wrapper { static Wrapper stride(int64_t v) { return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0} : Wrapper{false, v}; } static Wrapper offset(int64_t v) { return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0} : Wrapper{false, v}; } static Wrapper size(int64_t v) { return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v}; } int64_t asOffset() { return saturated ? ShapedType::kDynamicStrideOrOffset : v; } int64_t asSize() { return saturated ? ShapedType::kDynamicSize : v; } int64_t asStride() { return saturated ? ShapedType::kDynamicStrideOrOffset : v; } bool operator==(Wrapper other) { return (saturated && other.saturated) || (!saturated && !other.saturated && v == other.v); } bool operator!=(Wrapper other) { return !(*this == other); } Wrapper operator+(Wrapper other) { if (saturated || other.saturated) return Wrapper{true, 0}; return Wrapper{false, other.v + v}; } Wrapper operator*(Wrapper other) { if (saturated || other.saturated) return Wrapper{true, 0}; return Wrapper{false, other.v * v}; } bool saturated; int64_t v; }; } // namespace saturated_arith } // namespace /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *MemRefDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (arith::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, value, type); return nullptr; } //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref.cast /// into the root operation directly. LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); if (cast && operand.get() != inner && !cast.getOperand().getType().isa()) { operand.set(cast.getOperand()); folded = true; } } return success(folded); } /// Return an unranked/ranked tensor type for the given unranked/ranked memref /// type. Type mlir::memref::getTensorTypeFromMemRefType(Type type) { if (auto memref = type.dyn_cast()) return RankedTensorType::get(memref.getShape(), memref.getElementType()); if (auto memref = type.dyn_cast()) return UnrankedTensorType::get(memref.getElementType()); return NoneType::get(type.getContext()); } //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// template static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { static_assert(llvm::is_one_of::value, "applies to only alloc or alloca"); auto memRefType = op.getResult().getType().template dyn_cast(); if (!memRefType) return op.emitOpError("result must be a memref"); if (static_cast(op.dynamicSizes().size()) != memRefType.getNumDynamicDims()) return op.emitOpError("dimension operand count does not equal memref " "dynamic dimension count"); unsigned numSymbols = 0; if (!memRefType.getLayout().isIdentity()) numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); if (op.symbolOperands().size() != numSymbols) return op.emitOpError("symbol operand count does not equal memref symbol " "count: expected ") << numSymbols << ", got " << op.symbolOperands().size(); return success(); } LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); } LogicalResult AllocaOp::verify() { // An alloca op needs to have an ancestor with an allocation scope trait. if (!(*this)->getParentWithTrait()) return emitOpError( "requires an ancestor op with AutomaticAllocationScope trait"); return verifyAllocLikeOp(*this); } namespace { /// Fold constant dimensions into an alloc like operation. template struct SimplifyAllocConst : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AllocLikeOp alloc, PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { return matchPattern(operand, matchConstantIndex()); })) return failure(); auto memrefType = alloc.getType(); // Ok, we have one or more constant operands. Collect the non-constant ones // and keep track of the resultant memref type to build. SmallVector newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); SmallVector dynamicSizes; unsigned dynamicDimPos = 0; for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. if (dimSize != -1) { newShapeConstants.push_back(dimSize); continue; } auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; auto *defOp = dynamicSize.getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.value()); } else { // Dynamic shape dimension not folded; copy dynamicSize from old memref. newShapeConstants.push_back(-1); dynamicSizes.push_back(dynamicSize); } dynamicDimPos++; } // Create new memref type (which will have fewer dynamic dimensions). MemRefType newMemRefType = MemRefType::Builder(memrefType).setShape(newShapeConstants); assert(static_cast(dynamicSizes.size()) == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. auto newAlloc = rewriter.create( alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), alloc.alignmentAttr()); // Insert a cast so we have the same type as the old alloc. auto resultCast = rewriter.create(alloc.getLoc(), alloc.getType(), newAlloc); rewriter.replaceOp(alloc, {resultCast}); return success(); } }; /// Fold alloc operations with no users or only store and dealloc uses. template struct SimplifyDeadAlloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(T alloc, PatternRewriter &rewriter) const override { if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { if (auto storeOp = dyn_cast(op)) return storeOp.value() == alloc; return !isa(op); })) return failure(); for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) rewriter.eraseOp(user); rewriter.eraseOp(alloc); return success(); } }; } // namespace void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, SimplifyDeadAlloc>(context); } void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, SimplifyDeadAlloc>( context); } //===----------------------------------------------------------------------===// // AllocaScopeOp //===----------------------------------------------------------------------===// void AllocaScopeOp::print(OpAsmPrinter &p) { bool printBlockTerminators = false; p << ' '; if (!results().empty()) { p << " -> (" << getResultTypes() << ")"; printBlockTerminators = true; } p << ' '; p.printRegion(bodyRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/printBlockTerminators); p.printOptionalAttrDict((*this)->getAttrs()); } ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { // Create a region for the body. result.regions.reserve(1); Region *bodyRegion = result.addRegion(); // Parse optional results type list. if (parser.parseOptionalArrowTypeList(result.types)) return failure(); // Parse the body region. if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), result.location); // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } void AllocaScopeOp::getSuccessorRegions( Optional index, ArrayRef operands, SmallVectorImpl ®ions) { if (index.hasValue()) { regions.push_back(RegionSuccessor(getResults())); return; } regions.push_back(RegionSuccessor(&bodyRegion())); } /// Given an operation, return whether this op is guaranteed to /// allocate an AutomaticAllocationScopeResource static bool isGuaranteedAutomaticAllocation(Operation *op) { MemoryEffectOpInterface interface = dyn_cast(op); if (!interface) return false; for (auto res : op->getResults()) { if (auto effect = interface.getEffectOnValue(res)) { if (isa( effect->getResource())) return true; } } return false; } /// Given an operation, return whether this op itself could /// allocate an AutomaticAllocationScopeResource. Note that /// this will not check whether an operation contained within /// the op can allocate. static bool isOpItselfPotentialAutomaticAllocation(Operation *op) { // This op itself doesn't create a stack allocation, // the inner allocation should be handled separately. if (op->hasTrait()) return false; MemoryEffectOpInterface interface = dyn_cast(op); if (!interface) return true; for (auto res : op->getResults()) { if (auto effect = interface.getEffectOnValue(res)) { if (isa( effect->getResource())) return true; } } return false; } /// Return whether this op is the last non terminating op /// in a region. That is to say, it is in a one-block region /// and is only followed by a terminator. This prevents /// extending the lifetime of allocations. static bool lastNonTerminatorInRegion(Operation *op) { return op->getNextNode() == op->getBlock()->getTerminator() && op->getParentRegion()->getBlocks().size() == 1; } /// Inline an AllocaScopeOp if either the direct parent is an allocation scope /// or it contains no allocation. struct AllocaScopeInliner : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override { bool hasPotentialAlloca = op->walk([&](Operation *alloc) { if (alloc == op) return WalkResult::advance(); if (isOpItselfPotentialAutomaticAllocation(alloc)) return WalkResult::interrupt(); if (alloc->hasTrait()) return WalkResult::skip(); return WalkResult::advance(); }).wasInterrupted(); // If this contains no potential allocation, it is always legal to // inline. Otherwise, consider two conditions: if (hasPotentialAlloca) { // If the parent isn't an allocation scope, or we are not the last // non-terminator op in the parent, we will extend the lifetime. if (!op->getParentOp()->hasTrait()) return failure(); if (!lastNonTerminatorInRegion(op)) return failure(); } Block *block = &op.getRegion().front(); Operation *terminator = block->getTerminator(); ValueRange results = terminator->getOperands(); rewriter.mergeBlockBefore(block, op); rewriter.replaceOp(op, results); rewriter.eraseOp(terminator); return success(); } }; /// Move allocations into an allocation scope, if it is legal to /// move them (e.g. their operands are available at the location /// the op would be moved to). struct AllocaScopeHoister : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override { if (!op->getParentWithTrait()) return failure(); Operation *lastParentWithoutScope = op->getParentOp(); if (!lastParentWithoutScope || lastParentWithoutScope->hasTrait()) return failure(); // Only apply to if this is this last non-terminator // op in the block (lest lifetime be extended) of a one // block region if (!lastNonTerminatorInRegion(op) || !lastNonTerminatorInRegion(lastParentWithoutScope)) return failure(); while (!lastParentWithoutScope->getParentOp() ->hasTrait()) { lastParentWithoutScope = lastParentWithoutScope->getParentOp(); if (!lastParentWithoutScope || !lastNonTerminatorInRegion(lastParentWithoutScope)) return failure(); } assert(lastParentWithoutScope->getParentOp() ->hasTrait()); Region *containingRegion = nullptr; for (auto &r : lastParentWithoutScope->getRegions()) { if (r.isAncestor(op->getParentRegion())) { assert(containingRegion == nullptr && "only one region can contain the op"); containingRegion = &r; } } assert(containingRegion && "op must be contained in a region"); SmallVector toHoist; op->walk([&](Operation *alloc) { if (!isGuaranteedAutomaticAllocation(alloc)) return WalkResult::skip(); // If any operand is not defined before the location of // lastParentWithoutScope (i.e. where we would hoist to), skip. if (llvm::any_of(alloc->getOperands(), [&](Value v) { return containingRegion->isAncestor(v.getParentRegion()); })) return WalkResult::skip(); toHoist.push_back(alloc); return WalkResult::advance(); }); if (toHoist.empty()) return failure(); rewriter.setInsertionPoint(lastParentWithoutScope); for (auto *op : toHoist) { auto *cloned = rewriter.clone(*op); rewriter.replaceOp(op, cloned->getResults()); } return success(); } }; void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // AssumeAlignmentOp //===----------------------------------------------------------------------===// LogicalResult AssumeAlignmentOp::verify() { if (!llvm::isPowerOf2_32(alignment())) return emitOpError("alignment must be power of 2"); return success(); } //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// /// Determines whether MemRef_CastOp casts to a more dynamic version of the /// source memref. This is useful to to fold a memref.cast into a consuming op /// and implement canonicalization patterns for ops in different dialects that /// may consume the results of memref.cast operations. Such foldable memref.cast /// operations are typically inserted as `view` and `subview` ops are /// canonicalized, to preserve the type compatibility of their uses. /// /// Returns true when all conditions are met: /// 1. source and result are ranked memrefs with strided semantics and same /// element type and rank. /// 2. each of the source's size, offset or stride has more static information /// than the corresponding result's size, offset or stride. /// /// Example 1: /// ```mlir /// %1 = memref.cast %0 : memref<8x16xf32> to memref /// %2 = consumer %1 ... : memref ... /// ``` /// /// may fold into: /// /// ```mlir /// %2 = consumer %0 ... : memref<8x16xf32> ... /// ``` /// /// Example 2: /// ``` /// %1 = memref.cast %0 : memref(16 * i + j)>> /// to memref /// consumer %1 : memref ... /// ``` /// /// may fold into: /// /// ``` /// consumer %0 ... : memref(16 * i + j)>> /// ``` bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { MemRefType sourceType = castOp.source().getType().dyn_cast(); MemRefType resultType = castOp.getType().dyn_cast(); // Requires ranked MemRefType. if (!sourceType || !resultType) return false; // Requires same elemental type. if (sourceType.getElementType() != resultType.getElementType()) return false; // Requires same rank. if (sourceType.getRank() != resultType.getRank()) return false; // Only fold casts between strided memref forms. int64_t sourceOffset, resultOffset; SmallVector sourceStrides, resultStrides; if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) return false; // If cast is towards more static sizes along any dimension, don't fold. for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) return false; } // If cast is towards more static offset along any dimension, don't fold. if (sourceOffset != resultOffset) if (ShapedType::isDynamicStrideOrOffset(sourceOffset) && !ShapedType::isDynamicStrideOrOffset(resultOffset)) return false; // If cast is towards more static strides along any dimension, don't fold. for (auto it : llvm::zip(sourceStrides, resultStrides)) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) if (ShapedType::isDynamicStrideOrOffset(ss) && !ShapedType::isDynamicStrideOrOffset(st)) return false; } return true; } bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (inputs.size() != 1 || outputs.size() != 1) return false; Type a = inputs.front(), b = outputs.front(); auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); auto uaT = a.dyn_cast(); auto ubT = b.dyn_cast(); if (aT && bT) { if (aT.getElementType() != bT.getElementType()) return false; if (aT.getLayout() != bT.getLayout()) { int64_t aOffset, bOffset; SmallVector aStrides, bStrides; if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || failed(getStridesAndOffset(bT, bStrides, bOffset)) || aStrides.size() != bStrides.size()) return false; // Strides along a dimension/offset are compatible if the value in the // source memref is static and the value in the target memref is the // same. They are also compatible if either one is dynamic (see // description of MemRefCastOp for details). auto checkCompatible = [](int64_t a, int64_t b) { return (a == MemRefType::getDynamicStrideOrOffset() || b == MemRefType::getDynamicStrideOrOffset() || a == b); }; if (!checkCompatible(aOffset, bOffset)) return false; for (const auto &aStride : enumerate(aStrides)) if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) return false; } if (aT.getMemorySpace() != bT.getMemorySpace()) return false; // They must have the same rank, and any specified dimensions must match. if (aT.getRank() != bT.getRank()) return false; for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); if (aDim != -1 && bDim != -1 && aDim != bDim) return false; } return true; } else { if (!aT && !uaT) return false; if (!bT && !ubT) return false; // Unranked to unranked casting is unsupported if (uaT && ubT) return false; auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); if (aEltType != bEltType) return false; auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); return aMemSpace == bMemSpace; } return false; } OpFoldResult CastOp::fold(ArrayRef operands) { return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// namespace { /// If the source/target of a CopyOp is a CastOp that does not modify the shape /// and element type, the cast can be skipped. Such CastOps only cast the layout /// of the type. struct FoldCopyOfCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CopyOp copyOp, PatternRewriter &rewriter) const override { bool modified = false; // Check source. if (auto castOp = copyOp.source().getDefiningOp()) { auto fromType = castOp.source().getType().dyn_cast(); auto toType = castOp.source().getType().dyn_cast(); if (fromType && toType) { if (fromType.getShape() == toType.getShape() && fromType.getElementType() == toType.getElementType()) { rewriter.updateRootInPlace( copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); }); modified = true; } } } // Check target. if (auto castOp = copyOp.target().getDefiningOp()) { auto fromType = castOp.source().getType().dyn_cast(); auto toType = castOp.source().getType().dyn_cast(); if (fromType && toType) { if (fromType.getShape() == toType.getShape() && fromType.getElementType() == toType.getElementType()) { rewriter.updateRootInPlace( copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); }); modified = true; } } } return success(modified); } }; /// Fold memref.copy(%x, %x). struct FoldSelfCopy : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CopyOp copyOp, PatternRewriter &rewriter) const override { if (copyOp.source() != copyOp.target()) return failure(); rewriter.eraseOp(copyOp); return success(); } }; } // namespace void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } LogicalResult CopyOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// copy(memrefcast) -> copy bool folded = false; Operation *op = *this; for (OpOperand &operand : op->getOpOperands()) { auto castOp = operand.get().getDefiningOp(); if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); folded = true; } } return success(folded); } //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// LogicalResult DeallocOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dealloc(memrefcast) -> dealloc return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // DimOp //===----------------------------------------------------------------------===// void DimOp::build(OpBuilder &builder, OperationState &result, Value source, int64_t index) { auto loc = result.location; Value indexValue = builder.create(loc, index); build(builder, result, source, indexValue); } void DimOp::build(OpBuilder &builder, OperationState &result, Value source, Value index) { auto indexTy = builder.getIndexType(); build(builder, result, indexTy, source, index); } Optional DimOp::getConstantIndex() { if (auto constantOp = index().getDefiningOp()) return constantOp.getValue().cast().getInt(); return {}; } LogicalResult DimOp::verify() { // Assume unknown index to be in range. Optional index = getConstantIndex(); if (!index.hasValue()) return success(); // Check that constant index is not knowingly out of range. auto type = source().getType(); if (auto memrefType = type.dyn_cast()) { if (index.getValue() >= memrefType.getRank()) return emitOpError("index is out of range"); } else if (type.isa()) { // Assume index to be in range. } else { llvm_unreachable("expected operand with memref type"); } return success(); } /// Return a map with key being elements in `vals` and data being number of /// occurences of it. Use std::map, since the `vals` here are strides and the /// dynamic stride value is the same as the tombstone value for /// `DenseMap`. static std::map getNumOccurences(ArrayRef vals) { std::map numOccurences; for (auto val : vals) numOccurences[val]++; return numOccurences; } /// Given the `originalType` and a `candidateReducedType` whose shape is assumed /// to be a subset of `originalType` with some `1` entries erased, return the /// set of indices that specifies which of the entries of `originalShape` are /// dropped to obtain `reducedShape`. /// This accounts for cases where there are multiple unit-dims, but only a /// subset of those are dropped. For MemRefTypes these can be disambiguated /// using the strides. If a dimension is dropped the stride must be dropped too. static llvm::Optional computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef sizes) { llvm::SmallBitVector unusedDims(originalType.getRank()); if (originalType.getRank() == reducedType.getRank()) return unusedDims; for (const auto &dim : llvm::enumerate(sizes)) if (auto attr = dim.value().dyn_cast()) if (attr.cast().getInt() == 1) unusedDims.set(dim.index()); // Early exit for the case where the number of unused dims matches the number // of ranks reduced. if (static_cast(unusedDims.count()) + reducedType.getRank() == originalType.getRank()) return unusedDims; SmallVector originalStrides, candidateStrides; int64_t originalOffset, candidateOffset; if (failed( getStridesAndOffset(originalType, originalStrides, originalOffset)) || failed( getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) return llvm::None; // For memrefs, a dimension is truly dropped if its corresponding stride is // also dropped. This is particularly important when more than one of the dims // is 1. Track the number of occurences of the strides in the original type // and the candidate type. For each unused dim that stride should not be // present in the candidate type. Note that there could be multiple dimensions // that have the same size. We dont need to exactly figure out which dim // corresponds to which stride, we just need to verify that the number of // reptitions of a stride in the original + number of unused dims with that // stride == number of repititions of a stride in the candidate. std::map currUnaccountedStrides = getNumOccurences(originalStrides); std::map candidateStridesNumOccurences = getNumOccurences(candidateStrides); for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) { if (!unusedDims.test(dim)) continue; int64_t originalStride = originalStrides[dim]; if (currUnaccountedStrides[originalStride] > candidateStridesNumOccurences[originalStride]) { // This dim can be treated as dropped. currUnaccountedStrides[originalStride]--; continue; } if (currUnaccountedStrides[originalStride] == candidateStridesNumOccurences[originalStride]) { // The stride for this is not dropped. Keep as is. unusedDims.reset(dim); continue; } if (currUnaccountedStrides[originalStride] < candidateStridesNumOccurences[originalStride]) { // This should never happen. Cant have a stride in the reduced rank type // that wasnt in the original one. return llvm::None; } } if ((int64_t)unusedDims.count() + reducedType.getRank() != originalType.getRank()) return llvm::None; return unusedDims; } llvm::SmallBitVector SubViewOp::getDroppedDims() { MemRefType sourceType = getSourceType(); MemRefType resultType = getType(); llvm::Optional unusedDims = computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); assert(unusedDims && "unable to find unused dims of subview"); return *unusedDims; } OpFoldResult DimOp::fold(ArrayRef operands) { // All forms of folding require a known index. auto index = operands[1].dyn_cast_or_null(); if (!index) return {}; // Folding for unranked types (UnrankedMemRefType) is not supported. auto memrefType = source().getType().dyn_cast(); if (!memrefType) return {}; // Fold if the shape extent along the given index is known. if (!memrefType.isDynamicDim(index.getInt())) { Builder builder(getContext()); return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); } // The size at the given index is now known to be a dynamic size. unsigned unsignedIndex = index.getValue().getZExtValue(); // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. Operation *definingOp = source().getDefiningOp(); if (auto alloc = dyn_cast_or_null(definingOp)) return *(alloc.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); if (auto alloca = dyn_cast_or_null(definingOp)) return *(alloca.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); if (auto view = dyn_cast_or_null(definingOp)) return *(view.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); if (auto subview = dyn_cast_or_null(definingOp)) { llvm::SmallBitVector unusedDims = subview.getDroppedDims(); unsigned resultIndex = 0; unsigned sourceRank = subview.getSourceType().getRank(); unsigned sourceIndex = 0; for (auto i : llvm::seq(0, sourceRank)) { if (unusedDims.test(i)) continue; if (resultIndex == unsignedIndex) { sourceIndex = i; break; } resultIndex++; } assert(subview.isDynamicSize(sourceIndex) && "expected dynamic subview size"); return subview.getDynamicSize(sourceIndex); } if (auto sizeInterface = dyn_cast_or_null(definingOp)) { assert(sizeInterface.isDynamicSize(unsignedIndex) && "Expected dynamic subview size"); return sizeInterface.getDynamicSize(unsignedIndex); } // dim(memrefcast) -> dim if (succeeded(foldMemRefCast(*this))) return getResult(); return {}; } namespace { /// Fold dim of a memref reshape operation to a load into the reshape's shape /// operand. struct DimOfMemRefReshape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DimOp dim, PatternRewriter &rewriter) const override { auto reshape = dim.source().getDefiningOp(); if (!reshape) return failure(); // Place the load directly after the reshape to ensure that the shape memref // was not mutated. rewriter.setInsertionPointAfter(reshape); Location loc = dim.getLoc(); Value load = rewriter.create(loc, reshape.shape(), dim.index()); if (load.getType() != dim.getType()) load = rewriter.create(loc, dim.getType(), load); rewriter.replaceOp(dim, load); return success(); } }; } // namespace void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } // --------------------------------------------------------------------------- // DmaStartOp // --------------------------------------------------------------------------- void DmaStartOp::build(OpBuilder &builder, OperationState &result, Value srcMemRef, ValueRange srcIndices, Value destMemRef, ValueRange destIndices, Value numElements, Value tagMemRef, ValueRange tagIndices, Value stride, Value elementsPerStride) { result.addOperands(srcMemRef); result.addOperands(srcIndices); result.addOperands(destMemRef); result.addOperands(destIndices); result.addOperands({numElements, tagMemRef}); result.addOperands(tagIndices); if (stride) result.addOperands({stride, elementsPerStride}); } void DmaStartOp::print(OpAsmPrinter &p) { p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; if (isStrided()) p << ", " << getStride() << ", " << getNumElementsPerStride(); p.printOptionalAttrDict((*this)->getAttrs()); p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() << ", " << getTagMemRef().getType(); } // Parse DmaStartOp. // Ex: // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, // %tag[%index], %stride, %num_elt_per_stride : // : memref<3076 x f32, 0>, // memref<1024 x f32, 2>, // memref<1 x i32> // ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand srcMemRefInfo; SmallVector srcIndexInfos; OpAsmParser::UnresolvedOperand dstMemRefInfo; SmallVector dstIndexInfos; OpAsmParser::UnresolvedOperand numElementsInfo; OpAsmParser::UnresolvedOperand tagMemrefInfo; SmallVector tagIndexInfos; SmallVector strideInfo; SmallVector types; auto indexType = parser.getBuilder().getIndexType(); // Parse and resolve the following list of operands: // *) source memref followed by its indices (in square brackets). // *) destination memref followed by its indices (in square brackets). // *) dma size in KiB. if (parser.parseOperand(srcMemRefInfo) || parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(dstMemRefInfo) || parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(numElementsInfo) || parser.parseComma() || parser.parseOperand(tagMemrefInfo) || parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) return failure(); // Parse optional stride and elements per stride. if (parser.parseTrailingOperandList(strideInfo)) return failure(); bool isStrided = strideInfo.size() == 2; if (!strideInfo.empty() && !isStrided) { return parser.emitError(parser.getNameLoc(), "expected two stride related operands"); } if (parser.parseColonTypeList(types)) return failure(); if (types.size() != 3) return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || parser.resolveOperands(srcIndexInfos, indexType, result.operands) || parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || parser.resolveOperands(dstIndexInfos, indexType, result.operands) || // size should be an index. parser.resolveOperand(numElementsInfo, indexType, result.operands) || parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || // tag indices should be index. parser.resolveOperands(tagIndexInfos, indexType, result.operands)) return failure(); if (isStrided) { if (parser.resolveOperands(strideInfo, indexType, result.operands)) return failure(); } return success(); } LogicalResult DmaStartOp::verify() { unsigned numOperands = getNumOperands(); // Mandatory non-variadic operands are: src memref, dst memref, tag memref and // the number of elements. if (numOperands < 4) return emitOpError("expected at least 4 operands"); // Check types of operands. The order of these calls is important: the later // calls rely on some type properties to compute the operand position. // 1. Source memref. if (!getSrcMemRef().getType().isa()) return emitOpError("expected source to be of memref type"); if (numOperands < getSrcMemRefRank() + 4) return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 << " operands"; if (!getSrcIndices().empty() && !llvm::all_of(getSrcIndices().getTypes(), [](Type t) { return t.isIndex(); })) return emitOpError("expected source indices to be of index type"); // 2. Destination memref. if (!getDstMemRef().getType().isa()) return emitOpError("expected destination to be of memref type"); unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; if (numOperands < numExpectedOperands) return emitOpError() << "expected at least " << numExpectedOperands << " operands"; if (!getDstIndices().empty() && !llvm::all_of(getDstIndices().getTypes(), [](Type t) { return t.isIndex(); })) return emitOpError("expected destination indices to be of index type"); // 3. Number of elements. if (!getNumElements().getType().isIndex()) return emitOpError("expected num elements to be of index type"); // 4. Tag memref. if (!getTagMemRef().getType().isa()) return emitOpError("expected tag to be of memref type"); numExpectedOperands += getTagMemRefRank(); if (numOperands < numExpectedOperands) return emitOpError() << "expected at least " << numExpectedOperands << " operands"; if (!getTagIndices().empty() && !llvm::all_of(getTagIndices().getTypes(), [](Type t) { return t.isIndex(); })) return emitOpError("expected tag indices to be of index type"); // Optional stride-related operands must be either both present or both // absent. if (numOperands != numExpectedOperands && numOperands != numExpectedOperands + 2) return emitOpError("incorrect number of operands"); // 5. Strides. if (isStrided()) { if (!getStride().getType().isIndex() || !getNumElementsPerStride().getType().isIndex()) return emitOpError( "expected stride and num elements per stride to be of type index"); } return success(); } LogicalResult DmaStartOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dma_start(memrefcast) -> dma_start return foldMemRefCast(*this); } // --------------------------------------------------------------------------- // DmaWaitOp // --------------------------------------------------------------------------- LogicalResult DmaWaitOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dma_wait(memrefcast) -> dma_wait return foldMemRefCast(*this); } LogicalResult DmaWaitOp::verify() { // Check that the number of tag indices matches the tagMemRef rank. unsigned numTagIndices = tagIndices().size(); unsigned tagMemRefRank = getTagMemRefRank(); if (numTagIndices != tagMemRefRank) return emitOpError() << "expected tagIndices to have the same number of " "elements as the tagMemRef rank, expected " << tagMemRefRank << ", but got " << numTagIndices; return success(); } //===----------------------------------------------------------------------===// // GenericAtomicRMWOp //===----------------------------------------------------------------------===// void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, Value memref, ValueRange ivs) { result.addOperands(memref); result.addOperands(ivs); if (auto memrefType = memref.getType().dyn_cast()) { Type elementType = memrefType.getElementType(); result.addTypes(elementType); Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block()); bodyRegion->addArgument(elementType, memref.getLoc()); } } LogicalResult GenericAtomicRMWOp::verify() { auto &body = getRegion(); if (body.getNumArguments() != 1) return emitOpError("expected single number of entry block arguments"); if (getResult().getType() != body.getArgument(0).getType()) return emitOpError("expected block argument of the same type result type"); bool hasSideEffects = body.walk([&](Operation *nestedOp) { if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) return WalkResult::advance(); nestedOp->emitError( "body of 'memref.generic_atomic_rmw' should contain " "only operations with no side effects"); return WalkResult::interrupt(); }) .wasInterrupted(); return hasSideEffects ? failure() : success(); } ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand memref; Type memrefType; SmallVector ivs; Type indexType = parser.getBuilder().getIndexType(); if (parser.parseOperand(memref) || parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || parser.parseColonType(memrefType) || parser.resolveOperand(memref, memrefType, result.operands) || parser.resolveOperands(ivs, indexType, result.operands)) return failure(); Region *body = result.addRegion(); if (parser.parseRegion(*body, llvm::None, llvm::None) || parser.parseOptionalAttrDict(result.attributes)) return failure(); result.types.push_back(memrefType.cast().getElementType()); return success(); } void GenericAtomicRMWOp::print(OpAsmPrinter &p) { p << ' ' << memref() << "[" << indices() << "] : " << memref().getType() << ' '; p.printRegion(getRegion()); p.printOptionalAttrDict((*this)->getAttrs()); } //===----------------------------------------------------------------------===// // AtomicYieldOp //===----------------------------------------------------------------------===// LogicalResult AtomicYieldOp::verify() { Type parentType = (*this)->getParentOp()->getResultTypes().front(); Type resultType = result().getType(); if (parentType != resultType) return emitOpError() << "types mismatch between yield op: " << resultType << " and its parent: " << parentType; return success(); } //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue) { p << type; if (!op.isExternal()) { p << " = "; if (op.isUninitialized()) p << "uninitialized"; else p.printAttributeWithoutType(initialValue); } } static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue) { Type type; if (parser.parseType(type)) return failure(); auto memrefType = type.dyn_cast(); if (!memrefType || !memrefType.hasStaticShape()) return parser.emitError(parser.getNameLoc()) << "type should be static shaped memref, but got " << type; typeAttr = TypeAttr::get(type); if (parser.parseOptionalEqual()) return success(); if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { initialValue = UnitAttr::get(parser.getContext()); return success(); } Type tensorType = getTensorTypeFromMemRefType(memrefType); if (parser.parseAttribute(initialValue, tensorType)) return failure(); if (!initialValue.isa()) return parser.emitError(parser.getNameLoc()) << "initial value should be a unit or elements attribute"; return success(); } LogicalResult GlobalOp::verify() { auto memrefType = type().dyn_cast(); if (!memrefType || !memrefType.hasStaticShape()) return emitOpError("type should be static shaped memref, but got ") << type(); // Verify that the initial value, if present, is either a unit attribute or // an elements attribute. if (initial_value().hasValue()) { Attribute initValue = initial_value().getValue(); if (!initValue.isa() && !initValue.isa()) return emitOpError("initial value should be a unit or elements " "attribute, but got ") << initValue; // Check that the type of the initial value is compatible with the type of // the global variable. if (initValue.isa()) { Type initType = initValue.getType(); Type tensorType = getTensorTypeFromMemRefType(memrefType); if (initType != tensorType) return emitOpError("initial value expected to be of type ") << tensorType << ", but was of type " << initType; } } if (Optional alignAttr = alignment()) { uint64_t alignment = alignAttr.getValue(); if (!llvm::isPowerOf2_64(alignment)) return emitError() << "alignment attribute value " << alignment << " is not a power of 2"; } // TODO: verify visibility for declarations. return success(); } ElementsAttr GlobalOp::getConstantInitValue() { auto initVal = initial_value(); if (constant() && initVal.hasValue()) return initVal.getValue().cast(); return {}; } //===----------------------------------------------------------------------===// // GetGlobalOp //===----------------------------------------------------------------------===// LogicalResult GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Verify that the result type is same as the type of the referenced // memref.global op. auto global = symbolTable.lookupNearestSymbolFrom(*this, nameAttr()); if (!global) return emitOpError("'") << name() << "' does not reference a valid global memref"; Type resultType = result().getType(); if (global.type() != resultType) return emitOpError("result type ") << resultType << " does not match type " << global.type() << " of the global memref @" << name(); return success(); } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// LogicalResult LoadOp::verify() { if (getNumOperands() != 1 + getMemRefType().getRank()) return emitOpError("incorrect number of indices for load"); return success(); } OpFoldResult LoadOp::fold(ArrayRef cstOperands) { /// load(memrefcast) -> load if (succeeded(foldMemRefCast(*this))) return getResult(); return OpFoldResult(); } //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// void PrefetchOp::print(OpAsmPrinter &p) { p << " " << memref() << '['; p.printOperands(indices()); p << ']' << ", " << (isWrite() ? "write" : "read"); p << ", locality<" << localityHint(); p << ">, " << (isDataCache() ? "data" : "instr"); p.printOptionalAttrDict( (*this)->getAttrs(), /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); p << " : " << getMemRefType(); } ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand memrefInfo; SmallVector indexInfo; IntegerAttr localityHint; MemRefType type; StringRef readOrWrite, cacheType; auto indexTy = parser.getBuilder().getIndexType(); auto i32Type = parser.getBuilder().getIntegerType(32); if (parser.parseOperand(memrefInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseKeyword(&readOrWrite) || parser.parseComma() || parser.parseKeyword("locality") || parser.parseLess() || parser.parseAttribute(localityHint, i32Type, "localityHint", result.attributes) || parser.parseGreater() || parser.parseComma() || parser.parseKeyword(&cacheType) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperands(indexInfo, indexTy, result.operands)) return failure(); if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) return parser.emitError(parser.getNameLoc(), "rw specifier has to be 'read' or 'write'"); result.addAttribute( PrefetchOp::getIsWriteAttrName(), parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); if (!cacheType.equals("data") && !cacheType.equals("instr")) return parser.emitError(parser.getNameLoc(), "cache type has to be 'data' or 'instr'"); result.addAttribute( PrefetchOp::getIsDataCacheAttrName(), parser.getBuilder().getBoolAttr(cacheType.equals("data"))); return success(); } LogicalResult PrefetchOp::verify() { if (getNumOperands() != 1 + getMemRefType().getRank()) return emitOpError("too few indices"); return success(); } LogicalResult PrefetchOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { // prefetch(memrefcast) -> prefetch return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// OpFoldResult RankOp::fold(ArrayRef operands) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); auto shapedType = type.dyn_cast(); if (shapedType && shapedType.hasRank()) return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); return IntegerAttr(); } //===----------------------------------------------------------------------===// // ReinterpretCastOp //===----------------------------------------------------------------------===// /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, /// `staticSizes` and `staticStrides` are automatically filled with /// source-memref-rank sentinel values that encode dynamic entries. void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, OpFoldResult offset, ArrayRef sizes, ArrayRef strides, ArrayRef attrs) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); build(b, result, resultType, source, dynamicOffsets, dynamicSizes, dynamicStrides, b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, int64_t offset, ArrayRef sizes, ArrayRef strides, ArrayRef attrs) { SmallVector sizeValues = llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); })); SmallVector strideValues = llvm::to_vector<4>( llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); })); build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, strideValues, attrs); } void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, Value offset, ValueRange sizes, ValueRange strides, ArrayRef attrs) { SmallVector sizeValues = llvm::to_vector<4>( llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); SmallVector strideValues = llvm::to_vector<4>( llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); } // TODO: ponder whether we want to allow missing trailing sizes/strides that are // completed automatically, like we have for subview and extract_slice. LogicalResult ReinterpretCastOp::verify() { // The source and result memrefs should be in the same memory space. auto srcType = source().getType().cast(); auto resultType = getType().cast(); if (srcType.getMemorySpace() != resultType.getMemorySpace()) return emitError("different memory spaces specified for source type ") << srcType << " and result memref type " << resultType; if (srcType.getElementType() != resultType.getElementType()) return emitError("different element types specified for source type ") << srcType << " and result memref type " << resultType; // Match sizes in result memref type and in static_sizes attribute. for (auto &en : llvm::enumerate(llvm::zip( resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) { int64_t resultSize = std::get<0>(en.value()); int64_t expectedSize = std::get<1>(en.value()); if (!ShapedType::isDynamic(resultSize) && !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) return emitError("expected result type with size = ") << expectedSize << " instead of " << resultSize << " in dim = " << en.index(); } // Match offset and strides in static_offset and static_strides attributes. If // result memref type has no affine map specified, this will assume an // identity layout. int64_t resultOffset; SmallVector resultStrides; if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) return emitError("expected result type to have strided layout but found ") << resultType; // Match offset in result memref type and in static_offsets attribute. int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front(); if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && !ShapedType::isDynamicStrideOrOffset(expectedOffset) && resultOffset != expectedOffset) return emitError("expected result type with offset = ") << resultOffset << " instead of " << expectedOffset; // Match strides in result memref type and in static_strides attribute. for (auto &en : llvm::enumerate(llvm::zip( resultStrides, extractFromI64ArrayAttr(static_strides())))) { int64_t resultStride = std::get<0>(en.value()); int64_t expectedStride = std::get<1>(en.value()); if (!ShapedType::isDynamicStrideOrOffset(resultStride) && !ShapedType::isDynamicStrideOrOffset(expectedStride) && resultStride != expectedStride) return emitError("expected result type with stride = ") << expectedStride << " instead of " << resultStride << " in dim = " << en.index(); } return success(); } OpFoldResult ReinterpretCastOp::fold(ArrayRef /*operands*/) { Value src = source(); auto getPrevSrc = [&]() -> Value { // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x). if (auto prev = src.getDefiningOp()) return prev.source(); // reinterpret_cast(cast(x)) -> reinterpret_cast(x). if (auto prev = src.getDefiningOp()) return prev.source(); // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets // are 0. if (auto prev = src.getDefiningOp()) if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) { return isConstantIntValue(val, 0); })) return prev.source(); return nullptr; }; if (auto prevSrc = getPrevSrc()) { sourceMutable().assign(prevSrc); return getResult(); } return nullptr; } //===----------------------------------------------------------------------===// // Reassociative reshape ops //===----------------------------------------------------------------------===// /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp /// result and operand. Layout maps are verified separately. /// /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are /// allowed in a reassocation group. static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef collapsedShape, ArrayRef expandedShape, ArrayRef reassociation, bool allowMultipleDynamicDimsPerGroup) { // There must be one reassociation group per collapsed dimension. if (collapsedShape.size() != reassociation.size()) return op->emitOpError("invalid number of reassociation groups: found ") << reassociation.size() << ", expected " << collapsedShape.size(); // The next expected expanded dimension index (while iterating over // reassociation indices). int64_t nextDim = 0; for (const auto &it : llvm::enumerate(reassociation)) { ReassociationIndices group = it.value(); int64_t collapsedDim = it.index(); bool foundDynamic = false; for (int64_t expandedDim : group) { if (expandedDim != nextDim++) return op->emitOpError("reassociation indices must be contiguous"); if (expandedDim >= static_cast(expandedShape.size())) return op->emitOpError("reassociation index ") << expandedDim << " is out of bounds"; // Check if there are multiple dynamic dims in a reassociation group. if (ShapedType::isDynamic(expandedShape[expandedDim])) { if (foundDynamic && !allowMultipleDynamicDimsPerGroup) return op->emitOpError( "at most one dimension in a reassociation group may be dynamic"); foundDynamic = true; } } // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity. if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic) return op->emitOpError("collapsed dim (") << collapsedDim << ") must be dynamic if and only if reassociation group is " "dynamic"; // If all dims in the reassociation group are static, the size of the // collapsed dim can be verified. if (!foundDynamic) { int64_t groupSize = 1; for (int64_t expandedDim : group) groupSize *= expandedShape[expandedDim]; if (groupSize != collapsedShape[collapsedDim]) return op->emitOpError("collapsed dim size (") << collapsedShape[collapsedDim] << ") must equal reassociation group size (" << groupSize << ")"; } } if (collapsedShape.empty()) { // Rank 0: All expanded dimensions must be 1. for (int64_t d : expandedShape) if (d != 1) return op->emitOpError( "rank 0 memrefs can only be extended/collapsed with/from ones"); } else if (nextDim != static_cast(expandedShape.size())) { // Rank >= 1: Number of dimensions among all reassociation groups must match // the result memref rank. return op->emitOpError("expanded rank (") << expandedShape.size() << ") inconsistent with number of reassociation indices (" << nextDim << ")"; } return success(); } SmallVector CollapseShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } SmallVector CollapseShapeOp::getReassociationExprs() { return convertReassociationIndicesToExprs(getContext(), getReassociationIndices()); } SmallVector ExpandShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } SmallVector ExpandShapeOp::getReassociationExprs() { return convertReassociationIndicesToExprs(getContext(), getReassociationIndices()); } /// Compute the layout map after expanding a given source MemRef type with the /// specified reassociation indices. static FailureOr computeExpandedLayoutMap(MemRefType srcType, ArrayRef resultShape, ArrayRef reassociation) { int64_t srcOffset; SmallVector srcStrides; if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) return failure(); assert(srcStrides.size() == reassociation.size() && "invalid reassociation"); // 1-1 mapping between srcStrides and reassociation packs. // Each srcStride starts with the given value and gets expanded according to // the proper entries in resultShape. // Example: // srcStrides = [10000, 1 , 100 ], // reassociations = [ [0], [1], [2, 3, 4]], // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]] // -> For the purpose of stride calculation, the useful sizes are: // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]]. // resultStrides = [10000, 1, 600, 200, 100] // Note that a stride does not get expanded along the first entry of each // shape pack. SmallVector reverseResultStrides; reverseResultStrides.reserve(resultShape.size()); unsigned shapeIndex = resultShape.size() - 1; for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) { ReassociationIndices reassoc = std::get<0>(it); int64_t currentStrideToExpand = std::get<1>(it); for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) { using saturated_arith::Wrapper; reverseResultStrides.push_back(currentStrideToExpand); currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) * Wrapper::size(resultShape[shapeIndex--])) .asStride(); } } return makeStridedLinearLayoutMap( llvm::to_vector<8>(llvm::reverse(reverseResultStrides)), srcOffset, srcType.getContext()); } static FailureOr computeExpandedType(MemRefType srcType, ArrayRef resultShape, ArrayRef reassociation) { if (srcType.getLayout().isIdentity()) { // If the source is contiguous (i.e., no layout map specified), so is the // result. MemRefLayoutAttrInterface layout; return MemRefType::get(resultShape, srcType.getElementType(), layout, srcType.getMemorySpace()); } // Source may not be contiguous. Compute the layout map. FailureOr computedLayout = computeExpandedLayoutMap(srcType, resultShape, reassociation); if (failed(computedLayout)) return failure(); auto computedType = MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, srcType.getMemorySpaceAsInt()); return canonicalizeStridedLayout(computedType); } void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, ArrayRef resultShape, Value src, ArrayRef reassociation) { // Only ranked memref source values are supported. auto srcType = src.getType().cast(); FailureOr resultType = computeExpandedType(srcType, resultShape, reassociation); // Failure of this assertion usually indicates a problem with the source // type, e.g., could not get strides/offset. assert(succeeded(resultType) && "could not compute layout"); build(builder, result, *resultType, src, reassociation); } LogicalResult ExpandShapeOp::verify() { MemRefType srcType = getSrcType(); MemRefType resultType = getResultType(); // Verify result shape. if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(), resultType.getShape(), getReassociationIndices(), /*allowMultipleDynamicDimsPerGroup=*/false))) return failure(); // Compute expected result type (including layout map). FailureOr expectedResultType = computeExpandedType( srcType, resultType.getShape(), getReassociationIndices()); if (failed(expectedResultType)) return emitOpError("invalid source layout map"); // Check actual result type. auto canonicalizedResultType = canonicalizeStridedLayout(resultType); if (*expectedResultType != canonicalizedResultType) return emitOpError("expected expanded type to be ") << *expectedResultType << " but found " << canonicalizedResultType; return success(); } void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, ComposeExpandOfCollapseOp>( context); } /// Compute the layout map after collapsing a given source MemRef type with the /// specified reassociation indices. /// /// Note: All collapsed dims in a reassociation group must be contiguous. It is /// not possible to check this by inspecting a MemRefType in the general case. /// If non-contiguity cannot be checked statically, the collapse is assumed to /// be valid (and thus accepted by this function) unless `strict = true`. static FailureOr computeCollapsedLayoutMap(MemRefType srcType, ArrayRef reassociation, bool strict = false) { int64_t srcOffset; SmallVector srcStrides; auto srcShape = srcType.getShape(); if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) return failure(); // The result strides are exactly the strides of the last entry of each // reassociation. SmallVector resultStrides; resultStrides.reserve(reassociation.size()); for (ReassociationIndices reassoc : reassociation) resultStrides.push_back(srcStrides[reassoc.back()]); // Validate that each reassociation group is contiguous. unsigned resultStrideIndex = resultStrides.size() - 1; for (ReassociationIndices reassoc : llvm::reverse(reassociation)) { auto trailingReassocs = ArrayRef(reassoc).drop_front(); using saturated_arith::Wrapper; auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]); for (int64_t idx : llvm::reverse(trailingReassocs)) { stride = stride * Wrapper::size(srcShape[idx]); // Both source and result stride must have the same static value. In that // case, we can be sure, that the dimensions are collapsible (because they // are contiguous). // // One special case is when the srcShape is `1`, in which case it can // never produce non-contiguity. if (srcShape[idx] == 1) continue; // If `strict = false` (default during op verification), we accept cases // where one or both strides are dynamic. This is best effort: We reject // ops where obviously non-contiguous dims are collapsed, but accept ops // where we cannot be sure statically. Such ops may fail at runtime. See // the op documentation for details. auto srcStride = Wrapper::stride(srcStrides[idx - 1]); if (strict && (stride.saturated || srcStride.saturated)) return failure(); if (!stride.saturated && !srcStride.saturated && stride != srcStride) return failure(); } } return makeStridedLinearLayoutMap(resultStrides, srcOffset, srcType.getContext()); } bool ExpandShapeOp::isGuaranteedCollapsible( MemRefType srcType, ArrayRef reassociation) { // MemRefs with standard layout are always collapsible. if (srcType.getLayout().isIdentity()) return true; return succeeded(computeCollapsedLayoutMap(srcType, reassociation, /*strict=*/true)); } static MemRefType computeCollapsedType(MemRefType srcType, ArrayRef reassociation) { SmallVector resultShape; resultShape.reserve(reassociation.size()); for (const ReassociationIndices &group : reassociation) { using saturated_arith::Wrapper; auto groupSize = Wrapper::size(1); for (int64_t srcDim : group) groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim)); resultShape.push_back(groupSize.asSize()); } if (srcType.getLayout().isIdentity()) { // If the source is contiguous (i.e., no layout map specified), so is the // result. MemRefLayoutAttrInterface layout; return MemRefType::get(resultShape, srcType.getElementType(), layout, srcType.getMemorySpace()); } // Source may not be fully contiguous. Compute the layout map. // Note: Dimensions that are collapsed into a single dim are assumed to be // contiguous. FailureOr computedLayout = computeCollapsedLayoutMap(srcType, reassociation); assert(succeeded(computedLayout) && "invalid source layout map or collapsing non-contiguous dims"); auto computedType = MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, srcType.getMemorySpaceAsInt()); return canonicalizeStridedLayout(computedType); } void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { auto srcType = src.getType().cast(); MemRefType resultType = computeCollapsedType(srcType, reassociation); build(b, result, resultType, src, attrs); result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); } LogicalResult CollapseShapeOp::verify() { MemRefType srcType = getSrcType(); MemRefType resultType = getResultType(); // Verify result shape. if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(), srcType.getShape(), getReassociationIndices(), /*allowMultipleDynamicDimsPerGroup=*/true))) return failure(); // Compute expected result type (including layout map). MemRefType expectedResultType; if (srcType.getLayout().isIdentity()) { // If the source is contiguous (i.e., no layout map specified), so is the // result. MemRefLayoutAttrInterface layout; expectedResultType = MemRefType::get(resultType.getShape(), srcType.getElementType(), layout, srcType.getMemorySpace()); } else { // Source may not be fully contiguous. Compute the layout map. // Note: Dimensions that are collapsed into a single dim are assumed to be // contiguous. FailureOr computedLayout = computeCollapsedLayoutMap(srcType, getReassociationIndices()); if (failed(computedLayout)) return emitOpError( "invalid source layout map or collapsing non-contiguous dims"); auto computedType = MemRefType::get(resultType.getShape(), srcType.getElementType(), *computedLayout, srcType.getMemorySpaceAsInt()); expectedResultType = canonicalizeStridedLayout(computedType); } auto canonicalizedResultType = canonicalizeStridedLayout(resultType); if (expectedResultType != canonicalizedResultType) return emitOpError("expected collapsed type to be ") << expectedResultType << " but found " << canonicalizedResultType; return success(); } struct CollapseShapeOpMemRefCastFolder : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override { auto cast = op.getOperand().getDefiningOp(); if (!cast) return failure(); if (!CastOp::canFoldIntoConsumerOp(cast)) return failure(); Type newResultType = computeCollapsedType(cast.getOperand().getType().cast(), op.getReassociationIndices()); if (newResultType == op.getResultType()) { rewriter.updateRootInPlace( op, [&]() { op.srcMutable().assign(cast.source()); }); } else { Value newOp = rewriter.create( op->getLoc(), cast.source(), op.getReassociationIndices()); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); } return success(); } }; void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, ComposeCollapseOfExpandOp, CollapseShapeOpMemRefCastFolder>(context); } OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { return foldReshapeOp(*this, operands); } OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { return foldReshapeOp(*this, operands); } //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// LogicalResult ReshapeOp::verify() { Type operandType = source().getType(); Type resultType = result().getType(); Type operandElementType = operandType.cast().getElementType(); Type resultElementType = resultType.cast().getElementType(); if (operandElementType != resultElementType) return emitOpError("element types of source and destination memref " "types should be the same"); if (auto operandMemRefType = operandType.dyn_cast()) if (!operandMemRefType.getLayout().isIdentity()) return emitOpError("source memref type should have identity affine map"); int64_t shapeSize = shape().getType().cast().getDimSize(0); auto resultMemRefType = resultType.dyn_cast(); if (resultMemRefType) { if (!resultMemRefType.getLayout().isIdentity()) return emitOpError("result memref type should have identity affine map"); if (shapeSize == ShapedType::kDynamicSize) return emitOpError("cannot use shape operand with dynamic length to " "reshape to statically-ranked memref type"); if (shapeSize != resultMemRefType.getRank()) return emitOpError( "length of shape operand differs from the result's memref rank"); } return success(); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// LogicalResult StoreOp::verify() { if (getNumOperands() != 2 + getMemRefType().getRank()) return emitOpError("store index operand count not equal to memref rank"); return success(); } LogicalResult StoreOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// store(memrefcast) -> store return foldMemRefCast(*this, getValueToStore()); } //===----------------------------------------------------------------------===// // SubViewOp //===----------------------------------------------------------------------===// /// A subview result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. Type SubViewOp::inferResultType(MemRefType sourceMemRefType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides) { unsigned rank = sourceMemRefType.getRank(); (void)rank; assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); assert(staticSizes.size() == rank && "staticSizes length mismatch"); assert(staticStrides.size() == rank && "staticStrides length mismatch"); // Extract source offset and strides. int64_t sourceOffset; SmallVector sourceStrides; auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); assert(succeeded(res) && "SubViewOp expected strided memref type"); (void)res; // Compute target offset whose value is: // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. int64_t targetOffset = sourceOffset; for (auto it : llvm::zip(staticOffsets, sourceStrides)) { auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); using saturated_arith::Wrapper; targetOffset = (Wrapper::offset(targetOffset) + Wrapper::offset(staticOffset) * Wrapper::stride(targetStride)) .asOffset(); } // Compute target stride whose value is: // `sourceStrides_i * staticStrides_i`. SmallVector targetStrides; targetStrides.reserve(staticOffsets.size()); for (auto it : llvm::zip(sourceStrides, staticStrides)) { auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); using saturated_arith::Wrapper; targetStrides.push_back( (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride)) .asStride()); } // The type is now known. return MemRefType::get( staticSizes, sourceMemRefType.getElementType(), makeStridedLinearLayoutMap(targetStrides, targetOffset, sourceMemRefType.getContext()), sourceMemRefType.getMemorySpace()); } Type SubViewOp::inferResultType(MemRefType sourceMemRefType, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, staticSizes, staticStrides); } Type SubViewOp::inferRankReducedResultType(unsigned resultRank, MemRefType sourceRankedTensorType, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { auto inferredType = inferResultType(sourceRankedTensorType, offsets, sizes, strides) .cast(); assert(inferredType.getRank() >= resultRank && "expected "); int rankDiff = inferredType.getRank() - resultRank; if (rankDiff > 0) { auto shape = inferredType.getShape(); llvm::SmallBitVector dimsToProject = getPositionsOfShapeOne(rankDiff, shape); SmallVector projectedShape; for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) if (!dimsToProject.test(pos)) projectedShape.push_back(shape[pos]); AffineMap map = inferredType.getLayout().getAffineMap(); if (!map.isIdentity()) map = getProjectedMap(map, dimsToProject); inferredType = MemRefType::get(projectedShape, inferredType.getElementType(), map, inferredType.getMemorySpace()); } return inferredType; } Type SubViewOp::inferRankReducedResultType(unsigned resultRank, MemRefType sourceRankedTensorType, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); return SubViewOp::inferRankReducedResultType( resultRank, sourceRankedTensorType, staticOffsets, staticSizes, staticStrides); } // Build a SubViewOp with mixed static and dynamic entries and custom result // type. If the type passed is nullptr, it is inferred. void SubViewOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, ArrayRef offsets, ArrayRef sizes, ArrayRef strides, ArrayRef attrs) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); auto sourceMemRefType = source.getType().cast(); // Structuring implementation this way avoids duplication between builders. if (!resultType) { resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, staticSizes, staticStrides) .cast(); } build(b, result, resultType, source, dynamicOffsets, dynamicSizes, dynamicStrides, b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } // Build a SubViewOp with mixed static and dynamic entries and inferred result // type. void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef offsets, ArrayRef sizes, ArrayRef strides, ArrayRef attrs) { build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); } // Build a SubViewOp with static entries and inferred result type. void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef offsets, ArrayRef sizes, ArrayRef strides, ArrayRef attrs) { SmallVector offsetValues = llvm::to_vector<4>( llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); })); SmallVector sizeValues = llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); })); SmallVector strideValues = llvm::to_vector<4>( llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); })); build(b, result, source, offsetValues, sizeValues, strideValues, attrs); } // Build a SubViewOp with dynamic entries and custom result type. If the // type passed is nullptr, it is inferred. void SubViewOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, ArrayRef offsets, ArrayRef sizes, ArrayRef strides, ArrayRef attrs) { SmallVector offsetValues = llvm::to_vector<4>( llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); })); SmallVector sizeValues = llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); })); SmallVector strideValues = llvm::to_vector<4>( llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); })); build(b, result, resultType, source, offsetValues, sizeValues, strideValues, attrs); } // Build a SubViewOp with dynamic entries and custom result type. If the type // passed is nullptr, it is inferred. void SubViewOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef attrs) { SmallVector offsetValues = llvm::to_vector<4>( llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); SmallVector sizeValues = llvm::to_vector<4>( llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); SmallVector strideValues = llvm::to_vector<4>( llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); build(b, result, resultType, source, offsetValues, sizeValues, strideValues); } // Build a SubViewOp with dynamic entries and inferred result type. void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef attrs) { build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); } /// For ViewLikeOpInterface. Value SubViewOp::getViewSource() { return source(); } /// Return true if t1 and t2 have equal offsets (both dynamic or of same /// static value). static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { AffineExpr t1Offset, t2Offset; SmallVector t1Strides, t2Strides; auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; } /// Checks if `original` Type type can be rank reduced to `reduced` type. /// This function is slight variant of `is subsequence` algorithm where /// not matching dimension must be 1. static SliceVerificationResult isRankReducedMemRefType(MemRefType originalType, MemRefType candidateRankReducedType, ArrayRef sizes) { auto partialRes = isRankReducedType(originalType, candidateRankReducedType); if (partialRes != SliceVerificationResult::Success) return partialRes; auto optionalUnusedDimsMask = computeMemRefRankReductionMask( originalType, candidateRankReducedType, sizes); // Sizes cannot be matched in case empty vector is returned. if (!optionalUnusedDimsMask.hasValue()) return SliceVerificationResult::LayoutMismatch; if (originalType.getMemorySpace() != candidateRankReducedType.getMemorySpace()) return SliceVerificationResult::MemSpaceMismatch; // No amount of stride dropping can reconcile incompatible offsets. if (!haveCompatibleOffsets(originalType, candidateRankReducedType)) return SliceVerificationResult::LayoutMismatch; return SliceVerificationResult::Success; } template static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType) { auto memrefType = expectedType.cast(); switch (result) { case SliceVerificationResult::Success: return success(); case SliceVerificationResult::RankTooLarge: return op.emitError("expected result rank to be smaller or equal to ") << "the source rank. "; case SliceVerificationResult::SizeMismatch: return op.emitError("expected result type to be ") << expectedType << " or a rank-reduced version. (mismatch of result sizes) "; case SliceVerificationResult::ElemTypeMismatch: return op.emitError("expected result element type to be ") << memrefType.getElementType(); case SliceVerificationResult::MemSpaceMismatch: return op.emitError("expected result and source memory spaces to match."); case SliceVerificationResult::LayoutMismatch: return op.emitError("expected result type to be ") << expectedType << " or a rank-reduced version. (mismatch of result layout) "; } llvm_unreachable("unexpected subview verification result"); } /// Verifier for SubViewOp. LogicalResult SubViewOp::verify() { MemRefType baseType = getSourceType(); MemRefType subViewType = getType(); // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != subViewType.getMemorySpace()) return emitError("different memory spaces specified for base memref " "type ") << baseType << " and subview memref type " << subViewType; // Verify that the base memref type has a strided layout map. if (!isStrided(baseType)) return emitError("base type ") << baseType << " is not strided"; // Verify result type against inferred type. auto expectedType = SubViewOp::inferResultType( baseType, extractFromI64ArrayAttr(static_offsets()), extractFromI64ArrayAttr(static_sizes()), extractFromI64ArrayAttr(static_strides())); auto result = isRankReducedMemRefType(expectedType.cast(), subViewType, getMixedSizes()); return produceSubViewErrorMsg(result, *this, expectedType); } raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { return os << "range " << range.offset << ":" << range.size << ":" << range.stride; } /// Return the list of Range (i.e. offset, size, stride). Each Range /// entry contains either the dynamic value or a ConstantIndexOp constructed /// with `b` at location `loc`. SmallVector mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc) { std::array ranks = op.getArrayAttrMaxRanks(); assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); SmallVector res; unsigned rank = ranks[0]; res.reserve(rank); for (unsigned idx = 0; idx < rank; ++idx) { Value offset = op.isDynamicOffset(idx) ? op.getDynamicOffset(idx) : b.create(loc, op.getStaticOffset(idx)); Value size = op.isDynamicSize(idx) ? op.getDynamicSize(idx) : b.create(loc, op.getStaticSize(idx)); Value stride = op.isDynamicStride(idx) ? op.getDynamicStride(idx) : b.create(loc, op.getStaticStride(idx)); res.emplace_back(Range{offset, size, stride}); } return res; } /// Compute the canonical result type of a SubViewOp. Call `inferResultType` /// to deduce the result type for the given `sourceType`. Additionally, reduce /// the rank of the inferred result type if `currentResultType` is lower rank /// than `currentSourceType`. Use this signature if `sourceType` is updated /// together with the result type. In this case, it is important to compute /// the dropped dimensions using `currentSourceType` whose strides align with /// `currentResultType`. static MemRefType getCanonicalSubViewResultType( MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef mixedOffsets, ArrayRef mixedSizes, ArrayRef mixedStrides) { auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, mixedSizes, mixedStrides) .cast(); llvm::Optional unusedDims = computeMemRefRankReductionMask(currentSourceType, currentResultType, mixedSizes); // Return nullptr as failure mode. if (!unusedDims) return nullptr; SmallVector shape; for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) { if (unusedDims->test(sizes.index())) continue; shape.push_back(sizes.value()); } AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); if (!layoutMap.isIdentity()) layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, nonRankReducedType.getMemorySpace()); } /// Compute the canonical result type of a SubViewOp. Call `inferResultType` /// to deduce the result type. Additionally, reduce the rank of the inferred /// result type if `currentResultType` is lower rank than `sourceType`. static MemRefType getCanonicalSubViewResultType( MemRefType currentResultType, MemRefType sourceType, ArrayRef mixedOffsets, ArrayRef mixedSizes, ArrayRef mixedStrides) { return getCanonicalSubViewResultType(currentResultType, sourceType, sourceType, mixedOffsets, mixedSizes, mixedStrides); } /// Helper method to check if a `subview` operation is trivially a no-op. This /// is the case if the all offsets are zero, all strides are 1, and the source /// shape is same as the size of the subview. In such cases, the subview can /// be folded into its source. static bool isTrivialSubViewOp(SubViewOp subViewOp) { if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) return false; auto mixedOffsets = subViewOp.getMixedOffsets(); auto mixedSizes = subViewOp.getMixedSizes(); auto mixedStrides = subViewOp.getMixedStrides(); // Check offsets are zero. if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) { Optional intValue = getConstantIntValue(ofr); return !intValue || intValue.getValue() != 0; })) return false; // Check strides are one. if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) { Optional intValue = getConstantIntValue(ofr); return !intValue || intValue.getValue() != 1; })) return false; // Check all size values are static and matches the (static) source shape. ArrayRef sourceShape = subViewOp.getSourceType().getShape(); for (const auto &size : llvm::enumerate(mixedSizes)) { Optional intValue = getConstantIntValue(size.value()); if (!intValue || intValue.getValue() != sourceShape[size.index()]) return false; } // All conditions met. The `SubViewOp` is foldable as a no-op. return true; } namespace { /// Pattern to rewrite a subview op with MemRefCast arguments. /// This essentially pushes memref.cast past its consuming subview when /// `canFoldIntoConsumerOp` is true. /// /// Example: /// ``` /// %0 = memref.cast %V : memref<16x16xf32> to memref /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : /// memref to memref<3x4xf32, offset:?, strides:[?, 1]> /// ``` /// is rewritten into: /// ``` /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to /// memref<3x4xf32, offset:?, strides:[?, 1]> /// ``` class SubViewOpMemRefCastFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { // Any constant operand, just return to let SubViewOpConstantFolder kick // in. if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { return matchPattern(operand, matchConstantIndex()); })) return failure(); auto castOp = subViewOp.source().getDefiningOp(); if (!castOp) return failure(); if (!CastOp::canFoldIntoConsumerOp(castOp)) return failure(); // Compute the SubViewOp result type after folding the MemRefCastOp. Use // the MemRefCastOp source operand type to infer the result type and the // current SubViewOp source operand type to compute the dropped dimensions // if the operation is rank-reducing. auto resultType = getCanonicalSubViewResultType( subViewOp.getType(), subViewOp.getSourceType(), castOp.source().getType().cast(), subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), subViewOp.getMixedStrides()); if (!resultType) return failure(); Value newSubView = rewriter.create( subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), subViewOp.static_sizes(), subViewOp.static_strides()); rewriter.replaceOpWithNewOp(subViewOp, subViewOp.getType(), newSubView); return success(); } }; /// Canonicalize subview ops that are no-ops. When the source shape is not /// same as a result shape due to use of `affine_map`. class TrivialSubViewOpFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { if (!isTrivialSubViewOp(subViewOp)) return failure(); if (subViewOp.getSourceType() == subViewOp.getType()) { rewriter.replaceOp(subViewOp, subViewOp.source()); return success(); } rewriter.replaceOpWithNewOp(subViewOp, subViewOp.getType(), subViewOp.source()); return success(); } }; } // namespace /// Return the canonical type of the result of a subview. struct SubViewReturnTypeCanonicalizer { MemRefType operator()(SubViewOp op, ArrayRef mixedOffsets, ArrayRef mixedSizes, ArrayRef mixedStrides) { return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides); } }; /// A canonicalizer wrapper to replace SubViewOps. struct SubViewCanonicalizer { void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { rewriter.replaceOpWithNewOp(op, op.getType(), newOp); } }; void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results .add, SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); } OpFoldResult SubViewOp::fold(ArrayRef operands) { auto resultShapedType = getResult().getType().cast(); auto sourceShapedType = source().getType().cast(); if (resultShapedType.hasStaticShape() && resultShapedType == sourceShapedType) { return getViewSource(); } return {}; } //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// /// Build a strided memref type by applying `permutationMap` tp `memRefType`. static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap) { auto rank = memRefType.getRank(); auto originalSizes = memRefType.getShape(); // Compute permuted sizes. SmallVector sizes(rank, 0); for (const auto &en : llvm::enumerate(permutationMap.getResults())) sizes[en.index()] = originalSizes[en.value().cast().getPosition()]; // Compute permuted strides. int64_t offset; SmallVector strides; auto res = getStridesAndOffset(memRefType, strides, offset); assert(succeeded(res) && strides.size() == static_cast(rank)); (void)res; auto map = makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); map = permutationMap ? map.compose(permutationMap) : map; return MemRefType::Builder(memRefType) .setShape(sizes) .setLayout(AffineMapAttr::get(map)); } void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, AffineMapAttr permutation, ArrayRef attrs) { auto permutationMap = permutation.getValue(); assert(permutationMap); auto memRefType = in.getType().cast(); // Compute result type. MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); build(b, result, resultType, in, attrs); result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); } // transpose $in $permutation attr-dict : type($in) `to` type(results) void TransposeOp::print(OpAsmPrinter &p) { p << " " << in() << " " << permutation(); p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); p << " : " << in().getType() << " to " << getType(); } ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand in; AffineMap permutation; MemRefType srcType, dstType; if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || parser.resolveOperand(in, srcType, result.operands) || parser.parseKeywordType("to", dstType) || parser.addTypeToList(dstType, result.types)) return failure(); result.addAttribute(TransposeOp::getPermutationAttrName(), AffineMapAttr::get(permutation)); return success(); } LogicalResult TransposeOp::verify() { if (!permutation().isPermutation()) return emitOpError("expected a permutation map"); if (permutation().getNumDims() != getShapedType().getRank()) return emitOpError("expected a permutation map of same rank as the input"); auto srcType = in().getType().cast(); auto dstType = getType().cast(); auto transposedType = inferTransposeResultType(srcType, permutation()); if (dstType != transposedType) return emitOpError("output type ") << dstType << " does not match transposed input type " << srcType << ", " << transposedType; return success(); } OpFoldResult TransposeOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) return getResult(); return {}; } //===----------------------------------------------------------------------===// // ViewOp //===----------------------------------------------------------------------===// LogicalResult ViewOp::verify() { auto baseType = getOperand(0).getType().cast(); auto viewType = getType(); // The base memref should have identity layout map (or none). if (!baseType.getLayout().isIdentity()) return emitError("unsupported map for base memref type ") << baseType; // The result memref should have identity layout map (or none). if (!viewType.getLayout().isIdentity()) return emitError("unsupported map for result memref type ") << viewType; // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != viewType.getMemorySpace()) return emitError("different memory spaces specified for base memref " "type ") << baseType << " and view memref type " << viewType; // Verify that we have the correct number of sizes for the result type. unsigned numDynamicDims = viewType.getNumDynamicDims(); if (sizes().size() != numDynamicDims) return emitError("incorrect number of size operands for type ") << viewType; return success(); } Value ViewOp::getViewSource() { return source(); } namespace { struct ViewOpShapeFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ViewOp viewOp, PatternRewriter &rewriter) const override { // Return if none of the operands are constants. if (llvm::none_of(viewOp.getOperands(), [](Value operand) { return matchPattern(operand, matchConstantIndex()); })) return failure(); // Get result memref type. auto memrefType = viewOp.getType(); // Get offset from old memref view type 'memRefType'. int64_t oldOffset; SmallVector oldStrides; if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) return failure(); assert(oldOffset == 0 && "Expected 0 offset"); SmallVector newOperands; // Offset cannot be folded into result type. // Fold any dynamic dim operands which are produced by a constant. SmallVector newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); unsigned dynamicDimPos = 0; unsigned rank = memrefType.getRank(); for (unsigned dim = 0, e = rank; dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. if (!ShapedType::isDynamic(dimSize)) { newShapeConstants.push_back(dimSize); continue; } auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.value()); } else { // Dynamic shape dimension not folded; copy operand from old memref. newShapeConstants.push_back(dimSize); newOperands.push_back(viewOp.sizes()[dynamicDimPos]); } dynamicDimPos++; } // Create new memref type with constant folded dims. MemRefType newMemRefType = MemRefType::Builder(memrefType).setShape(newShapeConstants); // Nothing new, don't fold. if (newMemRefType == memrefType) return failure(); // Create new ViewOp. auto newViewOp = rewriter.create(viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), viewOp.byte_shift(), newOperands); // Insert a cast so we have the same type as the old memref type. rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), newViewOp); return success(); } }; struct ViewOpMemrefCastFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ViewOp viewOp, PatternRewriter &rewriter) const override { Value memrefOperand = viewOp.getOperand(0); CastOp memrefCastOp = memrefOperand.getDefiningOp(); if (!memrefCastOp) return failure(); Value allocOperand = memrefCastOp.getOperand(); AllocOp allocOp = allocOperand.getDefiningOp(); if (!allocOp) return failure(); rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), allocOperand, viewOp.byte_shift(), viewOp.sizes()); return success(); } }; } // namespace void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // AtomicRMWOp //===----------------------------------------------------------------------===// LogicalResult AtomicRMWOp::verify() { if (getMemRefType().getRank() != getNumOperands() - 2) return emitOpError( "expects the number of subscripts to be equal to memref rank"); switch (kind()) { case arith::AtomicRMWKind::addf: case arith::AtomicRMWKind::maxf: case arith::AtomicRMWKind::minf: case arith::AtomicRMWKind::mulf: if (!value().getType().isa()) return emitOpError() << "with kind '" << arith::stringifyAtomicRMWKind(kind()) << "' expects a floating-point type"; break; case arith::AtomicRMWKind::addi: case arith::AtomicRMWKind::maxs: case arith::AtomicRMWKind::maxu: case arith::AtomicRMWKind::mins: case arith::AtomicRMWKind::minu: case arith::AtomicRMWKind::muli: case arith::AtomicRMWKind::ori: case arith::AtomicRMWKind::andi: if (!value().getType().isa()) return emitOpError() << "with kind '" << arith::stringifyAtomicRMWKind(kind()) << "' expects an integer type"; break; default: break; } return success(); } OpFoldResult AtomicRMWOp::fold(ArrayRef operands) { /// atomicrmw(memrefcast) -> atomicrmw if (succeeded(foldMemRefCast(*this, value()))) return getResult(); return OpFoldResult(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"