mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-28 05:56:05 +00:00
[mlir][memref] Fold copy of cast
If the source/dest is a cast that does not change shape/element type, the cast can be skipped. Differential Revision: https://reviews.llvm.org/D117215
This commit is contained in:
parent
c86a982d7d
commit
96acdfa0de
@ -406,6 +406,7 @@ def CopyOp : MemRef_Op<"copy",
|
||||
$source `,` $target attr-dict `:` type($source) `to` type($target)
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
|
@ -438,6 +438,61 @@ OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CopyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// If the source/target of a CopyOp is a CastOp that does not modify the shape
|
||||
/// and element type, the cast can be skipped. Such CastOps only cast the layout
|
||||
/// of the type.
|
||||
struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
|
||||
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CopyOp copyOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
bool modified = false;
|
||||
|
||||
// Check source.
|
||||
if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) {
|
||||
auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
auto toType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
|
||||
if (fromType && toType) {
|
||||
if (fromType.getShape() == toType.getShape() &&
|
||||
fromType.getElementType() == toType.getElementType()) {
|
||||
rewriter.updateRootInPlace(
|
||||
copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); });
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check target.
|
||||
if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) {
|
||||
auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
auto toType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
|
||||
if (fromType && toType) {
|
||||
if (fromType.getShape() == toType.getShape() &&
|
||||
fromType.getElementType() == toType.getElementType()) {
|
||||
rewriter.updateRootInPlace(
|
||||
copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); });
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success(modified);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<FoldCopyOfCast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DeallocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -31,7 +31,7 @@ func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
|
||||
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
|
||||
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
|
||||
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
|
||||
// CHECK: memref.copy %[[A_memref]], %[[casted]]
|
||||
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
|
||||
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
|
||||
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
|
||||
|
||||
@ -95,4 +95,4 @@ func @rank_reducing(
|
||||
scf.yield %10 : tensor<?x1x6x8xf32>
|
||||
}
|
||||
return %5: tensor<?x1x6x8xf32>
|
||||
}
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
|
||||
// CHECK-TENSOR: %[[alloc:.*]] = memref.alloc
|
||||
// CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]]
|
||||
// CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
|
||||
// CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]]
|
||||
// CHECK-TENSOR: memref.copy %[[t1_memref]], %[[alloc]]
|
||||
// CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]]
|
||||
%0 = tensor.insert %f into %t1[%c0] : tensor<?xf32>
|
||||
// CHECK-TENSOR: return %[[casted_tensor]]
|
||||
@ -177,7 +177,7 @@ func @simple_scf_for(
|
||||
// CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
|
||||
// CHECK-SCF: %[[alloc:.*]] = memref.alloc
|
||||
// CHECK-SCF: %[[casted:.*]] = memref.cast %[[alloc]]
|
||||
// CHECK-SCF: memref.copy %[[t1_memref]], %[[casted]]
|
||||
// CHECK-SCF: memref.copy %[[t1_memref]], %[[alloc]]
|
||||
// CHECK-SCF: %[[scf_for:.*]] = scf.for %[[iv:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[arg0:.*]] = %[[casted]]) -> ({{.*}}) {
|
||||
%0 = scf.for %iv = %c0 to %sz step %step iter_args(%arg0 = %t1) -> tensor<?xf32> {
|
||||
// CHECK-SCF: %[[arg0_tensor:.*]] = bufferization.to_tensor %[[arg0]]
|
||||
|
@ -510,3 +510,18 @@ func @atomicrmw_cast_fold(%arg0 : f32, %arg1 : memref<4xf32>, %c : index) {
|
||||
|
||||
// CHECK-LABEL: func @atomicrmw_cast_fold
|
||||
// CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
|
||||
func @copy_of_cast(%m1: memref<?xf32>, %m2: memref<*xf32>) {
|
||||
%casted1 = memref.cast %m1 : memref<?xf32> to memref<?xf32, #map>
|
||||
%casted2 = memref.cast %m2 : memref<*xf32> to memref<?xf32, #map>
|
||||
memref.copy %casted1, %casted2 : memref<?xf32, #map> to memref<?xf32, #map>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @copy_of_cast(
|
||||
// CHECK-SAME: %[[m1:.*]]: memref<?xf32>, %[[m2:.*]]: memref<*xf32>
|
||||
// CHECK: %[[casted2:.*]] = memref.cast %[[m2]]
|
||||
// CHECK: memref.copy %[[m1]], %[[casted2]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user