mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 16:26:05 +00:00
[mlir][Index] Implement InferIntRangeInterface, re-land
Re-land D140899 to fix a missing dependency in the index dialect's CMakeLists.txt. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D142147
This commit is contained in:
parent
5b4ed49050
commit
5af9d16dae
@ -13,6 +13,7 @@
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
include "mlir/Dialect/Index/IR/IndexDialect.td"
|
||||
include "mlir/Dialect/Index/IR/IndexEnums.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/InferIntRangeInterface.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
@ -23,7 +24,8 @@ include "mlir/IR/OpBase.td"
|
||||
|
||||
/// Base class for Index dialect operations.
|
||||
class IndexOp<string mnemonic, list<Trait> traits = []>
|
||||
: Op<IndexDialect, mnemonic, [Pure] # traits>;
|
||||
: Op<IndexDialect, mnemonic,
|
||||
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>] # traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IndexBinaryOp
|
||||
|
126
mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
Normal file
126
mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
Normal file
@ -0,0 +1,126 @@
|
||||
//===- InferIntRangeCommon.cpp - Inference for common ops --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file declares implementations of range inference for operations that are
|
||||
// common to both the `arith` and `index` dialects to facilitate reuse.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
|
||||
#define MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
|
||||
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace intrange {
|
||||
/// Function that performs inference on an array of `ConstantIntRanges`,
|
||||
/// abstracted away here to permit writing the function that handles both
|
||||
/// 64- and 32-bit index types.
|
||||
using InferRangeFn =
|
||||
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
|
||||
|
||||
static constexpr unsigned indexMinWidth = 32;
|
||||
static constexpr unsigned indexMaxWidth = 64;
|
||||
|
||||
enum class CmpMode : uint32_t { Both, Signed, Unsigned };
|
||||
|
||||
/// Compute `inferFn` on `ranges`, whose size should be the index storage
|
||||
/// bitwidth. Then, compute the function on `argRanges` again after truncating
|
||||
/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
|
||||
/// equal to the 32-bit result, use it (to preserve compatibility with folders
|
||||
/// and inference precision), and take the union of the results otherwise.
|
||||
///
|
||||
/// The `mode` argument specifies if the unsigned, signed, or both results of
|
||||
/// the inference computation should be used when comparing the results.
|
||||
ConstantIntRanges inferIndexOp(InferRangeFn inferFn,
|
||||
ArrayRef<ConstantIntRanges> argRanges,
|
||||
CmpMode mode);
|
||||
|
||||
/// Independently zero-extend the unsigned values and sign-extend the signed
|
||||
/// values in `range` to `destWidth` bits, returning the resulting range.
|
||||
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth);
|
||||
|
||||
/// Use the unsigned values in `range` to zero-extend it to `destWidth`.
|
||||
ConstantIntRanges extUIRange(const ConstantIntRanges &range,
|
||||
unsigned destWidth);
|
||||
|
||||
/// Use the signed values in `range` to sign-extend it to `destWidth`.
|
||||
ConstantIntRanges extSIRange(const ConstantIntRanges &range,
|
||||
unsigned destWidth);
|
||||
|
||||
/// Truncate `range` to `destWidth` bits, taking care to handle cases such as
|
||||
/// the truncation of [255, 256] to i8 not being a uniform range.
|
||||
ConstantIntRanges truncRange(const ConstantIntRanges &range,
|
||||
unsigned destWidth);
|
||||
|
||||
ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferDivU(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferRemS(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferRemU(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferMaxS(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferMaxU(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferMinS(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferMinU(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferAnd(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
ConstantIntRanges inferShrU(ArrayRef<ConstantIntRanges> argRanges);
|
||||
|
||||
/// Copy of the enum from `arith` and `index` to allow the common integer range
|
||||
/// infrastructure to not depend on either dialect.
|
||||
enum class CmpPredicate : uint64_t {
|
||||
eq,
|
||||
ne,
|
||||
slt,
|
||||
sle,
|
||||
sgt,
|
||||
sge,
|
||||
ult,
|
||||
ule,
|
||||
ugt,
|
||||
uge,
|
||||
};
|
||||
|
||||
/// Returns a boolean value if `pred` is statically true or false for
|
||||
/// anypossible inputs falling within `lhs` and `rhs`, and std::nullopt if the
|
||||
/// value of the predicate cannot be determined.
|
||||
Optional<bool> evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs,
|
||||
const ConstantIntRanges &rhs);
|
||||
|
||||
} // namespace intrange
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
|
@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRArithDialect
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRDialect
|
||||
MLIRInferIntRangeCommon
|
||||
MLIRInferIntRangeInterface
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include <optional>
|
||||
@ -16,48 +17,7 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::arith;
|
||||
|
||||
/// Function that evaluates the result of doing something on arithmetic
|
||||
/// constants and returns std::nullopt on overflow.
|
||||
using ConstArithFn =
|
||||
function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
|
||||
|
||||
/// Return the maxmially wide signed or unsigned range for a given bitwidth.
|
||||
|
||||
/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
|
||||
/// If either computation overflows, make the result unbounded.
|
||||
static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
|
||||
const APInt &minRight,
|
||||
const APInt &maxLeft,
|
||||
const APInt &maxRight, bool isSigned) {
|
||||
std::optional<APInt> maybeMin = op(minLeft, minRight);
|
||||
std::optional<APInt> maybeMax = op(maxLeft, maxRight);
|
||||
if (maybeMin && maybeMax)
|
||||
return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
|
||||
return ConstantIntRanges::maxRange(minLeft.getBitWidth());
|
||||
}
|
||||
|
||||
/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
|
||||
/// ignoring unbounded values. Returns the maximal range if `op` overflows.
|
||||
static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
|
||||
ArrayRef<APInt> rhs, bool isSigned) {
|
||||
unsigned width = lhs[0].getBitWidth();
|
||||
APInt min =
|
||||
isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
|
||||
APInt max =
|
||||
isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
|
||||
for (const APInt &left : lhs) {
|
||||
for (const APInt &right : rhs) {
|
||||
std::optional<APInt> maybeThisResult = op(left, right);
|
||||
if (!maybeThisResult)
|
||||
return ConstantIntRanges::maxRange(width);
|
||||
APInt result = std::move(*maybeThisResult);
|
||||
min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
|
||||
max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
|
||||
}
|
||||
}
|
||||
return ConstantIntRanges::range(min, max, isSigned);
|
||||
}
|
||||
using namespace mlir::intrange;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantOp
|
||||
@ -78,25 +38,7 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
ConstArithFn uadd = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.uadd_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
ConstArithFn sadd = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.sadd_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
|
||||
ConstantIntRanges urange = computeBoundsBy(
|
||||
uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
|
||||
ConstantIntRanges srange = computeBoundsBy(
|
||||
sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
|
||||
setResultRange(getResult(), urange.intersection(srange));
|
||||
setResultRange(getResult(), inferAdd(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -105,25 +47,7 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
ConstArithFn usub = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.usub_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
ConstArithFn ssub = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.ssub_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
ConstantIntRanges urange = computeBoundsBy(
|
||||
usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
|
||||
ConstantIntRanges srange = computeBoundsBy(
|
||||
ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
|
||||
setResultRange(getResult(), urange.intersection(srange));
|
||||
setResultRange(getResult(), inferSub(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -132,96 +56,25 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
ConstArithFn umul = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.umul_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
ConstArithFn smul = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.smul_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
|
||||
ConstantIntRanges urange =
|
||||
minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
|
||||
/*isSigned=*/false);
|
||||
ConstantIntRanges srange =
|
||||
minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
|
||||
/*isSigned=*/true);
|
||||
|
||||
setResultRange(getResult(), urange.intersection(srange));
|
||||
setResultRange(getResult(), inferMul(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DivUIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Fix up division results (ex. for ceiling and floor), returning an APInt
|
||||
/// if there has been no overflow
|
||||
using DivisionFixupFn = function_ref<std::optional<APInt>(
|
||||
const APInt &lhs, const APInt &rhs, const APInt &result)>;
|
||||
|
||||
static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs,
|
||||
const ConstantIntRanges &rhs,
|
||||
DivisionFixupFn fixup) {
|
||||
const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
|
||||
&rhsMax = rhs.umax();
|
||||
|
||||
if (!rhsMin.isZero()) {
|
||||
auto udiv = [&fixup](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
return fixup(a, b, a.udiv(b));
|
||||
};
|
||||
return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
|
||||
/*isSigned=*/false);
|
||||
}
|
||||
// Otherwise, it's possible we might divide by 0.
|
||||
return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
|
||||
}
|
||||
|
||||
void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferDivUIRange(argRanges[0], argRanges[1],
|
||||
[](const APInt &lhs, const APInt &rhs,
|
||||
const APInt &result) { return result; }));
|
||||
setResultRange(getResult(), inferDivU(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DivSIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs,
|
||||
const ConstantIntRanges &rhs,
|
||||
DivisionFixupFn fixup) {
|
||||
const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
|
||||
&rhsMax = rhs.smax();
|
||||
bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
|
||||
|
||||
if (canDivide) {
|
||||
auto sdiv = [&fixup](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.sdiv_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : fixup(a, b, result);
|
||||
};
|
||||
return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
|
||||
/*isSigned=*/true);
|
||||
}
|
||||
return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
|
||||
}
|
||||
|
||||
void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferDivSIRange(argRanges[0], argRanges[1],
|
||||
[](const APInt &lhs, const APInt &rhs,
|
||||
const APInt &result) { return result; }));
|
||||
setResultRange(getResult(), inferDivS(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -230,20 +83,7 @@ void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::CeilDivUIOp::inferResultRanges(
|
||||
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
DivisionFixupFn ceilDivUIFix =
|
||||
[](const APInt &lhs, const APInt &rhs,
|
||||
const APInt &result) -> std::optional<APInt> {
|
||||
if (!lhs.urem(rhs).isZero()) {
|
||||
bool overflowed = false;
|
||||
APInt corrected =
|
||||
result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
|
||||
return overflowed ? std::optional<APInt>() : corrected;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix));
|
||||
setResultRange(getResult(), inferCeilDivU(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -252,20 +92,7 @@ void arith::CeilDivUIOp::inferResultRanges(
|
||||
|
||||
void arith::CeilDivSIOp::inferResultRanges(
|
||||
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
DivisionFixupFn ceilDivSIFix =
|
||||
[](const APInt &lhs, const APInt &rhs,
|
||||
const APInt &result) -> std::optional<APInt> {
|
||||
if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
|
||||
bool overflowed = false;
|
||||
APInt corrected =
|
||||
result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
|
||||
return overflowed ? std::optional<APInt>() : corrected;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix));
|
||||
setResultRange(getResult(), inferCeilDivS(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -274,20 +101,7 @@ void arith::CeilDivSIOp::inferResultRanges(
|
||||
|
||||
void arith::FloorDivSIOp::inferResultRanges(
|
||||
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
DivisionFixupFn floorDivSIFix =
|
||||
[](const APInt &lhs, const APInt &rhs,
|
||||
const APInt &result) -> std::optional<APInt> {
|
||||
if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
|
||||
bool overflowed = false;
|
||||
APInt corrected =
|
||||
result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
|
||||
return overflowed ? std::optional<APInt>() : corrected;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix));
|
||||
return setResultRange(getResult(), inferFloorDivS(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -296,29 +110,7 @@ void arith::FloorDivSIOp::inferResultRanges(
|
||||
|
||||
void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
|
||||
|
||||
unsigned width = rhsMin.getBitWidth();
|
||||
APInt umin = APInt::getZero(width);
|
||||
APInt umax = APInt::getMaxValue(width);
|
||||
|
||||
if (!rhsMin.isZero()) {
|
||||
umax = rhsMax - 1;
|
||||
// Special case: sweeping out a contiguous range in N/[modulus]
|
||||
if (rhsMin == rhsMax) {
|
||||
const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
|
||||
if ((lhsMax - lhsMin).ult(rhsMax)) {
|
||||
APInt minRem = lhsMin.urem(rhsMax);
|
||||
APInt maxRem = lhsMax.urem(rhsMax);
|
||||
if (minRem.ule(maxRem)) {
|
||||
umin = minRem;
|
||||
umax = maxRem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
|
||||
setResultRange(getResult(), inferRemU(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -327,67 +119,16 @@ void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
|
||||
&rhsMax = rhs.smax();
|
||||
|
||||
unsigned width = rhsMax.getBitWidth();
|
||||
APInt smin = APInt::getSignedMinValue(width);
|
||||
APInt smax = APInt::getSignedMaxValue(width);
|
||||
// No bounds if zero could be a divisor.
|
||||
bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
|
||||
if (canBound) {
|
||||
APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
|
||||
bool canNegativeDividend = lhsMin.isNegative();
|
||||
bool canPositiveDividend = lhsMax.isStrictlyPositive();
|
||||
APInt zero = APInt::getZero(maxDivisor.getBitWidth());
|
||||
APInt maxPositiveResult = maxDivisor - 1;
|
||||
APInt minNegativeResult = -maxPositiveResult;
|
||||
smin = canNegativeDividend ? minNegativeResult : zero;
|
||||
smax = canPositiveDividend ? maxPositiveResult : zero;
|
||||
// Special case: sweeping out a contiguous range in N/[modulus].
|
||||
if (rhsMin == rhsMax) {
|
||||
if ((lhsMax - lhsMin).ult(maxDivisor)) {
|
||||
APInt minRem = lhsMin.srem(maxDivisor);
|
||||
APInt maxRem = lhsMax.srem(maxDivisor);
|
||||
if (minRem.sle(maxRem)) {
|
||||
smin = minRem;
|
||||
smax = maxRem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
|
||||
setResultRange(getResult(), inferRemS(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AndIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
|
||||
/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
|
||||
/// that both bonuds have in common. This gives us a consertive approximation
|
||||
/// for what values can be passed to bitwise operations.
|
||||
static std::tuple<APInt, APInt>
|
||||
widenBitwiseBounds(const ConstantIntRanges &bound) {
|
||||
APInt leftVal = bound.umin(), rightVal = bound.umax();
|
||||
unsigned bitwidth = leftVal.getBitWidth();
|
||||
unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
|
||||
leftVal.clearLowBits(differingBits);
|
||||
rightVal.setLowBits(differingBits);
|
||||
return std::make_tuple(std::move(leftVal), std::move(rightVal));
|
||||
}
|
||||
|
||||
void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
|
||||
auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
|
||||
auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
|
||||
return a & b;
|
||||
};
|
||||
setResultRange(getResult(),
|
||||
minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
|
||||
/*isSigned=*/false));
|
||||
setResultRange(getResult(), inferAnd(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -396,14 +137,7 @@ void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
|
||||
auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
|
||||
auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
|
||||
return a | b;
|
||||
};
|
||||
setResultRange(getResult(),
|
||||
minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
|
||||
/*isSigned=*/false));
|
||||
setResultRange(getResult(), inferOr(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -412,14 +146,7 @@ void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
|
||||
auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
|
||||
auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
|
||||
return a ^ b;
|
||||
};
|
||||
setResultRange(getResult(),
|
||||
minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
|
||||
/*isSigned=*/false));
|
||||
setResultRange(getResult(), inferXor(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -428,11 +155,7 @@ void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
|
||||
const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
|
||||
setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
|
||||
setResultRange(getResult(), inferMaxS(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -441,11 +164,7 @@ void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
|
||||
const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
|
||||
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
|
||||
setResultRange(getResult(), inferMaxU(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -454,11 +173,7 @@ void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
|
||||
const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
|
||||
setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
|
||||
setResultRange(getResult(), inferMinS(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -467,94 +182,40 @@ void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
|
||||
const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
|
||||
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
|
||||
setResultRange(getResult(), inferMinU(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtUIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ConstantIntRanges extUIRange(const ConstantIntRanges &range,
|
||||
Type destType) {
|
||||
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
|
||||
APInt umin = range.umin().zext(destWidth);
|
||||
APInt umax = range.umax().zext(destWidth);
|
||||
return ConstantIntRanges::fromUnsigned(umin, umax);
|
||||
}
|
||||
|
||||
void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
Type destType = getResult().getType();
|
||||
setResultRange(getResult(), extUIRange(argRanges[0], destType));
|
||||
unsigned destWidth =
|
||||
ConstantIntRanges::getStorageBitwidth(getResult().getType());
|
||||
setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtSIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ConstantIntRanges extSIRange(const ConstantIntRanges &range,
|
||||
Type destType) {
|
||||
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
|
||||
APInt smin = range.smin().sext(destWidth);
|
||||
APInt smax = range.smax().sext(destWidth);
|
||||
return ConstantIntRanges::fromSigned(smin, smax);
|
||||
}
|
||||
|
||||
void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
Type destType = getResult().getType();
|
||||
setResultRange(getResult(), extSIRange(argRanges[0], destType));
|
||||
unsigned destWidth =
|
||||
ConstantIntRanges::getStorageBitwidth(getResult().getType());
|
||||
setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TruncIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ConstantIntRanges truncIRange(const ConstantIntRanges &range,
|
||||
Type destType) {
|
||||
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
|
||||
// If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
|
||||
// the range of the resulting value is not contiguous ind includes 0.
|
||||
// Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
|
||||
// but you can't truncate [255, 257] similarly.
|
||||
bool hasUnsignedRollover =
|
||||
range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
|
||||
APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
|
||||
: range.umin().trunc(destWidth);
|
||||
APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
|
||||
: range.umax().trunc(destWidth);
|
||||
|
||||
// Signed post-truncation rollover will not occur when either:
|
||||
// - The high parts of the min and max, plus the sign bit, are the same
|
||||
// - The high halves + sign bit of the min and max are either all 1s or all 0s
|
||||
// and you won't create a [positive, negative] range by truncating.
|
||||
// For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
|
||||
// but not [255, 257]_i16 to a range of i8s. You can also truncate
|
||||
// [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
|
||||
// You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
|
||||
// will truncate to 0x7e, which is greater than 0
|
||||
APInt sminHighPart = range.smin().ashr(destWidth - 1);
|
||||
APInt smaxHighPart = range.smax().ashr(destWidth - 1);
|
||||
bool hasSignedOverflow =
|
||||
(sminHighPart != smaxHighPart) &&
|
||||
!(sminHighPart.isAllOnes() &&
|
||||
(smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
|
||||
!(sminHighPart.isZero() && smaxHighPart.isZero());
|
||||
APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
|
||||
: range.smin().trunc(destWidth);
|
||||
APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
|
||||
: range.smax().trunc(destWidth);
|
||||
return {umin, umax, smin, smax};
|
||||
}
|
||||
|
||||
void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
Type destType = getResult().getType();
|
||||
setResultRange(getResult(), truncIRange(argRanges[0], destType));
|
||||
unsigned destWidth =
|
||||
ConstantIntRanges::getStorageBitwidth(getResult().getType());
|
||||
setResultRange(getResult(), truncRange(argRanges[0], destWidth));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -569,9 +230,9 @@ void arith::IndexCastOp::inferResultRanges(
|
||||
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
|
||||
|
||||
if (srcWidth < destWidth)
|
||||
setResultRange(getResult(), extSIRange(argRanges[0], destType));
|
||||
setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
|
||||
else if (srcWidth > destWidth)
|
||||
setResultRange(getResult(), truncIRange(argRanges[0], destType));
|
||||
setResultRange(getResult(), truncRange(argRanges[0], destWidth));
|
||||
else
|
||||
setResultRange(getResult(), argRanges[0]);
|
||||
}
|
||||
@ -588,9 +249,9 @@ void arith::IndexCastUIOp::inferResultRanges(
|
||||
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
|
||||
|
||||
if (srcWidth < destWidth)
|
||||
setResultRange(getResult(), extUIRange(argRanges[0], destType));
|
||||
setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
|
||||
else if (srcWidth > destWidth)
|
||||
setResultRange(getResult(), truncIRange(argRanges[0], destType));
|
||||
setResultRange(getResult(), truncRange(argRanges[0], destWidth));
|
||||
else
|
||||
setResultRange(getResult(), argRanges[0]);
|
||||
}
|
||||
@ -599,51 +260,19 @@ void arith::IndexCastUIOp::inferResultRanges(
|
||||
// CmpIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs,
|
||||
const ConstantIntRanges &rhs) {
|
||||
switch (pred) {
|
||||
case arith::CmpIPredicate::sle:
|
||||
case arith::CmpIPredicate::slt:
|
||||
return (applyCmpPredicate(pred, lhs.smax(), rhs.smin()));
|
||||
case arith::CmpIPredicate::ule:
|
||||
case arith::CmpIPredicate::ult:
|
||||
return applyCmpPredicate(pred, lhs.umax(), rhs.umin());
|
||||
case arith::CmpIPredicate::sge:
|
||||
case arith::CmpIPredicate::sgt:
|
||||
return applyCmpPredicate(pred, lhs.smin(), rhs.smax());
|
||||
case arith::CmpIPredicate::uge:
|
||||
case arith::CmpIPredicate::ugt:
|
||||
return applyCmpPredicate(pred, lhs.umin(), rhs.umax());
|
||||
case arith::CmpIPredicate::eq: {
|
||||
std::optional<APInt> lhsConst = lhs.getConstantValue();
|
||||
std::optional<APInt> rhsConst = rhs.getConstantValue();
|
||||
return lhsConst && rhsConst && lhsConst == rhsConst;
|
||||
}
|
||||
case arith::CmpIPredicate::ne: {
|
||||
// While equality requires that there is an interpration of the preceeding
|
||||
// computations that produces equal constants, whether that be signed or
|
||||
// unsigned, statically determining inequality requires that neither
|
||||
// interpretation produce potentially overlapping ranges.
|
||||
bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) ||
|
||||
isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs);
|
||||
bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) ||
|
||||
isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs);
|
||||
return sne && une;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
arith::CmpIPredicate pred = getPredicate();
|
||||
arith::CmpIPredicate arithPred = getPredicate();
|
||||
intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
APInt min = APInt::getZero(1);
|
||||
APInt max = APInt::getAllOnesValue(1);
|
||||
if (isStaticallyTrue(pred, lhs, rhs))
|
||||
|
||||
Optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
|
||||
if (truthValue.has_value() && *truthValue)
|
||||
min = max;
|
||||
else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
|
||||
else if (truthValue.has_value() && !(*truthValue))
|
||||
max = min;
|
||||
|
||||
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
|
||||
@ -673,18 +302,7 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
ConstArithFn shl = [](const APInt &l,
|
||||
const APInt &r) -> std::optional<APInt> {
|
||||
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
|
||||
};
|
||||
ConstantIntRanges urange =
|
||||
minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
|
||||
/*isSigned=*/false);
|
||||
ConstantIntRanges srange =
|
||||
minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
|
||||
/*isSigned=*/true);
|
||||
setResultRange(getResult(), urange.intersection(srange));
|
||||
setResultRange(getResult(), inferShl(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -693,15 +311,7 @@ void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
ConstArithFn lshr = [](const APInt &l,
|
||||
const APInt &r) -> std::optional<APInt> {
|
||||
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
|
||||
};
|
||||
setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()},
|
||||
{rhs.umin(), rhs.umax()},
|
||||
/*isSigned=*/false));
|
||||
setResultRange(getResult(), inferShrU(argRanges));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -710,14 +320,5 @@ void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
|
||||
void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
ConstArithFn ashr = [](const APInt &l,
|
||||
const APInt &r) -> std::optional<APInt> {
|
||||
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
|
||||
};
|
||||
|
||||
setResultRange(getResult(),
|
||||
minMaxBy(ashr, {lhs.smin(), lhs.smax()},
|
||||
{rhs.umin(), rhs.umax()}, /*isSigned=*/true));
|
||||
setResultRange(getResult(), inferShrS(argRanges));
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIndexDialect
|
||||
IndexAttrs.cpp
|
||||
IndexDialect.cpp
|
||||
IndexOps.cpp
|
||||
InferIntRangeInterfaceImpls.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRIndexOpsIncGen
|
||||
@ -10,6 +11,8 @@ add_mlir_dialect_library(MLIRIndexDialect
|
||||
MLIRDialect
|
||||
MLIRIR
|
||||
MLIRCastInterfaces
|
||||
MLIRInferIntRangeCommon
|
||||
MLIRInferIntRangeInterface
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRSideEffectInterfaces
|
||||
)
|
||||
|
252
mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
Normal file
252
mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
Normal file
@ -0,0 +1,252 @@
|
||||
//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "int-range-analysis"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::index;
|
||||
using namespace mlir::intrange;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Constants
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
const APInt &value = getValue();
|
||||
setResultRange(getResult(), ConstantIntRanges::constant(value));
|
||||
}
|
||||
|
||||
void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
bool value = getValue();
|
||||
APInt asInt(/*numBits=*/1, value);
|
||||
setResultRange(getResult(), ConstantIntRanges::constant(asInt));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Arithmec operations. All of these operations will have their results inferred
|
||||
// using both the 64-bit values and truncated 32-bit values of their inputs,
|
||||
// with the results being the union of those inferences, except where the
|
||||
// truncation of the 64-bit result is equal to the 32-bit result (at which time
|
||||
// we take the 64-bit result).
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
|
||||
}
|
||||
|
||||
void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
|
||||
}
|
||||
|
||||
void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
|
||||
}
|
||||
|
||||
void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
|
||||
}
|
||||
|
||||
void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
|
||||
}
|
||||
|
||||
void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
|
||||
}
|
||||
|
||||
void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
|
||||
}
|
||||
|
||||
void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
return setResultRange(
|
||||
getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
|
||||
}
|
||||
|
||||
void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
|
||||
}
|
||||
|
||||
void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
|
||||
}
|
||||
|
||||
void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
|
||||
}
|
||||
|
||||
void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
|
||||
}
|
||||
|
||||
void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
|
||||
}
|
||||
|
||||
void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
|
||||
}
|
||||
|
||||
void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
|
||||
}
|
||||
|
||||
void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
|
||||
}
|
||||
|
||||
void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
|
||||
}
|
||||
|
||||
void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
|
||||
}
|
||||
|
||||
void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
|
||||
}
|
||||
|
||||
void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
setResultRange(getResult(),
|
||||
inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Casts
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range,
|
||||
unsigned srcWidth, unsigned destWidth,
|
||||
bool isSigned) {
|
||||
if (srcWidth < destWidth)
|
||||
return isSigned ? extSIRange(range, destWidth)
|
||||
: extUIRange(range, destWidth);
|
||||
if (srcWidth > destWidth)
|
||||
return truncRange(range, destWidth);
|
||||
return range;
|
||||
}
|
||||
|
||||
// When casting to `index`, we will take the union of the possible fixed-width
|
||||
// casts.
|
||||
static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
|
||||
Type sourceType, Type destType,
|
||||
bool isSigned) {
|
||||
unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
|
||||
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
|
||||
if (sourceType.isIndex())
|
||||
return makeLikeDest(range, srcWidth, destWidth, isSigned);
|
||||
// We are casting to indexs, so use the union of the 32-bit and 64-bit casts
|
||||
ConstantIntRanges storageRange =
|
||||
makeLikeDest(range, srcWidth, destWidth, isSigned);
|
||||
ConstantIntRanges minWidthRange =
|
||||
makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
|
||||
ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
|
||||
ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
Type sourceType = getOperand().getType();
|
||||
Type destType = getResult().getType();
|
||||
setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
|
||||
/*isSigned=*/true));
|
||||
}
|
||||
|
||||
void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
Type sourceType = getOperand().getType();
|
||||
Type destType = getResult().getType();
|
||||
setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
|
||||
/*isSigned=*/false));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CmpOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
index::IndexCmpPredicate indexPred = getPred();
|
||||
intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
APInt min = APInt::getZero(1);
|
||||
APInt max = APInt::getAllOnesValue(1);
|
||||
|
||||
Optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
|
||||
|
||||
ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
|
||||
rhsTrunc = truncRange(rhs, indexMinWidth);
|
||||
Optional<bool> truthValue32 =
|
||||
intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
|
||||
|
||||
if (truthValue64 == truthValue32) {
|
||||
if (truthValue64.has_value() && *truthValue64)
|
||||
min = max;
|
||||
else if (truthValue64.has_value() && !(*truthValue64))
|
||||
max = min;
|
||||
}
|
||||
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRange) {
|
||||
unsigned storageWidth =
|
||||
ConstantIntRanges::getStorageBitwidth(getResult().getType());
|
||||
APInt min(/*numBits=*/storageWidth, indexMinWidth);
|
||||
APInt max(/*numBits=*/storageWidth, indexMaxWidth);
|
||||
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
|
||||
}
|
@ -51,3 +51,5 @@ add_mlir_interface_library(SideEffectInterfaces)
|
||||
add_mlir_interface_library(TilingInterface)
|
||||
add_mlir_interface_library(VectorInterfaces)
|
||||
add_mlir_interface_library(ViewLikeInterface)
|
||||
|
||||
add_subdirectory(Utils)
|
||||
|
13
mlir/lib/Interfaces/Utils/CMakeLists.txt
Normal file
13
mlir/lib/Interfaces/Utils/CMakeLists.txt
Normal file
@ -0,0 +1,13 @@
|
||||
add_mlir_library(MLIRInferIntRangeCommon
|
||||
InferIntRangeCommon.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils
|
||||
|
||||
DEPENDS
|
||||
MLIRInferIntRangeInterfaceIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRInferIntRangeInterface
|
||||
MLIRIR
|
||||
)
|
663
mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
Normal file
663
mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
Normal file
@ -0,0 +1,663 @@
|
||||
//===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains implementations of range inference for operations that are
|
||||
// common to both the `arith` and `index` dialects to facilitate reuse.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
|
||||
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define DEBUG_TYPE "int-range-analysis"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// General utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Function that evaluates the result of doing something on arithmetic
|
||||
/// constants and returns std::nullopt on overflow.
|
||||
using ConstArithFn =
|
||||
function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
|
||||
|
||||
/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
|
||||
/// If either computation overflows, make the result unbounded.
|
||||
static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
|
||||
const APInt &minRight,
|
||||
const APInt &maxLeft,
|
||||
const APInt &maxRight, bool isSigned) {
|
||||
std::optional<APInt> maybeMin = op(minLeft, minRight);
|
||||
std::optional<APInt> maybeMax = op(maxLeft, maxRight);
|
||||
if (maybeMin && maybeMax)
|
||||
return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
|
||||
return ConstantIntRanges::maxRange(minLeft.getBitWidth());
|
||||
}
|
||||
|
||||
/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
|
||||
/// ignoring unbounded values. Returns the maximal range if `op` overflows.
|
||||
static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
|
||||
ArrayRef<APInt> rhs, bool isSigned) {
|
||||
unsigned width = lhs[0].getBitWidth();
|
||||
APInt min =
|
||||
isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
|
||||
APInt max =
|
||||
isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
|
||||
for (const APInt &left : lhs) {
|
||||
for (const APInt &right : rhs) {
|
||||
std::optional<APInt> maybeThisResult = op(left, right);
|
||||
if (!maybeThisResult)
|
||||
return ConstantIntRanges::maxRange(width);
|
||||
APInt result = std::move(*maybeThisResult);
|
||||
min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
|
||||
max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
|
||||
}
|
||||
}
|
||||
return ConstantIntRanges::range(min, max, isSigned);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ext, trunc, index op handling
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferIndexOp(InferRangeFn inferFn,
|
||||
ArrayRef<ConstantIntRanges> argRanges,
|
||||
intrange::CmpMode mode) {
|
||||
ConstantIntRanges sixtyFour = inferFn(argRanges);
|
||||
SmallVector<ConstantIntRanges, 2> truncated;
|
||||
llvm::transform(argRanges, std::back_inserter(truncated),
|
||||
[](const ConstantIntRanges &range) {
|
||||
return truncRange(range, /*destWidth=*/indexMinWidth);
|
||||
});
|
||||
ConstantIntRanges thirtyTwo = inferFn(truncated);
|
||||
ConstantIntRanges thirtyTwoAsSixtyFour =
|
||||
extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
|
||||
ConstantIntRanges sixtyFourAsThirtyTwo =
|
||||
truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
|
||||
<< " 32-bit = " << thirtyTwo << "\n");
|
||||
bool truncEqual = false;
|
||||
switch (mode) {
|
||||
case intrange::CmpMode::Both:
|
||||
truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
|
||||
break;
|
||||
case intrange::CmpMode::Signed:
|
||||
truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
|
||||
thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
|
||||
break;
|
||||
case intrange::CmpMode::Unsigned:
|
||||
truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
|
||||
thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
|
||||
break;
|
||||
}
|
||||
if (truncEqual)
|
||||
// Returing the 64-bit result preserves more information.
|
||||
return sixtyFour;
|
||||
ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
|
||||
return merged;
|
||||
}
|
||||
|
||||
ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
|
||||
unsigned int destWidth) {
|
||||
APInt umin = range.umin().zext(destWidth);
|
||||
APInt umax = range.umax().zext(destWidth);
|
||||
APInt smin = range.smin().sext(destWidth);
|
||||
APInt smax = range.smax().sext(destWidth);
|
||||
return {umin, umax, smin, smax};
|
||||
}
|
||||
|
||||
ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
|
||||
unsigned destWidth) {
|
||||
APInt umin = range.umin().zext(destWidth);
|
||||
APInt umax = range.umax().zext(destWidth);
|
||||
return ConstantIntRanges::fromUnsigned(umin, umax);
|
||||
}
|
||||
|
||||
ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
|
||||
unsigned destWidth) {
|
||||
APInt smin = range.smin().sext(destWidth);
|
||||
APInt smax = range.smax().sext(destWidth);
|
||||
return ConstantIntRanges::fromSigned(smin, smax);
|
||||
}
|
||||
|
||||
ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
|
||||
unsigned int destWidth) {
|
||||
// If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
|
||||
// the range of the resulting value is not contiguous ind includes 0.
|
||||
// Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
|
||||
// but you can't truncate [255, 257] similarly.
|
||||
bool hasUnsignedRollover =
|
||||
range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
|
||||
APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
|
||||
: range.umin().trunc(destWidth);
|
||||
APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
|
||||
: range.umax().trunc(destWidth);
|
||||
|
||||
// Signed post-truncation rollover will not occur when either:
|
||||
// - The high parts of the min and max, plus the sign bit, are the same
|
||||
// - The high halves + sign bit of the min and max are either all 1s or all 0s
|
||||
// and you won't create a [positive, negative] range by truncating.
|
||||
// For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
|
||||
// but not [255, 257]_i16 to a range of i8s. You can also truncate
|
||||
// [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
|
||||
// You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
|
||||
// will truncate to 0x7e, which is greater than 0
|
||||
APInt sminHighPart = range.smin().ashr(destWidth - 1);
|
||||
APInt smaxHighPart = range.smax().ashr(destWidth - 1);
|
||||
bool hasSignedOverflow =
|
||||
(sminHighPart != smaxHighPart) &&
|
||||
!(sminHighPart.isAllOnes() &&
|
||||
(smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
|
||||
!(sminHighPart.isZero() && smaxHighPart.isZero());
|
||||
APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
|
||||
: range.smin().trunc(destWidth);
|
||||
APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
|
||||
: range.smax().trunc(destWidth);
|
||||
return {umin, umax, smin, smax};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Addition
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
ConstArithFn uadd = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.uadd_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
ConstArithFn sadd = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.sadd_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
|
||||
ConstantIntRanges urange = computeBoundsBy(
|
||||
uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
|
||||
ConstantIntRanges srange = computeBoundsBy(
|
||||
sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
|
||||
return urange.intersection(srange);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Subtraction
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
ConstArithFn usub = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.usub_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
ConstArithFn ssub = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.ssub_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
ConstantIntRanges urange = computeBoundsBy(
|
||||
usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
|
||||
ConstantIntRanges srange = computeBoundsBy(
|
||||
ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
|
||||
return urange.intersection(srange);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Multiplication
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
ConstArithFn umul = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.umul_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
ConstArithFn smul = [](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.smul_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : result;
|
||||
};
|
||||
|
||||
ConstantIntRanges urange =
|
||||
minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
|
||||
/*isSigned=*/false);
|
||||
ConstantIntRanges srange =
|
||||
minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
|
||||
/*isSigned=*/true);
|
||||
return urange.intersection(srange);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DivU, CeilDivU (Unsigned division)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Fix up division results (ex. for ceiling and floor), returning an APInt
|
||||
/// if there has been no overflow
|
||||
using DivisionFixupFn = function_ref<std::optional<APInt>(
|
||||
const APInt &lhs, const APInt &rhs, const APInt &result)>;
|
||||
|
||||
static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
|
||||
const ConstantIntRanges &rhs,
|
||||
DivisionFixupFn fixup) {
|
||||
const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
|
||||
&rhsMax = rhs.umax();
|
||||
|
||||
if (!rhsMin.isZero()) {
|
||||
auto udiv = [&fixup](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
return fixup(a, b, a.udiv(b));
|
||||
};
|
||||
return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
|
||||
/*isSigned=*/false);
|
||||
}
|
||||
// Otherwise, it's possible we might divide by 0.
|
||||
return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
return inferDivURange(argRanges[0], argRanges[1],
|
||||
[](const APInt &lhs, const APInt &rhs,
|
||||
const APInt &result) { return result; });
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
DivisionFixupFn ceilDivUIFix =
|
||||
[](const APInt &lhs, const APInt &rhs,
|
||||
const APInt &result) -> std::optional<APInt> {
|
||||
if (!lhs.urem(rhs).isZero()) {
|
||||
bool overflowed = false;
|
||||
APInt corrected =
|
||||
result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
|
||||
return overflowed ? std::optional<APInt>() : corrected;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
return inferDivURange(lhs, rhs, ceilDivUIFix);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DivS, CeilDivS, FloorDivS (Signed division)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
|
||||
const ConstantIntRanges &rhs,
|
||||
DivisionFixupFn fixup) {
|
||||
const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
|
||||
&rhsMax = rhs.smax();
|
||||
bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
|
||||
|
||||
if (canDivide) {
|
||||
auto sdiv = [&fixup](const APInt &a,
|
||||
const APInt &b) -> std::optional<APInt> {
|
||||
bool overflowed = false;
|
||||
APInt result = a.sdiv_ov(b, overflowed);
|
||||
return overflowed ? std::optional<APInt>() : fixup(a, b, result);
|
||||
};
|
||||
return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
|
||||
/*isSigned=*/true);
|
||||
}
|
||||
return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
return inferDivSRange(argRanges[0], argRanges[1],
|
||||
[](const APInt &lhs, const APInt &rhs,
|
||||
const APInt &result) { return result; });
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
DivisionFixupFn ceilDivSIFix =
|
||||
[](const APInt &lhs, const APInt &rhs,
|
||||
const APInt &result) -> std::optional<APInt> {
|
||||
if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
|
||||
bool overflowed = false;
|
||||
APInt corrected =
|
||||
result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
|
||||
return overflowed ? std::optional<APInt>() : corrected;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
return inferDivSRange(lhs, rhs, ceilDivSIFix);
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
DivisionFixupFn floorDivSIFix =
|
||||
[](const APInt &lhs, const APInt &rhs,
|
||||
const APInt &result) -> std::optional<APInt> {
|
||||
if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
|
||||
bool overflowed = false;
|
||||
APInt corrected =
|
||||
result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
|
||||
return overflowed ? std::optional<APInt>() : corrected;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
return inferDivSRange(lhs, rhs, floorDivSIFix);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Signed remainder (RemS)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
|
||||
&rhsMax = rhs.smax();
|
||||
|
||||
unsigned width = rhsMax.getBitWidth();
|
||||
APInt smin = APInt::getSignedMinValue(width);
|
||||
APInt smax = APInt::getSignedMaxValue(width);
|
||||
// No bounds if zero could be a divisor.
|
||||
bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
|
||||
if (canBound) {
|
||||
APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
|
||||
bool canNegativeDividend = lhsMin.isNegative();
|
||||
bool canPositiveDividend = lhsMax.isStrictlyPositive();
|
||||
APInt zero = APInt::getZero(maxDivisor.getBitWidth());
|
||||
APInt maxPositiveResult = maxDivisor - 1;
|
||||
APInt minNegativeResult = -maxPositiveResult;
|
||||
smin = canNegativeDividend ? minNegativeResult : zero;
|
||||
smax = canPositiveDividend ? maxPositiveResult : zero;
|
||||
// Special case: sweeping out a contiguous range in N/[modulus].
|
||||
if (rhsMin == rhsMax) {
|
||||
if ((lhsMax - lhsMin).ult(maxDivisor)) {
|
||||
APInt minRem = lhsMin.srem(maxDivisor);
|
||||
APInt maxRem = lhsMax.srem(maxDivisor);
|
||||
if (minRem.sle(maxRem)) {
|
||||
smin = minRem;
|
||||
smax = maxRem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ConstantIntRanges::fromSigned(smin, smax);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unsigned remainder (RemU)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
|
||||
|
||||
unsigned width = rhsMin.getBitWidth();
|
||||
APInt umin = APInt::getZero(width);
|
||||
APInt umax = APInt::getMaxValue(width);
|
||||
|
||||
if (!rhsMin.isZero()) {
|
||||
umax = rhsMax - 1;
|
||||
// Special case: sweeping out a contiguous range in N/[modulus]
|
||||
if (rhsMin == rhsMax) {
|
||||
const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
|
||||
if ((lhsMax - lhsMin).ult(rhsMax)) {
|
||||
APInt minRem = lhsMin.urem(rhsMax);
|
||||
APInt maxRem = lhsMax.urem(rhsMax);
|
||||
if (minRem.ule(maxRem)) {
|
||||
umin = minRem;
|
||||
umax = maxRem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ConstantIntRanges::fromUnsigned(umin, umax);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Max and min (MaxS, MaxU, MinS, MinU)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
|
||||
const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
|
||||
return ConstantIntRanges::fromSigned(smin, smax);
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
|
||||
const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
|
||||
return ConstantIntRanges::fromUnsigned(umin, umax);
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
|
||||
const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
|
||||
return ConstantIntRanges::fromSigned(smin, smax);
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
|
||||
const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
|
||||
return ConstantIntRanges::fromUnsigned(umin, umax);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bitwise operators (And, Or, Xor)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
|
||||
/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
|
||||
/// that both bonuds have in common. This gives us a consertive approximation
|
||||
/// for what values can be passed to bitwise operations.
|
||||
static std::tuple<APInt, APInt>
|
||||
widenBitwiseBounds(const ConstantIntRanges &bound) {
|
||||
APInt leftVal = bound.umin(), rightVal = bound.umax();
|
||||
unsigned bitwidth = leftVal.getBitWidth();
|
||||
unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
|
||||
leftVal.clearLowBits(differingBits);
|
||||
rightVal.setLowBits(differingBits);
|
||||
return std::make_tuple(std::move(leftVal), std::move(rightVal));
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
|
||||
auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
|
||||
auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
|
||||
return a & b;
|
||||
};
|
||||
return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
|
||||
/*isSigned=*/false);
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
|
||||
auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
|
||||
auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
|
||||
return a | b;
|
||||
};
|
||||
return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
|
||||
/*isSigned=*/false);
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
|
||||
auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
|
||||
auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
|
||||
return a ^ b;
|
||||
};
|
||||
return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
|
||||
/*isSigned=*/false);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shifts (Shl, ShrS, ShrU)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
ConstArithFn shl = [](const APInt &l,
|
||||
const APInt &r) -> std::optional<APInt> {
|
||||
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
|
||||
};
|
||||
ConstantIntRanges urange =
|
||||
minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
|
||||
/*isSigned=*/false);
|
||||
ConstantIntRanges srange =
|
||||
minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
|
||||
/*isSigned=*/true);
|
||||
return urange.intersection(srange);
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
ConstArithFn ashr = [](const APInt &l,
|
||||
const APInt &r) -> std::optional<APInt> {
|
||||
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
|
||||
};
|
||||
|
||||
return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
|
||||
/*isSigned=*/true);
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
|
||||
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
||||
|
||||
ConstArithFn lshr = [](const APInt &l,
|
||||
const APInt &r) -> std::optional<APInt> {
|
||||
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
|
||||
};
|
||||
return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
|
||||
/*isSigned=*/false);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Comparisons (Cmp)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) {
|
||||
switch (pred) {
|
||||
case intrange::CmpPredicate::eq:
|
||||
return intrange::CmpPredicate::ne;
|
||||
case intrange::CmpPredicate::ne:
|
||||
return intrange::CmpPredicate::eq;
|
||||
case intrange::CmpPredicate::slt:
|
||||
return intrange::CmpPredicate::sge;
|
||||
case intrange::CmpPredicate::sle:
|
||||
return intrange::CmpPredicate::sgt;
|
||||
case intrange::CmpPredicate::sgt:
|
||||
return intrange::CmpPredicate::sle;
|
||||
case intrange::CmpPredicate::sge:
|
||||
return intrange::CmpPredicate::slt;
|
||||
case intrange::CmpPredicate::ult:
|
||||
return intrange::CmpPredicate::uge;
|
||||
case intrange::CmpPredicate::ule:
|
||||
return intrange::CmpPredicate::ugt;
|
||||
case intrange::CmpPredicate::ugt:
|
||||
return intrange::CmpPredicate::ule;
|
||||
case intrange::CmpPredicate::uge:
|
||||
return intrange::CmpPredicate::ult;
|
||||
}
|
||||
llvm_unreachable("unknown cmp predicate value");
|
||||
}
|
||||
|
||||
static bool isStaticallyTrue(intrange::CmpPredicate pred,
|
||||
const ConstantIntRanges &lhs,
|
||||
const ConstantIntRanges &rhs) {
|
||||
switch (pred) {
|
||||
case intrange::CmpPredicate::sle:
|
||||
return lhs.smax().sle(rhs.smin());
|
||||
case intrange::CmpPredicate::slt:
|
||||
return lhs.smax().slt(rhs.smin());
|
||||
case intrange::CmpPredicate::ule:
|
||||
return lhs.umax().ule(rhs.umin());
|
||||
case intrange::CmpPredicate::ult:
|
||||
return lhs.umax().ult(rhs.umin());
|
||||
case intrange::CmpPredicate::sge:
|
||||
return lhs.smin().sge(rhs.smax());
|
||||
case intrange::CmpPredicate::sgt:
|
||||
return lhs.smin().sgt(rhs.smax());
|
||||
case intrange::CmpPredicate::uge:
|
||||
return lhs.umin().uge(rhs.umax());
|
||||
case intrange::CmpPredicate::ugt:
|
||||
return lhs.umin().ugt(rhs.umax());
|
||||
case intrange::CmpPredicate::eq: {
|
||||
std::optional<APInt> lhsConst = lhs.getConstantValue();
|
||||
std::optional<APInt> rhsConst = rhs.getConstantValue();
|
||||
return lhsConst && rhsConst && lhsConst == rhsConst;
|
||||
}
|
||||
case intrange::CmpPredicate::ne: {
|
||||
// While equality requires that there is an interpration of the preceeding
|
||||
// computations that produces equal constants, whether that be signed or
|
||||
// unsigned, statically determining inequality requires that neither
|
||||
// interpretation produce potentially overlapping ranges.
|
||||
bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) ||
|
||||
isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs);
|
||||
bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) ||
|
||||
isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs);
|
||||
return sne && une;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
|
||||
const ConstantIntRanges &lhs,
|
||||
const ConstantIntRanges &rhs) {
|
||||
if (isStaticallyTrue(pred, lhs, rhs))
|
||||
return true;
|
||||
if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
|
||||
return false;
|
||||
return std::nullopt;
|
||||
}
|
66
mlir/test/Dialect/Index/int-range-inference.mlir
Normal file
66
mlir/test/Dialect/Index/int-range-inference.mlir
Normal file
@ -0,0 +1,66 @@
|
||||
// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
|
||||
|
||||
// Most operations are covered by the `arith` tests, which use the same code
|
||||
// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling
|
||||
// code is operating as expected.
|
||||
|
||||
// CHECK-LABEL: func @add_same_for_both
|
||||
// CHECK: %[[true:.*]] = index.bool.constant true
|
||||
// CHECK: return %[[true]]
|
||||
func.func @add_same_for_both(%arg0 : index) -> i1 {
|
||||
%c1 = index.constant 1
|
||||
%calmostBig = index.constant 0xfffffffe
|
||||
%0 = index.minu %arg0, %calmostBig
|
||||
%1 = index.add %0, %c1
|
||||
%2 = index.cmp uge(%1, %c1)
|
||||
func.return %2 : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add_unsigned_ov
|
||||
// CHECK: %[[uge:.*]] = index.cmp uge
|
||||
// CHECK: return %[[uge]]
|
||||
func.func @add_unsigned_ov(%arg0 : index) -> i1 {
|
||||
%c1 = index.constant 1
|
||||
%cu32_max = index.constant 0xffffffff
|
||||
%0 = index.minu %arg0, %cu32_max
|
||||
%1 = index.add %0, %c1
|
||||
// On 32-bit, the add could wrap, so the result doesn't have to be >= 1
|
||||
%2 = index.cmp uge(%1, %c1)
|
||||
func.return %2 : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add_signed_ov
|
||||
// CHECK: %[[sge:.*]] = index.cmp sge
|
||||
// CHECK: return %[[sge]]
|
||||
func.func @add_signed_ov(%arg0 : index) -> i1 {
|
||||
%c0 = index.constant 0
|
||||
%c1 = index.constant 1
|
||||
%ci32_max = index.constant 0x7fffffff
|
||||
%0 = index.minu %arg0, %ci32_max
|
||||
%1 = index.add %0, %c1
|
||||
// On 32-bit, the add could wrap, so the result doesn't have to be positive
|
||||
%2 = index.cmp sge(%1, %c0)
|
||||
func.return %2 : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add_big
|
||||
// CHECK: %[[true:.*]] = index.bool.constant true
|
||||
// CHECK: return %[[true]]
|
||||
func.func @add_big(%arg0 : index) -> i1 {
|
||||
%c1 = index.constant 1
|
||||
%cmin = index.constant 0x300000000
|
||||
%cmax = index.constant 0x30000ffff
|
||||
// Note: the order of the clamps matters.
|
||||
// If you go max, then min, you infer the ranges [0x300...0, 0xff..ff]
|
||||
// and then [0x30...0000, 0x30...ffff]
|
||||
// If you switch the order of the below operations, you instead first infer
|
||||
// the range [0,0x3...ffff]. Then, the min inference can't constraint
|
||||
// this intermediate, since in the 32-bit case we could have, for example
|
||||
// trunc(%arg0 = 0x2ffffffff) = 0xffffffff > trunc(0x30000ffff) = 0x0000ffff
|
||||
// which means we can't do any inference.
|
||||
%0 = index.maxu %arg0, %cmin
|
||||
%1 = index.minu %0, %cmax
|
||||
%2 = index.add %1, %c1
|
||||
%3 = index.cmp uge(%1, %cmin)
|
||||
func.return %3 : i1
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user