mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-18 12:16:49 +00:00
[mlir][tensor] Fix slice canonicalizer for out-of-bounds cases (#132534)
Since #130487, `tensor.extract_slice` and `tensor.insert_slice` ops that are statically detected to go out of bounds are rejected by the verifier. This commit fixes canonicalization patterns that currently fold dynamically out-of-bounds ops (valid IR) to statically out-of-bounds ops (invalid IR).
This commit is contained in:
parent
85974a0537
commit
529ee3cf3b
@ -45,6 +45,28 @@ unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Result for slice bounds verification;
|
||||
struct SliceBoundsVerificationResult {
|
||||
/// If set to "true", the slice bounds verification was successful.
|
||||
bool isValid;
|
||||
/// An error message that can be printed during op verification.
|
||||
std::string errorMessage;
|
||||
};
|
||||
|
||||
/// Verify that the offsets/sizes/strides-style access into the given shape
|
||||
/// is in-bounds. Only static values are verified. If `generateErrorMessage`
|
||||
/// is set to "true", an error message is produced that can be printed by the
|
||||
/// op verifier.
|
||||
SliceBoundsVerificationResult
|
||||
verifyInBoundsSlice(ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
|
||||
ArrayRef<int64_t> staticSizes,
|
||||
ArrayRef<int64_t> staticStrides,
|
||||
bool generateErrorMessage = false);
|
||||
SliceBoundsVerificationResult verifyInBoundsSlice(
|
||||
ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
|
||||
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
|
||||
bool generateErrorMessage = false);
|
||||
|
||||
/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
|
||||
/// constant arguments. This pattern assumes that the op has a suitable builder
|
||||
/// that takes a result type, a "source" operand and mixed offsets, sizes and
|
||||
@ -54,7 +76,8 @@ namespace mlir {
|
||||
/// returns the new result type of the op, based on the new offsets, sizes and
|
||||
/// strides. `CastOpFunc` is used to generate a cast op if the result type of
|
||||
/// the op has changed.
|
||||
template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
|
||||
template <typename OpType, typename ResultTypeFn, typename CastOpFunc,
|
||||
bool CheckInBounds = false>
|
||||
class OpWithOffsetSizesAndStridesConstantArgumentFolder final
|
||||
: public OpRewritePattern<OpType> {
|
||||
public:
|
||||
@ -72,11 +95,22 @@ public:
|
||||
failed(foldDynamicIndexList(mixedStrides)))
|
||||
return failure();
|
||||
|
||||
// Create the new op in canonical form.
|
||||
if (CheckInBounds) {
|
||||
// Pattern does not apply if the produced op would not verify.
|
||||
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
|
||||
cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
|
||||
mixedSizes, mixedStrides);
|
||||
if (!sliceResult.isValid)
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Compute the new result type.
|
||||
auto resultType =
|
||||
ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
|
||||
if (!resultType)
|
||||
return failure();
|
||||
|
||||
// Create the new op in canonical form.
|
||||
auto newOp =
|
||||
rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
|
||||
mixedOffsets, mixedSizes, mixedStrides);
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h"
|
||||
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify that the offsets/sizes/strides-style access into the given tensor
|
||||
/// is in-bounds. Only static information is verified.
|
||||
static LogicalResult verifyInBoundsSlice(Operation *op,
|
||||
RankedTensorType tensorType,
|
||||
ArrayRef<int64_t> staticOffsets,
|
||||
ArrayRef<int64_t> staticSizes,
|
||||
ArrayRef<int64_t> staticStrides) {
|
||||
for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
|
||||
// Nothing to verify for dynamic source dims.
|
||||
if (tensorType.isDynamicDim(i))
|
||||
continue;
|
||||
// Nothing to verify if the offset is dynamic.
|
||||
if (ShapedType::isDynamic(staticOffsets[i]))
|
||||
continue;
|
||||
if (staticOffsets[i] >= tensorType.getDimSize(i))
|
||||
return op->emitOpError("offset ")
|
||||
<< i << " is out-of-bounds: " << staticOffsets[i]
|
||||
<< " >= " << tensorType.getDimSize(i);
|
||||
if (ShapedType::isDynamic(staticSizes[i]) ||
|
||||
ShapedType::isDynamic(staticStrides[i]))
|
||||
continue;
|
||||
int64_t lastPos =
|
||||
staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
|
||||
if (lastPos >= tensorType.getDimSize(i))
|
||||
return op->emitOpError("slice along dimension ")
|
||||
<< i << " runs out-of-bounds: " << lastPos
|
||||
<< " >= " << tensorType.getDimSize(i);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Verifier for ExtractSliceOp.
|
||||
LogicalResult ExtractSliceOp::verify() {
|
||||
RankedTensorType sourceType = getSourceType();
|
||||
@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() {
|
||||
|
||||
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
|
||||
// to the source tensor.
|
||||
return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(),
|
||||
getStaticSizes(), getStaticStrides());
|
||||
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
|
||||
sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
|
||||
getStaticStrides(), /*generateErrorMessage=*/true);
|
||||
if (!boundsResult.isValid)
|
||||
return getOperation()->emitError(boundsResult.errorMessage);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
|
||||
@ -2470,6 +2445,14 @@ public:
|
||||
if (!canFoldIntoConsumerOp(castOp))
|
||||
return failure();
|
||||
|
||||
// Pattern does not apply if the produced op would not verify.
|
||||
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
|
||||
cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
|
||||
sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
|
||||
sliceOp.getStaticStrides());
|
||||
if (!sliceResult.isValid)
|
||||
return failure();
|
||||
|
||||
// Create folded extract.
|
||||
Location loc = sliceOp.getLoc();
|
||||
Value newResult = rewriter.create<ExtractSliceOp>(
|
||||
@ -2634,10 +2617,10 @@ struct SliceCanonicalizer {
|
||||
|
||||
void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<
|
||||
OpWithOffsetSizesAndStridesConstantArgumentFolder<
|
||||
ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
|
||||
ExtractSliceOpCastFolder>(context);
|
||||
results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
|
||||
ExtractSliceOp, SliceReturnTypeCanonicalizer,
|
||||
SliceCanonicalizer, /*CheckInBounds=*/true>,
|
||||
ExtractSliceOpCastFolder>(context);
|
||||
}
|
||||
|
||||
//
|
||||
@ -2775,9 +2758,14 @@ LogicalResult InsertSliceOp::verify() {
|
||||
return produceSliceErrorMsg(result, *this, expectedType);
|
||||
|
||||
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
|
||||
// to the source tensor.
|
||||
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
|
||||
getStaticSizes(), getStaticStrides());
|
||||
// to the destination tensor.
|
||||
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
|
||||
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
|
||||
getStaticStrides(), /*generateErrorMessage=*/true);
|
||||
if (!boundsResult.isValid)
|
||||
return getOperation()->emitError(boundsResult.errorMessage);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// If we have two consecutive InsertSliceOp writing to the same slice, we
|
||||
@ -2872,6 +2860,13 @@ public:
|
||||
failed(foldDynamicStrideList(mixedStrides)))
|
||||
return failure();
|
||||
|
||||
// Pattern does not apply if the produced op would not verify.
|
||||
SliceBoundsVerificationResult sliceResult =
|
||||
verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
|
||||
mixedOffsets, mixedSizes, mixedStrides);
|
||||
if (!sliceResult.isValid)
|
||||
return failure();
|
||||
|
||||
// Create the new op in canonical form.
|
||||
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
|
||||
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
|
||||
@ -2969,10 +2964,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
|
||||
size = srcType.getDimSize(rankReducedIdx++);
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern does not apply if the produced op would not verify.
|
||||
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
|
||||
staticSizes, insertSliceOp.getStaticStrides()) !=
|
||||
SliceVerificationResult::Success)
|
||||
return failure();
|
||||
SliceBoundsVerificationResult sliceResult =
|
||||
verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
|
||||
mixedSizes, insertSliceOp.getMixedStrides());
|
||||
if (!sliceResult.isValid)
|
||||
return failure();
|
||||
|
||||
Operation *replacement = rewriter.create<InsertOpTy>(
|
||||
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
|
||||
@ -3800,9 +3802,14 @@ LogicalResult ParallelInsertSliceOp::verify() {
|
||||
return produceSliceErrorMsg(result, *this, expectedType);
|
||||
|
||||
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
|
||||
// to the source tensor.
|
||||
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
|
||||
getStaticSizes(), getStaticStrides());
|
||||
// to the destination tensor.
|
||||
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
|
||||
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
|
||||
getStaticStrides(), /*generateErrorMessage=*/true);
|
||||
if (!boundsResult.isValid)
|
||||
return getOperation()->emitError(boundsResult.errorMessage);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void ParallelInsertSliceOp::getCanonicalizationPatterns(
|
||||
|
@ -36,6 +36,64 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
|
||||
return success();
|
||||
}
|
||||
|
||||
SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
|
||||
ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
|
||||
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
|
||||
bool generateErrorMessage) {
|
||||
SliceBoundsVerificationResult result;
|
||||
result.isValid = true;
|
||||
for (int64_t i = 0, e = shape.size(); i < e; ++i) {
|
||||
// Nothing to verify for dynamic source dims.
|
||||
if (ShapedType::isDynamic(shape[i]))
|
||||
continue;
|
||||
// Nothing to verify if the offset is dynamic.
|
||||
if (ShapedType::isDynamic(staticOffsets[i]))
|
||||
continue;
|
||||
if (staticOffsets[i] >= shape[i]) {
|
||||
result.errorMessage =
|
||||
std::string("offset ") + std::to_string(i) +
|
||||
" is out-of-bounds: " + std::to_string(staticOffsets[i]) +
|
||||
" >= " + std::to_string(shape[i]);
|
||||
result.isValid = false;
|
||||
return result;
|
||||
}
|
||||
if (ShapedType::isDynamic(staticSizes[i]) ||
|
||||
ShapedType::isDynamic(staticStrides[i]))
|
||||
continue;
|
||||
int64_t lastPos =
|
||||
staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
|
||||
if (lastPos >= shape[i]) {
|
||||
result.errorMessage = std::string("slice along dimension ") +
|
||||
std::to_string(i) +
|
||||
" runs out-of-bounds: " + std::to_string(lastPos) +
|
||||
" >= " + std::to_string(shape[i]);
|
||||
result.isValid = false;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
|
||||
ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
|
||||
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
|
||||
bool generateErrorMessage) {
|
||||
auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
|
||||
SmallVector<int64_t> staticValues;
|
||||
for (OpFoldResult ofr : ofrs) {
|
||||
if (auto attr = dyn_cast<Attribute>(ofr)) {
|
||||
staticValues.push_back(cast<IntegerAttr>(attr).getInt());
|
||||
} else {
|
||||
staticValues.push_back(ShapedType::kDynamic);
|
||||
}
|
||||
}
|
||||
return staticValues;
|
||||
};
|
||||
return verifyInBoundsSlice(
|
||||
shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
|
||||
getStaticValues(mixedStrides), generateErrorMessage);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
|
||||
std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
|
||||
|
@ -582,6 +582,56 @@ func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @out_of_bounds_extract_slice
|
||||
// CHECK: tensor.extract_slice %{{.*}}[0] [%{{.*}}] [1] : tensor<5xf32> to tensor<?xf32>
|
||||
func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<?xf32> {
|
||||
%c10 = arith.constant 10 : index
|
||||
%r = tensor.extract_slice %t[0] [%c10] [1] : tensor<5xf32> to tensor<?xf32>
|
||||
return %r : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @out_of_bounds_extract_slice
|
||||
// CHECK: tensor.extract_slice %{{.*}}[0] [10] [1] : tensor<?xf32> to tensor<10xf32>
|
||||
func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<10xf32> {
|
||||
%t2 = tensor.cast %t : tensor<5xf32> to tensor<?xf32>
|
||||
%r = tensor.extract_slice %t2 [0][10][1] : tensor<?xf32> to tensor<10xf32>
|
||||
return %r : tensor<10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @out_of_bounds_insert_slice
|
||||
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [5] [1] : tensor<5xf32> into tensor<10xf32>
|
||||
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>) -> tensor<10xf32> {
|
||||
%c10 = arith.constant 10 : index
|
||||
%r = tensor.insert_slice %src into %dst[%c10] [5] [1] : tensor<5xf32> into tensor<10xf32>
|
||||
return %r : tensor<10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @out_of_bounds_insert_slice
|
||||
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [%{{.*}}] [1] : tensor<?xf32> into tensor<10xf32>
|
||||
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<10xf32> {
|
||||
%src2 = tensor.cast %src : tensor<5xf32> to tensor<?xf32>
|
||||
%r = tensor.insert_slice %src2 into %dst[7] [%sz] [1] : tensor<?xf32> into tensor<10xf32>
|
||||
return %r : tensor<10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @out_of_bounds_insert_slice
|
||||
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
|
||||
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<?xf32> {
|
||||
%dst2 = tensor.cast %dst : tensor<10xf32> to tensor<?xf32>
|
||||
%r = tensor.insert_slice %src into %dst2[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
|
||||
return %r : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
|
||||
// CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
|
||||
// CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
|
||||
|
Loading…
x
Reference in New Issue
Block a user