llvm-project/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Emilio Cota 1276ce9e97 Revert "[mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (#104783)"
This reverts commit 03483737a7a2d72a257a5ab6ff01748ad9cf0f75 and
99c8557, which is a fix-up on top of the former.

I'm reverting because this commit broke two tests:
  mlir/test/python/integration/dialects/linalg/opsrun.py
  mlir/test/python/integration/dialects/transform.py
See https://lab.llvm.org/buildbot/#/builders/138/builds/4872

I'm not familiar with the tests, so I'm leaving it to the original author
to either remove or adapt the broken tests, as discussed here:
  https://github.com/llvm/llvm-project/pull/104783#issuecomment-2406390905
2024-10-11 05:22:56 -04:00

1287 lines
52 KiB
C++

//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
using namespace mlir;
using namespace mlir::linalg;
/// Include the definitions of the copy operation interface.
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
//===----------------------------------------------------------------------===//
// Interface utility functions
//===----------------------------------------------------------------------===//
bool linalg::detail::canOpOperandsBeDroppedImpl(
linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
SmallVector<AffineMap> indexingMaps;
for (auto &opOperand : linalgOp->getOpOperands()) {
if (llvm::is_contained(droppedOperands, &opOperand))
continue;
indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
}
if (indexingMaps.empty()) {
// If there are no indexing maps, the operand can only be dropped
// if the op has no loops.
return linalgOp.getNumLoops() == 0;
}
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}
//===----------------------------------------------------------------------===//
// CopyOpInterface implementation
//===----------------------------------------------------------------------===//
bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
// Structural.
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
return false;
// Operands and maps.
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
return false;
auto mapRange = linalgOp.getIndexingMapsArray();
if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
!mapRange.back().isIdentity()) {
return false;
}
// Region.
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
}
//===----------------------------------------------------------------------===//
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
// Structural.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
return std::nullopt;
// Input should be referenced and init should not.
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
return std::nullopt;
OpOperand *value = genericOp.getDpsInputOperand(0);
if (!genericOp.isScalar(value))
return std::nullopt;
Block *body = genericOp.getBody();
if (body->getOperations().size() != 1)
return std::nullopt;
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
yieldOp->getOperand(0) != body->getArgument(0))
return std::nullopt;
return value->get();
}
//===----------------------------------------------------------------------===//
// Elementwise Single Unary/Binary-OpInterface implementation
//===----------------------------------------------------------------------===//
static bool
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
// Check all loops are parallel.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
genericOp.getNumLoops() < 1)
return false;
// Check there are arity-inputs, 1-output and all are identity-maps.
if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
!llvm::all_of(genericOp.getIndexingMapsArray(),
[](AffineMap map) { return map.isIdentity(); }))
return false;
// Init should not be referenced for elementwise operations.
if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
return false;
// A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
// as resulting from producer-consumer fusion. Here, we restrict to two ops in
// the body, where the first is the elementwise single op and the second a
// yield.
Block *body = genericOp.getBody();
if (body->getOperations().size() != 2)
return false;
Operation *op = &body->front();
if (op->getNumOperands() != arity || op->getNumResults() != 1)
return false;
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
yieldOp->getOperand(0).getDefiningOp() != op)
return false;
return true;
}
bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
// All basic elemwise checks.
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1))
return false;
// Check input is actully used.
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
return false;
return true;
}
bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2))
return false;
// Check both inputs are used (elementwise).
OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
!genericOp.payloadUsesValueFromOperand(inputOpOperand1))
return false;
return true;
}
//===----------------------------------------------------------------------===//
// ContractionOpInterface implementation
//===----------------------------------------------------------------------===//
/// If the value is defined by a chain of unary side effect-free, go up the
/// use-def chain until the first value that isn't defined by such an op.
// TODO: relax to multi-operands with constants, which are technically unary ops
// as needed (e.g. add5).
static Value getSourceSkipUnary(Value value) {
Operation *op = value.getDefiningOp();
while (op && op->getNumOperands() == 1) {
auto iface = dyn_cast<MemoryEffectOpInterface>(op);
if (!iface || !iface.hasNoEffect())
break;
value = op->getOperand(0);
op = value.getDefiningOp();
}
return value;
}
bool mlir::linalg::detail::isContractionBody(
Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
llvm::raw_ostream &errs) {
if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
errs << "no terminator in the block";
return false;
}
if (block.getNumArguments() != 3) {
errs << "expected block with 3 arguments";
return false;
}
Operation *terminator = block.getTerminator();
if (terminator->getNumOperands() != 1) {
errs << "expected terminator with 1 operand";
return false;
}
Value yielded = getSourceSkipUnary(terminator->getOperand(0));
Operation *reductionOp = yielded.getDefiningOp();
if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
errs << "expected reduction op to be binary";
return false;
}
Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
if (reductionLHS != block.getArgument(2) &&
reductionRHS != block.getArgument(2)) {
errs << "expected reduction to take block argument #2 as one of the "
"operands (modulo unary casts)";
return false;
}
Value contributed = getSourceSkipUnary(
isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
Operation *elementwiseOp = contributed.getDefiningOp();
if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
elementwiseOp->getNumOperands() != 2) {
errs << "expected elementwise op to be binary";
return false;
}
if (!isaPair(elementwiseOp, reductionOp)) {
errs << "expected reduction/elementwise op kind not satisfied";
return false;
}
Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
if ((elementwiseLHS == block.getArgument(0) &&
elementwiseRHS == block.getArgument(1)) ||
(elementwiseLHS == block.getArgument(1) &&
elementwiseRHS == block.getArgument(0))) {
return true;
}
errs << "expected elementwise op to apply to block arguments (modulo unary "
"casts)";
return false;
}
/// Returns true if the two operations are of the kinds specified by a pair of
/// consecutive template arguments.
template <typename AddOpTy, typename MulOpTy, typename... Args>
static bool isPairTemplateImpl(Operation *add, Operation *mul) {
static_assert(sizeof...(Args) % 2 == 0,
"expected an even number of template arguments");
if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
return true;
if constexpr (sizeof...(Args) > 0)
return isPairTemplateImpl<Args...>(add, mul);
else
return false;
}
/// Returns true if the block is a body of a contraction with the kinds of
/// operations given pairwise by template arguments.
template <typename... Args>
static bool isContractionBody(Block &block) {
return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
}
/// Given an `indexingMap` and its corresponding `iterators`, returns
/// the positions of the iterators of type `iter` that are indexed by
/// the `indexingMap` as a permutation. This is useful to infer various
/// subcomputations on a `LinalgOp`. This is performed by looking up
/// each result in the `indexingMap` and determining whether:
/// - It is a single AffineDimExpr.
/// - It is the only result involving this AffineDimExpr.
static llvm::SmallDenseSet<int64_t>
findPermutationsIndexingOperand(AffineMap indexingMap,
ArrayRef<utils::IteratorType> iterators,
utils::IteratorType iter) {
assert(iterators.size() == indexingMap.getNumDims());
llvm::SmallDenseSet<int64_t> res;
for (AffineExpr e : indexingMap.getResults()) {
if (auto d = dyn_cast<AffineDimExpr>(e)) {
if (iterators[d.getPosition()] == iter &&
llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
return e.isFunctionOfDim(d.getPosition());
}) == 1)
res.insert(d.getPosition());
}
}
return res;
}
namespace {
auto par = utils::IteratorType::parallel;
auto red = utils::IteratorType::reduction;
} // namespace
/// Infer the iterator types from the init affine map. This looks at which dims
/// are present in the map results, and returns an iterator types array with
/// parallel types for dims that are present, and reduction types for dims that
/// are not present.
static FailureOr<SmallVector<utils::IteratorType>>
inferIteratorsFromOutMap(AffineMap map) {
if (!map.isProjectedPermutation())
return failure();
SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
for (auto expr : map.getResults())
if (auto dim = dyn_cast<AffineDimExpr>(expr))
iterators[dim.getPosition()] = par;
return iterators;
}
/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
/// 1. The m dimension is involved in an outer-product along LHS
/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
/// 2. The n dimension is involved in an outer-product along RHS
/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
/// 3. The k dimension appears as a permutation on LHS and RHS.
/// 4. m, n and k appear only once in any given indexing.
/// 5. Optional batch dimensions that appear in all operands are captured.
/// This allows e.g. detecting that some contraction is embedded within
/// `linalgOp` with some orthogonal heuristic.
static FailureOr<ContractionDimensions>
inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
ArrayRef<utils::IteratorType> iterators) {
llvm::SmallDenseSet<int64_t> a =
findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
llvm::SmallDenseSet<int64_t> b =
findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
llvm::SmallDenseSet<int64_t> c =
findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
// A & C - B are the iterators involved in an outer-product along A (the LHS).
llvm::SmallDenseSet<int64_t> ac = a;
llvm::set_intersect(ac, c);
llvm::set_subtract(ac, b);
// B & C - A are the iterators involved in an outer-product along B (the RHS).
llvm::SmallDenseSet<int64_t> bc = b;
llvm::set_intersect(bc, c);
llvm::set_subtract(bc, a);
// A & B & C are the "batch" dimensions.
llvm::SmallDenseSet<int64_t> batches = a;
llvm::set_intersect(batches, b);
llvm::set_intersect(batches, c);
// A & B red are the reduction dimensions.
llvm::SmallDenseSet<int64_t> ra =
findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
llvm::SmallDenseSet<int64_t> rb =
findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
llvm::set_intersect(ra, rb);
// Return each set in sorted order.
ContractionDimensions dimensions{
SmallVector<unsigned, 2>(batches.begin(), batches.end()),
SmallVector<unsigned, 2>(ac.begin(), ac.end()),
SmallVector<unsigned, 2>(bc.begin(), bc.end()),
SmallVector<unsigned, 2>(ra.begin(), ra.end())};
llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
llvm::sort(dimensions.m.begin(), dimensions.m.end());
llvm::sort(dimensions.n.begin(), dimensions.n.end());
llvm::sort(dimensions.k.begin(), dimensions.k.end());
return dimensions;
}
FailureOr<ContractionDimensions>
mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
return failure();
return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
linalgOp.getIteratorTypesArray());
}
FailureOr<ContractionDimensions>
mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
if (indexingMaps.size() != 3)
return failure();
auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
if (failed(iterators))
return failure();
return inferContractionDimsImpl(indexingMaps, iterators.value());
}
namespace mlir::linalg::detail {
enum class MatchContractionResult {
Success = 0,
NotLinalgOp,
WrongNumOperands,
NoReduction,
NotProjectedPermutations,
NotAddMul
};
} // namespace mlir::linalg::detail
mlir::linalg::detail::MatchContractionResult
mlir::linalg::detail::isContractionInterfaceImpl(
Operation *op, mlir::linalg::ContractionDimensions *dimensions) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return MatchContractionResult::NotLinalgOp;
if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
return MatchContractionResult::WrongNumOperands;
auto mapRange = linalgOp.getIndexingMapsArray();
if (linalgOp.getNumReductionLoops() == 0)
return MatchContractionResult::NoReduction;
if (llvm::any_of(mapRange,
[](AffineMap m) { return !m.isProjectedPermutation(); }))
return MatchContractionResult::NotProjectedPermutations;
// TODO: more fields than add/mul.
// clang-format off
if (!::isContractionBody<
arith::MulFOp, arith::AddFOp,
arith::MulIOp, arith::AddIOp,
complex::MulOp, complex::AddOp,
arith::AndIOp, arith::OrIOp>(
*linalgOp.getBlock())) {
return MatchContractionResult::NotAddMul;
}
// clang-format on
if (dimensions) {
FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
assert(succeeded(res) && "unexpected failure to infer contraction dims");
*dimensions = *res;
}
return MatchContractionResult::Success;
}
StringRef
mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) {
switch (res) {
case MatchContractionResult::NotLinalgOp:
return "expected a LinalgOp";
case MatchContractionResult::WrongNumOperands:
return "expected op with 2 inputs and 1 output";
case MatchContractionResult::NoReduction:
return "expected at least 1 reduction";
case MatchContractionResult::NotProjectedPermutations:
return "expected indexing maps to be projected permutations";
case MatchContractionResult::NotAddMul:
return "expected add/mul op in the body";
case MatchContractionResult::Success:
return "";
}
llvm_unreachable("unhandled MatchContractionResult case");
}
bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
if (!linalgOp)
return false;
Operation *op = linalgOp.getOperation();
return isa<ContractionOpInterface>(op) ||
(mlir::linalg::detail::isContractionInterfaceImpl(op) ==
mlir::linalg::detail::MatchContractionResult::Success);
}
/// Verify that a LinalgOp `op` is a contraction.
/// A Linalg contraction is defined in general terms:
/// 1. Has 2 input and 1 output shapes.
/// 2. Has at least one reduction dimension.
/// 3. Has only projected permutation indexing maps.
/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
/// operations that may change the type (e.g. for mixed-precision).
/// As a consequence, when vectorization of such an op occurs, the only special
/// behavior is that the (unique) MulOpType is vectorized into a
/// `vector.contract`. All other ops are handled in a generic fashion.
/// In the future, we may wish to allow more input arguments and elementwise and
/// constant operations that do not involve the reduction dimension(s).
LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
auto res = isContractionInterfaceImpl(op);
if (res != MatchContractionResult::Success)
return op->emitError(getMatchContractionMessage(res));
return success();
}
//===----------------------------------------------------------------------===//
// ConvolutionOpInterface implementation
//===----------------------------------------------------------------------===//
/// Of the given two expressions returns one that is of type T (`lhs` gets
/// preference over `rhs`)
template <typename T>
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) {
return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
}
namespace {
/// Walk the indexing expressions for input of a convolution operation to verify
/// its of the right form, either
/// - AffineDimExpr
/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
///
/// classifies the AffineDimExpr as convolved dimensions or unconvolved
/// dimensions and verifies each dimension occurs only once.
struct ConvAccessExprWalker
: public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
// Stores dimensions used in expressions of the above form.
llvm::SmallDenseSet<int64_t> convolvedDims;
// Stores the dual mapping between LHS and RHS of convolution exprs.
llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
// Stores single use dimensions used by an AffineDimExpr.
llvm::SmallDenseSet<int64_t> unConvolvedDims;
// Stores a mapping from convolved dims to their coefficient.
llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
// Removes dims with multiple uses in the source input map from dimension
// sets tracked by this walker.
void clearMultiUseDims(AffineMap map) {
for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
return e.isFunctionOfDim(dimPos);
}) > 1) {
convolvedDims.erase(dimPos);
unConvolvedDims.erase(dimPos);
// If a duplicate dim is marked as convolved, the pair of the duplicate
// dim must be removed from the map as well.
auto it = convolvedDimMapping.find(dimPos);
if (it != convolvedDimMapping.end()) {
int64_t pairedDim = it->second;
convolvedDims.erase(pairedDim);
unConvolvedDims.erase(pairedDim);
strideAndDilationMapping.erase(pairedDim);
convolvedDimMapping.erase(dimPos);
convolvedDimMapping.erase(pairedDim);
}
}
}
}
LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
unsigned position = dimExpr.getPosition();
if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
return failure();
}
unConvolvedDims.insert(position);
return success();
}
LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
// In pre-order visit, top level op has to be an add op.
if (binaryExpr.getKind() != AffineExprKind::Add)
return failure();
auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
if (failed(lhsDimPos) || failed(rhsDimPos))
return failure();
convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
return success();
}
FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
int64_t dim = dimExpr.getPosition();
if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
return failure();
// Stride/dilation for this dim is implicitly 1.
strideAndDilationMapping[dim] =
getAffineConstantExpr(1, expr.getContext());
convolvedDims.insert(dim);
return dim;
}
if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
if (symbolMulExpr.getKind() != AffineExprKind::Mul)
return failure();
auto lhsExpr = symbolMulExpr.getLHS();
auto rhsExpr = symbolMulExpr.getRHS();
// Check for symbol expression.
AffineExpr mulExpr =
getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
// If there was no symbol expr, check for constant expression.
if (!mulExpr) {
mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
}
auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
if (!mulExpr || !dimExpr)
return failure();
int64_t dim = dimExpr.getPosition();
if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
return failure();
strideAndDilationMapping[dim] = mulExpr;
convolvedDims.insert(dim);
return dim;
}
return failure();
}
};
} // namespace
static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
assert(map.isProjectedPermutation() &&
"expected map to have projected permutations");
llvm::SmallDenseSet<int64_t> preservedDims;
for (auto expr : map.getResults())
preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
return preservedDims;
}
static SmallVector<int64_t, 2>
getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) {
SmallVector<int64_t, 2> vals;
for (auto e : exprs) {
auto constantExpr = dyn_cast<AffineConstantExpr>(e);
assert(constantExpr && "Found non-constant stride/dilation");
vals.push_back(constantExpr.getValue());
}
return vals;
}
/// Classifies dimensions in the `linalgOp` used by a convolution
/// subcomputation, as captured by `inputExprWalker`. If
/// `allowEmptyConvolvedDims` is not set this this will fail if there is not
/// at least convolved dimension pair (output image + filter loop). Convolution
/// dimensions are specified in sorted order, and strides match the order of
/// the filter loop dimensions, while the dilations match the order of the
/// output image dimensions.
static FailureOr<ConvolutionDimensions>
inferConvolutionDimsImpl(LinalgOp linalgOp,
ConvAccessExprWalker &inputExprWalker,
bool allowEmptyConvolvedDims) {
auto filterMap =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
auto outputMap =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
filterMap, linalgOp.getIteratorTypesArray(), par);
llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
outputMap, linalgOp.getIteratorTypesArray(), par);
// unConvolvedDims & outputDims - filterDims are the batch iterators.
llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
llvm::set_intersect(batch, outputDims);
llvm::set_subtract(batch, filterDims);
// convolvedDims & outputDims are the output image iterators.
llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
llvm::set_intersect(oi, outputDims);
// filterDims & outputDims - unConvolvedDims are the output channel iterators.
llvm::SmallDenseSet<int64_t> oc = filterDims;
llvm::set_intersect(oc, outputDims);
llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
// filterDims & outputDims & unConvolvedDims are the depth iterators.
llvm::SmallDenseSet<int64_t> depth = filterDims;
llvm::set_intersect(depth, outputDims);
llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
llvm::SmallDenseSet<int64_t> filterReducedDims =
findPermutationsIndexingOperand(filterMap,
linalgOp.getIteratorTypesArray(), red);
// convolvedDims & filterReducedDims are the filter loop iterators.
llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
llvm::set_intersect(fl, filterReducedDims);
// unConvolvedDims & filterReducedDims are the input channel iterators.
llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
llvm::set_intersect(ic, filterReducedDims);
if (oi.empty() && !allowEmptyConvolvedDims)
return failure();
// Return each set in sorted order.
ConvolutionDimensions dimensions{
SmallVector<unsigned, 2>(batch.begin(), batch.end()),
SmallVector<unsigned, 2>(oi.begin(), oi.end()),
SmallVector<unsigned, 2>(oc.begin(), oc.end()),
SmallVector<unsigned, 2>(fl.begin(), fl.end()),
SmallVector<unsigned, 2>(ic.begin(), ic.end()),
SmallVector<unsigned, 2>(depth.begin(), depth.end()),
/*strides=*/SmallVector<int64_t, 2>{},
/*dilations=*/SmallVector<int64_t, 2>{}};
llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
// Use the op carried strides/dilations attribute if present.
auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
if (!nativeStrides) {
SmallVector<AffineExpr, 2> strideExprs;
for (unsigned oiDim : dimensions.outputImage)
strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
dimensions.strides = getConstantsFromExprList(strideExprs);
} else {
dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
}
auto nativeDilations =
linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
if (!nativeDilations) {
SmallVector<AffineExpr, 2> dilationExprs;
for (unsigned flDim : dimensions.filterLoop)
dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
dimensions.dilations = getConstantsFromExprList(dilationExprs);
} else {
dimensions.dilations =
llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
}
return dimensions;
}
/// Find at least 1 parallel (output_image) and reduction (filter_loop)
/// dimension candidates that form a convolution subcomputation within
/// `linalgOp`. The LHS is assumed to be the convolution input while the
/// RHS is assumed as the filter.
/// These dimensions are such that:
/// 1. Optional batch dimensions that appear in the input and filter.
/// 2. The output_image dimension is involved in a cross-correlation along LHS
/// (i.e. it is a permutation on RES and LHS and has an associated
/// filter_loop in RHS).
/// 3. Optional output_channel dimension is involved in an outer-product along
/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
/// LHS).
/// 4. Optional input_channel dimension appears as a permutation on LHS and
/// RHS.
/// 5. The filter_loop dimension appears as a permutation on the RHS and
/// represents the shape of the kernel cross-correlated along a
/// corresponding output_image dim.
/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
/// 7. All dimensions appear only once in any given indexing map.
/// This allows e.g. detecting that some convolution is embedded within
/// `linalgOp` with some orthogonal heuristic.
/// When multiple dimension occurrences exist that match any classification
/// indices are returned in sorted order.
/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
FailureOr<ConvolutionDimensions>
mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) {
if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
return failure();
auto indexingMaps = linalgOp.getIndexingMapsArray();
// Check the input indexing map has the right form.
ConvAccessExprWalker inputExprWalker;
for (AffineExpr expr : indexingMaps[0].getResults())
(void)inputExprWalker.visit(expr);
inputExprWalker.clearMultiUseDims(indexingMaps[0]);
return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
/*allowEmptyConvolvedDims=*/false);
}
namespace mlir::linalg::detail {
enum class MatchConvolutionResult {
Success = 0,
NotLinalgOp,
WrongNumOperands,
WrongInputIndexingMap,
NotProjectedPermutations,
NonConvolutionLoop,
OutputDimsNotParallel,
NonOutputDimNotReduction,
EmptyConvolvedDims
};
} // namespace mlir::linalg::detail
mlir::linalg::detail::MatchConvolutionResult
mlir::linalg::detail::isConvolutionInterfaceImpl(
Operation *op, ConvolutionDimensions *dimensions,
bool allowEmptyConvolvedDims) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return MatchConvolutionResult::NotLinalgOp;
if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
return MatchConvolutionResult::WrongNumOperands;
auto indexingMaps = linalgOp.getIndexingMapsArray();
// Check the input indexing map has the right form.
ConvAccessExprWalker inputExprWalker;
if (llvm::any_of(indexingMaps[0].getResults(),
[&inputExprWalker](AffineExpr expr) {
return failed(inputExprWalker.visit(expr));
})) {
return MatchConvolutionResult::WrongInputIndexingMap;
}
// Filter and output maps must be projected permutation.
if (!indexingMaps[1].isProjectedPermutation() ||
!indexingMaps.back().isProjectedPermutation())
return MatchConvolutionResult::NotProjectedPermutations;
auto iteratorTypes = linalgOp.getIteratorTypesArray();
llvm::SmallDenseSet<int64_t> outputDims =
getPreservedDims(indexingMaps.back());
llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
// Make sure all loops are characterized as one of:
// - Batch loop : present in output, as non-convolved in input, not present in
// filter.
// - Output image dimension : present in output, convolved dims in input, not
// present in filter.
// - Output channel dimension : present in output, not present in input,
// present in filter.
// - Filter loop dimension : present in filter, convolved in input, not
// present in output.
// - Input channel dimension : unconvolved in input, not present in output,
// present in filter.
// - Depth multiplier : unconvolved in input, present in output, present in
// filter.
llvm::SmallDenseSet<int64_t> allLoopDims;
for (auto outputExpr : indexingMaps.back().getResults()) {
int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
!filterDims.count(outputDim)) {
// Batch dimension.
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
}
if (inputExprWalker.convolvedDims.count(outputDim) &&
!filterDims.count(outputDim)) {
// Output image Loop dimension.
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
}
if (!inputExprWalker.convolvedDims.count(outputDim) &&
!inputExprWalker.unConvolvedDims.count(outputDim) &&
filterDims.count(outputDim)) {
// Output channel dimension.
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
filterDims.count(outputDim)) {
// Depth multiplier.
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
}
return MatchConvolutionResult::NonConvolutionLoop;
}
for (auto filterExpr : indexingMaps[1].getResults()) {
int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
if (outputDims.count(filterDim) &&
!inputExprWalker.unConvolvedDims.count(filterDim) &&
!inputExprWalker.convolvedDims.count(filterDim)) {
// Output channel dimension. This is already seen, continue;
continue;
}
if (inputExprWalker.convolvedDims.count(filterDim) &&
!outputDims.count(filterDim)) {
// Filter loop dimension.
if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
return MatchConvolutionResult::NonOutputDimNotReduction;
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
allLoopDims.insert(filterDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
!outputDims.count(filterDim)) {
// Input channel dimension.
if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
return MatchConvolutionResult::NonOutputDimNotReduction;
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
allLoopDims.insert(filterDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
outputDims.count(filterDim)) {
// Depthwise loop. Already seen.
continue;
}
return MatchConvolutionResult::NonConvolutionLoop;
}
// All loops must be covered now.
if (allLoopDims.size() != linalgOp.getNumLoops())
return MatchConvolutionResult::NonConvolutionLoop;
if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
return MatchConvolutionResult::EmptyConvolvedDims;
if (dimensions) {
FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
linalgOp, inputExprWalker, allowEmptyConvolvedDims);
assert(succeeded(res) && "unexpected failure to infer convolution dims");
*dimensions = *res;
}
return MatchConvolutionResult::Success;
}
StringRef
mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
switch (res) {
case MatchConvolutionResult::NotLinalgOp:
return "expected a LinalgOp";
case MatchConvolutionResult::WrongNumOperands:
return "expected op with 2 inputs and 1 output";
case MatchConvolutionResult::WrongInputIndexingMap:
return "unexpected input index map for convolutions";
case MatchConvolutionResult::NotProjectedPermutations:
return "expected output/filter indexing maps to be projected permutations";
case MatchConvolutionResult::NonConvolutionLoop:
return "unexpected loop dimension for convolution op";
case MatchConvolutionResult::OutputDimsNotParallel:
return "expected all iterators used to access outputs to be parallel";
case MatchConvolutionResult::NonOutputDimNotReduction:
return "expected all iterators not used to access outputs to be reduction";
case MatchConvolutionResult::EmptyConvolvedDims:
return "expected convolved dim to be non-empty";
case MatchConvolutionResult::Success:
return "";
}
llvm_unreachable("unhandled MatchConvolutionResult case");
}
bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp,
bool allowEmptyConvolvedDims) {
return linalg::detail::isConvolutionInterfaceImpl(
linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
linalg::detail::MatchConvolutionResult::Success;
}
LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
MatchConvolutionResult res = isConvolutionInterfaceImpl(op);
if (res != MatchConvolutionResult::Success)
return op->emitError(getMatchConvolutionMessage(res));
return success();
}
//===----------------------------------------------------------------------===//
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
enum class MatchFillResult {
Success = 0,
NotLinalgOp,
WrongNumOperands,
NotScalarInput
};
static MatchFillResult isFillInterfaceImpl(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return MatchFillResult::NotLinalgOp;
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
return MatchFillResult::WrongNumOperands;
OpOperand *value = linalgOp.getDpsInputOperand(0);
if (!linalgOp.isScalar(value))
return MatchFillResult::NotScalarInput;
return MatchFillResult::Success;
}
LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
auto res = isFillInterfaceImpl(op);
if (res == MatchFillResult::NotLinalgOp)
return op->emitError("expected a LinalgOp");
if (res == MatchFillResult::WrongNumOperands)
return op->emitError("expected op with 1 input and 1 output");
if (res == MatchFillResult::NotScalarInput)
return op->emitError("expected op with scalar input");
return success();
}
//===----------------------------------------------------------------------===//
// StructuredOpInterface implementation
//===----------------------------------------------------------------------===//
SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
Location loc) {
SmallVector<OpFoldResult> res;
for (OpOperand &opOperand : getOperation()->getOpOperands()) {
for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
}
return res;
}
SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
SmallVector<int64_t, 4> res;
assert(!hasDynamicShape() && "expected operands to have static shapes");
for (OpOperand &opOperand : getOperation()->getOpOperands())
llvm::append_range(res, getShape(&opOperand));
return res;
}
SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
AffineMap map = getLoopsToShapesMap();
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
auto viewSizes = createFlatListOfOperandDims(b, loc);
SmallVector<Range, 4> res(numDims);
for (unsigned idx = 0; idx < numRes; ++idx) {
auto result = map.getResult(idx);
if (auto d = dyn_cast<AffineDimExpr>(result)) {
if (res[d.getPosition()].offset)
continue;
res[d.getPosition()] =
Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
}
}
return res;
}
SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
AffineMap map = getLoopsToShapesMap();
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
SmallVector<int64_t, 4> res(numDims, 0);
for (unsigned idx = 0; idx < numRes; ++idx) {
auto result = map.getResult(idx);
if (auto d = dyn_cast<AffineDimExpr>(result))
res[d.getPosition()] = allShapeSizes[idx];
}
return res;
}
/// Visitor to check if any of the given set of positions from AffineDimExprs
/// are used within an AffineExpr.
struct HasAffineDimExprVisitor
: public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
: positions(std::move(positions)) {}
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
}
bool visitDimExpr(AffineDimExpr dimExpr) {
return positions.test(dimExpr.getPosition());
}
bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
private:
llvm::SmallBitVector positions;
};
static std::pair<int64_t, int64_t>
getResultsPositionInLoopsToShapeMap(LinalgOp &op) {
int64_t inputRankSum = 0;
int64_t outputRankSum = 0;
for (OpOperand *input : op.getDpsInputOperands())
inputRankSum += op.getRank(input);
for (OpOperand &output : op.getDpsInitsMutable())
outputRankSum += op.getRank(&output);
return {inputRankSum, inputRankSum + outputRankSum};
}
LogicalResult
LinalgOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
// An example that helps understand the logic below.
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
// We want to express the shape of dim 0 of O in terms of shape of the inputs.
// This is achieved as follows.
// loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
// subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
// shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
// resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
AffineMap loopsToShapesMap = getLoopsToShapesMap();
// Find the position in the above map that represents the shape of the
// result:dim being inferred.
auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
/// From loopsToShapesMap extract the submap that represents the shape of the
/// (resultIdx, dim) needed.
AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
resultShapesSubMapPos.first,
resultShapesSubMapPos.second - resultShapesSubMapPos.first);
AffineMap resultShapesFromInputShapesMap =
loopToResultsShapeMap.compose(getShapesToLoopsMap());
// Check that the result dim map does not contain the positions corresponding
// to the outputs.
llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
Location loc = getOperation()->getLoc();
IRRewriter rewriter(b);
SmallVector<OpFoldResult> allResultDimValues =
affine::makeComposedFoldedMultiResultAffineApply(
rewriter, loc, resultShapesFromInputShapesMap,
createFlatListOfOperandDims(b, loc));
int64_t pos = 0;
ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
for (OpOperand &opOperand : getDpsInitsMutable()) {
SmallVector<OpFoldResult> shapes;
for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
if (!shapedType.isDynamicDim(dim)) {
// Static dim: Return IntegerAttr.
shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
} else {
// Dynamic dim: Return Value.
OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
? createOrFoldDimOp(b, loc, opOperand.get(), dim)
: allResultDimValues[pos];
shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
}
pos++;
}
reifiedReturnShapes.emplace_back(std::move(shapes));
}
return success();
}
/// Return the index in the indexingMaps vector that corresponds to this
/// `opOperand`.
int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
auto operandNumber = opOperand->getOperandNumber();
auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
if (!dpsIface.isDpsInput(opOperand))
return operandNumber;
unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
assert(!dpsIface.isDpsInit(opOperand));
// Account for potential inputs that are not DPS and may not appear in
// `indexingMaps`.
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getNumDpsInputs() +
operandNumber - start;
}
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
LinalgOp linalgOp = cast<LinalgOp>(op);
// Mixed tensor/buffer operands are not allowed.
if (!linalgOp.hasPureTensorSemantics() &&
!linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
return op->emitOpError("expected to have pure tensor or buffer semantics");
// Before checking indexing maps, we need to make sure the attributes
// referenced by it are valid.
if (linalgOp.hasDynamicIndexingMaps())
if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
return failure();
// All input/output operands must be indexed.
if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
linalgOp->getNumOperands())
return op->emitOpError("expected the number of indexing_map (")
<< linalgOp.getIndexingMapsArray().size()
<< ") to be equal to the number of input/output operands ("
<< linalgOp->getNumOperands() << ")";
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
// Symbols disallowed.
if (indexingMap.getNumSymbols() != 0)
return op->emitOpError("unexpected symbols in indexing_map #")
<< opOperand.getOperandNumber();
// Domain must be consistent.
unsigned numLoops = linalgOp.getNumLoops();
if (indexingMap.getNumDims() != numLoops)
return op->emitOpError("expected indexing_map #")
<< opOperand.getOperandNumber() << " to have " << numLoops
<< " dim(s) to match the number of loops";
int64_t rank = linalgOp.getRank(&opOperand);
if (indexingMap.getNumResults() != rank)
return op->emitOpError("expected operand rank (")
<< rank << ") to match the result rank of indexing_map #"
<< opOperand.getOperandNumber() << " ("
<< indexingMap.getNumResults() << ")";
}
SmallVector<unsigned> redDims;
linalgOp.getReductionDims(redDims);
if (!linalgOp.getShapesToLoopsMap())
return op->emitOpError("expected the shape-to-loops map to be non-null");
// Check if given shapes match to inferred shapes.
SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
// Verify only static cases since we can't get exact dimension sizes and loop
// ranges for dynamic cases in this stage.
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
for (int64_t &range : endLoopRangeValues)
range -= 1;
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
SmallVector<int64_t, 4> startIndices =
indexingMap.compose(startLoopRangeValues);
SmallVector<int64_t, 4> endIndices =
indexingMap.compose(endLoopRangeValues);
ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
// Ignore dynamic dimension or the case that the dimension size is 0
if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
continue;
// The first index or last index should be the maximum or the minimum in
// the inferred index ranges since the range is increasing or
// decreasing. The size of dimensions of input/output operands and the
// maximum value + 1 in the inferred range should be the same. But, for
// now we check if the inferred ranges are in boundary of input/output
// operands' size or not in case that Affine Expressions are complicated
// such as d0 * 3
// + d1 since it is not easy to handle the issues.
// Found the case that this solution can't check, for example, (d0, d1)
// -> (d1 - d0)
int64_t inferredDimSize =
std::max(startIndices[dim], endIndices[dim]) + 1;
if (std::min(startIndices[dim], endIndices[dim]) < 0) {
std::string mapStr;
{
llvm::raw_string_ostream os(mapStr);
os << indexingMap;
}
return op->emitOpError(
"unexpected result less than 0 at expression #")
<< dim << " in " << mapStr;
}
if (dyn_cast<AffineDimExpr>(indexingMap.getResult(dim))) {
if (inferredDimSize != shape[dim]) {
return op->emitOpError("inferred input/output operand #")
<< opOperand.getOperandNumber() << " has shape's dimension #"
<< dim << " to be " << inferredDimSize << ", but found "
<< shape[dim];
}
} else {
if (inferredDimSize > shape[dim]) {
return op->emitOpError("inferred input/output operand #")
<< opOperand.getOperandNumber() << " has shape's dimension #"
<< dim << " to be greater than or equal to "
<< inferredDimSize << ", but found " << shape[dim];
}
}
}
}
}
// Check the region has exactly one block.
if (linalgOp->getNumRegions() != 1 ||
!llvm::hasSingleElement(linalgOp->getRegion(0)))
return op->emitOpError("expects to have 1 region with 1 block");
// Simplifying assumption: bbargs match 1-1 with shape operands elemental
// types.
// TODO: once ranked shape types are plugged in, we may want to drop the
// corresponding bbargs, that can never be read from. This will be subject to
// consistency discussions (i.e. what to do with output tensors whose bbarg is
// not used).
Block &block = linalgOp->getRegion(0).front();
if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
return op->emitOpError("expected as many non-induction variable region "
"arguments as the number of input/output operands");
for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
Type elementType = opOperand->get().getType();
if (isa<MemRefType, RankedTensorType>(elementType))
elementType = getElementTypeOrSelf(opOperand->get().getType());
Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
if (elementType != argType)
return op->emitOpError("expected type of bb argument #")
<< opOperand->getOperandNumber() << " (" << argType << ")"
<< " to match element or self type of the corresponding operand ("
<< elementType << ")";
}
return success();
}