mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 13:06:08 +00:00

PR #112292 added support for vectors to the integer range inference interface and analysis, but didn't update the getDestWidth() method. This caused crashes when trying to infer the ranges of `arith.extsi` with vector inputs, as the code would try to sign-extend a N-bit value to a 0-bit one, which would assert and crash. This commit fixes the issue by adding a getElementTypeOrSelf().
179 lines
6.3 KiB
C++
179 lines
6.3 KiB
C++
//===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
|
|
#include <optional>
|
|
|
|
using namespace mlir;
|
|
|
|
bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
|
|
return umin().getBitWidth() == other.umin().getBitWidth() &&
|
|
umin() == other.umin() && umax() == other.umax() &&
|
|
smin() == other.smin() && smax() == other.smax();
|
|
}
|
|
|
|
const APInt &ConstantIntRanges::umin() const { return uminVal; }
|
|
|
|
const APInt &ConstantIntRanges::umax() const { return umaxVal; }
|
|
|
|
const APInt &ConstantIntRanges::smin() const { return sminVal; }
|
|
|
|
const APInt &ConstantIntRanges::smax() const { return smaxVal; }
|
|
|
|
unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
|
|
type = getElementTypeOrSelf(type);
|
|
if (type.isIndex())
|
|
return IndexType::kInternalStorageBitWidth;
|
|
if (auto integerType = dyn_cast<IntegerType>(type))
|
|
return integerType.getWidth();
|
|
// Non-integer types have their bounds stored in width 0 `APInt`s.
|
|
return 0;
|
|
}
|
|
|
|
ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
|
|
return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
|
|
}
|
|
|
|
ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
|
|
return {value, value, value, value};
|
|
}
|
|
|
|
ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max,
|
|
bool isSigned) {
|
|
if (isSigned)
|
|
return fromSigned(min, max);
|
|
return fromUnsigned(min, max);
|
|
}
|
|
|
|
ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
|
|
const APInt &smax) {
|
|
unsigned int width = smin.getBitWidth();
|
|
APInt umin, umax;
|
|
if (smin.isNonNegative() == smax.isNonNegative()) {
|
|
umin = smin.ult(smax) ? smin : smax;
|
|
umax = smin.ugt(smax) ? smin : smax;
|
|
} else {
|
|
umin = APInt::getMinValue(width);
|
|
umax = APInt::getMaxValue(width);
|
|
}
|
|
return {umin, umax, smin, smax};
|
|
}
|
|
|
|
ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
|
|
const APInt &umax) {
|
|
unsigned int width = umin.getBitWidth();
|
|
APInt smin, smax;
|
|
if (umin.isNonNegative() == umax.isNonNegative()) {
|
|
smin = umin.slt(umax) ? umin : umax;
|
|
smax = umin.sgt(umax) ? umin : umax;
|
|
} else {
|
|
smin = APInt::getSignedMinValue(width);
|
|
smax = APInt::getSignedMaxValue(width);
|
|
}
|
|
return {umin, umax, smin, smax};
|
|
}
|
|
|
|
ConstantIntRanges
|
|
ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
|
|
// "Not an integer" poisons everything and also cannot be fed to comparison
|
|
// operators.
|
|
if (umin().getBitWidth() == 0)
|
|
return *this;
|
|
if (other.umin().getBitWidth() == 0)
|
|
return other;
|
|
|
|
const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
|
|
const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
|
|
const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
|
|
const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
|
|
|
|
return {uminUnion, umaxUnion, sminUnion, smaxUnion};
|
|
}
|
|
|
|
ConstantIntRanges
|
|
ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
|
|
// "Not an integer" poisons everything and also cannot be fed to comparison
|
|
// operators.
|
|
if (umin().getBitWidth() == 0)
|
|
return *this;
|
|
if (other.umin().getBitWidth() == 0)
|
|
return other;
|
|
|
|
const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
|
|
const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
|
|
const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
|
|
const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
|
|
|
|
return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
|
|
}
|
|
|
|
std::optional<APInt> ConstantIntRanges::getConstantValue() const {
|
|
// Note: we need to exclude the trivially-equal width 0 values here.
|
|
if (umin() == umax() && umin().getBitWidth() != 0)
|
|
return umin();
|
|
if (smin() == smax() && smin().getBitWidth() != 0)
|
|
return smin();
|
|
return std::nullopt;
|
|
}
|
|
|
|
raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
|
|
return os << "unsigned : [" << range.umin() << ", " << range.umax()
|
|
<< "] signed : [" << range.smin() << ", " << range.smax() << "]";
|
|
}
|
|
|
|
IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
|
|
unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
|
|
if (width == 0)
|
|
return {};
|
|
|
|
APInt umin = APInt::getMinValue(width);
|
|
APInt umax = APInt::getMaxValue(width);
|
|
APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
|
|
APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
|
|
return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
|
|
}
|
|
|
|
raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
|
|
range.print(os);
|
|
return os;
|
|
}
|
|
|
|
void mlir::intrange::detail::defaultInferResultRanges(
|
|
InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
|
|
SetIntLatticeFn setResultRanges) {
|
|
llvm::SmallVector<ConstantIntRanges> unpacked;
|
|
unpacked.reserve(argRanges.size());
|
|
|
|
for (const IntegerValueRange &range : argRanges) {
|
|
if (range.isUninitialized())
|
|
return;
|
|
unpacked.push_back(range.getValue());
|
|
}
|
|
|
|
interface.inferResultRanges(
|
|
unpacked,
|
|
[&setResultRanges](Value value, const ConstantIntRanges &argRanges) {
|
|
setResultRanges(value, IntegerValueRange{argRanges});
|
|
});
|
|
}
|
|
|
|
void mlir::intrange::detail::defaultInferResultRangesFromOptional(
|
|
InferIntRangeInterface interface, ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRanges) {
|
|
auto ranges = llvm::to_vector_of<IntegerValueRange>(argRanges);
|
|
interface.inferResultRangesFromOptional(
|
|
ranges,
|
|
[&setResultRanges](Value value, const IntegerValueRange &argRanges) {
|
|
if (!argRanges.isUninitialized())
|
|
setResultRanges(value, argRanges.getValue());
|
|
});
|
|
}
|