0
0
mirror of https://github.com/llvm/llvm-project.git synced 2025-04-21 09:56:48 +00:00

[mlir][linalg] Add a folder for transpose(fill) -> fill ()

This is similar to the existing folder for a linalg.copy. Transposing a
filled tensor is the same as filling the destination of the transpose.
This commit is contained in:
Quinn Dawkins 2024-03-02 17:47:16 -05:00 committed by GitHub
parent f505a92fc2
commit 205dce6029
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 1 deletions
mlir
lib/Dialect/Linalg/IR
test/Dialect/Linalg

@ -815,6 +815,22 @@ struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
}
};
/// Fold fill with transpose.
struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
rewriter.replaceOpWithNewOp<FillOp>(
transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
transposeOp.getDpsInitOperand(0)->get());
return success();
}
return failure();
}
};
} // namespace
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
@ -823,7 +839,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
.add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
FoldInsertPadIntoFill>(context);
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
}
//===----------------------------------------------------------------------===//

@ -993,6 +993,20 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
// -----
// CHECK-LABEL: func @canonicalize_fill_to_transpose_input(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK: %[[ZERO:.+]] = arith.constant 0.0
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>)
func.func @canonicalize_fill_to_transpose_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0.0 : f32
%fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
%transpose = linalg.transpose ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
return %transpose : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @broadcast_same_shape(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)