mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-29 10:26:06 +00:00
[NVPTX] support immediate values in st.param instructions (#91523)
Add support for generating `st.param` instructions with direct use of immediates. This eliminates the need for a `mov` instruction prior to the `st.param` resulting in more concise emitted PTX.
This commit is contained in:
parent
58c778565c
commit
c5b11a710e
@ -2182,6 +2182,100 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
|
||||
#define getOpcV2H(ty, opKind0, opKind1) \
|
||||
NVPTX::StoreParamV2##ty##_##opKind0##opKind1
|
||||
|
||||
#define getOpcV2H1(ty, opKind0, isImm1) \
|
||||
(isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
|
||||
|
||||
#define getOpcodeForVectorStParamV2(ty, isimm) \
|
||||
(isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
|
||||
|
||||
#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3) \
|
||||
NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
|
||||
|
||||
#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3) \
|
||||
(isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \
|
||||
: getOpcV4H(ty, opKind0, opKind1, opKind2, r)
|
||||
|
||||
#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3) \
|
||||
(isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \
|
||||
: getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
|
||||
|
||||
#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3) \
|
||||
(isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \
|
||||
: getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
|
||||
|
||||
#define getOpcodeForVectorStParamV4(ty, isimm) \
|
||||
(isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \
|
||||
: getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
|
||||
|
||||
#define getOpcodeForVectorStParam(n, ty, isimm) \
|
||||
(n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \
|
||||
: getOpcodeForVectorStParamV4(ty, isimm)
|
||||
|
||||
static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
|
||||
unsigned NumElts,
|
||||
MVT::SimpleValueType MemTy,
|
||||
SelectionDAG *CurDAG, SDLoc DL) {
|
||||
// Determine which inputs are registers and immediates make new operators
|
||||
// with constant values
|
||||
SmallVector<bool, 4> IsImm(NumElts, false);
|
||||
for (unsigned i = 0; i < NumElts; i++) {
|
||||
IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
|
||||
if (IsImm[i]) {
|
||||
SDValue Imm = Ops[i];
|
||||
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
|
||||
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
|
||||
const ConstantFP *CF = ConstImm->getConstantFPValue();
|
||||
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
|
||||
} else {
|
||||
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
|
||||
const ConstantInt *CI = ConstImm->getConstantIntValue();
|
||||
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
|
||||
}
|
||||
Ops[i] = Imm;
|
||||
}
|
||||
}
|
||||
|
||||
// Get opcode for MemTy, size, and register/immediate operand ordering
|
||||
switch (MemTy) {
|
||||
case MVT::i8:
|
||||
return getOpcodeForVectorStParam(NumElts, I8, IsImm);
|
||||
case MVT::i16:
|
||||
return getOpcodeForVectorStParam(NumElts, I16, IsImm);
|
||||
case MVT::i32:
|
||||
return getOpcodeForVectorStParam(NumElts, I32, IsImm);
|
||||
case MVT::i64:
|
||||
assert(NumElts == 2 && "MVT too large for NumElts > 2");
|
||||
return getOpcodeForVectorStParamV2(I64, IsImm);
|
||||
case MVT::f32:
|
||||
return getOpcodeForVectorStParam(NumElts, F32, IsImm);
|
||||
case MVT::f64:
|
||||
assert(NumElts == 2 && "MVT too large for NumElts > 2");
|
||||
return getOpcodeForVectorStParamV2(F64, IsImm);
|
||||
|
||||
// These cases don't support immediates, just use the all register version
|
||||
// and generate moves.
|
||||
case MVT::i1:
|
||||
return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
|
||||
: NVPTX::StoreParamV4I8_rrrr;
|
||||
case MVT::f16:
|
||||
case MVT::bf16:
|
||||
return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
|
||||
: NVPTX::StoreParamV4I16_rrrr;
|
||||
case MVT::v2f16:
|
||||
case MVT::v2bf16:
|
||||
case MVT::v2i16:
|
||||
case MVT::v4i8:
|
||||
return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
|
||||
: NVPTX::StoreParamV4I32_rrrr;
|
||||
default:
|
||||
llvm_unreachable("Cannot select st.param for unknown MemTy");
|
||||
}
|
||||
}
|
||||
|
||||
bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
|
||||
SDLoc DL(N);
|
||||
SDValue Chain = N->getOperand(0);
|
||||
@ -2193,10 +2287,10 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
|
||||
SDValue Glue = N->getOperand(N->getNumOperands() - 1);
|
||||
|
||||
// How many elements do we have?
|
||||
unsigned NumElts = 1;
|
||||
unsigned NumElts;
|
||||
switch (N->getOpcode()) {
|
||||
default:
|
||||
return false;
|
||||
llvm_unreachable("Unexpected opcode");
|
||||
case NVPTXISD::StoreParamU32:
|
||||
case NVPTXISD::StoreParamS32:
|
||||
case NVPTXISD::StoreParam:
|
||||
@ -2222,18 +2316,40 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
|
||||
// Determine target opcode
|
||||
// If we have an i1, use an 8-bit store. The lowering code in
|
||||
// NVPTXISelLowering will have already emitted an upcast.
|
||||
std::optional<unsigned> Opcode = 0;
|
||||
std::optional<unsigned> Opcode;
|
||||
switch (N->getOpcode()) {
|
||||
default:
|
||||
switch (NumElts) {
|
||||
default:
|
||||
return false;
|
||||
case 1:
|
||||
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
|
||||
NVPTX::StoreParamI8, NVPTX::StoreParamI16,
|
||||
NVPTX::StoreParamI32, NVPTX::StoreParamI64,
|
||||
NVPTX::StoreParamF32, NVPTX::StoreParamF64);
|
||||
if (Opcode == NVPTX::StoreParamI8) {
|
||||
llvm_unreachable("Unexpected NumElts");
|
||||
case 1: {
|
||||
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
|
||||
SDValue Imm = Ops[0];
|
||||
if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
|
||||
(isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
|
||||
// Convert immediate to target constant
|
||||
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
|
||||
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
|
||||
const ConstantFP *CF = ConstImm->getConstantFPValue();
|
||||
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
|
||||
} else {
|
||||
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
|
||||
const ConstantInt *CI = ConstImm->getConstantIntValue();
|
||||
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
|
||||
}
|
||||
Ops[0] = Imm;
|
||||
// Use immediate version of store param
|
||||
Opcode = pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i,
|
||||
NVPTX::StoreParamI16_i, NVPTX::StoreParamI32_i,
|
||||
NVPTX::StoreParamI64_i, NVPTX::StoreParamF32_i,
|
||||
NVPTX::StoreParamF64_i);
|
||||
} else
|
||||
Opcode =
|
||||
pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
|
||||
NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
|
||||
NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r,
|
||||
NVPTX::StoreParamF32_r, NVPTX::StoreParamF64_r);
|
||||
if (Opcode == NVPTX::StoreParamI8_r) {
|
||||
// Fine tune the opcode depending on the size of the operand.
|
||||
// This helps to avoid creating redundant COPY instructions in
|
||||
// InstrEmitter::AddRegisterOperand().
|
||||
@ -2241,35 +2357,28 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
|
||||
default:
|
||||
break;
|
||||
case MVT::i32:
|
||||
Opcode = NVPTX::StoreParamI8TruncI32;
|
||||
Opcode = NVPTX::StoreParamI8TruncI32_r;
|
||||
break;
|
||||
case MVT::i64:
|
||||
Opcode = NVPTX::StoreParamI8TruncI64;
|
||||
Opcode = NVPTX::StoreParamI8TruncI64_r;
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
|
||||
NVPTX::StoreParamV2I8, NVPTX::StoreParamV2I16,
|
||||
NVPTX::StoreParamV2I32, NVPTX::StoreParamV2I64,
|
||||
NVPTX::StoreParamV2F32, NVPTX::StoreParamV2F64);
|
||||
break;
|
||||
case 4:
|
||||
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
|
||||
NVPTX::StoreParamV4I8, NVPTX::StoreParamV4I16,
|
||||
NVPTX::StoreParamV4I32, std::nullopt,
|
||||
NVPTX::StoreParamV4F32, std::nullopt);
|
||||
case 4: {
|
||||
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
|
||||
Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
|
||||
break;
|
||||
}
|
||||
if (!Opcode)
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
// Special case: if we have a sign-extend/zero-extend node, insert the
|
||||
// conversion instruction first, and use that as the value operand to
|
||||
// the selected StoreParam node.
|
||||
case NVPTXISD::StoreParamU32: {
|
||||
Opcode = NVPTX::StoreParamI32;
|
||||
Opcode = NVPTX::StoreParamI32_r;
|
||||
SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
|
||||
MVT::i32);
|
||||
SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u32_u16, DL,
|
||||
@ -2278,7 +2387,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
|
||||
break;
|
||||
}
|
||||
case NVPTXISD::StoreParamS32: {
|
||||
Opcode = NVPTX::StoreParamI32;
|
||||
Opcode = NVPTX::StoreParamI32_r;
|
||||
SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
|
||||
MVT::i32);
|
||||
SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_s32_s16, DL,
|
||||
|
@ -2637,25 +2637,46 @@ class LoadParamRegInst<NVPTXRegClass regclass, string opstr> :
|
||||
[(set regclass:$dst, (LoadParam (i32 0), (i32 imm:$b)))]>;
|
||||
|
||||
let mayStore = true in {
|
||||
class StoreParamInst<NVPTXRegClass regclass, string opstr> :
|
||||
NVPTXInst<(outs), (ins regclass:$val, i32imm:$a, i32imm:$b),
|
||||
!strconcat("st.param", opstr, " \t[param$a+$b], $val;"),
|
||||
[]>;
|
||||
|
||||
class StoreParamV2Inst<NVPTXRegClass regclass, string opstr> :
|
||||
NVPTXInst<(outs), (ins regclass:$val, regclass:$val2,
|
||||
multiclass StoreParamInst<NVPTXRegClass regclass, Operand IMMType, string opstr, bit support_imm = true> {
|
||||
foreach op = [IMMType, regclass] in
|
||||
if !or(support_imm, !isa<NVPTXRegClass>(op)) then
|
||||
def _ # !if(!isa<NVPTXRegClass>(op), "r", "i")
|
||||
: NVPTXInst<(outs),
|
||||
(ins op:$val, i32imm:$a, i32imm:$b),
|
||||
"st.param" # opstr # " \t[param$a+$b], $val;",
|
||||
[]>;
|
||||
}
|
||||
|
||||
multiclass StoreParamV2Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
|
||||
foreach op1 = [IMMType, regclass] in
|
||||
foreach op2 = [IMMType, regclass] in
|
||||
def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
|
||||
# !if(!isa<NVPTXRegClass>(op2), "r", "i")
|
||||
: NVPTXInst<(outs),
|
||||
(ins op1:$val1, op2:$val2,
|
||||
i32imm:$a, i32imm:$b),
|
||||
"st.param.v2" # opstr # " \t[param$a+$b], {{$val1, $val2}};",
|
||||
[]>;
|
||||
}
|
||||
|
||||
multiclass StoreParamV4Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
|
||||
foreach op1 = [IMMType, regclass] in
|
||||
foreach op2 = [IMMType, regclass] in
|
||||
foreach op3 = [IMMType, regclass] in
|
||||
foreach op4 = [IMMType, regclass] in
|
||||
def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
|
||||
# !if(!isa<NVPTXRegClass>(op2), "r", "i")
|
||||
# !if(!isa<NVPTXRegClass>(op3), "r", "i")
|
||||
# !if(!isa<NVPTXRegClass>(op4), "r", "i")
|
||||
|
||||
: NVPTXInst<(outs),
|
||||
(ins op1:$val1, op2:$val2, op3:$val3, op4:$val4,
|
||||
i32imm:$a, i32imm:$b),
|
||||
!strconcat("st.param.v2", opstr,
|
||||
" \t[param$a+$b], {{$val, $val2}};"),
|
||||
[]>;
|
||||
|
||||
class StoreParamV4Inst<NVPTXRegClass regclass, string opstr> :
|
||||
NVPTXInst<(outs), (ins regclass:$val, regclass:$val2, regclass:$val3,
|
||||
regclass:$val4, i32imm:$a,
|
||||
i32imm:$b),
|
||||
!strconcat("st.param.v4", opstr,
|
||||
" \t[param$a+$b], {{$val, $val2, $val3, $val4}};"),
|
||||
[]>;
|
||||
"st.param.v4" # opstr #
|
||||
" \t[param$a+$b], {{$val1, $val2, $val3, $val4}};",
|
||||
[]>;
|
||||
}
|
||||
|
||||
class StoreRetvalInst<NVPTXRegClass regclass, string opstr> :
|
||||
NVPTXInst<(outs), (ins regclass:$val, i32imm:$a),
|
||||
@ -2735,27 +2756,30 @@ def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">;
|
||||
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".f64">;
|
||||
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".f32">;
|
||||
|
||||
def StoreParamI64 : StoreParamInst<Int64Regs, ".b64">;
|
||||
def StoreParamI32 : StoreParamInst<Int32Regs, ".b32">;
|
||||
defm StoreParamI64 : StoreParamInst<Int64Regs, i64imm, ".b64">;
|
||||
defm StoreParamI32 : StoreParamInst<Int32Regs, i32imm, ".b32">;
|
||||
defm StoreParamI16 : StoreParamInst<Int16Regs, i16imm, ".b16">;
|
||||
defm StoreParamI8 : StoreParamInst<Int16Regs, i8imm, ".b8">;
|
||||
|
||||
def StoreParamI16 : StoreParamInst<Int16Regs, ".b16">;
|
||||
def StoreParamI8 : StoreParamInst<Int16Regs, ".b8">;
|
||||
def StoreParamI8TruncI32 : StoreParamInst<Int32Regs, ".b8">;
|
||||
def StoreParamI8TruncI64 : StoreParamInst<Int64Regs, ".b8">;
|
||||
def StoreParamV2I64 : StoreParamV2Inst<Int64Regs, ".b64">;
|
||||
def StoreParamV2I32 : StoreParamV2Inst<Int32Regs, ".b32">;
|
||||
def StoreParamV2I16 : StoreParamV2Inst<Int16Regs, ".b16">;
|
||||
def StoreParamV2I8 : StoreParamV2Inst<Int16Regs, ".b8">;
|
||||
defm StoreParamI8TruncI32 : StoreParamInst<Int32Regs, i8imm, ".b8", /* support_imm */ false>;
|
||||
defm StoreParamI8TruncI64 : StoreParamInst<Int64Regs, i8imm, ".b8", /* support_imm */ false>;
|
||||
|
||||
def StoreParamV4I32 : StoreParamV4Inst<Int32Regs, ".b32">;
|
||||
def StoreParamV4I16 : StoreParamV4Inst<Int16Regs, ".b16">;
|
||||
def StoreParamV4I8 : StoreParamV4Inst<Int16Regs, ".b8">;
|
||||
defm StoreParamV2I64 : StoreParamV2Inst<Int64Regs, i64imm, ".b64">;
|
||||
defm StoreParamV2I32 : StoreParamV2Inst<Int32Regs, i32imm, ".b32">;
|
||||
defm StoreParamV2I16 : StoreParamV2Inst<Int16Regs, i16imm, ".b16">;
|
||||
defm StoreParamV2I8 : StoreParamV2Inst<Int16Regs, i8imm, ".b8">;
|
||||
|
||||
def StoreParamF32 : StoreParamInst<Float32Regs, ".f32">;
|
||||
def StoreParamF64 : StoreParamInst<Float64Regs, ".f64">;
|
||||
def StoreParamV2F32 : StoreParamV2Inst<Float32Regs, ".f32">;
|
||||
def StoreParamV2F64 : StoreParamV2Inst<Float64Regs, ".f64">;
|
||||
def StoreParamV4F32 : StoreParamV4Inst<Float32Regs, ".f32">;
|
||||
defm StoreParamV4I32 : StoreParamV4Inst<Int32Regs, i32imm, ".b32">;
|
||||
defm StoreParamV4I16 : StoreParamV4Inst<Int16Regs, i16imm, ".b16">;
|
||||
defm StoreParamV4I8 : StoreParamV4Inst<Int16Regs, i8imm, ".b8">;
|
||||
|
||||
defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".f32">;
|
||||
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".f64">;
|
||||
|
||||
defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".f32">;
|
||||
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".f64">;
|
||||
|
||||
defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".f32">;
|
||||
|
||||
def StoreRetvalI64 : StoreRetvalInst<Int64Regs, ".b64">;
|
||||
def StoreRetvalI32 : StoreRetvalInst<Int32Regs, ".b32">;
|
||||
|
2002
llvm/test/CodeGen/NVPTX/st-param-imm.ll
Normal file
2002
llvm/test/CodeGen/NVPTX/st-param-imm.ll
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user