[mlir][tensor] Fix slice canonicalizer for out-of-bounds cases (#132534)

Since #130487, `tensor.extract_slice` and `tensor.insert_slice` ops that
are statically detected to go out of bounds are rejected by the
verifier.

This commit fixes canonicalization patterns that currently fold
dynamically out-of-bounds ops (valid IR) to statically out-of-bounds ops
(invalid IR).
This commit is contained in:
Matthias Springer 2025-03-24 14:39:37 +01:00 committed by GitHub
parent 85974a0537
commit 529ee3cf3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 194 additions and 45 deletions

View File

@ -45,6 +45,28 @@ unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
namespace mlir {
/// Result for slice bounds verification;
struct SliceBoundsVerificationResult {
/// If set to "true", the slice bounds verification was successful.
bool isValid;
/// An error message that can be printed during op verification.
std::string errorMessage;
};
/// Verify that the offsets/sizes/strides-style access into the given shape
/// is in-bounds. Only static values are verified. If `generateErrorMessage`
/// is set to "true", an error message is produced that can be printed by the
/// op verifier.
SliceBoundsVerificationResult
verifyInBoundsSlice(ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides,
bool generateErrorMessage = false);
SliceBoundsVerificationResult verifyInBoundsSlice(
ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
bool generateErrorMessage = false);
/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
/// constant arguments. This pattern assumes that the op has a suitable builder
/// that takes a result type, a "source" operand and mixed offsets, sizes and
@ -54,7 +76,8 @@ namespace mlir {
/// returns the new result type of the op, based on the new offsets, sizes and
/// strides. `CastOpFunc` is used to generate a cast op if the result type of
/// the op has changed.
template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
template <typename OpType, typename ResultTypeFn, typename CastOpFunc,
bool CheckInBounds = false>
class OpWithOffsetSizesAndStridesConstantArgumentFolder final
: public OpRewritePattern<OpType> {
public:
@ -72,11 +95,22 @@ public:
failed(foldDynamicIndexList(mixedStrides)))
return failure();
// Create the new op in canonical form.
if (CheckInBounds) {
// Pattern does not apply if the produced op would not verify.
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
mixedSizes, mixedStrides);
if (!sliceResult.isValid)
return failure();
}
// Compute the new result type.
auto resultType =
ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
if (!resultType)
return failure();
// Create the new op in canonical form.
auto newOp =
rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
mixedOffsets, mixedSizes, mixedStrides);

View File

@ -27,6 +27,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
}
}
/// Verify that the offsets/sizes/strides-style access into the given tensor
/// is in-bounds. Only static information is verified.
static LogicalResult verifyInBoundsSlice(Operation *op,
RankedTensorType tensorType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides) {
for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
// Nothing to verify for dynamic source dims.
if (tensorType.isDynamicDim(i))
continue;
// Nothing to verify if the offset is dynamic.
if (ShapedType::isDynamic(staticOffsets[i]))
continue;
if (staticOffsets[i] >= tensorType.getDimSize(i))
return op->emitOpError("offset ")
<< i << " is out-of-bounds: " << staticOffsets[i]
<< " >= " << tensorType.getDimSize(i);
if (ShapedType::isDynamic(staticSizes[i]) ||
ShapedType::isDynamic(staticStrides[i]))
continue;
int64_t lastPos =
staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
if (lastPos >= tensorType.getDimSize(i))
return op->emitOpError("slice along dimension ")
<< i << " runs out-of-bounds: " << lastPos
<< " >= " << tensorType.getDimSize(i);
}
return success();
}
/// Verifier for ExtractSliceOp.
LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() {
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the source tensor.
return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(),
getStaticSizes(), getStaticStrides());
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
getStaticStrides(), /*generateErrorMessage=*/true);
if (!boundsResult.isValid)
return getOperation()->emitError(boundsResult.errorMessage);
return success();
}
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
@ -2470,6 +2445,14 @@ public:
if (!canFoldIntoConsumerOp(castOp))
return failure();
// Pattern does not apply if the produced op would not verify.
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
sliceOp.getStaticStrides());
if (!sliceResult.isValid)
return failure();
// Create folded extract.
Location loc = sliceOp.getLoc();
Value newResult = rewriter.create<ExtractSliceOp>(
@ -2634,10 +2617,10 @@ struct SliceCanonicalizer {
void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
OpWithOffsetSizesAndStridesConstantArgumentFolder<
ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
ExtractSliceOpCastFolder>(context);
results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
ExtractSliceOp, SliceReturnTypeCanonicalizer,
SliceCanonicalizer, /*CheckInBounds=*/true>,
ExtractSliceOpCastFolder>(context);
}
//
@ -2775,9 +2758,14 @@ LogicalResult InsertSliceOp::verify() {
return produceSliceErrorMsg(result, *this, expectedType);
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the source tensor.
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
getStaticSizes(), getStaticStrides());
// to the destination tensor.
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
getStaticStrides(), /*generateErrorMessage=*/true);
if (!boundsResult.isValid)
return getOperation()->emitError(boundsResult.errorMessage);
return success();
}
/// If we have two consecutive InsertSliceOp writing to the same slice, we
@ -2872,6 +2860,13 @@ public:
failed(foldDynamicStrideList(mixedStrides)))
return failure();
// Pattern does not apply if the produced op would not verify.
SliceBoundsVerificationResult sliceResult =
verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
mixedOffsets, mixedSizes, mixedStrides);
if (!sliceResult.isValid)
return failure();
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
@ -2969,10 +2964,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
size = srcType.getDimSize(rankReducedIdx++);
}
}
// Pattern does not apply if the produced op would not verify.
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
staticSizes, insertSliceOp.getStaticStrides()) !=
SliceVerificationResult::Success)
return failure();
SliceBoundsVerificationResult sliceResult =
verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
mixedSizes, insertSliceOp.getMixedStrides());
if (!sliceResult.isValid)
return failure();
Operation *replacement = rewriter.create<InsertOpTy>(
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
@ -3800,9 +3802,14 @@ LogicalResult ParallelInsertSliceOp::verify() {
return produceSliceErrorMsg(result, *this, expectedType);
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the source tensor.
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
getStaticSizes(), getStaticStrides());
// to the destination tensor.
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
getStaticStrides(), /*generateErrorMessage=*/true);
if (!boundsResult.isValid)
return getOperation()->emitError(boundsResult.errorMessage);
return success();
}
void ParallelInsertSliceOp::getCanonicalizationPatterns(

View File

@ -36,6 +36,64 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
return success();
}
SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
bool generateErrorMessage) {
SliceBoundsVerificationResult result;
result.isValid = true;
for (int64_t i = 0, e = shape.size(); i < e; ++i) {
// Nothing to verify for dynamic source dims.
if (ShapedType::isDynamic(shape[i]))
continue;
// Nothing to verify if the offset is dynamic.
if (ShapedType::isDynamic(staticOffsets[i]))
continue;
if (staticOffsets[i] >= shape[i]) {
result.errorMessage =
std::string("offset ") + std::to_string(i) +
" is out-of-bounds: " + std::to_string(staticOffsets[i]) +
" >= " + std::to_string(shape[i]);
result.isValid = false;
return result;
}
if (ShapedType::isDynamic(staticSizes[i]) ||
ShapedType::isDynamic(staticStrides[i]))
continue;
int64_t lastPos =
staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
if (lastPos >= shape[i]) {
result.errorMessage = std::string("slice along dimension ") +
std::to_string(i) +
" runs out-of-bounds: " + std::to_string(lastPos) +
" >= " + std::to_string(shape[i]);
result.isValid = false;
return result;
}
}
return result;
}
SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
bool generateErrorMessage) {
auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
SmallVector<int64_t> staticValues;
for (OpFoldResult ofr : ofrs) {
if (auto attr = dyn_cast<Attribute>(ofr)) {
staticValues.push_back(cast<IntegerAttr>(attr).getInt());
} else {
staticValues.push_back(ShapedType::kDynamic);
}
}
return staticValues;
};
return verifyInBoundsSlice(
shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
getStaticValues(mixedStrides), generateErrorMessage);
}
LogicalResult
mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();

