[MLIR][AMDGPU] Adding Vector transfer_read to load rewrite pattern (#131803)

This PR adds the Vector transfer_read to load rewrite pattern. The
pattern creates a transfer read op lowering. A vector trasfer read op
will be lowered to a combination of `vector.load`, `arith.select` and
`vector.broadcast` if:
 - The transfer op is masked.
 - The memref is in buffer address space.
 - Other conditions introduced from `TransferReadToVectorLoadLowering`

The motivation of this PR is due to the lack of support of masked load
from amdgpu backend. `llvm.intr.masked.load` lower to a series of
conditional scalar loads refer to (`scalarize-masked-mem-intrin` pass).
This PR will make it possible for masked transfer_read to be lowered
towards buffer load with bounds check, allowing a more optimized global
load accessing pattern compared with existing implementation of
`llvm.intr.masked.load` on vectors.
This commit is contained in:
Zhuoran Yin 2025-03-21 08:42:04 -04:00 committed by GitHub
parent 09feaa9261
commit ea03bdee70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 259 additions and 0 deletions

View File

@ -22,6 +22,7 @@ namespace amdgpu {
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
@ -30,6 +31,9 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
Chipset chipset);
void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns);
void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns);
} // namespace amdgpu
} // namespace mlir

View File

@ -51,4 +51,18 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
];
}
def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
let summary = "Lower the operations from the vector transfer_read to vector load";
let description = [{
This pass creates a transfer read op lowering. A vector trasfer read op
will be lowered to a combination of vector.load, arith.select and
vector.broadcast.
This pattern will make it possible for masked transfer_read to be lowered
towards buffer load with bounds check, allowing a more optimized global
load accessing pattern compared with existing implementation of
llvm.intr.masked.load on vectors.
}];
let dependentDialects = [];
}
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

View File

@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
ResolveStridedMetadata.cpp
TransferReadToLoad.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms

View File

@ -0,0 +1,154 @@
//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir::amdgpu {
#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
} // namespace mlir::amdgpu
using namespace mlir;
using namespace mlir::amdgpu;
/// This pattern supports lowering of:
/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
/// `vector.broadcast` if all of the following hold:
/// - The transfer op is masked.
/// - The memref is in buffer address space.
/// - Stride of most minor memref dimension must be 1.
/// - Out-of-bounds masking is not required.
/// - If the memref's element type is a vector type then it coincides with the
/// result type.
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
/// pass.
static LogicalResult transferPreconditions(
PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
if (!xferOp.getMask())
return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
// We let the 0-d corner case pass-through as it is supported.
SmallVector<unsigned> broadcastedDims;
if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
&broadcastedDims))
return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
if (!memRefType)
return rewriter.notifyMatchFailure(xferOp, "not a memref source");
Attribute addrSpace = memRefType.getMemorySpace();
if (!addrSpace || !dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace))
return rewriter.notifyMatchFailure(xferOp, "no address space");
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
amdgpu::AddressSpace::FatRawBuffer)
return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
// Non-unit strides are handled by VectorToSCF.
if (!memRefType.isLastDimUnitStride())
return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
// If there is broadcasting involved then we first load the unbroadcasted
// vector, and then broadcast it with `vector.broadcast`.
ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
for (unsigned i : broadcastedDims)
unbroadcastedVectorShape[i] = 1;
unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
requiresBroadcasting = !broadcastedDims.empty();
// `vector.load` supports vector types as memref's elements only when the
// resulting vector type is the same as the element type.
auto memrefElTy = memRefType.getElementType();
if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
// Otherwise, element types of the memref and the vector must match.
if (!isa<VectorType>(memrefElTy) &&
memrefElTy != xferOp.getVectorType().getElementType())
return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
// Out-of-bounds dims are handled by MaterializeTransferMask.
if (xferOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
if (xferOp.getVectorType().getRank() != 1)
// vector.maskedload operates on 1-D vectors.
return rewriter.notifyMatchFailure(
xferOp, "vector type is not rank 1, can't create masked load, needs "
"VectorToSCF");
return success();
}
namespace {
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
bool requiresBroadcasting = false;
VectorType unbroadcastedVectorType;
if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
unbroadcastedVectorType))) {
return failure();
}
Location loc = readOp.getLoc();
Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
readOp.getPadding());
Value load = rewriter.create<vector::LoadOp>(
loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
readOp.getMask(), load, fill);
// Insert a broadcasting op if required.
if (requiresBroadcasting) {
res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
res);
}
rewriter.replaceOp(readOp, res);
return success();
}
};
} // namespace
void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
RewritePatternSet &patterns) {
patterns.add<TransferReadLowering>(patterns.getContext());
}
struct AmdgpuTransferReadToLoadPass final
: amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
AmdgpuTransferReadToLoadPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateAmdgpuTransferReadToLoadPatterns(patterns);
walkAndApplyPatterns(getOperation(), std::move(patterns));
}
};

View File

@ -0,0 +1,86 @@
// RUN: mlir-opt %s --amdgpu-transfer-read-to-load --split-input-file | FileCheck %s
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
%cf0 = arith.constant 0.0 : f32
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
return %res : vector<4xf32>
}
// CHECK: %[[CST:.*]] = arith.constant 0.0
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
// CHECK: return %[[SELECT]] : vector<4xf32>
// -----
// CHECK-LABEL: func @transfer_to_maskedload_regular(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
%cf0 = arith.constant 0.0 : f32
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
return %res : vector<4xf32>
}
// CHECK: %[[CST:.*]] = arith.constant 0.0
// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
// CHECK: return %[[RES]] : vector<4xf32>
// -----
// CHECK-LABEL: func @transfer_to_maskedload_addrspace(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #gpu.address_space<workgroup>>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_space<workgroup>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
%cf0 = arith.constant 0.0 : f32
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
return %res : vector<4xf32>
}
// CHECK: %[[CST:.*]] = arith.constant 0.0
// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
// CHECK: return %[[RES]] : vector<4xf32>
// -----
// CHECK-LABEL: func @transfer_broadcasting(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
#broadcast_1d = affine_map<(d0, d1) -> (0)>
func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<4xf32> {
%cf0 = arith.constant 0.0 : f32
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
{in_bounds = [true], permutation_map = #broadcast_1d}
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
return %res : vector<4xf32>
}
// CHECK: %[[CST:.*]] = arith.constant 0.0
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
// CHECK: return %[[BROADCAST]] : vector<4xf32>
// -----
// CHECK-LABEL: func @transfer_scalar(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<1xf32> {
%cf0 = arith.constant 0.0 : f32
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
{in_bounds = [true]}
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
return %res : vector<1xf32>
}
// CHECK: %[[CST:.*]] = arith.constant 0.0
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
// CHECK: return %[[SELECT]] : vector<1xf32>