2021-02-03 22:19:12 +00:00
|
|
|
//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
|
|
|
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
2021-10-12 23:14:57 +00:00
|
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
2022-05-12 13:35:27 +02:00
|
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
2021-02-10 13:53:11 +01:00
|
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
2021-11-25 11:42:16 +01:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2021-02-03 22:19:12 +00:00
|
|
|
#include "mlir/IR/AffineExprVisitor.h"
|
|
|
|
#include "mlir/IR/AffineMap.h"
|
2021-05-31 13:30:56 +00:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
2022-02-12 14:19:35 +01:00
|
|
|
#include "llvm/ADT/SmallBitVector.h"
|
2021-02-03 22:19:12 +00:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::linalg;
|
|
|
|
|
|
|
|
/// Include the definitions of the copy operation interface.
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
|
|
|
|
|
2022-05-12 03:50:21 +00:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Interface utility functions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool linalg::detail::canOpOperandsBeDroppedImpl(
|
|
|
|
linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
|
|
|
|
SmallVector<AffineMap> indexingMaps;
|
2022-05-16 09:50:18 +00:00
|
|
|
for (auto *opOperand : linalgOp.getInputAndOutputOperands()) {
|
2022-05-12 03:50:21 +00:00
|
|
|
if (llvm::is_contained(droppedOperands, opOperand))
|
|
|
|
continue;
|
|
|
|
indexingMaps.push_back(linalgOp.getTiedIndexingMap(opOperand));
|
|
|
|
}
|
|
|
|
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
|
|
|
|
}
|
|
|
|
|
2021-02-04 16:49:09 +00:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ContractionOpInterface implementation
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Return true if the use-def chain from `v` to `from` consists of 0 or more
|
|
|
|
/// unary single-operand operations.
|
|
|
|
// TODO: relax to multi-operands with constants, which are technically unary ops
|
|
|
|
// as needed (e.g. add5).
|
|
|
|
static bool isChainOfUnaryOpsFrom(Value v, Value from) {
|
|
|
|
while (true) {
|
|
|
|
if (v == from)
|
|
|
|
return true;
|
|
|
|
Operation *op = v.getDefiningOp();
|
|
|
|
if (!op || op->getNumOperands() != 1)
|
|
|
|
return false;
|
|
|
|
v = op->getOperand(0);
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Return the unique instance of OpType in `block` if it is indeed unique.
|
|
|
|
/// Return null if none or more than 1 instances exist.
|
|
|
|
template <typename OpType>
|
|
|
|
static OpType getSingleOpOfType(Block &block) {
|
|
|
|
OpType res = nullptr;
|
|
|
|
block.walk([&](OpType op) {
|
|
|
|
if (res) {
|
|
|
|
res = nullptr;
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
}
|
|
|
|
res = op;
|
|
|
|
return WalkResult::advance();
|
|
|
|
});
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))`
|
|
|
|
/// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent
|
|
|
|
/// unary operations that may change the type.
|
|
|
|
template <typename AddOpType, typename MulOpType>
|
|
|
|
static bool isAddMul(Block &block) {
|
|
|
|
if (block.getNumArguments() != 3)
|
|
|
|
return false;
|
|
|
|
Operation *yieldOp = block.getTerminator();
|
|
|
|
if (yieldOp->getNumOperands() != 1)
|
|
|
|
return false;
|
|
|
|
|
|
|
|
AddOpType addOp = getSingleOpOfType<AddOpType>(block);
|
|
|
|
MulOpType mulOp = getSingleOpOfType<MulOpType>(block);
|
|
|
|
if (!addOp || !mulOp)
|
|
|
|
return false;
|
|
|
|
|
|
|
|
Value argA = block.getArgument(0), argB = block.getArgument(1);
|
|
|
|
Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
|
|
|
|
Value mul = mulOp->getResult(0);
|
|
|
|
Value argC = block.getArgument(2);
|
|
|
|
Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1);
|
|
|
|
Value add = addOp->getResult(0);
|
|
|
|
Value res = yieldOp->getOperand(0);
|
|
|
|
// Result traces back to add.
|
|
|
|
auto un = isChainOfUnaryOpsFrom;
|
|
|
|
bool success = un(res, add);
|
|
|
|
// One of the operands of add traces back to argC, the other to the mul.
|
|
|
|
success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC));
|
|
|
|
// One of the operands of mul traces back to argA, the other to argB.
|
|
|
|
success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA));
|
|
|
|
return success;
|
|
|
|
}
|
|
|
|
|
2021-09-20 10:40:31 -07:00
|
|
|
enum class MatchContractionResult {
|
2021-02-04 16:49:09 +00:00
|
|
|
Success = 0,
|
|
|
|
NotLinalgOp,
|
|
|
|
WrongNumOperands,
|
|
|
|
NoReduction,
|
|
|
|
NotProjectedPermutations,
|
|
|
|
NotAddMul
|
|
|
|
};
|
|
|
|
static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
|
|
|
|
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
|
|
|
|
if (!linalgOp)
|
|
|
|
return MatchContractionResult::NotLinalgOp;
|
|
|
|
if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
|
|
|
|
return MatchContractionResult::WrongNumOperands;
|
2022-07-19 14:42:57 -07:00
|
|
|
auto mapRange = linalgOp.getIndexingMapsArray();
|
2021-02-04 16:49:09 +00:00
|
|
|
if (linalgOp.getNumReductionLoops() == 0)
|
|
|
|
return MatchContractionResult::NoReduction;
|
|
|
|
if (llvm::any_of(mapRange,
|
|
|
|
[](AffineMap m) { return !m.isProjectedPermutation(); }))
|
|
|
|
return MatchContractionResult::NotProjectedPermutations;
|
|
|
|
// TODO: more fields than add/mul.
|
2021-10-12 23:14:57 +00:00
|
|
|
if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
|
2022-05-12 13:35:27 +02:00
|
|
|
!isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()) &&
|
2022-07-07 20:36:41 +02:00
|
|
|
!isAddMul<complex::AddOp, complex::MulOp>(
|
|
|
|
linalgOp->getRegion(0).front()) &&
|
|
|
|
!isAddMul<arith::OrIOp, arith::AndIOp>(linalgOp->getRegion(0).front()))
|
2021-02-04 16:49:09 +00:00
|
|
|
return MatchContractionResult::NotAddMul;
|
|
|
|
return MatchContractionResult::Success;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
|
2021-02-05 13:37:00 +00:00
|
|
|
if (!linalgOp)
|
|
|
|
return false;
|
2021-02-04 16:49:09 +00:00
|
|
|
Operation *op = linalgOp.getOperation();
|
|
|
|
return isa<ContractionOpInterface>(op) ||
|
|
|
|
(isContractionInterfaceImpl(op) == MatchContractionResult::Success);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Verify that a LinalgOp `op` is a contraction.
|
|
|
|
/// A Linalg contraction is defined in general terms:
|
|
|
|
/// 1. Has 2 input and 1 output shapes.
|
|
|
|
/// 2. Has at least one reduction dimension.
|
|
|
|
/// 3. Has only projected permutation indexing maps.
|
|
|
|
/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
|
|
|
|
/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
|
|
|
|
/// operations that may change the type (e.g. for mixed-precision).
|
|
|
|
/// As a consequence, when vectorization of such an op occurs, the only special
|
|
|
|
/// behavior is that the (unique) MulOpType is vectorized into a
|
|
|
|
/// `vector.contract`. All other ops are handled in a generic fashion.
|
|
|
|
/// In the future, we may wish to allow more input arguments and elementwise and
|
|
|
|
/// constant operations that do not involve the reduction dimension(s).
|
|
|
|
LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
|
|
|
|
auto res = isContractionInterfaceImpl(op);
|
|
|
|
if (res == MatchContractionResult::NotLinalgOp)
|
|
|
|
return op->emitError("expected a LinalgOp");
|
|
|
|
if (res == MatchContractionResult::WrongNumOperands)
|
|
|
|
return op->emitError("expected op with 2 inputs and 1 outputs");
|
|
|
|
if (res == MatchContractionResult::NoReduction)
|
|
|
|
return op->emitError("expected at least a reduction loop");
|
|
|
|
if (res == MatchContractionResult::NotProjectedPermutations)
|
|
|
|
return op->emitError("expected all indexings to be projected permutations");
|
|
|
|
if (res == MatchContractionResult::NotAddMul)
|
|
|
|
return op->emitError("(add, mul) operations not found");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-09-20 10:40:31 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConvolutionOpInterface implementation
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Of the given two expressions returns one that is of type T (`lhs` gets
|
|
|
|
/// preference over `rhs`)
|
|
|
|
template <typename T>
|
|
|
|
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) {
|
|
|
|
return lhs.isa<T>() ? lhs.cast<T>()
|
|
|
|
: (rhs.isa<T>() ? rhs.cast<T>() : nullptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
/// Walk the indexing expressions for input of a convolution operation to verify
|
|
|
|
/// its of the right form, either
|
|
|
|
/// - AffineDimExpr
|
|
|
|
/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
|
|
|
|
/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
|
|
|
|
///
|
|
|
|
/// classifies the AffineDimExpr as convolved dimensions or unconvolved
|
|
|
|
/// dimensions and verifies each dimension occurs only once.
|
|
|
|
struct ConvAccessExprWalker
|
|
|
|
: public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
|
|
|
|
llvm::SmallDenseSet<unsigned> convolvedDims;
|
|
|
|
llvm::SmallDenseSet<unsigned> unConvolvedDims;
|
|
|
|
|
|
|
|
LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
|
|
|
|
unsigned position = dimExpr.getPosition();
|
|
|
|
if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
unConvolvedDims.insert(position);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
|
|
|
|
|
|
|
|
LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
|
|
|
|
|
|
|
|
LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
|
|
|
|
// In pre-order visit, top level op has to be an add op.
|
|
|
|
if (binaryExpr.getKind() != AffineExprKind::Add)
|
|
|
|
return failure();
|
|
|
|
return success(succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) &&
|
|
|
|
succeeded(isDimExprOrMulExpr(binaryExpr.getRHS())));
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult isDimExprOrMulExpr(AffineExpr expr) {
|
|
|
|
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
|
|
|
|
unsigned dim = dimExpr.getPosition();
|
|
|
|
if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
|
|
|
|
return failure();
|
|
|
|
convolvedDims.insert(dim);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
if (auto symbolMulExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
|
|
|
|
if (symbolMulExpr.getKind() != AffineExprKind::Mul)
|
|
|
|
return failure();
|
|
|
|
auto lhsExpr = symbolMulExpr.getLHS();
|
|
|
|
auto rhsExpr = symbolMulExpr.getRHS();
|
|
|
|
// Check for symbol expression.
|
|
|
|
AffineExpr mulExpr =
|
|
|
|
getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
|
|
|
|
// If there was no symbol expr, check for constant expression.
|
|
|
|
if (!mulExpr) {
|
|
|
|
mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
|
|
|
|
}
|
|
|
|
auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
|
|
|
|
if (!mulExpr || !dimExpr)
|
|
|
|
return failure();
|
|
|
|
unsigned dim = dimExpr.getPosition();
|
|
|
|
if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
|
|
|
|
return failure();
|
|
|
|
convolvedDims.insert(dim);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) {
|
|
|
|
assert(map.isProjectedPermutation() &&
|
|
|
|
"expected map to have projected permutations");
|
|
|
|
llvm::SmallDenseSet<unsigned> preservedDims;
|
|
|
|
for (auto expr : map.getResults())
|
|
|
|
preservedDims.insert(expr.cast<AffineDimExpr>().getPosition());
|
|
|
|
return preservedDims;
|
|
|
|
}
|
|
|
|
|
|
|
|
enum class MatchConvolutionResult {
|
|
|
|
Success = 0,
|
|
|
|
NotLinalgOp,
|
|
|
|
WrongNumOperands,
|
|
|
|
WrongInputIndexingMap,
|
|
|
|
NotProjectedPermutations,
|
|
|
|
NonConvolutionLoop,
|
|
|
|
OutputDimsNotParallel,
|
|
|
|
NonOutputDimNotReduction
|
|
|
|
};
|
|
|
|
|
|
|
|
static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
|
|
|
|
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
|
|
|
|
if (!linalgOp)
|
|
|
|
return MatchConvolutionResult::NotLinalgOp;
|
|
|
|
if (linalgOp.getNumInputs() < 2 || linalgOp.getNumOutputs() != 1)
|
|
|
|
return MatchConvolutionResult::WrongNumOperands;
|
|
|
|
|
2022-07-19 14:42:57 -07:00
|
|
|
auto indexingMaps = linalgOp.getIndexingMapsArray();
|
2021-09-20 10:40:31 -07:00
|
|
|
|
|
|
|
// Check the input indexing map has the right form.
|
|
|
|
ConvAccessExprWalker inputExprWalker;
|
|
|
|
if (llvm::any_of(indexingMaps[0].getResults(),
|
|
|
|
[&inputExprWalker](AffineExpr expr) {
|
|
|
|
return failed(inputExprWalker.visit(expr));
|
|
|
|
})) {
|
|
|
|
return MatchConvolutionResult::WrongInputIndexingMap;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Filter and output maps must be projected permutation.
|
|
|
|
if (!indexingMaps[1].isProjectedPermutation() ||
|
|
|
|
!indexingMaps.back().isProjectedPermutation())
|
|
|
|
return MatchConvolutionResult::NotProjectedPermutations;
|
|
|
|
|
|
|
|
auto iteratorTypesRange =
|
|
|
|
linalgOp.iterator_types().getAsValueRange<StringAttr>();
|
|
|
|
|
|
|
|
llvm::SmallDenseSet<unsigned> outputDims =
|
|
|
|
getPreservedDims(indexingMaps.back());
|
|
|
|
llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
|
|
|
|
// Make sure all loops are charecterized as one of:
|
|
|
|
// - Batch loop : present in output, as non-convolved in input, not present in
|
|
|
|
// filter.
|
|
|
|
// - Output image dimension : present in output, convolved dims in input, not
|
|
|
|
// present in filter.
|
|
|
|
// - Output channel dimension : present in output, not present in input,
|
|
|
|
// present in filter.
|
|
|
|
// - Filter loop dimension : present in filter, convolved in input, not
|
|
|
|
// present in output.
|
|
|
|
// - Input channel dimension : unconvolved in input, not present in output,
|
|
|
|
// present in filter.
|
|
|
|
// - Depth multiplier : unconvolved in input, present in output, present in
|
|
|
|
// filter.
|
|
|
|
llvm::SmallDenseSet<unsigned> allLoopDims;
|
|
|
|
for (auto outputExpr : indexingMaps.back().getResults()) {
|
|
|
|
unsigned outputDim = outputExpr.cast<AffineDimExpr>().getPosition();
|
|
|
|
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
|
|
|
|
!filterDims.count(outputDim)) {
|
|
|
|
// Batch dimension.
|
|
|
|
if (*std::next(iteratorTypesRange.begin(), outputDim) !=
|
|
|
|
getParallelIteratorTypeName())
|
|
|
|
return MatchConvolutionResult::OutputDimsNotParallel;
|
|
|
|
allLoopDims.insert(outputDim);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (inputExprWalker.convolvedDims.count(outputDim) &&
|
|
|
|
!filterDims.count(outputDim)) {
|
|
|
|
// Output image Loop dimension.
|
|
|
|
if (*std::next(iteratorTypesRange.begin(), outputDim) !=
|
|
|
|
getParallelIteratorTypeName())
|
|
|
|
return MatchConvolutionResult::OutputDimsNotParallel;
|
|
|
|
allLoopDims.insert(outputDim);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (!inputExprWalker.convolvedDims.count(outputDim) &&
|
|
|
|
!inputExprWalker.unConvolvedDims.count(outputDim) &&
|
|
|
|
filterDims.count(outputDim)) {
|
|
|
|
// Output channel dimension.
|
|
|
|
if (*std::next(iteratorTypesRange.begin(), outputDim) !=
|
|
|
|
getParallelIteratorTypeName())
|
|
|
|
return MatchConvolutionResult::OutputDimsNotParallel;
|
|
|
|
allLoopDims.insert(outputDim);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
|
|
|
|
filterDims.count(outputDim)) {
|
|
|
|
// Depth multiplier.
|
|
|
|
if (*std::next(iteratorTypesRange.begin(), outputDim) !=
|
|
|
|
getParallelIteratorTypeName())
|
|
|
|
return MatchConvolutionResult::OutputDimsNotParallel;
|
|
|
|
allLoopDims.insert(outputDim);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
return MatchConvolutionResult::NonConvolutionLoop;
|
|
|
|
}
|
|
|
|
for (auto filterExpr : indexingMaps[1].getResults()) {
|
|
|
|
unsigned filterDim = filterExpr.cast<AffineDimExpr>().getPosition();
|
|
|
|
if (outputDims.count(filterDim) &&
|
|
|
|
!inputExprWalker.unConvolvedDims.count(filterDim) &&
|
|
|
|
!inputExprWalker.convolvedDims.count(filterDim)) {
|
|
|
|
// Output channel dimension. THis is already seen, continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (inputExprWalker.convolvedDims.count(filterDim) &&
|
|
|
|
!outputDims.count(filterDim)) {
|
|
|
|
// Filter loop dimension.
|
|
|
|
if (*std::next(iteratorTypesRange.begin(), filterDim) !=
|
|
|
|
getReductionIteratorTypeName())
|
|
|
|
return MatchConvolutionResult::NonOutputDimNotReduction;
|
|
|
|
if (allLoopDims.count(filterDim))
|
|
|
|
return MatchConvolutionResult::NonConvolutionLoop;
|
|
|
|
allLoopDims.insert(filterDim);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
|
|
|
|
!outputDims.count(filterDim)) {
|
|
|
|
// Input channel dimension.
|
|
|
|
if (*std::next(iteratorTypesRange.begin(), filterDim) !=
|
|
|
|
getReductionIteratorTypeName())
|
|
|
|
return MatchConvolutionResult::NonOutputDimNotReduction;
|
|
|
|
if (allLoopDims.count(filterDim))
|
|
|
|
return MatchConvolutionResult::NonConvolutionLoop;
|
|
|
|
allLoopDims.insert(filterDim);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
|
|
|
|
outputDims.count(filterDim)) {
|
|
|
|
// Depthwise loop. Already seen.
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
return MatchConvolutionResult::NonConvolutionLoop;
|
|
|
|
}
|
|
|
|
// All loops must be covered now.
|
|
|
|
if (allLoopDims.size() != linalgOp.getNumLoops())
|
|
|
|
return MatchConvolutionResult::NonConvolutionLoop;
|
|
|
|
|
|
|
|
return MatchConvolutionResult::Success;
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
|
|
|
|
auto res = isConvolutionInterfaceImpl(op);
|
|
|
|
if (res == MatchConvolutionResult::NotLinalgOp)
|
|
|
|
return op->emitError("expected a LinalgOp");
|
|
|
|
if (res == MatchConvolutionResult::WrongNumOperands)
|
|
|
|
return op->emitError("expected op with 2 inputs and 1 output");
|
|
|
|
if (res == MatchConvolutionResult::WrongInputIndexingMap)
|
|
|
|
return op->emitError("unexpected input index map for convolutions");
|
|
|
|
if (res == MatchConvolutionResult::NotProjectedPermutations) {
|
|
|
|
return op->emitError(
|
|
|
|
"expected output/filter indexing maps to be projected permutations");
|
|
|
|
}
|
|
|
|
if (res == MatchConvolutionResult::NonConvolutionLoop) {
|
|
|
|
return op->emitError("unexpected loop dimension for convolution op");
|
|
|
|
}
|
|
|
|
if (res == MatchConvolutionResult::OutputDimsNotParallel) {
|
|
|
|
return op->emitError(
|
|
|
|
"expected all iterators used to access outputs to be parallel");
|
|
|
|
}
|
|
|
|
if (res == MatchConvolutionResult::NonOutputDimNotReduction) {
|
|
|
|
return op->emitError(
|
|
|
|
"expected all iterators not used to access outputs to be reduction");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
2022-03-08 15:33:47 +00:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// FillOpInterface implementation
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
enum class MatchFillResult {
|
|
|
|
Success = 0,
|
|
|
|
NotLinalgOp,
|
|
|
|
WrongNumOperands,
|
|
|
|
NotScalarInput
|
|
|
|
};
|
|
|
|
|
|
|
|
static MatchFillResult isFillInterfaceImpl(Operation *op) {
|
|
|
|
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
|
|
|
|
if (!linalgOp)
|
|
|
|
return MatchFillResult::NotLinalgOp;
|
|
|
|
if (linalgOp.getNumInputs() != 1 || linalgOp.getNumOutputs() != 1)
|
|
|
|
return MatchFillResult::WrongNumOperands;
|
|
|
|
|
|
|
|
OpOperand *value = linalgOp.getInputOperand(0);
|
|
|
|
if (!linalgOp.isScalar(value))
|
|
|
|
return MatchFillResult::NotScalarInput;
|
|
|
|
|
|
|
|
return MatchFillResult::Success;
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
|
|
|
|
auto res = isFillInterfaceImpl(op);
|
|
|
|
if (res == MatchFillResult::NotLinalgOp)
|
|
|
|
return op->emitError("expected a LinalgOp");
|
|
|
|
if (res == MatchFillResult::WrongNumOperands)
|
|
|
|
return op->emitError("expected op with 1 input and 1 output");
|
|
|
|
if (res == MatchFillResult::NotScalarInput)
|
|
|
|
return op->emitError("expected op with scalar input");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-02-04 16:49:09 +00:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// StructuredOpInterface implementation
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-05-31 12:46:32 +00:00
|
|
|
OpOperandVector::operator SmallVector<Value>() {
|
|
|
|
SmallVector<Value> result;
|
|
|
|
result.reserve(this->size());
|
|
|
|
llvm::transform(*this, std::back_inserter(result),
|
|
|
|
[](OpOperand *opOperand) { return opOperand->get(); });
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
2021-07-05 10:04:01 +09:00
|
|
|
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
|
|
|
|
/// the type of `source`.
|
|
|
|
static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
|
|
|
|
int64_t dim) {
|
|
|
|
if (source.getType().isa<UnrankedMemRefType, MemRefType>())
|
|
|
|
return b.createOrFold<memref::DimOp>(loc, source, dim);
|
|
|
|
if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
|
|
|
|
return b.createOrFold<tensor::DimOp>(loc, source, dim);
|
|
|
|
llvm_unreachable("Expected MemRefType or TensorType");
|
|
|
|
}
|
|
|
|
|
2021-02-03 22:19:12 +00:00
|
|
|
SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
|
|
|
|
Location loc) {
|
|
|
|
SmallVector<Value, 4> res;
|
2021-06-04 08:18:19 +00:00
|
|
|
for (OpOperand *opOperand : getInputAndOutputOperands()) {
|
|
|
|
for (int64_t i = 0, e = getRank(opOperand); i < e; ++i)
|
2021-07-01 09:58:48 +09:00
|
|
|
res.push_back(createOrFoldDimOp(b, loc, opOperand->get(), i));
|
2021-02-03 22:19:12 +00:00
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2021-04-20 11:26:44 +00:00
|
|
|
SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
|
|
|
|
SmallVector<int64_t, 4> res;
|
2021-06-04 08:18:19 +00:00
|
|
|
assert(!hasDynamicShape() && "expected operands to have static shapes");
|
|
|
|
for (OpOperand *opOperand : getInputAndOutputOperands())
|
|
|
|
llvm::append_range(res, getShape(opOperand));
|
2021-04-20 11:26:44 +00:00
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2021-02-03 22:19:12 +00:00
|
|
|
SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
|
|
|
|
AffineMap map = getLoopsToShapesMap();
|
|
|
|
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
|
|
|
|
auto viewSizes = createFlatListOfOperandDims(b, loc);
|
|
|
|
SmallVector<Range, 4> res(numDims);
|
2021-10-12 23:14:57 +00:00
|
|
|
Value zeroVal = b.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
Value oneVal = b.create<arith::ConstantIndexOp>(loc, 1);
|
2021-02-03 22:19:12 +00:00
|
|
|
for (unsigned idx = 0; idx < numRes; ++idx) {
|
|
|
|
auto result = map.getResult(idx);
|
|
|
|
if (auto d = result.dyn_cast<AffineDimExpr>()) {
|
|
|
|
if (res[d.getPosition()].offset)
|
|
|
|
continue;
|
|
|
|
res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2021-04-20 11:26:44 +00:00
|
|
|
SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
|
|
|
|
AffineMap map = getLoopsToShapesMap();
|
|
|
|
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
|
|
|
|
SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
|
|
|
|
SmallVector<int64_t, 4> res(numDims, 0);
|
|
|
|
for (unsigned idx = 0; idx < numRes; ++idx) {
|
|
|
|
auto result = map.getResult(idx);
|
|
|
|
if (auto d = result.dyn_cast<AffineDimExpr>())
|
|
|
|
res[d.getPosition()] = allShapeSizes[idx];
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2021-02-03 22:19:12 +00:00
|
|
|
/// Visitor to check if any of the given set of positions from AffineDimExprs
|
|
|
|
/// are used within an AffineExpr.
|
|
|
|
struct HasAffineDimExprVisitor
|
|
|
|
: public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
|
2022-02-12 14:19:35 +01:00
|
|
|
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
|
|
|
|
: positions(std::move(positions)) {}
|
2021-02-03 22:19:12 +00:00
|
|
|
|
|
|
|
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
|
|
|
|
return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
|
|
|
|
}
|
|
|
|
|
|
|
|
bool visitDimExpr(AffineDimExpr dimExpr) {
|
2022-02-12 14:19:35 +01:00
|
|
|
return positions.test(dimExpr.getPosition());
|
2021-02-03 22:19:12 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
|
|
|
|
|
|
|
|
bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
|
|
|
|
|
|
|
|
private:
|
2022-02-12 14:19:35 +01:00
|
|
|
llvm::SmallBitVector positions;
|
2021-02-03 22:19:12 +00:00
|
|
|
};
|
|
|
|
|
2021-07-19 14:35:20 -07:00
|
|
|
LogicalResult
|
|
|
|
LinalgOp::reifyResultShapes(OpBuilder &b,
|
|
|
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
2021-02-03 22:19:12 +00:00
|
|
|
// An example that helps understand the logic below.
|
|
|
|
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
|
|
|
|
// We want to express the shape of dim 0 of O in terms of shape of the inputs.
|
|
|
|
// This is achieved as follows.
|
|
|
|
// loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
|
2021-03-29 10:57:23 -07:00
|
|
|
// subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
|
2021-02-03 22:19:12 +00:00
|
|
|
// shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
|
2021-03-29 10:57:23 -07:00
|
|
|
// resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
|
|
|
|
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
|
2021-02-03 22:19:12 +00:00
|
|
|
AffineMap loopsToShapesMap = getLoopsToShapesMap();
|
|
|
|
|
|
|
|
// Find the position in the above map that represents the shape of the
|
|
|
|
// result:dim being inferred.
|
2021-03-29 10:57:23 -07:00
|
|
|
auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap();
|
2021-02-03 22:19:12 +00:00
|
|
|
|
|
|
|
/// From loopsToShapesMap extract the submap that represents the shape of the
|
2021-03-29 10:57:23 -07:00
|
|
|
/// (resultIdx, dim) needed.
|
2022-02-12 14:19:35 +01:00
|
|
|
AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
|
|
|
|
resultShapesSubMapPos.first,
|
|
|
|
resultShapesSubMapPos.second - resultShapesSubMapPos.first);
|
2021-03-29 10:57:23 -07:00
|
|
|
AffineMap resultShapesFromInputShapesMap =
|
|
|
|
loopToResultsShapeMap.compose(getShapesToLoopsMap());
|
2021-02-03 22:19:12 +00:00
|
|
|
|
|
|
|
// Check that the result dim map does not contain the positions corresponding
|
|
|
|
// to the outputs.
|
2022-02-12 14:19:35 +01:00
|
|
|
llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
|
|
|
|
outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
|
|
|
|
HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
|
2021-03-29 10:57:23 -07:00
|
|
|
Location loc = getOperation()->getLoc();
|
|
|
|
auto allResultDimValues =
|
|
|
|
applyMapToValues(b, loc, resultShapesFromInputShapesMap,
|
|
|
|
createFlatListOfOperandDims(b, loc));
|
2021-06-04 08:18:19 +00:00
|
|
|
int64_t pos = 0;
|
2021-03-29 10:57:23 -07:00
|
|
|
ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
|
2021-06-04 08:18:19 +00:00
|
|
|
for (OpOperand *opOperand : getOutputOperands()) {
|
2021-03-29 10:57:23 -07:00
|
|
|
SmallVector<Value> shapes;
|
2021-06-04 08:18:19 +00:00
|
|
|
for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) {
|
2021-03-29 10:57:23 -07:00
|
|
|
if (checkDimExpr.visit(shapeExprs[pos]))
|
2021-07-01 09:58:48 +09:00
|
|
|
shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim));
|
2021-03-29 10:57:23 -07:00
|
|
|
else
|
|
|
|
shapes.push_back(allResultDimValues[pos]);
|
|
|
|
pos++;
|
|
|
|
}
|
|
|
|
reifiedReturnShapes.emplace_back(std::move(shapes));
|
|
|
|
}
|
|
|
|
return success();
|
2021-02-03 22:19:12 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
|
|
|
LinalgOp linalgOp = cast<LinalgOp>(op);
|
2021-06-29 06:54:39 +00:00
|
|
|
// Expect at least one output operand.
|
2021-02-03 22:19:12 +00:00
|
|
|
// This means an op that constructs a tensor out of indices cannot be a
|
|
|
|
// LinalgOp at the moment. For now this will have to be a special op until we
|
|
|
|
// have output shape operands that are not tensors.
|
2021-06-29 06:54:39 +00:00
|
|
|
int64_t numInputs = linalgOp.getNumInputs();
|
|
|
|
int64_t numOutputs = linalgOp.getNumOutputs();
|
|
|
|
if (numOutputs == 0)
|
|
|
|
return op->emitOpError("expected at least one output operand");
|
|
|
|
if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
|
2021-02-03 22:19:12 +00:00
|
|
|
return failure();
|
2021-07-08 06:23:55 +00:00
|
|
|
// Verify the number of results matches the number of output tensors.
|
|
|
|
if (op->getNumResults() != linalgOp.getOutputTensorOperands().size())
|
|
|
|
return op->emitOpError("expected the number of results (")
|
|
|
|
<< op->getNumResults()
|
|
|
|
<< ") to be equal to the number of output tensors ("
|
|
|
|
<< linalgOp.getOutputTensorOperands().size() << ")";
|
2021-02-03 22:19:12 +00:00
|
|
|
|
2022-02-28 11:25:12 +00:00
|
|
|
// Check all iterator types are known.
|
|
|
|
auto iteratorTypesRange =
|
|
|
|
linalgOp.iterator_types().getAsValueRange<StringAttr>();
|
|
|
|
for (StringRef iteratorType : iteratorTypesRange) {
|
|
|
|
if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType))
|
|
|
|
return op->emitOpError("unexpected iterator_type (")
|
|
|
|
<< iteratorType << ")";
|
|
|
|
}
|
|
|
|
|
2021-02-09 08:47:21 -05:00
|
|
|
// Before checking indexing maps, we need to make sure the attributes
|
|
|
|
// referenced by it are valid.
|
|
|
|
if (linalgOp.hasDynamicIndexingMaps())
|
|
|
|
if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
|
|
|
|
return failure();
|
|
|
|
|
2021-06-14 05:59:33 +00:00
|
|
|
// All input/output operands must be indexed.
|
2022-07-19 14:42:57 -07:00
|
|
|
if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
|
2021-06-01 10:55:08 +00:00
|
|
|
linalgOp.getNumInputsAndOutputs())
|
2021-05-31 13:30:56 +00:00
|
|
|
return op->emitOpError("expected the number of indexing_map (")
|
2022-07-19 14:42:57 -07:00
|
|
|
<< linalgOp.getIndexingMapsArray().size()
|
2021-05-31 13:30:56 +00:00
|
|
|
<< ") to be equal to the number of input/output operands ("
|
|
|
|
<< linalgOp.getNumInputsAndOutputs() << ")";
|
2021-02-03 22:19:12 +00:00
|
|
|
|
2021-05-31 13:30:56 +00:00
|
|
|
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
|
|
|
AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
|
2021-02-03 22:19:12 +00:00
|
|
|
|
|
|
|
// Symbols disallowed.
|
2021-05-31 13:30:56 +00:00
|
|
|
if (indexingMap.getNumSymbols() != 0)
|
|
|
|
return op->emitOpError("unexpected symbols in indexing_map #")
|
|
|
|
<< opOperand->getOperandNumber();
|
2021-02-03 22:19:12 +00:00
|
|
|
|
|
|
|
// Domain must be consistent.
|
2021-05-31 13:30:56 +00:00
|
|
|
unsigned numLoops = linalgOp.getNumLoops();
|
|
|
|
if (indexingMap.getNumDims() != numLoops)
|
|
|
|
return op->emitOpError("expected indexing_map #")
|
|
|
|
<< opOperand->getOperandNumber() << " to have " << numLoops
|
2021-02-03 22:19:12 +00:00
|
|
|
<< " dim(s) to match the number of loops";
|
|
|
|
|
2021-05-31 13:30:56 +00:00
|
|
|
int64_t rank = linalgOp.getRank(opOperand);
|
|
|
|
if (indexingMap.getNumResults() != rank)
|
2021-06-14 05:59:33 +00:00
|
|
|
return op->emitOpError("expected operand rank (")
|
2021-05-31 13:30:56 +00:00
|
|
|
<< rank << ") to match the result rank of indexing_map #"
|
|
|
|
<< opOperand->getOperandNumber() << " ("
|
|
|
|
<< indexingMap.getNumResults() << ")";
|
2021-02-03 22:19:12 +00:00
|
|
|
}
|
|
|
|
|
2021-12-30 14:01:38 -08:00
|
|
|
SmallVector<unsigned> redDims;
|
2021-02-03 22:19:12 +00:00
|
|
|
linalgOp.getReductionDims(redDims);
|
|
|
|
|
|
|
|
// Simplifying assumption: either full tensor or full buffer mode.
|
|
|
|
// This allows simpler verification of output operands vs result types
|
|
|
|
// without premature tracking of which operand is what in mixed-mode.
|
|
|
|
// TODO: relax when mixed-mode needs to pass verification.
|
2021-05-31 13:30:56 +00:00
|
|
|
if (!linalgOp.getOutputBufferOperands().empty() &&
|
|
|
|
!linalgOp.getOutputTensorOperands().empty())
|
|
|
|
return op->emitOpError(
|
|
|
|
"expected output operands to all have tensor type or "
|
|
|
|
"all have buffer type");
|
|
|
|
|
|
|
|
for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) {
|
|
|
|
OpResult result = linalgOp.getTiedOpResult(opOperand);
|
|
|
|
if (result.getType() != opOperand->get().getType())
|
|
|
|
return op->emitOpError("expected type of operand #")
|
|
|
|
<< opOperand->getOperandNumber() << " ("
|
|
|
|
<< opOperand->get().getType() << ")"
|
|
|
|
<< " to match type of corresponding result (" << result.getType()
|
2021-02-03 22:19:12 +00:00
|
|
|
<< ")";
|
|
|
|
}
|
|
|
|
|
|
|
|
// Output tensor indexing map may not depend on reduction indices.
|
2021-05-31 13:30:56 +00:00
|
|
|
for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
|
|
|
|
AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
|
2021-12-30 14:01:38 -08:00
|
|
|
for (AffineExpr expr : indexingMap.getResults()) {
|
|
|
|
for (unsigned pos : redDims) {
|
2021-02-03 22:19:12 +00:00
|
|
|
if (expr.isFunctionOfDim(pos)) {
|
|
|
|
std::string exprStr;
|
|
|
|
{
|
|
|
|
llvm::raw_string_ostream os(exprStr);
|
|
|
|
os << expr;
|
|
|
|
}
|
2021-05-31 13:30:56 +00:00
|
|
|
return op->emitOpError(
|
2021-02-03 22:19:12 +00:00
|
|
|
"unexpected output tensor expression in indexing map #")
|
2021-05-31 13:30:56 +00:00
|
|
|
<< (opOperand->getOperandNumber() - linalgOp.getNumInputs())
|
2021-02-03 22:19:12 +00:00
|
|
|
<< " a.k.a '" << exprStr
|
|
|
|
<< "' is function of reduction iterator 'd" << pos << "'";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-10-14 09:08:39 +00:00
|
|
|
// Check the region has exactly one block.
|
|
|
|
if (linalgOp->getNumRegions() != 1 ||
|
|
|
|
!llvm::hasSingleElement(linalgOp->getRegion(0)))
|
|
|
|
return op->emitOpError("expects to have 1 region with 1 block");
|
2021-02-03 22:19:12 +00:00
|
|
|
|
|
|
|
if (!linalgOp.getShapesToLoopsMap())
|
|
|
|
return op->emitOpError("expected the shape-to-loops map to be non-null");
|
|
|
|
|
|
|
|
// Simplifying assumption: bbargs match 1-1 with shape operands elemental
|
|
|
|
// types.
|
|
|
|
// TODO: once ranked shape types are plugged in, we may want to drop the
|
|
|
|
// corresponding bbargs, that can never be read from. This will be subject to
|
|
|
|
// consistency discussions (i.e. what to do with output tensors whose bbarg is
|
|
|
|
// not used).
|
|
|
|
Block &block = linalgOp->getRegion(0).front();
|
|
|
|
|
2021-06-21 16:28:23 +02:00
|
|
|
if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments())
|
2021-05-31 13:30:56 +00:00
|
|
|
return op->emitOpError("expected as many non-induction variable region "
|
2021-06-14 05:59:33 +00:00
|
|
|
"arguments as the number of input/output operands");
|
2021-02-03 22:19:12 +00:00
|
|
|
|
2021-05-31 13:30:56 +00:00
|
|
|
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
2021-06-14 05:59:33 +00:00
|
|
|
Type elementType = getElementTypeOrSelf(opOperand->get());
|
2021-06-21 16:28:23 +02:00
|
|
|
Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
|
2021-05-31 13:30:56 +00:00
|
|
|
if (elementType != argType)
|
|
|
|
return op->emitOpError("expected type of bb argument #")
|
2021-06-21 16:28:23 +02:00
|
|
|
<< opOperand->getOperandNumber() << " (" << argType << ")"
|
2021-06-14 05:59:33 +00:00
|
|
|
<< " to match element or self type of the corresponding operand ("
|
2021-05-31 13:30:56 +00:00
|
|
|
<< elementType << ")";
|
2021-02-03 22:19:12 +00:00
|
|
|
}
|
|
|
|
|
2021-03-30 07:07:12 -07:00
|
|
|
// Check if given shapes match to inferred shapes.
|
2022-05-03 05:12:51 -07:00
|
|
|
SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
|
|
|
|
SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
|
2021-03-30 07:07:12 -07:00
|
|
|
|
|
|
|
// Verify only static cases since we can't get exact dimension sizes and loop
|
|
|
|
// ranges for dynamic cases in this stage.
|
2022-05-03 05:12:51 -07:00
|
|
|
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
|
|
|
|
for (int64_t &range : endLoopRangeValues)
|
2021-03-30 07:07:12 -07:00
|
|
|
range -= 1;
|
2021-05-31 13:30:56 +00:00
|
|
|
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
|
|
|
AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
|
|
|
|
SmallVector<int64_t, 4> startIndices =
|
|
|
|
indexingMap.compose(startLoopRangeValues);
|
|
|
|
SmallVector<int64_t, 4> endIndices =
|
2022-05-03 05:12:51 -07:00
|
|
|
indexingMap.compose(endLoopRangeValues);
|
2021-05-31 13:30:56 +00:00
|
|
|
ArrayRef<int64_t> shape = linalgOp.getShape(opOperand);
|
|
|
|
for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
|
2021-05-12 10:28:37 -07:00
|
|
|
// Ignore dynamic dimension or the case that the dimension size is 0
|
2021-05-31 13:30:56 +00:00
|
|
|
if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
|
2021-03-30 07:07:12 -07:00
|
|
|
continue;
|
|
|
|
|
2021-05-12 10:28:37 -07:00
|
|
|
// The first index or last index should be the maximum or the minimum in
|
|
|
|
// the inferred index ranges since the range is increasing or
|
2021-06-14 05:59:33 +00:00
|
|
|
// decreasing. The size of dimensions of input/output operands and the
|
|
|
|
// maximum value + 1 in the inferred range should be the same. But, for
|
|
|
|
// now we check if the inferred ranges are in boundary of input/output
|
|
|
|
// operands' size or not in case that Affine Expressions are complicated
|
|
|
|
// such as d0 * 3
|
2021-05-12 10:28:37 -07:00
|
|
|
// + d1 since it is not easy to handle the issues.
|
|
|
|
// Found the case that this solution can't check, for example, (d0, d1)
|
|
|
|
// -> (d1 - d0)
|
2021-05-31 13:30:56 +00:00
|
|
|
int64_t inferredDimSize =
|
|
|
|
std::max(startIndices[dim], endIndices[dim]) + 1;
|
|
|
|
if (std::min(startIndices[dim], endIndices[dim]) < 0) {
|
2021-05-12 10:28:37 -07:00
|
|
|
std::string mapStr;
|
|
|
|
{
|
|
|
|
llvm::raw_string_ostream os(mapStr);
|
2021-05-31 13:30:56 +00:00
|
|
|
os << indexingMap;
|
2021-05-12 10:28:37 -07:00
|
|
|
}
|
2021-05-31 13:30:56 +00:00
|
|
|
return op->emitOpError(
|
2021-05-12 10:28:37 -07:00
|
|
|
"unexpected result less than 0 at expression #")
|
2021-05-31 13:30:56 +00:00
|
|
|
<< dim << " in " << mapStr;
|
2021-05-12 10:28:37 -07:00
|
|
|
}
|
2021-05-31 13:30:56 +00:00
|
|
|
if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
|
|
|
|
if (inferredDimSize != shape[dim]) {
|
2021-06-14 05:59:33 +00:00
|
|
|
return op->emitOpError("inferred input/output operand #")
|
2021-05-31 13:30:56 +00:00
|
|
|
<< opOperand->getOperandNumber()
|
|
|
|
<< " has shape's dimension #" << dim << " to be "
|
|
|
|
<< inferredDimSize << ", but found " << shape[dim];
|
2021-03-30 07:07:12 -07:00
|
|
|
}
|
|
|
|
} else {
|
2021-05-31 13:30:56 +00:00
|
|
|
if (inferredDimSize > shape[dim]) {
|
2021-06-14 05:59:33 +00:00
|
|
|
return op->emitOpError("inferred input/output operand #")
|
2021-05-31 13:30:56 +00:00
|
|
|
<< opOperand->getOperandNumber()
|
|
|
|
<< " has shape's dimension #" << dim
|
2021-05-12 10:28:37 -07:00
|
|
|
<< " to be greater than or equal to " << inferredDimSize
|
2021-05-31 13:30:56 +00:00
|
|
|
<< ", but found " << shape[dim];
|
2021-03-30 07:07:12 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-02-03 22:19:12 +00:00
|
|
|
return success();
|
|
|
|
}
|