llvm-project/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

1531 lines
54 KiB
C++

//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file
// TOSA canonicalization patterns and folders.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include <functional>
using namespace mlir;
using namespace mlir::tosa;
//===----------------------------------------------------------------------===//
// Operator Canonicalizers.
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Tensor Data Engine Operators.
//===----------------------------------------------------------------------===//
// Check that the zero point of the tensor and padding operations are aligned.
bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
// Check that padConst is a constant value and a scalar tensor
DenseElementsAttr padConstAttr;
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
(padConstAttr.size() != 1)) {
return false;
}
// Check that floating point pad is zero
if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
return padConstVal == 0.0f;
}
// Check that the zp and padConst align for the integer (quantized) case
if (auto padConstIntAttr =
mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
DenseIntElementsAttr zpAttr;
// Check that zp is a constant value and a scalar tensor
if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
return false;
}
// Check equality
int64_t zpVal = (*zpAttr.begin()).getSExtValue();
int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
return zpVal == padConstVal;
}
// Bail-out on unsupported type
return false;
}
namespace {
template <typename OpTy>
struct PoolPadFoldAdaptor;
template <>
struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
using OpTy = tosa::AvgPool2dOp;
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
return false;
return true;
}
static bool checkPadConstCompliance(OpTy op, Value padConst) {
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
}
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
Value padInput, ArrayRef<int64_t> newPad) {
rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
op.getAccType());
}
};
template <>
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
using OpTy = tosa::MaxPool2dOp;
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
return false;
return true;
}
static bool checkPadConstCompliance(OpTy, Value padConst) {
// Check that padConst is a constant value and a scalar tensor
DenseElementsAttr padConstAttr;
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
padConstAttr.size() != 1) {
return false;
}
// Pad needs to be in the minimum value to be able to merge
if (auto padConstFpAttr =
mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
const APFloat padConstVal = *padConstFpAttr.begin();
const APFloat lowestVal =
APFloat::getLargest(padConstVal.getSemantics(), true);
return padConstVal == lowestVal;
} else if (auto padConstIntAttr =
mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
const APInt padConstVal = *padConstIntAttr.begin();
const unsigned int bitWidth = padConstVal.getBitWidth();
const APInt lowestVal =
padConstIntAttr.getElementType().isUnsignedInteger()
? APInt::getZero(bitWidth)
: APInt::getSignedMinValue(bitWidth);
return padConstVal == lowestVal;
}
// Bail-out on unsupported type
return false;
}
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
Value padInput, ArrayRef<int64_t> newPad) {
rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
op, op.getType(), padInput, op.getKernel(), op.getStride(),
rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
}
};
template <typename OpTy>
struct ConvPadFoldAdaptor {
static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
return true;
}
static bool checkPadConstCompliance(OpTy op, Value padConst) {
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
}
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
Value padInput, ArrayRef<int64_t> newPad) {
rewriter.replaceOpWithNewOp<OpTy>(
op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
op.getDilationAttr(), op.getAccType(), op.getLocalBound());
}
};
// Pattern attempts to fold a `tosa.pad` operator to a following tensor
// operation like `tosa.conv2d` by merging the padding associated with the
// pad operator directly to the implicit padding of the tensor operation.
// This helps eliminate the explicit padding operator if unused.
template <typename OpTy, typename AdaptorTy>
struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy tensorOp,
PatternRewriter &rewriter) const override {
// Check producer is a tosa::PadOp
auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
if (!padOp)
return rewriter.notifyMatchFailure(tensorOp,
"Producer must be a tosa::PadOp.");
// Validate that tensor operation has sane padding
const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
return rewriter.notifyMatchFailure(
tensorOp, "Tensor operation padding shall have 4 elements.");
// Validate tosa::PadOp padding
DenseIntElementsAttr padOpPadding;
if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
return rewriter.notifyMatchFailure(
tensorOp,
"The `padding` input specified on the tosa::PadOp must be constant.");
}
// N_before, N_after, H_before, H_after, W_before, W_after, C_before,
// C_after
if (padOpPadding.size() != 8)
return rewriter.notifyMatchFailure(tensorOp,
"Pad padding should have 8 elements.");
int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
return rewriter.notifyMatchFailure(
tensorOp, "Folding padding in N or C dimensions is not supported.");
// Fold padding from Pad into the tensor operation
// 4 elements - pad_top, pad_bottom, pad_left, pad_right
SmallVector<int64_t> foldedPad(tensorOpPad.size());
foldedPad[0] = padHBefore + tensorOpPad[0];
foldedPad[1] = padHAfter + tensorOpPad[1];
foldedPad[2] = padWBefore + tensorOpPad[2];
foldedPad[3] = padWAfter + tensorOpPad[3];
// Check kernel related restrictions
if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
return rewriter.notifyMatchFailure(
tensorOp, "Padding size not aligned with kernel restrictions.");
}
// Check padding constant restrictions
if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
return rewriter.notifyMatchFailure(
tensorOp,
"Padding constant is not aligned with operator zero-point.");
}
// Check that padding doesn't grow more than 8K level (8192) for now
if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
return rewriter.notifyMatchFailure(
tensorOp, "Padding size more than the 8K level limit.");
}
// Create operator
AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
foldedPad);
return success();
}
};
} // namespace
void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
context);
}
void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
context);
}
void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
context);
}
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();
Value output = op.getOutput();
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
ShapedType outputType = llvm::cast<ShapedType>(output.getType());
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
return failure();
}
// If the output and input shapes are 1x1, then this is a no op.
ArrayRef<int64_t> outputShape = outputType.getShape();
if (outputShape[1] != 1 || outputShape[2] != 1) {
return failure();
}
ArrayRef<int64_t> inputShape = inputType.getShape();
if (inputShape[1] != 1 || inputShape[2] != 1) {
return failure();
}
rewriter.replaceOp(op, input);
return success();
}
};
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MaxPool2dIsNoOp,
FoldPadToTensorOp<tosa::MaxPool2dOp,
PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
context);
}
//===----------------------------------------------------------------------===//
// Data Layout / Memory Reinterpretation.
//===----------------------------------------------------------------------===//
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ConcatOp op,
PatternRewriter &rewriter) const override {
if (op.getInput1().size() != 1)
return failure();
if (op.getInput1().front().getType() != op.getType()) {
rewriter
.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
op.getInput1().front())
.getResult();
return success();
}
rewriter.replaceOp(op, op.getInput1().front());
return success();
}
};
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConcatOptimization>(context);
}
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
if (!notOp)
return failure();
rewriter.modifyOpInPlace(op, [&]() {
op.getOperation()->setOperands(
{notOp.getInput1(), op.getInput3(), op.getInput2()});
});
return success();
}
struct ConsolidateTransposeOptimization
: public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
// Input is also TransposeOp - transpose(transpose(A)).
auto innerTranspose =
transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
if (!innerTranspose)
return rewriter.notifyMatchFailure(transposeOp,
"input must be transpose operation");
const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
const llvm::ArrayRef<int32_t> innerTransposePerms =
innerTranspose.getPerms();
if (transposePerms.size() != innerTransposePerms.size())
return rewriter.notifyMatchFailure(
transposeOp,
"transpose and inner transpose perms sizes must be equal");
if (transposePerms.empty())
return rewriter.notifyMatchFailure(
transposeOp, "transpose perms sizes must be positive");
// Consolidate transposes into one transpose.
SmallVector<int32_t> perms(transposePerms.size());
for (int i = 0, s = transposePerms.size(); i < s; ++i)
perms[i] = innerTransposePerms[transposePerms[i]];
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
return success();
}
};
// Determines the case when tosa.transpose is a tosa.reshape operation.
struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
return rewriter.notifyMatchFailure(
op, "Src is from transpose, can compose transposes");
Value result = op.getResult();
for (Operation *subop : result.getUsers()) {
if (dyn_cast_or_null<tosa::TransposeOp>(subop))
return rewriter.notifyMatchFailure(
op, "Dest is used by transpose, can compose transposes");
}
auto input = op.getInput1();
auto inputTy = llvm::cast<ShapedType>(input.getType());
if (!inputTy.hasRank())
return rewriter.notifyMatchFailure(op, "Unranked input.");
int64_t numDynDims = 0;
for (int i = 0; i < inputTy.getRank(); ++i)
if (inputTy.isDynamicDim(i))
numDynDims++;
if (numDynDims > 1)
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
const llvm::ArrayRef<int32_t> permValues = op.getPerms();
SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
for (auto idx : permValues) {
auto sz = inputTy.getDimSize(idx);
if (sz != 1)
nonZeroPerms.push_back(idx);
}
for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
if (nonZeroPerms[i - 1] > nonZeroPerms[i])
return rewriter.notifyMatchFailure(op,
"Transpose changes memory layout.");
SmallVector<int64_t> newShape;
newShape.reserve(inputTy.getRank());
for (int i = 0, s = inputTy.getRank(); i < s; ++i)
newShape.push_back(inputTy.getDimSize(permValues[i]));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, op.getType(), op.getInput1(),
getTosaConstShape(rewriter, op.getLoc(), newShape));
return success();
}
};
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
}
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ClampOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
auto inputElementType = inputType.getElementType();
if (!inputType.hasStaticShape()) {
return failure();
}
if (isa<FloatType>(inputElementType)) {
// Unlike integer types, floating point types can represent infinity.
auto minClamp =
llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
auto maxClamp =
llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
bool isMin = minClamp.isNegInfinity();
bool isMax = maxClamp.isInfinity();
if (isMin && isMax) {
rewriter.replaceOp(op, input);
return success();
}
return failure();
}
if (inputElementType.isUnsignedInteger()) {
int64_t minClamp =
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
int64_t maxClamp =
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
int64_t intMin =
APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
.getZExtValue();
int64_t intMax =
APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
.getZExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
return success();
}
return failure();
}
if (llvm::isa<IntegerType>(inputElementType)) {
int64_t minClamp =
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
int64_t maxClamp =
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
int64_t intMin =
APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
.getSExtValue();
int64_t intMax =
APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
.getSExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
return success();
}
return failure();
}
return failure();
}
};
// Attempts the following transformation:
//
// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
// tensor X the following identity holds:
//
// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
//
// subject to the following valid NaN propagation semantics:
// --------------------------------------------
// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
// |-------------|--------------|-------------|
// | PROPAGATE | PROPAGATE | PROPAGATE |
// | PROPAGATE | IGNORE | IGNORE |
// | IGNORE | PROPAGATE | INVALID |
// | IGNORE | IGNORE | IGNORE |
// |------------------------------------------|
struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
// Helper structure to describe the range of a clamp operation.
template <typename T>
struct ClampRange {
ClampRange(const T &start, const T &end) : start(start), end(end) {}
T start;
T end;
// Helper function to determine if two Clamp ranges intersect.
bool intersects(const ClampRange<T> &otherRange) {
return start < otherRange.end && otherRange.start < end;
}
};
LogicalResult matchAndRewrite(tosa::ClampOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();
// Check the input to the CLAMP op is itself a CLAMP.
auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
if (!clampOp)
return failure();
// Check we have a valid NaN propagation combination.
const auto opNanMode = op.getNanMode();
const auto clampNanMode = clampOp.getNanMode();
if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
return failure();
auto maxValAttr = op.getMaxValAttr();
auto minValAttr = op.getMinValAttr();
auto clampOpMaxValAttr = clampOp.getMaxValAttr();
auto clampOpMinValAttr = clampOp.getMinValAttr();
auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
inputEType = quantType.getStorageType();
}
Attribute newMinValAttr, newMaxValAttr;
if (mlir::isa<FloatType>(inputEType)) {
auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
// Check we have intersecting ranges.
const auto opMinFloat = floatMinValAttr.getValue();
const auto opMaxFloat = floatMaxValAttr.getValue();
const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
clampOpMaxFloat);
if (!opRangeFloatRange.intersects(clampRangeFloatRange))
return failure();
// Run the transformation.
auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
} else {
assert(mlir::isa<IntegerType>(inputEType));
auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
if (inputEType.isUnsignedInteger()) {
// Check we have intersecting ranges.
const auto opMinInt = intMinValAttr.getUInt();
const auto opMaxInt = intMaxValAttr.getUInt();
const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
clampOpMaxInt);
if (!opRangeIntRange.intersects(clampRangeIntRange))
return failure();
// Run the transformation.
auto newMinVal = std::max(opMinInt, clampOpMinInt);
auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
} else {
// Check we have intersecting ranges.
const auto opMinInt = intMinValAttr.getInt();
const auto opMaxInt = intMaxValAttr.getInt();
const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
clampOpMaxInt);
if (!opRangeIntRange.intersects(clampRangeIntRange))
return failure();
// Run the transformation.
auto newMinVal = std::max(opMinInt, clampOpMinInt);
auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
}
}
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
: opNanMode));
return success();
}
};
void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ClampIsNoOp>(context);
results.add<ClampClampOptimization>(context);
}
struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
PatternRewriter &rewriter) const override {
Value sliceInput = sliceOp.getInput1();
auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
if (!concatOp)
return rewriter.notifyMatchFailure(
sliceOp, "slice input must be concat operation");
OperandRange inputs = concatOp.getInput1();
auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
if (!concatType || !concatType.hasStaticShape())
return rewriter.notifyMatchFailure(
sliceOp, "slice input must be a static ranked tensor");
int32_t axis = concatOp.getAxis();
DenseElementsAttr startElems;
DenseElementsAttr sizeElems;
if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
return rewriter.notifyMatchFailure(
sliceOp, "start of slice must be a static ranked shape");
if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
return rewriter.notifyMatchFailure(
sliceOp, "size of slice must be a static ranked shape");
llvm::SmallVector<int64_t> sliceStarts =
llvm::to_vector(startElems.getValues<int64_t>());
llvm::SmallVector<int64_t> sliceSizes =
llvm::to_vector(sizeElems.getValues<int64_t>());
// Validate slice on the concatenated axis. Slicing along this
// axis should span only one of the inputs to the concatenate
// operation.
std::optional<Value> replaceWithSlice;
for (auto input : inputs) {
auto inputType = dyn_cast<RankedTensorType>(input.getType());
if (!inputType || !inputType.hasStaticShape())
return rewriter.notifyMatchFailure(
sliceOp, "concat input must be a static ranked tensor");
if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
inputType.getDimSize(axis)) {
auto start_op =
getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
auto size_op =
getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
replaceWithSlice =
rewriter
.create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
input, start_op, size_op)
.getResult();
break;
}
sliceStarts[axis] -= inputType.getDimSize(axis);
}
if (!replaceWithSlice)
return rewriter.notifyMatchFailure(
sliceOp, "corresponding concat input not found for slice");
rewriter.replaceOp(sliceOp, replaceWithSlice.value());
return success();
}
};
// Update size operand of tosa.slice if size has dynamic dims but corresponding
// output dim is static
struct SliceDynamicSizeCanonicalization
: public OpRewritePattern<tosa::SliceOp> {
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
PatternRewriter &rewriter) const override {
ShapedType resultType = cast<ShapedType>(sliceOp.getType());
ElementsAttr sizeElems;
if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
return rewriter.notifyMatchFailure(
sliceOp, "size of slice must be a static ranked shape");
}
llvm::SmallVector<int64_t> sliceSizes =
llvm::to_vector(sizeElems.getValues<int64_t>());
bool replaceSliceSize{false};
// if size op has -1 indicating dynamic shape but corresponding dim on the
// output is statically known, update size to match with known output dim
// shape
for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
if (size == -1 && !resultType.isDynamicDim(index)) {
sliceSizes[index] = resultType.getDimSize(index);
replaceSliceSize = true;
}
}
if (!replaceSliceSize) {
return rewriter.notifyMatchFailure(
sliceOp, "no dimension of size of slice is dynamic that resolves "
"to static output shape");
}
auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
auto newSliceOp = rewriter.create<tosa::SliceOp>(
sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),
sliceOp.getStart(), size_op);
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
// Remove const_shape size op when it no longer has use point.
Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
if (sizeConstShape->getResult(0).hasOneUse())
rewriter.eraseOp(sizeConstShape);
return success();
}
};
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
context);
}
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
template <typename IntFolder, typename FloatFolder>
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType returnTy) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
if (lETy != rETy)
return {};
if (llvm::isa<IntegerType>(lETy)) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
auto result = IntFolder()(l, r);
return DenseElementsAttr::get(returnTy, result);
}
if (llvm::isa<FloatType>(lETy)) {
APFloat l = lhs.getSplatValue<APFloat>();
APFloat r = rhs.getSplatValue<APFloat>();
auto result = FloatFolder()(l, r);
return DenseElementsAttr::get(returnTy, result);
}
}
return {};
}
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
if (llvm::isa<IntegerType>(elemType))
return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
return false;
}
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
if (llvm::isa<FloatType>(elemType))
return val && val.isSplat() &&
val.getSplatValue<APFloat>().isExactlyValue(1.0);
if (llvm::isa<IntegerType>(elemType)) {
const int64_t shifted = 1LL << shift;
return val && val.isSplat() &&
val.getSplatValue<APInt>().getSExtValue() == shifted;
}
return false;
}
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!lhsTy || !rhsTy || !resultTy)
return {};
// Cannot create an ElementsAttr from non-int/float/index types
if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
!rhsTy.getElementType().isIntOrIndexOrFloat())
return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
return getInput2();
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
resultTy);
}
OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
!outputTy.hasStaticShape())
return {};
if (inputTy.getDimSize(getAxis()) == 1)
return DenseElementsAttr::get(outputTy, 0);
return {};
}
OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!lhsTy || !rhsTy || !resultTy)
return {};
if (lhsTy != rhsTy)
return {};
// IntDivOp inputs must be integer type, no need to check for quantized type
auto resultETy = resultTy.getElementType();
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsAttr && lhsAttr.isSplat()) {
if (llvm::isa<IntegerType>(resultETy) &&
lhsAttr.getSplatValue<APInt>().isZero())
return lhsAttr;
}
if (rhsAttr && rhsAttr.isSplat()) {
if (llvm::isa<IntegerType>(resultETy) &&
rhsAttr.getSplatValue<APInt>().isOne())
return getInput1();
}
if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
llvm::isa<IntegerType>(resultETy)) {
APInt l = lhsAttr.getSplatValue<APInt>();
APInt r = rhsAttr.getSplatValue<APInt>();
if (!r.isZero()) {
APInt result = l.sdiv(r);
return DenseElementsAttr::get(resultTy, result);
}
}
return {};
}
namespace {
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType ty, int32_t shift) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
if (llvm::isa<IntegerType>(ty.getElementType())) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
if (shift == 0) {
return DenseElementsAttr::get(ty, l * r);
}
auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
l = l.sext(bitwidth * 2);
r = r.sext(bitwidth * 2);
auto result = l * r;
result.lshrInPlace(shift);
result = result.trunc(bitwidth);
return DenseElementsAttr::get(ty, result);
}
if (llvm::isa<FloatType>(ty.getElementType())) {
APFloat l = lhs.getSplatValue<APFloat>();
APFloat r = rhs.getSplatValue<APFloat>();
APFloat result = l * r;
return DenseElementsAttr::get(ty, result);
}
}
return {};
}
} // namespace
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto lhs = getInput1();
auto rhs = getInput2();
auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!lhsTy || !rhsTy || !resultTy)
return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
// Result right shift on i32_t data type only. For simplification, synthesize
// a zero shift for other data type.
int32_t shift = 0;
if (resultETy.isInteger(32)) {
ElementsAttr shift_elem;
if (getShift().getImpl()) {
if (!matchPattern(getShift(), m_Constant(&shift_elem)))
// cannot be folded when the shift value is unknown.
return {};
shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
}
}
if (rhsTy == resultTy) {
if (isSplatZero(resultETy, lhsAttr))
return lhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, lhsAttr, shift))
return rhs;
}
if (lhsTy == resultTy) {
if (isSplatZero(resultETy, rhsAttr))
return rhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, rhsAttr, shift))
return lhs;
}
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
}
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!lhsTy || !rhsTy || !resultTy)
return {};
// Cannot create an ElementsAttr from non-int/float/index types
if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
!rhsTy.getElementType().isIntOrIndexOrFloat())
return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
resultTy);
}
namespace {
template <typename Cmp>
struct ComparisonFold {
ComparisonFold() = default;
APInt operator()(const APInt &l, const APInt &r) {
return APInt(1, Cmp()(l, r));
}
APInt operator()(const APFloat &l, const APFloat &r) {
return APInt(1, Cmp()(l, r));
}
};
struct APIntFoldGreater {
APIntFoldGreater() = default;
APInt operator()(const APInt &l, const APInt &r) {
return APInt(1, l.sgt(r));
}
};
struct APIntFoldGreaterEqual {
APIntFoldGreaterEqual() = default;
APInt operator()(const APInt &l, const APInt &r) {
return APInt(1, l.sge(r));
}
};
} // namespace
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<APIntFoldGreaterEqual,
ComparisonFold<std::greater_equal<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
Value lhs = getInput1();
Value rhs = getInput2();
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
// If we are comparing an integer value to itself it is always true. We can
// not do this with float due to float values.
if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
resultTy.hasStaticShape() && lhs == rhs) {
return DenseElementsAttr::get(resultTy, true);
}
if (!lhsAttr || !rhsAttr)
return {};
return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
resultTy);
}
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
if (getInput().getType() == getType())
return getInput();
auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
if (!operand)
return {};
auto inTy = llvm::cast<ShapedType>(getInput().getType());
auto outTy = llvm::cast<ShapedType>(getType());
auto inETy = inTy.getElementType();
auto outETy = outTy.getElementType();
if (operand.isSplat()) {
if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
bool overflow;
auto splatVal = operand.getSplatValue<APFloat>();
auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
&overflow);
return SplatElementsAttr::get(outTy, splatVal);
}
if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
llvm::RoundingMode::NearestTiesToEven);
return SplatElementsAttr::get(outTy, splatVal);
}
if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
auto intVal = APSInt(
llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
auto floatVal = operand.getSplatValue<APFloat>();
bool exact;
floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
&exact);
return SplatElementsAttr::get(outTy, intVal);
}
if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
bool trunc =
inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
auto intVal = operand.getSplatValue<APInt>();
auto bitwidth = outETy.getIntOrFloatBitWidth();
if (trunc) {
intVal = intVal.trunc(bitwidth);
} else if (unsignIn) {
intVal = intVal.zext(bitwidth);
} else {
intVal = intVal.sext(bitwidth);
}
return SplatElementsAttr::get(outTy, intVal);
}
}
return {};
}
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
#define REDUCE_FOLDER(OP) \
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
if (!inputTy.hasRank()) \
return {}; \
if (inputTy != getType()) \
return {}; \
if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
return getInput(); \
return {}; \
}
REDUCE_FOLDER(ReduceAllOp)
REDUCE_FOLDER(ReduceAnyOp)
REDUCE_FOLDER(ReduceMaxOp)
REDUCE_FOLDER(ReduceMinOp)
REDUCE_FOLDER(ReduceProductOp)
REDUCE_FOLDER(ReduceSumOp)
#undef REDUCE_FOLDER
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputTy || !outputTy)
return {};
// Fold when the input and output types are the same. This is only safe when
// there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
// there may still be a productive reshape.
if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
return getInput1();
// reshape(reshape(x)) -> reshape(x)
if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
getInput1().getDefiningOp())) {
getInput1Mutable().assign(reshapeOp.getInput1());
return getResult();
}
// Cannot create an ElementsAttr from non-int/float/index types
if (!inputTy.getElementType().isIntOrIndexOrFloat())
return {};
// reshape(const(x)) -> const(reshape-attr(x))
if (auto operand =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
// Constants must have static shape.
if (!outputTy.hasStaticShape())
return {};
// Okay to duplicate splat constants.
if (operand.isSplat())
return SplatElementsAttr::get(outputTy,
operand.getSplatValue<Attribute>());
// Don't duplicate other constants.
if (!getInput1().hasOneUse())
return {};
llvm::SmallVector<int64_t> shapeVec;
if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
return {};
return operand.reshape(
llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
}
return {};
}
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
// If the pad is all zeros we can fold this operation away.
if (adaptor.getPadding() && getInput1().getType() == getType()) {
auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
if (densePad && densePad.isSplat() &&
densePad.getSplatValue<APInt>().isZero()) {
return getInput1();
}
}
return {};
}
// Fold away cases where a tosa.resize operation returns a copy
// of the input image.
OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
auto scaleAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
auto offsetAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
auto borderAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
if (!scaleAttr || !offsetAttr || !borderAttr) {
return {};
}
auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
return {};
}
// Check unit scaling.
if (scale[0] != scale[1] || scale[2] != scale[3]) {
return {};
}
// There should be no offset.
if (offset[0] != 0 || offset[1] != 0) {
return {};
}
// There should be no border.
if (border[0] != 0 || border[1] != 0) {
return {};
}
auto input = getInput();
auto inputTy = llvm::cast<RankedTensorType>(input.getType());
auto resultTy = llvm::cast<RankedTensorType>(getType());
if (inputTy != resultTy)
return {};
return input;
}
OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
auto operand = getInput1();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
auto axis = getAxis();
auto operandAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
if (operandAttr)
return operandAttr;
// If the dim-length is 1, tosa.reverse is a no-op.
if (operandTy.hasRank() &&
(operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
return operand;
return {};
}
OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputTy || !outputTy)
return {};
if (inputTy == outputTy && inputTy.hasStaticShape())
return getInput1();
if (!adaptor.getInput1())
return {};
// Cannot create an ElementsAttr from non-int/float/index types
if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
!outputTy.getElementType().isIntOrIndexOrFloat())
return {};
auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
if (operand.isSplat() && outputTy.hasStaticShape()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
outputTy.getNumElements() == 1) {
DenseElementsAttr startElems;
if (!matchPattern(getStart(), m_Constant(&startElems)))
return {};
llvm::SmallVector<uint64_t> indices =
llvm::to_vector(startElems.getValues<uint64_t>());
auto value = operand.getValues<Attribute>()[indices];
return SplatElementsAttr::get(outputTy, value);
}
return {};
}
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
if (getInput2() == getInput3())
return getInput2();
auto predicate =
llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
if (!predicate)
return {};
if (!predicate.isSplat())
return {};
return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
: getInput3();
}
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
if (getInput1().getType() == getType()) {
if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
adaptor.getMultiples())) {
if (multiples.isSplat() &&
multiples.getSplatValue<APInt>().getSExtValue() == 1)
return getInput1();
if (auto int_array_attr =
llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
if (llvm::all_of(int_array_attr.getValues<APInt>(),
[](APInt v) { return v.getSExtValue() == 1; }))
return getInput1();
}
}
}
return {};
}
OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::cast<ShapedType>(getType());
// Transposing splat values just means reshaping.
if (auto input =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
if (input.isSplat() && resultTy.hasStaticShape() &&
input.getType().getElementType() == resultTy.getElementType())
return input.reshape(resultTy);
}
// Transpose is not the identity transpose.
const llvm::ArrayRef<int32_t> perms = getPerms();
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return {};
return getInput1();
}
OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
auto input = getInput1();
// Element-wise log(exp(x)) = x
if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
return op.getInput1();
}
return {};
}
OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
auto input = getInput1();
// Element-wise exp(log(x)) = x
if (auto op = input.getDefiningOp<tosa::LogOp>()) {
return op.getInput1();
}
return {};
}
OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
// Element-wise negate(negate(x)) = x
// iff all zero points are constant 0
auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
if (!definingOp) {
// defining op of input1 is not a negate, cannot fold
return {};
}
if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
failed(maybeIZp) || *maybeIZp != 0) {
// input1 zero point is not constant 0, cannot fold
return {};
}
if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
failed(maybeOZp) || *maybeOZp != 0) {
// output zero point is not constant 0, cannot fold
return {};
}
if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
failed(maybeIZp) || *maybeIZp != 0) {
// definingOp's input1 zero point is not constant 0, cannot fold
return {};
}
if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
failed(maybeOZp) || *maybeOZp != 0) {
// definingOp's output zero point is not constant 0, cannot fold
return {};
}
return definingOp.getInput1();
}
OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
auto input = getInput1();
// Element-wise abs(abs(x)) = abs(x)
if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
return input;
}
return {};
}
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
// Fold consecutive concats on the same axis into a single op.
// Keep track of the operands so we are able to construct a new concat
// later. Conservatively assume that we double the number of operands when
// folding
SmallVector<Value, 8> concatOperands;
concatOperands.reserve(2 * getNumOperands());
// Find all operands that are foldable concats
bool foundFoldableConcat = false;
for (Value operand : getOperands()) {
concatOperands.emplace_back(operand);
auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
if (!producer)
continue;
// Not foldable if axes are not the same
if (getAxis() != producer.getAxis())
continue;
// Replace the original operand with all incoming operands
foundFoldableConcat = true;
concatOperands.pop_back();
llvm::append_range(concatOperands, producer->getOperands());
}
if (!foundFoldableConcat)
return {};
getOperation()->setOperands(concatOperands);
return getResult();
}
OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
auto input = adaptor.getInput1();
auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
// Fold splat inputs only.
if (!inputAttr || !inputAttr.isSplat())
return {};
auto shapeType = llvm::cast<ShapedType>(getType());
if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
auto floatVal = inputAttr.getSplatValue<APFloat>();
return DenseElementsAttr::get(shapeType,
ReciprocalOp::calcOneElement(floatVal));
}
return {};
}