mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-21 09:56:48 +00:00
[mlir][linalg] Add a folder for transpose(fill) -> fill (#83623)
This is similar to the existing folder for a linalg.copy. Transposing a filled tensor is the same as filling the destination of the transpose.
This commit is contained in:
parent
f505a92fc2
commit
205dce6029
mlir
@ -815,6 +815,22 @@ struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold fill with transpose.
|
||||
struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
|
||||
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
|
||||
rewriter.replaceOpWithNewOp<FillOp>(
|
||||
transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
|
||||
transposeOp.getDpsInitOperand(0)->get());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
@ -823,7 +839,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
.add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
|
||||
FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
|
||||
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
|
||||
FoldInsertPadIntoFill>(context);
|
||||
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -993,6 +993,20 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @canonicalize_fill_to_transpose_input(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
|
||||
// CHECK: %[[ZERO:.+]] = arith.constant 0.0
|
||||
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>)
|
||||
func.func @canonicalize_fill_to_transpose_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
%fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%transpose = linalg.transpose ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
|
||||
return %transpose : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @broadcast_same_shape(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)
|
||||
|
Loading…
x
Reference in New Issue
Block a user