mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-16 16:56:35 +00:00
[MLIR][AMDGPU] Adding Vector transfer_read to load rewrite pattern (#131803)
This PR adds the Vector transfer_read to load rewrite pattern. The pattern creates a transfer read op lowering. A vector trasfer read op will be lowered to a combination of `vector.load`, `arith.select` and `vector.broadcast` if: - The transfer op is masked. - The memref is in buffer address space. - Other conditions introduced from `TransferReadToVectorLoadLowering` The motivation of this PR is due to the lack of support of masked load from amdgpu backend. `llvm.intr.masked.load` lower to a series of conditional scalar loads refer to (`scalarize-masked-mem-intrin` pass). This PR will make it possible for masked transfer_read to be lowered towards buffer load with bounds check, allowing a more optimized global load accessing pattern compared with existing implementation of `llvm.intr.masked.load` on vectors.
This commit is contained in:
parent
09feaa9261
commit
ea03bdee70
@ -22,6 +22,7 @@ namespace amdgpu {
|
||||
|
||||
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
|
||||
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
|
||||
#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
|
||||
|
||||
@ -30,6 +31,9 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
|
||||
Chipset chipset);
|
||||
|
||||
void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns);
|
||||
|
||||
void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns);
|
||||
|
||||
} // namespace amdgpu
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -51,4 +51,18 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
|
||||
];
|
||||
}
|
||||
|
||||
def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
|
||||
let summary = "Lower the operations from the vector transfer_read to vector load";
|
||||
let description = [{
|
||||
This pass creates a transfer read op lowering. A vector trasfer read op
|
||||
will be lowered to a combination of vector.load, arith.select and
|
||||
vector.broadcast.
|
||||
|
||||
This pattern will make it possible for masked transfer_read to be lowered
|
||||
towards buffer load with bounds check, allowing a more optimized global
|
||||
load accessing pattern compared with existing implementation of
|
||||
llvm.intr.masked.load on vectors.
|
||||
}];
|
||||
let dependentDialects = [];
|
||||
}
|
||||
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
|
||||
|
@ -1,6 +1,7 @@
|
||||
add_mlir_dialect_library(MLIRAMDGPUTransforms
|
||||
EmulateAtomics.cpp
|
||||
ResolveStridedMetadata.cpp
|
||||
TransferReadToLoad.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
|
||||
|
154
mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
Normal file
154
mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
Normal file
@ -0,0 +1,154 @@
|
||||
//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
namespace mlir::amdgpu {
|
||||
#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
|
||||
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
|
||||
} // namespace mlir::amdgpu
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::amdgpu;
|
||||
|
||||
/// This pattern supports lowering of:
|
||||
/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
|
||||
/// `vector.broadcast` if all of the following hold:
|
||||
/// - The transfer op is masked.
|
||||
/// - The memref is in buffer address space.
|
||||
/// - Stride of most minor memref dimension must be 1.
|
||||
/// - Out-of-bounds masking is not required.
|
||||
/// - If the memref's element type is a vector type then it coincides with the
|
||||
/// result type.
|
||||
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
|
||||
/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
|
||||
/// pass.
|
||||
static LogicalResult transferPreconditions(
|
||||
PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
|
||||
bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
|
||||
if (!xferOp.getMask())
|
||||
return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
|
||||
|
||||
// Permutations are handled by VectorToSCF or
|
||||
// populateVectorTransferPermutationMapLoweringPatterns.
|
||||
// We let the 0-d corner case pass-through as it is supported.
|
||||
SmallVector<unsigned> broadcastedDims;
|
||||
if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
|
||||
&broadcastedDims))
|
||||
return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
|
||||
|
||||
auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
|
||||
if (!memRefType)
|
||||
return rewriter.notifyMatchFailure(xferOp, "not a memref source");
|
||||
|
||||
Attribute addrSpace = memRefType.getMemorySpace();
|
||||
if (!addrSpace || !dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace))
|
||||
return rewriter.notifyMatchFailure(xferOp, "no address space");
|
||||
|
||||
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
|
||||
amdgpu::AddressSpace::FatRawBuffer)
|
||||
return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
|
||||
|
||||
// Non-unit strides are handled by VectorToSCF.
|
||||
if (!memRefType.isLastDimUnitStride())
|
||||
return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
|
||||
|
||||
// If there is broadcasting involved then we first load the unbroadcasted
|
||||
// vector, and then broadcast it with `vector.broadcast`.
|
||||
ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
|
||||
SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
|
||||
for (unsigned i : broadcastedDims)
|
||||
unbroadcastedVectorShape[i] = 1;
|
||||
unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
|
||||
unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
|
||||
requiresBroadcasting = !broadcastedDims.empty();
|
||||
|
||||
// `vector.load` supports vector types as memref's elements only when the
|
||||
// resulting vector type is the same as the element type.
|
||||
auto memrefElTy = memRefType.getElementType();
|
||||
if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
|
||||
return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
|
||||
|
||||
// Otherwise, element types of the memref and the vector must match.
|
||||
if (!isa<VectorType>(memrefElTy) &&
|
||||
memrefElTy != xferOp.getVectorType().getElementType())
|
||||
return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
|
||||
|
||||
// Out-of-bounds dims are handled by MaterializeTransferMask.
|
||||
if (xferOp.hasOutOfBoundsDim())
|
||||
return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
|
||||
|
||||
if (xferOp.getVectorType().getRank() != 1)
|
||||
// vector.maskedload operates on 1-D vectors.
|
||||
return rewriter.notifyMatchFailure(
|
||||
xferOp, "vector type is not rank 1, can't create masked load, needs "
|
||||
"VectorToSCF");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
bool requiresBroadcasting = false;
|
||||
VectorType unbroadcastedVectorType;
|
||||
if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
|
||||
unbroadcastedVectorType))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = readOp.getLoc();
|
||||
Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
|
||||
readOp.getPadding());
|
||||
Value load = rewriter.create<vector::LoadOp>(
|
||||
loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
|
||||
Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
|
||||
readOp.getMask(), load, fill);
|
||||
|
||||
// Insert a broadcasting op if required.
|
||||
if (requiresBroadcasting) {
|
||||
res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
|
||||
res);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(readOp, res);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<TransferReadLowering>(patterns.getContext());
|
||||
}
|
||||
|
||||
struct AmdgpuTransferReadToLoadPass final
|
||||
: amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
|
||||
AmdgpuTransferReadToLoadPass> {
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateAmdgpuTransferReadToLoadPatterns(patterns);
|
||||
walkAndApplyPatterns(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
86
mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
Normal file
86
mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
Normal file
@ -0,0 +1,86 @@
|
||||
// RUN: mlir-opt %s --amdgpu-transfer-read-to-load --split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: index
|
||||
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
|
||||
func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
|
||||
%cf0 = arith.constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
|
||||
return %res : vector<4xf32>
|
||||
}
|
||||
// CHECK: %[[CST:.*]] = arith.constant 0.0
|
||||
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
|
||||
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
|
||||
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
|
||||
// CHECK: return %[[SELECT]] : vector<4xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_to_maskedload_regular(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: index
|
||||
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
|
||||
func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
|
||||
%cf0 = arith.constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
|
||||
return %res : vector<4xf32>
|
||||
}
|
||||
// CHECK: %[[CST:.*]] = arith.constant 0.0
|
||||
// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
|
||||
// CHECK: return %[[RES]] : vector<4xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_to_maskedload_addrspace(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #gpu.address_space<workgroup>>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: index
|
||||
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
|
||||
func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_space<workgroup>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
|
||||
%cf0 = arith.constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
|
||||
return %res : vector<4xf32>
|
||||
}
|
||||
// CHECK: %[[CST:.*]] = arith.constant 0.0
|
||||
// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
|
||||
// CHECK: return %[[RES]] : vector<4xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_broadcasting(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: index
|
||||
// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
|
||||
#broadcast_1d = affine_map<(d0, d1) -> (0)>
|
||||
func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<4xf32> {
|
||||
%cf0 = arith.constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
|
||||
{in_bounds = [true], permutation_map = #broadcast_1d}
|
||||
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
|
||||
return %res : vector<4xf32>
|
||||
}
|
||||
// CHECK: %[[CST:.*]] = arith.constant 0.0
|
||||
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
|
||||
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
|
||||
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
|
||||
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
|
||||
// CHECK: return %[[BROADCAST]] : vector<4xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_scalar(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: index
|
||||
// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
|
||||
func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<1xf32> {
|
||||
%cf0 = arith.constant 0.0 : f32
|
||||
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
|
||||
{in_bounds = [true]}
|
||||
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
|
||||
return %res : vector<1xf32>
|
||||
}
|
||||
// CHECK: %[[CST:.*]] = arith.constant 0.0
|
||||
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
|
||||
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
|
||||
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
|
||||
// CHECK: return %[[SELECT]] : vector<1xf32>
|
Loading…
x
Reference in New Issue
Block a user