mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-28 04:46:07 +00:00
[mlir][utils] Fix invalid reshapes in ComposeCollapseOfExpandOp
Do not generate CollapseShapeOps/ExpandShapeOps that have the same source and result shape. Generate casts instead. Such reshapes became invalid with D138498. Differential Revision: https://reviews.llvm.org/D138557
This commit is contained in:
parent
19ab2a671e
commit
f2d91a7ae1
@ -225,7 +225,7 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
|
||||
//
|
||||
/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
|
||||
/// `reassociation_2` and produce `expand_shape`.
|
||||
template <typename CollapseOpTy, typename ExpandOpTy>
|
||||
template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
|
||||
struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
|
||||
using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
|
||||
@ -250,8 +250,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
|
||||
SmallVector<ReassociationIndices, 4> higherRankReassociation,
|
||||
lowerRankReassociation;
|
||||
|
||||
bool isResultCollapsed = srcRank > resultRank;
|
||||
if (isResultCollapsed) {
|
||||
if (srcRank > resultRank) {
|
||||
higherRankReassociation = expandOp.getReassociationIndices();
|
||||
lowerRankReassociation = collapseOp.getReassociationIndices();
|
||||
} else {
|
||||
@ -274,12 +273,20 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
|
||||
}
|
||||
composedReassociation.push_back(composedIndices);
|
||||
}
|
||||
if (isResultCollapsed)
|
||||
if (srcRank > resultRank) {
|
||||
rewriter.replaceOpWithNewOp<CollapseOpTy>(
|
||||
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
|
||||
else
|
||||
} else if (srcRank < resultRank) {
|
||||
rewriter.replaceOpWithNewOp<ExpandOpTy>(
|
||||
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
|
||||
} else {
|
||||
// Collapses/expansions that do not change the rank are not allowed. Use
|
||||
// a cast instead.
|
||||
assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
|
||||
"expected same shape");
|
||||
rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType,
|
||||
expandOp.getSrc());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2447,7 +2447,7 @@ public:
|
||||
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
|
||||
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
|
||||
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
|
||||
CollapseShapeOpMemRefCastFolder>(context);
|
||||
}
|
||||
|
||||
|
@ -1586,7 +1586,7 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results
|
||||
.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
|
||||
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
|
||||
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
|
||||
FoldReshapeWithConstant<CollapseShapeOp>,
|
||||
FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
|
||||
context);
|
||||
|
@ -859,3 +859,19 @@ func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{
|
||||
memref.store %v, %0[%i2] : memref<4xf32>
|
||||
return %src : memref<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @collapse_expand_fold_to_cast(
|
||||
// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[1]>, 3>
|
||||
// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<?xf32, strided<[1]>, 3> to memref<?xf32, 3
|
||||
// CHECK: return %[[casted]]
|
||||
func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>)
|
||||
-> (memref<?xf32, 3>)
|
||||
{
|
||||
%0 = memref.expand_shape %m [[0, 1]]
|
||||
: memref<?xf32, strided<[1]>, 3> into memref<1x?xf32, 3>
|
||||
%1 = memref.collapse_shape %0 [[0, 1]]
|
||||
: memref<1x?xf32, 3> into memref<?xf32, 3>
|
||||
return %1 : memref<?xf32, 3>
|
||||
}
|
||||
|
@ -1666,3 +1666,15 @@ func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
|
||||
%1 = tensor.dim %0, %c1 : tensor<?x?xf32>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @collapse_expand_fold_to_cast(
|
||||
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
|
||||
// CHECK: return %[[t]]
|
||||
func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>) -> (tensor<?xf32>)
|
||||
{
|
||||
%0 = tensor.expand_shape %t [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
|
||||
%1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user