mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-28 17:36:06 +00:00

The Diagnostic class contains all of the information necessary to report a diagnostic to the DiagnosticEngine. It should generally not be constructed directly, and instead used transitively via InFlightDiagnostic. A diagnostic is currently comprised of several different elements: * A severity level. * A source Location. * A list of DiagnosticArguments that help compose and comprise the output message. * A DiagnosticArgument represents any value that may be part of the diagnostic, e.g. string, integer, Type, Attribute, etc. * Arguments can be added to the diagnostic via the stream(<<) operator. * (In a future cl) A list of attached notes. * These are in the form of other diagnostics that provide supplemental information to the main diagnostic, but do not have context on their own. The InFlightDiagnostic class represents an RAII wrapper around a Diagnostic that is set to be reported with the diagnostic engine. This allows for the user to modify a diagnostic that is inflight. The internally wrapped diagnostic can be reported directly or automatically upon destruction. These classes allow for more natural composition of diagnostics by removing the restriction that the message of a diagnostic is comprised of a single Twine. They should also allow for nice incremental improvements to the diagnostics experience in the future, e.g. formatv style diagnostics. Simple Example: emitError(loc, "integer bitwidth is limited to " + Twine(IntegerType::kMaxWidth) + " bits"); emitError(loc) << "integer bitwidth is limited to " << IntegerType::kMaxWidth << " bits"; -- PiperOrigin-RevId: 246526439
412 lines
15 KiB
C++
412 lines
15 KiB
C++
//===- LowerUniformRealMath.cpp ------------------------------------------===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
|
|
#include "UniformKernelUtils.h"
|
|
|
|
#include "mlir/FxpMathOps/FxpMathOps.h"
|
|
#include "mlir/FxpMathOps/Passes.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/StandardOps/Ops.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::fxpmath;
|
|
using namespace mlir::fxpmath::detail;
|
|
using namespace mlir::quant;
|
|
|
|
namespace {
|
|
|
|
struct LowerUniformRealMathPass
|
|
: public FunctionPass<LowerUniformRealMathPass> {
|
|
void runOnFunction() override;
|
|
};
|
|
|
|
struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
|
|
void runOnFunction() override;
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dequantize
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static Value *emitUniformPerLayerDequantize(Location loc, Value *input,
|
|
UniformQuantizedType elementType,
|
|
PatternRewriter &rewriter) {
|
|
// Pre-conditions.
|
|
if (!elementType.isSigned()) {
|
|
// TODO: Support unsigned storage type.
|
|
rewriter.getContext()->getDiagEngine().emit(loc,
|
|
DiagnosticSeverity::Warning)
|
|
<< "unimplemented: dequantize signed uniform";
|
|
return nullptr;
|
|
}
|
|
|
|
Type storageType = elementType.castToStorageType(input->getType());
|
|
Type realType = elementType.castToExpressedType(input->getType());
|
|
Type intermediateType =
|
|
castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
|
|
assert(storageType && "cannot cast to storage type");
|
|
assert(realType && "cannot cast to expressed type");
|
|
|
|
// Cast to storage type.
|
|
input = rewriter.create<StorageCastOp>(loc, storageType, input);
|
|
|
|
// Promote to intermediate type.
|
|
input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
|
|
|
|
// Apply zero-point offset.
|
|
if (elementType.getZeroPoint() != 0) {
|
|
Value *negZeroPointConst = rewriter.create<ConstantOp>(
|
|
loc, broadcastScalarConstIntValue(intermediateType,
|
|
-elementType.getZeroPoint()));
|
|
input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
|
|
}
|
|
|
|
// Convert to float.
|
|
input = rewriter.create<ConvertISToFOp>(loc, realType, input);
|
|
|
|
// Mul by scale.
|
|
Value *scaleConst = rewriter.create<ConstantOp>(
|
|
loc, broadcastScalarConstFloatValue(realType,
|
|
APFloat(elementType.getScale())));
|
|
return rewriter.create<MulFOp>(loc, input, scaleConst);
|
|
}
|
|
|
|
static Value *
|
|
emitUniformPerAxisDequantize(Location loc, Value *input,
|
|
UniformQuantizedPerAxisType elementType,
|
|
PatternRewriter &rewriter) {
|
|
// TODO: Support per-axis dequantize.
|
|
rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
|
|
<< "unimplemented: per-axis uniform dequantization";
|
|
return nullptr;
|
|
}
|
|
|
|
static Value *emitDequantize(Location loc, Value *input,
|
|
PatternRewriter &rewriter) {
|
|
Type inputType = input->getType();
|
|
QuantizedType qElementType =
|
|
QuantizedType::getQuantizedElementType(inputType);
|
|
if (auto uperLayerElementType =
|
|
qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
|
|
return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
|
|
rewriter);
|
|
} else if (auto uperAxisElementType =
|
|
qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
|
|
return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
|
|
rewriter);
|
|
} else {
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct UniformDequantizePattern : public RewritePattern {
|
|
UniformDequantizePattern(MLIRContext *context)
|
|
: RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {}
|
|
|
|
PatternMatchResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const {
|
|
auto dcastOp = op->cast<DequantizeCastOp>();
|
|
Type inputType = dcastOp.arg()->getType();
|
|
Type outputType = dcastOp.getResult()->getType();
|
|
|
|
QuantizedType inputElementType =
|
|
QuantizedType::getQuantizedElementType(inputType);
|
|
Type expressedOutputType = inputElementType.castToExpressedType(inputType);
|
|
if (expressedOutputType != outputType) {
|
|
// Not a valid uniform cast.
|
|
return matchFailure();
|
|
}
|
|
|
|
Value *dequantizedValue =
|
|
emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter);
|
|
if (!dequantizedValue) {
|
|
return matchFailure();
|
|
}
|
|
|
|
rewriter.replaceOp(op, dequantizedValue);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Elementwise add
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult
|
|
tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
|
|
PatternRewriter &rewriter) {
|
|
if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
|
|
info.rhsType != info.resultType) {
|
|
return failure();
|
|
}
|
|
|
|
// Choose a byte aligned intermediate width big enough to perform the
|
|
// calculation without overflow.
|
|
// TODO: This should probably be made just big enough to avoid overflow and
|
|
// leave the downstream tooling to decide how to align that to machine
|
|
// word sizes.
|
|
unsigned intermediateWidth =
|
|
info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
|
|
IntegerType intermediateElementType =
|
|
IntegerType::get(intermediateWidth, rewriter.getContext());
|
|
Type intermediateType =
|
|
castElementType(info.resultStorageType, intermediateElementType);
|
|
|
|
// Cast operands to storage type.
|
|
Value *lhsValue = rewriter
|
|
.create<StorageCastOp>(info.op->getLoc(),
|
|
info.lhsStorageType, info.lhs)
|
|
.getResult();
|
|
Value *rhsValue = rewriter
|
|
.create<StorageCastOp>(info.op->getLoc(),
|
|
info.rhsStorageType, info.rhs)
|
|
.getResult();
|
|
|
|
// Cast to the intermediate sized type.
|
|
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
|
lhsValue);
|
|
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
|
rhsValue);
|
|
|
|
// Add.
|
|
Value *resultValue =
|
|
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
|
|
|
|
// Zero point offset adjustment.
|
|
// result = (lhs - zp) + (rhs - zp) + zp
|
|
// zpOffset = -zp
|
|
int zpOffset = -1 * info.resultType.getZeroPoint();
|
|
if (zpOffset != 0) {
|
|
Value *zpOffsetConst = rewriter.create<ConstantOp>(
|
|
info.op->getLoc(),
|
|
broadcastScalarConstIntValue(intermediateType, zpOffset));
|
|
resultValue =
|
|
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
|
|
}
|
|
|
|
// Clamp.
|
|
auto clampMinMax = info.getClampMinMax(intermediateElementType);
|
|
resultValue = rewriter.create<ClampISOp>(
|
|
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
|
|
|
|
// Convert back to original type.
|
|
resultValue = rewriter.create<ConvertISOp>(
|
|
info.op->getLoc(), info.resultStorageType, resultValue);
|
|
|
|
// Cast back for new result.
|
|
rewriter.replaceOpWithNewOp<StorageCastOp>(
|
|
info.op, info.getQuantizedResultType(), resultValue);
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Elementwise mul
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult
|
|
tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
|
|
PatternRewriter &rewriter) {
|
|
if (!info.resultType.isSigned()) {
|
|
return failure();
|
|
}
|
|
|
|
double outputMultiplierReal = info.lhsType.getScale() *
|
|
info.rhsType.getScale() /
|
|
info.resultType.getScale();
|
|
if (outputMultiplierReal > 1.0) {
|
|
info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0");
|
|
return failure();
|
|
}
|
|
|
|
// TODO: Choose an appropriate intermediate width for muls > 8 bits to
|
|
// avoid overflow.
|
|
unsigned intermediateWidth = 32;
|
|
IntegerType intermediateElementType =
|
|
IntegerType::get(intermediateWidth, rewriter.getContext());
|
|
Type intermediateType =
|
|
castElementType(info.resultStorageType, intermediateElementType);
|
|
|
|
// Cast operands to storage type.
|
|
Value *lhsValue = rewriter
|
|
.create<StorageCastOp>(info.op->getLoc(),
|
|
info.lhsStorageType, info.lhs)
|
|
.getResult();
|
|
Value *rhsValue = rewriter
|
|
.create<StorageCastOp>(info.op->getLoc(),
|
|
info.rhsStorageType, info.rhs)
|
|
.getResult();
|
|
|
|
// Cast to the intermediate sized type.
|
|
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
|
lhsValue);
|
|
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
|
|
rhsValue);
|
|
|
|
// Apply argument zeroPoints.
|
|
if (info.lhsType.getZeroPoint() != 0) {
|
|
Value *zpOffsetConst = rewriter.create<ConstantOp>(
|
|
info.op->getLoc(), broadcastScalarConstIntValue(
|
|
intermediateType, -info.lhsType.getZeroPoint()));
|
|
lhsValue =
|
|
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
|
|
}
|
|
|
|
if (info.rhsType.getZeroPoint() != 0) {
|
|
Value *zpOffsetConst = rewriter.create<ConstantOp>(
|
|
info.op->getLoc(), broadcastScalarConstIntValue(
|
|
intermediateType, -info.rhsType.getZeroPoint()));
|
|
rhsValue =
|
|
rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
|
|
}
|
|
|
|
// Mul.
|
|
Value *resultValue =
|
|
rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
|
|
|
|
// Scale output.
|
|
QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
|
|
resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
|
|
info.op->getLoc(), resultValue,
|
|
IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
|
|
resultValue = rewriter.create<RoundingDivideByPotISOp>(
|
|
info.op->getLoc(), resultValue,
|
|
IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
|
|
|
|
// Zero point offset adjustment.
|
|
if (info.resultType.getZeroPoint() != 0) {
|
|
Value *zpOffsetConst = rewriter.create<ConstantOp>(
|
|
info.op->getLoc(),
|
|
broadcastScalarConstIntValue(intermediateType,
|
|
info.resultType.getZeroPoint()));
|
|
resultValue =
|
|
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
|
|
}
|
|
|
|
// Clamp.
|
|
auto clampMinMax = info.getClampMinMax(intermediateElementType);
|
|
resultValue = rewriter.create<ClampISOp>(
|
|
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
|
|
|
|
// Convert back to original type.
|
|
resultValue = rewriter.create<ConvertISOp>(
|
|
info.op->getLoc(), info.resultStorageType, resultValue);
|
|
|
|
// Cast back for new result.
|
|
rewriter.replaceOpWithNewOp<StorageCastOp>(
|
|
info.op, info.getQuantizedResultType(), resultValue);
|
|
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct UniformRealAddEwPattern : public RewritePattern {
|
|
UniformRealAddEwPattern(MLIRContext *context)
|
|
: RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
|
|
|
|
PatternMatchResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const {
|
|
auto addOp = op->cast<RealAddEwOp>();
|
|
const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
|
|
addOp.clamp_min(), addOp.clamp_max());
|
|
if (!info.isValid()) {
|
|
return matchFailure();
|
|
}
|
|
|
|
// Try all of the permutations we support.
|
|
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
|
|
return matchSuccess();
|
|
}
|
|
|
|
return matchFailure();
|
|
}
|
|
};
|
|
|
|
struct UniformRealMulEwPattern : public RewritePattern {
|
|
UniformRealMulEwPattern(MLIRContext *context)
|
|
: RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
|
|
|
|
PatternMatchResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const {
|
|
auto mulOp = op->cast<RealMulEwOp>();
|
|
const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
|
|
mulOp.clamp_min(), mulOp.clamp_max());
|
|
if (!info.isValid()) {
|
|
return matchFailure();
|
|
}
|
|
|
|
// Try all of the permutations we support.
|
|
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
|
|
return matchSuccess();
|
|
}
|
|
|
|
return matchFailure();
|
|
}
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LowerUniformRealMath pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void LowerUniformRealMathPass::runOnFunction() {
|
|
auto &fn = getFunction();
|
|
OwningRewritePatternList patterns;
|
|
auto *context = &getContext();
|
|
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
|
|
patterns.push_back(llvm::make_unique<UniformRealMulEwPattern>(context));
|
|
applyPatternsGreedily(fn, std::move(patterns));
|
|
}
|
|
|
|
FunctionPassBase *mlir::fxpmath::createLowerUniformRealMathPass() {
|
|
return new LowerUniformRealMathPass();
|
|
}
|
|
|
|
static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
|
|
"fxpmath-lower-uniform-real-math",
|
|
"Lowers uniform-quantized real math ops to integer arithmetic.");
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LowerUniformCasts pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void LowerUniformCastsPass::runOnFunction() {
|
|
auto &fn = getFunction();
|
|
OwningRewritePatternList patterns;
|
|
auto *context = &getContext();
|
|
patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
|
|
applyPatternsGreedily(fn, std::move(patterns));
|
|
}
|
|
|
|
FunctionPassBase *mlir::fxpmath::createLowerUniformCastsPass() {
|
|
return new LowerUniformCastsPass();
|
|
}
|
|
|
|
static PassRegistration<LowerUniformCastsPass>
|
|
lowerUniformCastsPass("fxpmath-lower-uniform-casts",
|
|
"Lowers uniform-quantized casts.");
|