[mlir][utils] Fix invalid reshapes in ComposeCollapseOfExpandOp

Do not generate CollapseShapeOps/ExpandShapeOps that have the same source and result shape. Generate casts instead. Such reshapes became invalid with D138498.

Differential Revision: https://reviews.llvm.org/D138557
This commit is contained in:
Matthias Springer 2022-11-23 11:56:07 +01:00
parent 19ab2a671e
commit f2d91a7ae1
5 changed files with 42 additions and 7 deletions

View File

@ -225,7 +225,7 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
//
/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
/// `reassociation_2` and produce `expand_shape`.
template <typename CollapseOpTy, typename ExpandOpTy>
template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@ -250,8 +250,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
SmallVector<ReassociationIndices, 4> higherRankReassociation,
lowerRankReassociation;
bool isResultCollapsed = srcRank > resultRank;
if (isResultCollapsed) {
if (srcRank > resultRank) {
higherRankReassociation = expandOp.getReassociationIndices();
lowerRankReassociation = collapseOp.getReassociationIndices();
} else {
@ -274,12 +273,20 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
}
composedReassociation.push_back(composedIndices);
}
if (isResultCollapsed)
if (srcRank > resultRank) {
rewriter.replaceOpWithNewOp<CollapseOpTy>(
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
else
} else if (srcRank < resultRank) {
rewriter.replaceOpWithNewOp<ExpandOpTy>(
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
} else {
// Collapses/expansions that do not change the rank are not allowed. Use
// a cast instead.
assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
"expected same shape");
rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType,
expandOp.getSrc());
}
return success();
}
};

View File

@ -2447,7 +2447,7 @@ public:
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
CollapseShapeOpMemRefCastFolder>(context);
}

View File

@ -1586,7 +1586,7 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
FoldReshapeWithConstant<CollapseShapeOp>,
FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
context);

View File

@ -859,3 +859,19 @@ func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{
memref.store %v, %0[%i2] : memref<4xf32>
return %src : memref<2xf32>
}
// -----
// CHECK-LABEL: func @collapse_expand_fold_to_cast(
// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[1]>, 3>
// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<?xf32, strided<[1]>, 3> to memref<?xf32, 3
// CHECK: return %[[casted]]
func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>)
-> (memref<?xf32, 3>)
{
%0 = memref.expand_shape %m [[0, 1]]
: memref<?xf32, strided<[1]>, 3> into memref<1x?xf32, 3>
%1 = memref.collapse_shape %0 [[0, 1]]
: memref<1x?xf32, 3> into memref<?xf32, 3>
return %1 : memref<?xf32, 3>
}

View File

@ -1666,3 +1666,15 @@ func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
%1 = tensor.dim %0, %c1 : tensor<?x?xf32>
return %1 : index
}
// -----
// CHECK-LABEL: func @collapse_expand_fold_to_cast(
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
// CHECK: return %[[t]]
func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>) -> (tensor<?xf32>)
{
%0 = tensor.expand_shape %t [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
%1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
return %1 : tensor<?xf32>
}