AMDGPU: Convert vector 64-bit shl to 32-bit if shift amt >= 32 (#132964)

Convert vector 64-bit shl to 32-bit if shift amt is known to be >= 32.

---------

Signed-off-by: John Lu <John.Lu@amd.com>
This commit is contained in:
LU-JOHN 2025-03-28 11:46:35 -05:00 committed by GitHub
parent 6b1acdb818
commit 827f2ad643
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 25 deletions

View File

@ -4084,7 +4084,7 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
}
}
if (VT != MVT::i64)
if (VT.getScalarType() != MVT::i64)
return SDValue();
// i64 (shl x, C) -> (build_pair 0, (shl x, C -32))
@ -4092,21 +4092,24 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
// common case, splitting this into a move and a 32-bit shift is faster and
// the same code size.
EVT TargetType = VT.getHalfSizedIntegerVT(*DAG.getContext());
EVT TargetVecPairType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
KnownBits Known = DAG.computeKnownBits(RHS);
if (Known.getMinValue().getZExtValue() < TargetType.getSizeInBits())
EVT ElementType = VT.getScalarType();
EVT TargetScalarType = ElementType.getHalfSizedIntegerVT(*DAG.getContext());
EVT TargetType = VT.isVector() ? VT.changeVectorElementType(TargetScalarType)
: TargetScalarType;
if (Known.getMinValue().getZExtValue() < TargetScalarType.getSizeInBits())
return SDValue();
SDValue ShiftAmt;
if (CRHS) {
ShiftAmt =
DAG.getConstant(RHSVal - TargetType.getSizeInBits(), SL, TargetType);
ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
TargetType);
} else {
SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
const SDValue ShiftMask =
DAG.getConstant(TargetType.getSizeInBits() - 1, SL, TargetType);
DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
// This AND instruction will clamp out of bounds shift values.
// It will also be removed during later instruction selection.
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
@ -4116,9 +4119,23 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
SDValue NewShift =
DAG.getNode(ISD::SHL, SL, TargetType, Lo, ShiftAmt, N->getFlags());
const SDValue Zero = DAG.getConstant(0, SL, TargetType);
const SDValue Zero = DAG.getConstant(0, SL, TargetScalarType);
SDValue Vec;
SDValue Vec = DAG.getBuildVector(TargetVecPairType, SL, {Zero, NewShift});
if (VT.isVector()) {
EVT ConcatType = TargetType.getDoubleNumVectorElementsVT(*DAG.getContext());
unsigned NElts = TargetType.getVectorNumElements();
SmallVector<SDValue, 8> HiOps;
SmallVector<SDValue, 16> HiAndLoOps(NElts * 2, Zero);
DAG.ExtractVectorElements(NewShift, HiOps, 0, NElts);
for (unsigned I = 0; I != NElts; ++I)
HiAndLoOps[2 * I + 1] = HiOps[I];
Vec = DAG.getNode(ISD::BUILD_VECTOR, SL, ConcatType, HiAndLoOps);
} else {
EVT ConcatType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
Vec = DAG.getBuildVector(ConcatType, SL, {Zero, NewShift});
}
return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
}
@ -5182,9 +5199,14 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case ISD::SHL: {
if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
// Range metadata can be invalidated when loads are converted to legal types
// (e.g. v2i64 -> v4i32).
// Try to convert vector shl before type legalization so that range metadata
// can be utilized.
if (!(N->getValueType(0).isVector() &&
DCI.getDAGCombineLevel() == BeforeLegalizeTypes) &&
DCI.getDAGCombineLevel() < AfterLegalizeDAG)
break;
return performShlCombine(N, DCI);
}
case ISD::SRL: {

View File

@ -72,39 +72,41 @@ define i64 @shl_metadata_cant_be_narrowed_to_i32(i64 %arg0, ptr %arg1.ptr) {
ret i64 %shl
}
; FIXME: This case should be reduced
define <2 x i64> @shl_v2_metadata(<2 x i64> %arg0, ptr %arg1.ptr) {
; CHECK-LABEL: shl_v2_metadata:
; CHECK: ; %bb.0:
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; CHECK-NEXT: flat_load_dwordx4 v[4:7], v[4:5]
; CHECK-NEXT: flat_load_dwordx4 v[3:6], v[4:5]
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
; CHECK-NEXT: v_lshlrev_b64 v[0:1], v4, v[0:1]
; CHECK-NEXT: v_lshlrev_b64 v[2:3], v6, v[2:3]
; CHECK-NEXT: v_lshlrev_b32_e32 v1, v3, v0
; CHECK-NEXT: v_lshlrev_b32_e32 v3, v5, v2
; CHECK-NEXT: v_mov_b32_e32 v0, 0
; CHECK-NEXT: v_mov_b32_e32 v2, 0
; CHECK-NEXT: s_setpc_b64 s[30:31]
%shift.amt = load <2 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
%shl = shl <2 x i64> %arg0, %shift.amt
ret <2 x i64> %shl
}
; FIXME: This case should be reduced
define <3 x i64> @shl_v3_metadata(<3 x i64> %arg0, ptr %arg1.ptr) {
; CHECK-LABEL: shl_v3_metadata:
; CHECK: ; %bb.0:
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; CHECK-NEXT: flat_load_dword v12, v[6:7] offset:16
; CHECK-NEXT: flat_load_dword v1, v[6:7] offset:16
; CHECK-NEXT: flat_load_dwordx4 v[8:11], v[6:7]
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
; CHECK-NEXT: v_lshlrev_b64 v[4:5], v12, v[4:5]
; CHECK-NEXT: v_lshlrev_b64 v[0:1], v8, v[0:1]
; CHECK-NEXT: v_lshlrev_b64 v[2:3], v10, v[2:3]
; CHECK-NEXT: v_lshlrev_b32_e32 v5, v1, v4
; CHECK-NEXT: v_lshlrev_b32_e32 v1, v8, v0
; CHECK-NEXT: v_lshlrev_b32_e32 v3, v10, v2
; CHECK-NEXT: v_mov_b32_e32 v0, 0
; CHECK-NEXT: v_mov_b32_e32 v2, 0
; CHECK-NEXT: v_mov_b32_e32 v4, 0
; CHECK-NEXT: s_setpc_b64 s[30:31]
%shift.amt = load <3 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
%shl = shl <3 x i64> %arg0, %shift.amt
ret <3 x i64> %shl
}
; FIXME: This case should be reduced
define <4 x i64> @shl_v4_metadata(<4 x i64> %arg0, ptr %arg1.ptr) {
; CHECK-LABEL: shl_v4_metadata:
; CHECK: ; %bb.0:
@ -113,11 +115,15 @@ define <4 x i64> @shl_v4_metadata(<4 x i64> %arg0, ptr %arg1.ptr) {
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
; CHECK-NEXT: flat_load_dwordx4 v[13:16], v[8:9] offset:16
; CHECK-NEXT: ; kill: killed $vgpr8 killed $vgpr9
; CHECK-NEXT: v_lshlrev_b64 v[0:1], v10, v[0:1]
; CHECK-NEXT: v_lshlrev_b64 v[2:3], v12, v[2:3]
; CHECK-NEXT: v_lshlrev_b32_e32 v1, v10, v0
; CHECK-NEXT: v_lshlrev_b32_e32 v3, v12, v2
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
; CHECK-NEXT: v_lshlrev_b64 v[4:5], v13, v[4:5]
; CHECK-NEXT: v_lshlrev_b64 v[6:7], v15, v[6:7]
; CHECK-NEXT: v_lshlrev_b32_e32 v5, v13, v4
; CHECK-NEXT: v_lshlrev_b32_e32 v7, v15, v6
; CHECK-NEXT: v_mov_b32_e32 v0, 0
; CHECK-NEXT: v_mov_b32_e32 v2, 0
; CHECK-NEXT: v_mov_b32_e32 v4, 0
; CHECK-NEXT: v_mov_b32_e32 v6, 0
; CHECK-NEXT: s_setpc_b64 s[30:31]
%shift.amt = load <4 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
%shl = shl <4 x i64> %arg0, %shift.amt