mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-29 19:26:05 +00:00
[mlir][tensor] Fold tensor.cast
into tensor.collapse_shape
op
This commit folds a `tensor.cast` op into a `tensor.collapse_shape` op when following two conditions meet: 1. the `tensor.collapse_shape` op consumes result of the `tensor.cast` op. 2. `tensor.cast` op casts to a more dynamic version of the source tensor. This is added as a canonicalization pattern in `tensor.collapse_shape` op. Signed-Off-By: Gaurav Shukla <gaurav@nod-labs.com> Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D130650
This commit is contained in:
parent
8a61749f76
commit
7d6ef5caef
@ -928,6 +928,36 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Fold CastOp into CollapseShapeOp when adding static information.
|
||||||
|
struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
|
||||||
|
using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
|
||||||
|
if (!tensor::canFoldIntoConsumerOp(castOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
RankedTensorType srcType =
|
||||||
|
castOp.getSource().getType().cast<RankedTensorType>();
|
||||||
|
RankedTensorType newResultType = computeTensorReshapeCollapsedType(
|
||||||
|
srcType, collapseShapeOp.getReassociationMaps());
|
||||||
|
|
||||||
|
if (newResultType == collapseShapeOp.getResultType()) {
|
||||||
|
rewriter.updateRootInPlace(collapseShapeOp, [&]() {
|
||||||
|
collapseShapeOp.getSrcMutable().assign(castOp.getSource());
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto newOp = rewriter.create<CollapseShapeOp>(
|
||||||
|
collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
|
||||||
|
collapseShapeOp.getReassociation());
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
|
collapseShapeOp, collapseShapeOp.getResultType(), newOp);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
@ -940,10 +970,12 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||||||
|
|
||||||
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
|
results
|
||||||
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
|
.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
|
||||||
FoldReshapeWithConstant<CollapseShapeOp>,
|
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
|
||||||
FoldReshapeWithFromElements<CollapseShapeOp>>(context);
|
FoldReshapeWithConstant<CollapseShapeOp>,
|
||||||
|
FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
@ -673,6 +673,20 @@ func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @collapse_of_cast(
|
||||||
|
// CHECK-SAME: %[[IN:.*]]: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
|
||||||
|
// CHECK-NEXT: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1], [2]] : tensor<8x12x32xf32> into tensor<96x32xf32>
|
||||||
|
// CHECK-NEXT %[[CAST:.*]] = tensor.cast %[[COLLAPSE]] : tensor<96x32xf32> to tensor<?x32xf32>
|
||||||
|
// CHECK-NEXT return %[[CAST]] : tensor<?x32xf32>
|
||||||
|
func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
|
||||||
|
%0 = tensor.cast %t : tensor<8x12x32xf32> to tensor<?x?x?xf32>
|
||||||
|
%1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||||
|
%2 = tensor.cast %1 : tensor<?x?xf32> to tensor<?x32xf32>
|
||||||
|
return %2 : tensor<?x32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
|
func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
|
||||||
%0 = tensor.expand_shape %arg0 [[0, 1], [2]]
|
%0 = tensor.expand_shape %arg0 [[0, 1], [2]]
|
||||||
: tensor<12x4xf32> into tensor<3x4x4xf32>
|
: tensor<12x4xf32> into tensor<3x4x4xf32>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user