mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-24 04:16:08 +00:00
[mlir] Replace dynamic sizes in insert_slice of tensor.cast canonicalization (#91352)
In some cases this pattern may ignore static information due to dynamic operands in the insert_slice sizes operands, e.g.: ``` %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32> %1 = tensor.insert_slice %0 into %arg1[...] [%s0, %s1] [...] : tensor<?x?xf32> into tensor<?x?xf32> ``` Can be rewritten into: ``` %1 = tensor.insert_slice %arg0 into %arg1[...] [1, %s1] [...] : tensor<1x?xf32> into tensor<?x?xf32> ``` This PR updates the matching in the pattern to allow rewrites like this.
This commit is contained in:
parent
2f956a35ed
commit
7e35a9a0e7
@ -360,9 +360,15 @@ private:
|
||||
/// which dimensions must be kept when e.g. compute MemRef strides under
|
||||
/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
|
||||
/// obtained by dropping only `1` entries in `originalShape`.
|
||||
/// If `matchDynamic` is true, then dynamic dims in `originalShape` and
|
||||
/// `reducedShape` will be considered matching with non-dynamic dims, unless
|
||||
/// the non-dynamic dim is from `originalShape` and equal to 1. For example,
|
||||
/// in ([1, 3, ?], [?, 5]), the mask would be {1, 0, 0}, since 3 and 5 will
|
||||
/// match with the corresponding dynamic dims.
|
||||
std::optional<llvm::SmallDenseSet<unsigned>>
|
||||
computeRankReductionMask(ArrayRef<int64_t> originalShape,
|
||||
ArrayRef<int64_t> reducedShape);
|
||||
ArrayRef<int64_t> reducedShape,
|
||||
bool matchDynamic = false);
|
||||
|
||||
/// Enum that captures information related to verifier error conditions on
|
||||
/// slice insert/extract type of ops.
|
||||
|
@ -2713,15 +2713,38 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
|
||||
auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
|
||||
if (!srcType || !dstType)
|
||||
return failure();
|
||||
|
||||
// The tensor.cast source could have additional static information not seen
|
||||
// in the insert slice op static sizes, so we ignore dynamic dims when
|
||||
// computing the rank reduction mask.
|
||||
SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
|
||||
auto rankReductionMask = computeRankReductionMask(
|
||||
staticSizes, srcType.getShape(), /*matchDynamic=*/true);
|
||||
if (!rankReductionMask.has_value())
|
||||
return failure();
|
||||
// Replace dimensions in the insert slice op with corresponding static dims
|
||||
// from the cast source type. If the insert slice sizes have static dims
|
||||
// that are not static in the tensor.cast source (i.e., when the cast op
|
||||
// casts a dynamic dim to static), the dim should not be replaced, and the
|
||||
// pattern will fail later in `verifyInsertSliceOp`.
|
||||
SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
|
||||
int64_t rankReducedIdx = 0;
|
||||
for (auto [idx, size] : enumerate(staticSizes)) {
|
||||
if (!rankReductionMask.value().contains(idx) &&
|
||||
!srcType.isDynamicDim(rankReducedIdx)) {
|
||||
mixedSizes[idx] = getAsIndexOpFoldResult(
|
||||
rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
|
||||
size = srcType.getDimSize(rankReducedIdx++);
|
||||
}
|
||||
}
|
||||
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
|
||||
insertSliceOp.getStaticSizes(),
|
||||
insertSliceOp.getStaticStrides()) !=
|
||||
staticSizes, insertSliceOp.getStaticStrides()) !=
|
||||
SliceVerificationResult::Success)
|
||||
return failure();
|
||||
|
||||
Operation *replacement = rewriter.create<InsertOpTy>(
|
||||
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
|
||||
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
|
||||
mixedSizes, insertSliceOp.getMixedStrides());
|
||||
|
||||
// In the parallel case there is no result and so nothing to cast.
|
||||
bool isParallelInsert =
|
||||
|
@ -408,24 +408,24 @@ unsigned BaseMemRefType::getMemorySpaceAsInt() const {
|
||||
// MemRefType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
|
||||
/// `originalShape` with some `1` entries erased, return the set of indices
|
||||
/// that specifies which of the entries of `originalShape` are dropped to obtain
|
||||
/// `reducedShape`. The returned mask can be applied as a projection to
|
||||
/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
|
||||
/// which dimensions must be kept when e.g. compute MemRef strides under
|
||||
/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
|
||||
/// obtained by dropping only `1` entries in `originalShape`.
|
||||
std::optional<llvm::SmallDenseSet<unsigned>>
|
||||
mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
|
||||
ArrayRef<int64_t> reducedShape) {
|
||||
ArrayRef<int64_t> reducedShape,
|
||||
bool matchDynamic) {
|
||||
size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
|
||||
llvm::SmallDenseSet<unsigned> unusedDims;
|
||||
unsigned reducedIdx = 0;
|
||||
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
|
||||
// Greedily insert `originalIdx` if match.
|
||||
if (reducedIdx < reducedRank &&
|
||||
originalShape[originalIdx] == reducedShape[reducedIdx]) {
|
||||
int64_t origSize = originalShape[originalIdx];
|
||||
// if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
|
||||
if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
|
||||
(ShapedType::isDynamic(reducedShape[reducedIdx]) ||
|
||||
ShapedType::isDynamic(origSize))) {
|
||||
reducedIdx++;
|
||||
continue;
|
||||
}
|
||||
if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
|
||||
reducedIdx++;
|
||||
continue;
|
||||
}
|
||||
@ -433,7 +433,7 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
|
||||
unusedDims.insert(originalIdx);
|
||||
// If no match on `originalIdx`, the `originalShape` at this dimension
|
||||
// must be 1, otherwise we bail.
|
||||
if (originalShape[originalIdx] != 1)
|
||||
if (origSize != 1)
|
||||
return std::nullopt;
|
||||
}
|
||||
// The whole reducedShape must be scanned, otherwise we bail.
|
||||
|
@ -755,6 +755,34 @@ func.func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @insert_slice_cast
|
||||
func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
|
||||
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
|
||||
// CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
|
||||
// CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
|
||||
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
|
||||
// CHECK: return %[[RES]] : tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @insert_slice_cast_no_fold
|
||||
func.func @insert_slice_cast_no_fold(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
|
||||
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x5xf32>
|
||||
// CHECK: %[[CAST:.*]] = tensor.cast
|
||||
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
|
||||
// CHECK-SAME: [{{.*}}, {{.*}}] [{{.*}}, 5] [{{.*}}, {{.*}}]
|
||||
// CHECK-SAME: : tensor<?x5xf32> into tensor<?x?xf32>
|
||||
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, 5] [%arg6, %arg7] : tensor<?x5xf32> into tensor<?x?xf32>
|
||||
// CHECK: return %[[RES]] : tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
|
||||
// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
|
||||
// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
|
||||
@ -1890,21 +1918,6 @@ func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// There was an issue in cast + insert_slice folding generating invalid ir.
|
||||
// https://github.com/llvm/llvm-project/issues/53099
|
||||
// CHECK-LABEL: func @insert_slice_cast
|
||||
func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
|
||||
// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xf32> to tensor<?x?xf32>
|
||||
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
|
||||
// CHECK-SAME: : tensor<?x?xf32> into tensor<?x?xf32>
|
||||
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
|
||||
// CHECK: return %[[RES]] : tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @cast_extract_slice
|
||||
func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
|
||||
-> tensor<16x512xf32> {
|
||||
|
Loading…
x
Reference in New Issue
Block a user