[mlir] Convert expand_shape to more static form (#112265)

Add pattern that converts a `tensor.expand_shape` op to a more static
form.

This matches the pattern: `tensor.cast` -> `tensor.expand_shape` if it
has a foldable `tensor.cast` and some constant foldable `output_shape`
operands for the `tensor.expand_shape`. This makes the
`tensor.expand_shape` more static, as well as allowing the static
information to be propagated further down in the program.
This commit is contained in:
Ian Wood 2024-10-24 17:04:02 -07:00 committed by GitHub
parent 8c2e8b5124
commit 455f71d285
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 136 additions and 1 deletions

View File

@ -24,6 +24,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
@ -1982,6 +1983,86 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
return success();
}
};
/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
/// matching constant output_shape operands of the expand. This makes the
/// `tensor.expand_shape` more static and creates a consumer cast that can be
/// propagated further.
struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
if (!canFoldIntoConsumerOp(castOp))
return failure();
ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
SmallVector<ReassociationIndices, 4> reassoc =
expandOp.getReassociationIndices();
SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
SmallVector<Value> dynamicOutputShape;
auto outputIt = expandOp.getOutputShape().begin();
for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
for (uint64_t outDim : innerReassoc) {
if (!ShapedType::isDynamic(newOutputShape[outDim]))
continue;
// If the cast's src type is dynamic, don't infer any of the
// corresponding expanded dimensions. `tensor.expand_shape` requires at
// least one of the expanded dimensions to be dynamic if the input is
// dynamic.
Value val = *outputIt;
++outputIt;
if (ShapedType::isDynamic(castSrcShape[inputDim])) {
dynamicOutputShape.push_back(val);
continue;
}
APInt cst;
if (matchPattern(val, m_ConstantInt(&cst))) {
newOutputShape[outDim] = cst.getSExtValue();
} else {
dynamicOutputShape.push_back(val);
}
}
}
// Couldn't match any values, nothing to change
if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
return failure();
// Calculate the input shape from the output
SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
for (auto outDim : reassoc[inDim]) {
auto ofr = newOutputShape[outDim];
if (ShapedType::isDynamic(ofr)) {
newInputShape[inDim] = ShapedType::kDynamic;
break;
}
newInputShape[inDim] *= ofr;
}
}
SmallVector<OpFoldResult> outputOfr =
getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
auto inputType = RankedTensorType::get(
newInputShape, expandOp.getSrcType().getElementType());
auto outputType = RankedTensorType::get(
newOutputShape, expandOp.getSrcType().getElementType());
auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
expandOp.getSrc());
auto newExpand = rewriter.create<ExpandShapeOp>(
expandOp.getLoc(), outputType, inputCast.getResult(),
expandOp.getReassociationIndices(), outputOfr);
rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
newExpand.getResult());
return success();
}
};
} // namespace
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
@ -1989,7 +2070,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
FoldReshapeWithConstant<ExpandShapeOp>,
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithSplat<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
FoldDimOfCollapseShape>(context);

View File

@ -2741,3 +2741,57 @@ func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128
%pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
return %pack : tensor<128x?x100x16x1xf16>
}
// -----
func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
-> tensor<10x1x10xf32> {
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
: tensor<?x?xf32> into tensor<?x?x?xf32>
%2 = tensor.cast %1 : tensor<?x?x?xf32> to tensor<10x1x10xf32>
return %2 : tensor<10x1x10xf32>
}
// CHECK-LABEL: func.func @fold_expand_of_cast
// CHECK: %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
// CHECK: return %[[RES]]
// -----
func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
-> tensor<?x?x?xf32> {
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%0 = tensor.cast %arg0 : tensor<?x10xf32> to tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
: tensor<?x?xf32> into tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
// CHECK-LABEL: func.func @sink_expand_of_cast
// CHECK-DAG: %[[C10:.*]] = arith.constant 10
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: output_shape [%[[C10]], %[[C1]], 10]
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
// CHECK: return %[[RES]]
// -----
func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index)
-> tensor<?x?x?xf32> {
%c10 = arith.constant 10 : index
%0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10]
: tensor<?x?xf32> into tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
// CHECK-LABEL: func.func @partial_sink_expand_of_cast
// CHECK: %[[CAST:.+]] = tensor.cast
// CHECK-SAME: tensor<10x10xf32> to tensor<?x10xf32>
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: output_shape [%{{.*}}, %{{.*}}, 10]
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
// CHECK: return %[[RES]]