View File

@ -582,6 +582,56 @@ func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<1
// -----
// CHECK-LABEL: func @out_of_bounds_extract_slice
// CHECK: tensor.extract_slice %{{.*}}[0] [%{{.*}}] [1] : tensor<5xf32> to tensor<?xf32>
func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<?xf32> {
%c10 = arith.constant 10 : index
%r = tensor.extract_slice %t[0] [%c10] [1] : tensor<5xf32> to tensor<?xf32>
return %r : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @out_of_bounds_extract_slice
// CHECK: tensor.extract_slice %{{.*}}[0] [10] [1] : tensor<?xf32> to tensor<10xf32>
func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<10xf32> {
%t2 = tensor.cast %t : tensor<5xf32> to tensor<?xf32>
%r = tensor.extract_slice %t2 [0][10][1] : tensor<?xf32> to tensor<10xf32>
return %r : tensor<10xf32>
}
// -----
// CHECK-LABEL: func @out_of_bounds_insert_slice
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [5] [1] : tensor<5xf32> into tensor<10xf32>
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>) -> tensor<10xf32> {
%c10 = arith.constant 10 : index
%r = tensor.insert_slice %src into %dst[%c10] [5] [1] : tensor<5xf32> into tensor<10xf32>
return %r : tensor<10xf32>
}
// -----
// CHECK-LABEL: func @out_of_bounds_insert_slice
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [%{{.*}}] [1] : tensor<?xf32> into tensor<10xf32>
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<10xf32> {
%src2 = tensor.cast %src : tensor<5xf32> to tensor<?xf32>
%r = tensor.insert_slice %src2 into %dst[7] [%sz] [1] : tensor<?xf32> into tensor<10xf32>
return %r : tensor<10xf32>
}
// -----
// CHECK-LABEL: func @out_of_bounds_insert_slice
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<?xf32> {
%dst2 = tensor.cast %dst : tensor<10xf32> to tensor<?xf32>
%r = tensor.insert_slice %src into %dst2[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
return %r : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
// CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
// CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>