[MLIR][AMDGPU] Fixing word alignment check for bufferload fastpath (#135982)

`delta_bytes % (32 ceilDiv elementBitwidth) != 0` condition is incorrect
in https://github.com/llvm/llvm-project/pull/135014

For example, last load is issued to load only one last element of fp16.
Then `delta bytes = 2`, `(32 ceildiv 16) = 2`. In this case it will be
judged as word aligned. It will send to fast path but get all zeros for
the fp16 because it cross the word boundary.

In reality the equation should be just `delta_bytes % 4` , since a word
is 4 bytes. This PR fix the bug by amending the mod target to 4.
This commit is contained in:
Zhuoran Yin 2025-04-17 08:50:31 -04:00 committed by GitHub
parent 5a993558c5
commit 47f4f39265
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 17 deletions

View File

@ -11,12 +11,10 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
@ -225,15 +223,12 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
// 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
Value deltaBytes = rewriter.create<arith::MulIOp>(
loc, delta,
rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
// 2) check if (detla % elements_per_word != 0)
Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
loc, llvm::divideCeil(32, elementBitWidth));
Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne,
rewriter.create<arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
rewriter.create<arith::ConstantIndexOp>(loc, 0));
// We take the fallback of transfer_read default lowering only it is both

View File

@ -10,8 +10,7 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
return %res : vector<4xf32>
}
// CHECK: %[[FALSE:.*]] = arith.constant false
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
// CHECK: %[[IF:.*]] = scf.if
// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]]
// CHECK: } else {
@ -35,14 +34,13 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
// CHECK-DAG: %[[SIZE:.*]] = arith.constant 64
// CHECK-DAG: %[[BYTES:.*]] = arith.constant 2
// CHECK-DAG: %[[VECTORSIZE:.*]] = arith.constant 4
// CHECK-DAG: %[[C4:.*]] = arith.constant 4
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[ARG2]]]
// CHECK: %[[DELTA:.*]] = arith.subi %[[SIZE]], %[[LINEAR]]
// CHECK: %[[COND1:.*]] = arith.cmpi ult, %[[DELTA]], %[[VECTORSIZE]]
// CHECK: %[[COND1:.*]] = arith.cmpi ult, %[[DELTA]], %[[C4]]
// CHECK: %[[DELTABYTES:.*]] = arith.muli %[[DELTA]], %[[BYTES]]
// CHECK: %[[REM:.*]] = arith.remui %[[DELTABYTES]], %[[BYTES]]
// CHECK: %[[REM:.*]] = arith.remui %[[DELTA]], %[[BYTES]]
// CHECK: %[[COND2:.*]] = arith.cmpi ne, %[[REM]], %[[C0]]
// CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]]
@ -120,8 +118,7 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
return %res : vector<4xf32>
}
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
// CHECK: %[[FALSE:.*]] = arith.constant false
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
// CHECK: %[[IF:.*]] = scf.if
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
@ -140,7 +137,6 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
return %res : vector<1xf32>
}
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
// CHECK: %[[FALSE:.*]] = arith.constant false
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<1xf32>) {
// CHECK: %[[IF:.*]] = scf.if
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]

View File

@ -1559,6 +1559,7 @@ cc_library(
hdrs = glob(["include/mlir/Dialect/AMDGPU/Transforms/*.h"]),
includes = ["include"],
deps = [
":AffineDialect",
":AMDGPUDialect",
":AMDGPUPassIncGen",
":AMDGPUUtils",
@ -1569,6 +1570,7 @@ cc_library(
":FuncDialect",
":GPUDialect",
":IR",
":LLVMSupportHeaders",
":MemRefDialect",
":MemRefUtils",
":Pass",