mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 12:26:08 +00:00
[mlir][Vector] Remove Vector{Load|Store}ToMemrefLoadLowering (#121454)
0-d vectors are supported now and so these patterns are no longer required. This covers a part of this issue https://github.com/llvm/llvm-project/issues/112913 . Additionally this removes %arg2 in mlir/test/Conversion/GPUCommon/transfer_write.mlir and renames %arg3 to %arg2 as %arg2 was originally not required.
This commit is contained in:
parent
c745ece254
commit
099fd018d1
@ -492,60 +492,6 @@ struct TransferReadToVectorLoadLowering
|
||||
std::optional<unsigned> maxTransferRank;
|
||||
};
|
||||
|
||||
/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
|
||||
// TODO: we shouldn't cross the vector/scalar domains just for this
|
||||
// but atm we lack the infra to avoid it. Possible solutions include:
|
||||
// - go directly to LLVM + bitcast
|
||||
// - introduce a bitcast op and likely a new pointer dialect
|
||||
// - let memref.load/store additionally support the 0-d vector case
|
||||
// There are still deeper data layout issues lingering even in this
|
||||
// trivial case (for architectures for which this matters).
|
||||
struct VectorLoadToMemrefLoadLowering
|
||||
: public OpRewritePattern<vector::LoadOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto vecType = loadOp.getVectorType();
|
||||
if (vecType.getNumElements() != 1)
|
||||
return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
|
||||
|
||||
auto memrefLoad = rewriter.create<memref::LoadOp>(
|
||||
loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
|
||||
memrefLoad);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
|
||||
struct VectorStoreToMemrefStoreLowering
|
||||
: public OpRewritePattern<vector::StoreOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto vecType = storeOp.getVectorType();
|
||||
if (vecType.getNumElements() != 1)
|
||||
return rewriter.notifyMatchFailure(storeOp, "not single element vector");
|
||||
|
||||
Value extracted;
|
||||
if (vecType.getRank() == 0) {
|
||||
// TODO: Unifiy once ExtractOp supports 0-d vectors.
|
||||
extracted = rewriter.create<vector::ExtractElementOp>(
|
||||
storeOp.getLoc(), storeOp.getValueToStore());
|
||||
} else {
|
||||
SmallVector<int64_t> indices(vecType.getRank(), 0);
|
||||
extracted = rewriter.create<vector::ExtractOp>(
|
||||
storeOp.getLoc(), storeOp.getValueToStore(), indices);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of transfer_write. This pattern supports lowering of
|
||||
/// `vector.transfer_write` to `vector.store` if all of the following hold:
|
||||
/// - Stride of most minor memref dimension must be 1.
|
||||
@ -645,7 +591,4 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
|
||||
patterns.add<TransferReadToVectorLoadLowering,
|
||||
TransferWriteToVectorStoreLowering>(patterns.getContext(),
|
||||
maxTransferRank, benefit);
|
||||
patterns
|
||||
.add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
|
||||
patterns.getContext(), benefit);
|
||||
}
|
||||
|
@ -1,13 +1,15 @@
|
||||
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
|
||||
|
||||
func.func @warp_extract(%arg0: index, %arg1: memref<1024x1024xf32>, %arg2: index, %arg3: vector<1xf32>) {
|
||||
// CHECK-LABEL: @warp_extract
|
||||
// CHECK-SAME: %[[VEC:[a-zA-Z0-9_]+]]: vector<1xf32>
|
||||
// CHECK:%[[BASE:[0-9]+]] = llvm.extractvalue
|
||||
// CHECK:%[[PTR:[0-9]+]] = llvm.getelementptr %[[BASE]]
|
||||
// CHECK:llvm.store %[[VEC]], %[[PTR]] {alignment = 4 : i64} : vector<1xf32>, !llvm.ptr
|
||||
|
||||
func.func @warp_extract(%arg0: index, %arg1: memref<1024x1024xf32>, %arg2: vector<1xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
gpu.warp_execute_on_lane_0(%arg0)[32] {
|
||||
// CHECK:%[[val:[0-9]+]] = llvm.extractelement
|
||||
// CHECK:%[[base:[0-9]+]] = llvm.extractvalue
|
||||
// CHECK:%[[ptr:[0-9]+]] = llvm.getelementptr %[[base]]
|
||||
// CHECK:llvm.store %[[val]], %[[ptr]]
|
||||
vector.transfer_write %arg3, %arg1[%c0, %c0] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32>
|
||||
vector.transfer_write %arg2, %arg1[%c0, %c0] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -3292,13 +3292,17 @@ func.func @load_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vec
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @load_0d
|
||||
// CHECK: %[[LOAD:.*]] = memref.load %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
// CHECK: %[[VEC:.*]] = llvm.mlir.undef : vector<1xf32>
|
||||
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[INSERTED:.*]] = llvm.insertelement %[[LOAD]], %[[VEC]][%[[C0]] : i32] : vector<1xf32>
|
||||
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[INSERTED]] : vector<1xf32> to vector<f32>
|
||||
// CHECK: return %[[CAST]] : vector<f32>
|
||||
|
||||
// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64
|
||||
// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64
|
||||
// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
|
||||
// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
|
||||
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
|
||||
// CHECK: %[[ADDR:.*]] = llvm.getelementptr %[[REF]][%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: %[[LOAD:.*]] = llvm.load %[[ADDR]] {alignment = 4 : i64} : !llvm.ptr -> vector<1xf32>
|
||||
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[LOAD]] : vector<1xf32> to vector<f32>
|
||||
// CHECK: return %[[RES]] : vector<f32>
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -3392,11 +3396,18 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @store_0d
|
||||
// CHECK: %[[VAL:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
|
||||
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[VAL]] : vector<f32> to vector<1xf32>
|
||||
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK: %[[EXTRACTED:.*]] = llvm.extractelement %[[CAST]][%[[C0]] : i64] : vector<1xf32>
|
||||
// CHECK: memref.store %[[EXTRACTED]], %{{.*}}[%{{.*}}, %{{.*}}]
|
||||
// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64
|
||||
// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64
|
||||
// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
|
||||
// CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<f32> to vector<1xf32>
|
||||
// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
|
||||
// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
|
||||
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
|
||||
// CHECK: %[[ADDR:.*]] = llvm.getelementptr %[[REF]][%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: llvm.store %[[VAL]], %[[ADDR]] {alignment = 4 : i64} : vector<1xf32>, !llvm.ptr
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -6,16 +6,13 @@
|
||||
func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf32>) {
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
|
||||
// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
|
||||
// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
|
||||
// CHECK-NEXT: %[[S:.*]] = vector.load %[[MEM]][] : memref<f32>, vector<f32>
|
||||
%0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
|
||||
|
||||
// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
|
||||
// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref<f32>
|
||||
// CHECK-NEXT: vector.store %[[S]], %[[MEM]][] : memref<f32>, vector<f32>
|
||||
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
|
||||
|
||||
// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
|
||||
// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref<f32>
|
||||
// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<1x1x1xf32>
|
||||
vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
|
||||
|
||||
return
|
||||
@ -191,8 +188,8 @@ func.func @transfer_perm_map(%mem : memref<8x8xf32>, %idx : index) -> vector<4xf
|
||||
// CHECK-LABEL: func @transfer_broadcasting(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32>
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<1xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<1xf32> to vector<4xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
@ -208,8 +205,7 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %idx : index) -> vector
|
||||
// CHECK-LABEL: func @transfer_scalar(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> {
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>, vector<1xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<1xf32>
|
||||
// CHECK-NEXT: }
|
||||
func.func @transfer_scalar(%mem : memref<?x?xf32>, %idx : index) -> vector<1xf32> {
|
||||
@ -222,8 +218,8 @@ func.func @transfer_scalar(%mem : memref<?x?xf32>, %idx : index) -> vector<1xf32
|
||||
// CHECK-LABEL: func @transfer_broadcasting_2D(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4xf32> {
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32>
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<1x1xf32>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<1x1xf32> to vector<4x4xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<4x4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
@ -322,8 +318,8 @@ func.func @transfer_read_permutations(%mem_0 : memref<?x?xf32>, %mem_1 : memref<
|
||||
// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
|
||||
|
||||
%6 = vector.transfer_read %mem_0[%c0, %c0], %cst {in_bounds = [true], permutation_map = #map6} : memref<?x?xf32>, vector<8xf32>
|
||||
// CHECK: memref.load %{{.*}}[%[[C0]], %[[C0]]] : memref<?x?xf32>
|
||||
// CHECK: vector.broadcast %{{.*}} : f32 to vector<8xf32>
|
||||
// CHECK: vector.load %{{.*}}[%[[C0]], %[[C0]]] : memref<?x?xf32>, vector<1xf32>
|
||||
// CHECK: vector.broadcast %{{.*}} : vector<1xf32> to vector<8xf32>
|
||||
|
||||
return %0, %1, %2, %3, %4, %5, %6 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
|
||||
vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
|
||||
|
Loading…
x
Reference in New Issue
Block a user