mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 05:36:05 +00:00

Integer range analysis will not update the range of an operation when any of the inferred input lattices are uninitialized. In the current behavior, all lattice values for non integer types are uninitialized. For operations like arith.cmpf ```mlir %3 = arith.cmpf ugt, %arg0, %arg1 : f32 ``` that will result in the range of the output also being uninitialized, and so on for any consumer of the arith.cmpf result. When control-flow ops are involved, the lack of propagation results in incorrect ranges, as the back edges for loop carried values are not properly joined with the definitions from the body region. For example, an scf.while loop whose body region produces a value that is in a dataflow relationship with some floating-point values through an arith.cmpf operation: ```mlir func.func @test_bad_range(%arg0: f32, %arg1: f32) -> (index, index) { %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %3 = arith.cmpf ugt, %arg0, %arg1 : f32 %1:2 = scf.while (%arg2 = %c0, %arg3 = %c0) : (index, index) -> (index, index) { %2 = arith.cmpi ult, %arg2, %c4 : index scf.condition(%2) %arg2, %arg3 : index, index } do { ^bb0(%arg2: index, %arg3: index): %4 = arith.select %3, %arg3, %arg3 : index %5 = arith.addi %arg2, %c1 : index scf.yield %5, %4 : index, index } return %1#0, %1#1 : index, index } ``` The existing behavior results in the control condition %2 being optimized to true, turning the while loop into an infinite loop. The update to %arg2 through the body region is never factored into the range calculation, as the ranges for the body ops all test as uninitialized. This change causes all values initialized with setToEntryState to be set to some initialized range, even if the values are not integers. --------- Co-authored-by: Spenser Bauman <sabauma@fastmail>
177 lines
6.2 KiB
C++
177 lines
6.2 KiB
C++
//===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
|
|
#include <optional>
|
|
|
|
using namespace mlir;
|
|
|
|
bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
|
|
return umin().getBitWidth() == other.umin().getBitWidth() &&
|
|
umin() == other.umin() && umax() == other.umax() &&
|
|
smin() == other.smin() && smax() == other.smax();
|
|
}
|
|
|
|
const APInt &ConstantIntRanges::umin() const { return uminVal; }
|
|
|
|
const APInt &ConstantIntRanges::umax() const { return umaxVal; }
|
|
|
|
const APInt &ConstantIntRanges::smin() const { return sminVal; }
|
|
|
|
const APInt &ConstantIntRanges::smax() const { return smaxVal; }
|
|
|
|
unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
|
|
if (type.isIndex())
|
|
return IndexType::kInternalStorageBitWidth;
|
|
if (auto integerType = dyn_cast<IntegerType>(type))
|
|
return integerType.getWidth();
|
|
// Non-integer types have their bounds stored in width 0 `APInt`s.
|
|
return 0;
|
|
}
|
|
|
|
ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
|
|
return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
|
|
}
|
|
|
|
ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
|
|
return {value, value, value, value};
|
|
}
|
|
|
|
ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max,
|
|
bool isSigned) {
|
|
if (isSigned)
|
|
return fromSigned(min, max);
|
|
return fromUnsigned(min, max);
|
|
}
|
|
|
|
ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
|
|
const APInt &smax) {
|
|
unsigned int width = smin.getBitWidth();
|
|
APInt umin, umax;
|
|
if (smin.isNonNegative() == smax.isNonNegative()) {
|
|
umin = smin.ult(smax) ? smin : smax;
|
|
umax = smin.ugt(smax) ? smin : smax;
|
|
} else {
|
|
umin = APInt::getMinValue(width);
|
|
umax = APInt::getMaxValue(width);
|
|
}
|
|
return {umin, umax, smin, smax};
|
|
}
|
|
|
|
ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
|
|
const APInt &umax) {
|
|
unsigned int width = umin.getBitWidth();
|
|
APInt smin, smax;
|
|
if (umin.isNonNegative() == umax.isNonNegative()) {
|
|
smin = umin.slt(umax) ? umin : umax;
|
|
smax = umin.sgt(umax) ? umin : umax;
|
|
} else {
|
|
smin = APInt::getSignedMinValue(width);
|
|
smax = APInt::getSignedMaxValue(width);
|
|
}
|
|
return {umin, umax, smin, smax};
|
|
}
|
|
|
|
ConstantIntRanges
|
|
ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
|
|
// "Not an integer" poisons everything and also cannot be fed to comparison
|
|
// operators.
|
|
if (umin().getBitWidth() == 0)
|
|
return *this;
|
|
if (other.umin().getBitWidth() == 0)
|
|
return other;
|
|
|
|
const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
|
|
const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
|
|
const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
|
|
const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
|
|
|
|
return {uminUnion, umaxUnion, sminUnion, smaxUnion};
|
|
}
|
|
|
|
ConstantIntRanges
|
|
ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
|
|
// "Not an integer" poisons everything and also cannot be fed to comparison
|
|
// operators.
|
|
if (umin().getBitWidth() == 0)
|
|
return *this;
|
|
if (other.umin().getBitWidth() == 0)
|
|
return other;
|
|
|
|
const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
|
|
const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
|
|
const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
|
|
const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
|
|
|
|
return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
|
|
}
|
|
|
|
std::optional<APInt> ConstantIntRanges::getConstantValue() const {
|
|
// Note: we need to exclude the trivially-equal width 0 values here.
|
|
if (umin() == umax() && umin().getBitWidth() != 0)
|
|
return umin();
|
|
if (smin() == smax() && smin().getBitWidth() != 0)
|
|
return smin();
|
|
return std::nullopt;
|
|
}
|
|
|
|
raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
|
|
return os << "unsigned : [" << range.umin() << ", " << range.umax()
|
|
<< "] signed : [" << range.smin() << ", " << range.smax() << "]";
|
|
}
|
|
|
|
IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
|
|
unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
|
|
if (width == 0)
|
|
return {};
|
|
|
|
APInt umin = APInt::getMinValue(width);
|
|
APInt umax = APInt::getMaxValue(width);
|
|
APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
|
|
APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
|
|
return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
|
|
}
|
|
|
|
raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
|
|
range.print(os);
|
|
return os;
|
|
}
|
|
|
|
void mlir::intrange::detail::defaultInferResultRanges(
|
|
InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
|
|
SetIntLatticeFn setResultRanges) {
|
|
llvm::SmallVector<ConstantIntRanges> unpacked;
|
|
unpacked.reserve(argRanges.size());
|
|
|
|
for (const IntegerValueRange &range : argRanges) {
|
|
if (range.isUninitialized())
|
|
return;
|
|
unpacked.push_back(range.getValue());
|
|
}
|
|
|
|
interface.inferResultRanges(
|
|
unpacked,
|
|
[&setResultRanges](Value value, const ConstantIntRanges &argRanges) {
|
|
setResultRanges(value, IntegerValueRange{argRanges});
|
|
});
|
|
}
|
|
|
|
void mlir::intrange::detail::defaultInferResultRangesFromOptional(
|
|
InferIntRangeInterface interface, ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRanges) {
|
|
auto ranges = llvm::to_vector_of<IntegerValueRange>(argRanges);
|
|
interface.inferResultRangesFromOptional(
|
|
ranges,
|
|
[&setResultRanges](Value value, const IntegerValueRange &argRanges) {
|
|
if (!argRanges.isUninitialized())
|
|
setResultRanges(value, argRanges.getValue());
|
|
});
|
|
}
|