llvm-project/clang/lib/Support/RISCVVIntrinsicUtils.cpp
Zakk Chen dffdca85ec [RISCV][Clang] Support policy functions for Vector Reduction
Instructions.

We will switch all UndefValue to PoisonValue in follow up patches.

Thanks for Kito to help on verification with their interanl testsuite.

Reviewed By: kito-cheng

Differential Revision: https://reviews.llvm.org/D126748
2022-08-02 17:27:56 +00:00

1130 lines
34 KiB
C++

//===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "clang/Support/RISCVVIntrinsicUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
#include <set>
#include <unordered_map>
using namespace llvm;
namespace clang {
namespace RISCV {
const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
const PrototypeDescriptor PrototypeDescriptor::VL =
PrototypeDescriptor(BaseTypeModifier::SizeT);
const PrototypeDescriptor PrototypeDescriptor::Vector =
PrototypeDescriptor(BaseTypeModifier::Vector);
//===----------------------------------------------------------------------===//
// Type implementation
//===----------------------------------------------------------------------===//
LMULType::LMULType(int NewLog2LMUL) {
// Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
Log2LMUL = NewLog2LMUL;
}
std::string LMULType::str() const {
if (Log2LMUL < 0)
return "mf" + utostr(1ULL << (-Log2LMUL));
return "m" + utostr(1ULL << Log2LMUL);
}
VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
int Log2ScaleResult = 0;
switch (ElementBitwidth) {
default:
break;
case 8:
Log2ScaleResult = Log2LMUL + 3;
break;
case 16:
Log2ScaleResult = Log2LMUL + 2;
break;
case 32:
Log2ScaleResult = Log2LMUL + 1;
break;
case 64:
Log2ScaleResult = Log2LMUL;
break;
}
// Illegal vscale result would be less than 1
if (Log2ScaleResult < 0)
return llvm::None;
return 1 << Log2ScaleResult;
}
void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
RVVType::RVVType(BasicType BT, int Log2LMUL,
const PrototypeDescriptor &prototype)
: BT(BT), LMUL(LMULType(Log2LMUL)) {
applyBasicType();
applyModifier(prototype);
Valid = verifyType();
if (Valid) {
initBuiltinStr();
initTypeStr();
if (isVector()) {
initClangBuiltinStr();
}
}
}
// clang-format off
// boolean type are encoded the ratio of n (SEW/LMUL)
// SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
// c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
// IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
// type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
// -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
// i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
// i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
// i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
// i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
// double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
// float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
// half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
// clang-format on
bool RVVType::verifyType() const {
if (ScalarType == Invalid)
return false;
if (isScalar())
return true;
if (!Scale)
return false;
if (isFloat() && ElementBitwidth == 8)
return false;
unsigned V = Scale.value();
switch (ElementBitwidth) {
case 1:
case 8:
// Check Scale is 1,2,4,8,16,32,64
return (V <= 64 && isPowerOf2_32(V));
case 16:
// Check Scale is 1,2,4,8,16,32
return (V <= 32 && isPowerOf2_32(V));
case 32:
// Check Scale is 1,2,4,8,16
return (V <= 16 && isPowerOf2_32(V));
case 64:
// Check Scale is 1,2,4,8
return (V <= 8 && isPowerOf2_32(V));
}
return false;
}
void RVVType::initBuiltinStr() {
assert(isValid() && "RVVType is invalid");
switch (ScalarType) {
case ScalarTypeKind::Void:
BuiltinStr = "v";
return;
case ScalarTypeKind::Size_t:
BuiltinStr = "z";
if (IsImmediate)
BuiltinStr = "I" + BuiltinStr;
if (IsPointer)
BuiltinStr += "*";
return;
case ScalarTypeKind::Ptrdiff_t:
BuiltinStr = "Y";
return;
case ScalarTypeKind::UnsignedLong:
BuiltinStr = "ULi";
return;
case ScalarTypeKind::SignedLong:
BuiltinStr = "Li";
return;
case ScalarTypeKind::Boolean:
assert(ElementBitwidth == 1);
BuiltinStr += "b";
break;
case ScalarTypeKind::SignedInteger:
case ScalarTypeKind::UnsignedInteger:
switch (ElementBitwidth) {
case 8:
BuiltinStr += "c";
break;
case 16:
BuiltinStr += "s";
break;
case 32:
BuiltinStr += "i";
break;
case 64:
BuiltinStr += "Wi";
break;
default:
llvm_unreachable("Unhandled ElementBitwidth!");
}
if (isSignedInteger())
BuiltinStr = "S" + BuiltinStr;
else
BuiltinStr = "U" + BuiltinStr;
break;
case ScalarTypeKind::Float:
switch (ElementBitwidth) {
case 16:
BuiltinStr += "x";
break;
case 32:
BuiltinStr += "f";
break;
case 64:
BuiltinStr += "d";
break;
default:
llvm_unreachable("Unhandled ElementBitwidth!");
}
break;
default:
llvm_unreachable("ScalarType is invalid!");
}
if (IsImmediate)
BuiltinStr = "I" + BuiltinStr;
if (isScalar()) {
if (IsConstant)
BuiltinStr += "C";
if (IsPointer)
BuiltinStr += "*";
return;
}
BuiltinStr = "q" + utostr(*Scale) + BuiltinStr;
// Pointer to vector types. Defined for segment load intrinsics.
// segment load intrinsics have pointer type arguments to store the loaded
// vector values.
if (IsPointer)
BuiltinStr += "*";
}
void RVVType::initClangBuiltinStr() {
assert(isValid() && "RVVType is invalid");
assert(isVector() && "Handle Vector type only");
ClangBuiltinStr = "__rvv_";
switch (ScalarType) {
case ScalarTypeKind::Boolean:
ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t";
return;
case ScalarTypeKind::Float:
ClangBuiltinStr += "float";
break;
case ScalarTypeKind::SignedInteger:
ClangBuiltinStr += "int";
break;
case ScalarTypeKind::UnsignedInteger:
ClangBuiltinStr += "uint";
break;
default:
llvm_unreachable("ScalarTypeKind is invalid");
}
ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
}
void RVVType::initTypeStr() {
assert(isValid() && "RVVType is invalid");
if (IsConstant)
Str += "const ";
auto getTypeString = [&](StringRef TypeStr) {
if (isScalar())
return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
.str();
};
switch (ScalarType) {
case ScalarTypeKind::Void:
Str = "void";
return;
case ScalarTypeKind::Size_t:
Str = "size_t";
if (IsPointer)
Str += " *";
return;
case ScalarTypeKind::Ptrdiff_t:
Str = "ptrdiff_t";
return;
case ScalarTypeKind::UnsignedLong:
Str = "unsigned long";
return;
case ScalarTypeKind::SignedLong:
Str = "long";
return;
case ScalarTypeKind::Boolean:
if (isScalar())
Str += "bool";
else
// Vector bool is special case, the formulate is
// `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
Str += "vbool" + utostr(64 / *Scale) + "_t";
break;
case ScalarTypeKind::Float:
if (isScalar()) {
if (ElementBitwidth == 64)
Str += "double";
else if (ElementBitwidth == 32)
Str += "float";
else if (ElementBitwidth == 16)
Str += "_Float16";
else
llvm_unreachable("Unhandled floating type.");
} else
Str += getTypeString("float");
break;
case ScalarTypeKind::SignedInteger:
Str += getTypeString("int");
break;
case ScalarTypeKind::UnsignedInteger:
Str += getTypeString("uint");
break;
default:
llvm_unreachable("ScalarType is invalid!");
}
if (IsPointer)
Str += " *";
}
void RVVType::initShortStr() {
switch (ScalarType) {
case ScalarTypeKind::Boolean:
assert(isVector());
ShortStr = "b" + utostr(64 / *Scale);
return;
case ScalarTypeKind::Float:
ShortStr = "f" + utostr(ElementBitwidth);
break;
case ScalarTypeKind::SignedInteger:
ShortStr = "i" + utostr(ElementBitwidth);
break;
case ScalarTypeKind::UnsignedInteger:
ShortStr = "u" + utostr(ElementBitwidth);
break;
default:
llvm_unreachable("Unhandled case!");
}
if (isVector())
ShortStr += LMUL.str();
}
void RVVType::applyBasicType() {
switch (BT) {
case BasicType::Int8:
ElementBitwidth = 8;
ScalarType = ScalarTypeKind::SignedInteger;
break;
case BasicType::Int16:
ElementBitwidth = 16;
ScalarType = ScalarTypeKind::SignedInteger;
break;
case BasicType::Int32:
ElementBitwidth = 32;
ScalarType = ScalarTypeKind::SignedInteger;
break;
case BasicType::Int64:
ElementBitwidth = 64;
ScalarType = ScalarTypeKind::SignedInteger;
break;
case BasicType::Float16:
ElementBitwidth = 16;
ScalarType = ScalarTypeKind::Float;
break;
case BasicType::Float32:
ElementBitwidth = 32;
ScalarType = ScalarTypeKind::Float;
break;
case BasicType::Float64:
ElementBitwidth = 64;
ScalarType = ScalarTypeKind::Float;
break;
default:
llvm_unreachable("Unhandled type code!");
}
assert(ElementBitwidth != 0 && "Bad element bitwidth!");
}
Optional<PrototypeDescriptor> PrototypeDescriptor::parsePrototypeDescriptor(
llvm::StringRef PrototypeDescriptorStr) {
PrototypeDescriptor PD;
BaseTypeModifier PT = BaseTypeModifier::Invalid;
VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
if (PrototypeDescriptorStr.empty())
return PD;
// Handle base type modifier
auto PType = PrototypeDescriptorStr.back();
switch (PType) {
case 'e':
PT = BaseTypeModifier::Scalar;
break;
case 'v':
PT = BaseTypeModifier::Vector;
break;
case 'w':
PT = BaseTypeModifier::Vector;
VTM = VectorTypeModifier::Widening2XVector;
break;
case 'q':
PT = BaseTypeModifier::Vector;
VTM = VectorTypeModifier::Widening4XVector;
break;
case 'o':
PT = BaseTypeModifier::Vector;
VTM = VectorTypeModifier::Widening8XVector;
break;
case 'm':
PT = BaseTypeModifier::Vector;
VTM = VectorTypeModifier::MaskVector;
break;
case '0':
PT = BaseTypeModifier::Void;
break;
case 'z':
PT = BaseTypeModifier::SizeT;
break;
case 't':
PT = BaseTypeModifier::Ptrdiff;
break;
case 'u':
PT = BaseTypeModifier::UnsignedLong;
break;
case 'l':
PT = BaseTypeModifier::SignedLong;
break;
default:
llvm_unreachable("Illegal primitive type transformers!");
}
PD.PT = static_cast<uint8_t>(PT);
PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
// Compute the vector type transformers, it can only appear one time.
if (PrototypeDescriptorStr.startswith("(")) {
assert(VTM == VectorTypeModifier::NoModifier &&
"VectorTypeModifier should only have one modifier");
size_t Idx = PrototypeDescriptorStr.find(')');
assert(Idx != StringRef::npos);
StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
assert(!PrototypeDescriptorStr.contains('(') &&
"Only allow one vector type modifier");
auto ComplexTT = ComplexType.split(":");
if (ComplexTT.first == "Log2EEW") {
uint32_t Log2EEW;
if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
llvm_unreachable("Invalid Log2EEW value!");
return None;
}
switch (Log2EEW) {
case 3:
VTM = VectorTypeModifier::Log2EEW3;
break;
case 4:
VTM = VectorTypeModifier::Log2EEW4;
break;
case 5:
VTM = VectorTypeModifier::Log2EEW5;
break;
case 6:
VTM = VectorTypeModifier::Log2EEW6;
break;
default:
llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
return None;
}
} else if (ComplexTT.first == "FixedSEW") {
uint32_t NewSEW;
if (ComplexTT.second.getAsInteger(10, NewSEW)) {
llvm_unreachable("Invalid FixedSEW value!");
return None;
}
switch (NewSEW) {
case 8:
VTM = VectorTypeModifier::FixedSEW8;
break;
case 16:
VTM = VectorTypeModifier::FixedSEW16;
break;
case 32:
VTM = VectorTypeModifier::FixedSEW32;
break;
case 64:
VTM = VectorTypeModifier::FixedSEW64;
break;
default:
llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
return None;
}
} else if (ComplexTT.first == "LFixedLog2LMUL") {
int32_t Log2LMUL;
if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
llvm_unreachable("Invalid LFixedLog2LMUL value!");
return None;
}
switch (Log2LMUL) {
case -3:
VTM = VectorTypeModifier::LFixedLog2LMULN3;
break;
case -2:
VTM = VectorTypeModifier::LFixedLog2LMULN2;
break;
case -1:
VTM = VectorTypeModifier::LFixedLog2LMULN1;
break;
case 0:
VTM = VectorTypeModifier::LFixedLog2LMUL0;
break;
case 1:
VTM = VectorTypeModifier::LFixedLog2LMUL1;
break;
case 2:
VTM = VectorTypeModifier::LFixedLog2LMUL2;
break;
case 3:
VTM = VectorTypeModifier::LFixedLog2LMUL3;
break;
default:
llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
return None;
}
} else if (ComplexTT.first == "SFixedLog2LMUL") {
int32_t Log2LMUL;
if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
llvm_unreachable("Invalid SFixedLog2LMUL value!");
return None;
}
switch (Log2LMUL) {
case -3:
VTM = VectorTypeModifier::SFixedLog2LMULN3;
break;
case -2:
VTM = VectorTypeModifier::SFixedLog2LMULN2;
break;
case -1:
VTM = VectorTypeModifier::SFixedLog2LMULN1;
break;
case 0:
VTM = VectorTypeModifier::SFixedLog2LMUL0;
break;
case 1:
VTM = VectorTypeModifier::SFixedLog2LMUL1;
break;
case 2:
VTM = VectorTypeModifier::SFixedLog2LMUL2;
break;
case 3:
VTM = VectorTypeModifier::SFixedLog2LMUL3;
break;
default:
llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
return None;
}
} else {
llvm_unreachable("Illegal complex type transformers!");
}
}
PD.VTM = static_cast<uint8_t>(VTM);
// Compute the remain type transformers
TypeModifier TM = TypeModifier::NoModifier;
for (char I : PrototypeDescriptorStr) {
switch (I) {
case 'P':
if ((TM & TypeModifier::Const) == TypeModifier::Const)
llvm_unreachable("'P' transformer cannot be used after 'C'");
if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
llvm_unreachable("'P' transformer cannot be used twice");
TM |= TypeModifier::Pointer;
break;
case 'C':
TM |= TypeModifier::Const;
break;
case 'K':
TM |= TypeModifier::Immediate;
break;
case 'U':
TM |= TypeModifier::UnsignedInteger;
break;
case 'I':
TM |= TypeModifier::SignedInteger;
break;
case 'F':
TM |= TypeModifier::Float;
break;
case 'S':
TM |= TypeModifier::LMUL1;
break;
default:
llvm_unreachable("Illegal non-primitive type transformer!");
}
}
PD.TM = static_cast<uint8_t>(TM);
return PD;
}
void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
// Handle primitive type transformer
switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
case BaseTypeModifier::Scalar:
Scale = 0;
break;
case BaseTypeModifier::Vector:
Scale = LMUL.getScale(ElementBitwidth);
break;
case BaseTypeModifier::Void:
ScalarType = ScalarTypeKind::Void;
break;
case BaseTypeModifier::SizeT:
ScalarType = ScalarTypeKind::Size_t;
break;
case BaseTypeModifier::Ptrdiff:
ScalarType = ScalarTypeKind::Ptrdiff_t;
break;
case BaseTypeModifier::UnsignedLong:
ScalarType = ScalarTypeKind::UnsignedLong;
break;
case BaseTypeModifier::SignedLong:
ScalarType = ScalarTypeKind::SignedLong;
break;
case BaseTypeModifier::Invalid:
ScalarType = ScalarTypeKind::Invalid;
return;
}
switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
case VectorTypeModifier::Widening2XVector:
ElementBitwidth *= 2;
LMUL.MulLog2LMUL(1);
Scale = LMUL.getScale(ElementBitwidth);
break;
case VectorTypeModifier::Widening4XVector:
ElementBitwidth *= 4;
LMUL.MulLog2LMUL(2);
Scale = LMUL.getScale(ElementBitwidth);
break;
case VectorTypeModifier::Widening8XVector:
ElementBitwidth *= 8;
LMUL.MulLog2LMUL(3);
Scale = LMUL.getScale(ElementBitwidth);
break;
case VectorTypeModifier::MaskVector:
ScalarType = ScalarTypeKind::Boolean;
Scale = LMUL.getScale(ElementBitwidth);
ElementBitwidth = 1;
break;
case VectorTypeModifier::Log2EEW3:
applyLog2EEW(3);
break;
case VectorTypeModifier::Log2EEW4:
applyLog2EEW(4);
break;
case VectorTypeModifier::Log2EEW5:
applyLog2EEW(5);
break;
case VectorTypeModifier::Log2EEW6:
applyLog2EEW(6);
break;
case VectorTypeModifier::FixedSEW8:
applyFixedSEW(8);
break;
case VectorTypeModifier::FixedSEW16:
applyFixedSEW(16);
break;
case VectorTypeModifier::FixedSEW32:
applyFixedSEW(32);
break;
case VectorTypeModifier::FixedSEW64:
applyFixedSEW(64);
break;
case VectorTypeModifier::LFixedLog2LMULN3:
applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMULN2:
applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMULN1:
applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMUL0:
applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMUL1:
applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMUL2:
applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::LFixedLog2LMUL3:
applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
break;
case VectorTypeModifier::SFixedLog2LMULN3:
applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMULN2:
applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMULN1:
applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMUL0:
applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMUL1:
applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMUL2:
applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::SFixedLog2LMUL3:
applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
break;
case VectorTypeModifier::NoModifier:
break;
}
for (unsigned TypeModifierMaskShift = 0;
TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
++TypeModifierMaskShift) {
unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
TypeModifierMask)
continue;
switch (static_cast<TypeModifier>(TypeModifierMask)) {
case TypeModifier::Pointer:
IsPointer = true;
break;
case TypeModifier::Const:
IsConstant = true;
break;
case TypeModifier::Immediate:
IsImmediate = true;
IsConstant = true;
break;
case TypeModifier::UnsignedInteger:
ScalarType = ScalarTypeKind::UnsignedInteger;
break;
case TypeModifier::SignedInteger:
ScalarType = ScalarTypeKind::SignedInteger;
break;
case TypeModifier::Float:
ScalarType = ScalarTypeKind::Float;
break;
case TypeModifier::LMUL1:
LMUL = LMULType(0);
// Update ElementBitwidth need to update Scale too.
Scale = LMUL.getScale(ElementBitwidth);
break;
default:
llvm_unreachable("Unknown type modifier mask!");
}
}
}
void RVVType::applyLog2EEW(unsigned Log2EEW) {
// update new elmul = (eew/sew) * lmul
LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
// update new eew
ElementBitwidth = 1 << Log2EEW;
ScalarType = ScalarTypeKind::SignedInteger;
Scale = LMUL.getScale(ElementBitwidth);
}
void RVVType::applyFixedSEW(unsigned NewSEW) {
// Set invalid type if src and dst SEW are same.
if (ElementBitwidth == NewSEW) {
ScalarType = ScalarTypeKind::Invalid;
return;
}
// Update new SEW
ElementBitwidth = NewSEW;
Scale = LMUL.getScale(ElementBitwidth);
}
void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
switch (Type) {
case FixedLMULType::LargerThan:
if (Log2LMUL < LMUL.Log2LMUL) {
ScalarType = ScalarTypeKind::Invalid;
return;
}
break;
case FixedLMULType::SmallerThan:
if (Log2LMUL > LMUL.Log2LMUL) {
ScalarType = ScalarTypeKind::Invalid;
return;
}
break;
}
// Update new LMUL
LMUL = LMULType(Log2LMUL);
Scale = LMUL.getScale(ElementBitwidth);
}
Optional<RVVTypes>
RVVType::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
ArrayRef<PrototypeDescriptor> Prototype) {
// LMUL x NF must be less than or equal to 8.
if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
return llvm::None;
RVVTypes Types;
for (const PrototypeDescriptor &Proto : Prototype) {
auto T = computeType(BT, Log2LMUL, Proto);
if (!T)
return llvm::None;
// Record legal type index
Types.push_back(T.value());
}
return Types;
}
// Compute the hash value of RVVType, used for cache the result of computeType.
static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
PrototypeDescriptor Proto) {
// Layout of hash value:
// 0 8 16 24 32 40
// | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
assert(Log2LMUL >= -3 && Log2LMUL <= 3);
return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
((uint64_t)(Proto.PT & 0xff) << 16) |
((uint64_t)(Proto.TM & 0xff) << 24) |
((uint64_t)(Proto.VTM & 0xff) << 32);
}
Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL,
PrototypeDescriptor Proto) {
// Concat BasicType, LMUL and Proto as key
static std::unordered_map<uint64_t, RVVType> LegalTypes;
static std::set<uint64_t> IllegalTypes;
uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
// Search first
auto It = LegalTypes.find(Idx);
if (It != LegalTypes.end())
return &(It->second);
if (IllegalTypes.count(Idx))
return llvm::None;
// Compute type and record the result.
RVVType T(BT, Log2LMUL, Proto);
if (T.isValid()) {
// Record legal type index and value.
LegalTypes.insert({Idx, T});
return &(LegalTypes[Idx]);
}
// Record illegal type index.
IllegalTypes.insert(Idx);
return llvm::None;
}
//===----------------------------------------------------------------------===//
// RVVIntrinsic implementation
//===----------------------------------------------------------------------===//
RVVIntrinsic::RVVIntrinsic(
StringRef NewName, StringRef Suffix, StringRef NewOverloadedName,
StringRef OverloadedSuffix, StringRef IRName, bool IsMasked,
bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
bool SupportOverloading, bool HasBuiltinAlias, StringRef ManualCodegen,
const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
const std::vector<StringRef> &RequiredFeatures, unsigned NF,
Policy NewDefaultPolicy, bool IsPrototypeDefaultTU)
: IRName(IRName), IsMasked(IsMasked),
HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
ManualCodegen(ManualCodegen.str()), NF(NF),
DefaultPolicy(NewDefaultPolicy) {
// Init BuiltinName, Name and OverloadedName
BuiltinName = NewName.str();
Name = BuiltinName;
if (NewOverloadedName.empty())
OverloadedName = NewName.split("_").first.str();
else
OverloadedName = NewOverloadedName.str();
if (!Suffix.empty())
Name += "_" + Suffix.str();
if (!OverloadedSuffix.empty())
OverloadedName += "_" + OverloadedSuffix.str();
updateNamesAndPolicy(IsMasked, hasPolicy(), IsPrototypeDefaultTU, Name,
BuiltinName, OverloadedName, DefaultPolicy);
// Init OutputType and InputTypes
OutputType = OutInTypes[0];
InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
// IntrinsicTypes is unmasked TA version index. Need to update it
// if there is merge operand (It is always in first operand).
IntrinsicTypes = NewIntrinsicTypes;
if ((IsMasked && hasMaskedOffOperand()) ||
(!IsMasked && hasPassthruOperand() && !IsPrototypeDefaultTU)) {
for (auto &I : IntrinsicTypes) {
if (I >= 0)
I += NF;
}
}
}
std::string RVVIntrinsic::getBuiltinTypeStr() const {
std::string S;
S += OutputType->getBuiltinStr();
for (const auto &T : InputTypes) {
S += T->getBuiltinStr();
}
return S;
}
std::string RVVIntrinsic::getSuffixStr(
BasicType Type, int Log2LMUL,
llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
SmallVector<std::string> SuffixStrs;
for (auto PD : PrototypeDescriptors) {
auto T = RVVType::computeType(Type, Log2LMUL, PD);
SuffixStrs.push_back((*T)->getShortStr());
}
return join(SuffixStrs, "_");
}
llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes(
llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked,
bool HasMaskedOffOperand, bool HasVL, unsigned NF,
bool IsPrototypeDefaultTU, PolicyScheme DefaultScheme,
Policy DefaultPolicy) {
SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
Prototype.end());
// Update DefaultPolicy if need (TA or TAMA) for compute builtin types.
switch (DefaultPolicy) {
case Policy::MA:
DefaultPolicy = Policy::TAMA;
break;
case Policy::TAM:
DefaultPolicy = Policy::TAMA;
break;
case Policy::PolicyNone:
// Masked with no policy would not be TAMA.
if (!IsMasked) {
if (IsPrototypeDefaultTU)
DefaultPolicy = Policy::TU;
else
DefaultPolicy = Policy::TA;
}
break;
default:
break;
}
bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand;
if (IsMasked) {
// If HasMaskedOffOperand, insert result type as first input operand if
// need.
if (HasMaskedOffOperand) {
if (NF == 1 && DefaultPolicy != Policy::TAMA) {
NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
} else if (NF > 1) {
// Convert
// (void, op0 address, op1 address, ...)
// to
// (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
PrototypeDescriptor MaskoffType = NewPrototype[1];
MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
for (unsigned I = 0; I < NF; ++I)
NewPrototype.insert(NewPrototype.begin() + NF + 1, MaskoffType);
}
}
// Erase passthru operand for TAM
if (NF == 1 && IsPrototypeDefaultTU && DefaultPolicy == Policy::TAMA &&
HasPassthruOp && !HasMaskedOffOperand)
NewPrototype.erase(NewPrototype.begin() + 1);
if (HasMaskedOffOperand && NF > 1) {
// Convert
// (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
// to
// (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
// ...)
NewPrototype.insert(NewPrototype.begin() + NF + 1,
PrototypeDescriptor::Mask);
} else {
// If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
}
} else if (NF == 1) {
if (DefaultPolicy == Policy::TU && HasPassthruOp && !IsPrototypeDefaultTU)
NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]);
else if (DefaultPolicy == Policy::TA && HasPassthruOp &&
IsPrototypeDefaultTU)
NewPrototype.erase(NewPrototype.begin() + 1);
if (DefaultScheme == PolicyScheme::HasPassthruOperandAtIdx1) {
if (DefaultPolicy == Policy::TU && !IsPrototypeDefaultTU) {
// Insert undisturbed output to index 1
NewPrototype.insert(NewPrototype.begin() + 2, NewPrototype[0]);
} else if (DefaultPolicy == Policy::TA && IsPrototypeDefaultTU) {
// Erase passthru for TA policy
NewPrototype.erase(NewPrototype.begin() + 2);
}
}
}
// If HasVL, append PrototypeDescriptor:VL to last operand
if (HasVL)
NewPrototype.push_back(PrototypeDescriptor::VL);
return NewPrototype;
}
llvm::SmallVector<Policy>
RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy,
bool HasMaskPolicy) {
if (HasTailPolicy && HasMaskPolicy)
return {Policy::TUMA, Policy::TAMA, Policy::TUMU, Policy::TAMU};
else if (HasTailPolicy)
return {Policy::TUM, Policy::TAM};
return {Policy::MA, Policy::MU};
}
void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked, bool HasPolicy,
bool IsPrototypeDefaultTU,
std::string &Name,
std::string &BuiltinName,
std::string &OverloadedName,
Policy &DefaultPolicy) {
auto appendPolicySuffix = [&](const std::string &suffix) {
Name += suffix;
BuiltinName += suffix;
OverloadedName += suffix;
};
switch (DefaultPolicy) {
case Policy::TU:
appendPolicySuffix("_tu");
break;
case Policy::TA:
appendPolicySuffix("_ta");
break;
case Policy::MU:
appendPolicySuffix("_mu");
DefaultPolicy = Policy::TAMU;
break;
case Policy::MA:
appendPolicySuffix("_ma");
DefaultPolicy = Policy::TAMA;
break;
case Policy::TUM:
appendPolicySuffix("_tum");
DefaultPolicy = Policy::TUMA;
break;
case Policy::TAM:
appendPolicySuffix("_tam");
DefaultPolicy = Policy::TAMA;
break;
case Policy::TUMU:
appendPolicySuffix("_tumu");
break;
case Policy::TAMU:
appendPolicySuffix("_tamu");
break;
case Policy::TUMA:
appendPolicySuffix("_tuma");
break;
case Policy::TAMA:
appendPolicySuffix("_tama");
break;
default:
if (IsMasked) {
Name += "_m";
// FIXME: Currently _m default policy implementation is different with
// RVV intrinsic spec (TUMA)
DefaultPolicy = Policy::TUMU;
if (HasPolicy)
BuiltinName += "_tumu";
else
BuiltinName += "_m";
} else if (IsPrototypeDefaultTU) {
DefaultPolicy = Policy::TU;
if (HasPolicy)
BuiltinName += "_tu";
} else {
DefaultPolicy = Policy::TA;
if (HasPolicy)
BuiltinName += "_ta";
}
}
}
SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
SmallVector<PrototypeDescriptor> PrototypeDescriptors;
const StringRef Primaries("evwqom0ztul");
while (!Prototypes.empty()) {
size_t Idx = 0;
// Skip over complex prototype because it could contain primitive type
// character.
if (Prototypes[0] == '(')
Idx = Prototypes.find_first_of(')');
Idx = Prototypes.find_first_of(Primaries, Idx);
assert(Idx != StringRef::npos);
auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
Prototypes.slice(0, Idx + 1));
if (!PD)
llvm_unreachable("Error during parsing prototype.");
PrototypeDescriptors.push_back(*PD);
Prototypes = Prototypes.drop_front(Idx + 1);
}
return PrototypeDescriptors;
}
raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
OS << "{";
OS << "\"" << Record.Name << "\",";
if (Record.OverloadedName == nullptr ||
StringRef(Record.OverloadedName).empty())
OS << "nullptr,";
else
OS << "\"" << Record.OverloadedName << "\",";
OS << Record.PrototypeIndex << ",";
OS << Record.SuffixIndex << ",";
OS << Record.OverloadedSuffixIndex << ",";
OS << (int)Record.PrototypeLength << ",";
OS << (int)Record.SuffixLength << ",";
OS << (int)Record.OverloadedSuffixSize << ",";
OS << (int)Record.RequiredExtensions << ",";
OS << (int)Record.TypeRangeMask << ",";
OS << (int)Record.Log2LMULMask << ",";
OS << (int)Record.NF << ",";
OS << (int)Record.HasMasked << ",";
OS << (int)Record.HasVL << ",";
OS << (int)Record.HasMaskedOffOperand << ",";
OS << (int)Record.IsPrototypeDefaultTU << ",";
OS << (int)Record.HasTailPolicy << ",";
OS << (int)Record.HasMaskPolicy << ",";
OS << (int)Record.UnMaskedPolicyScheme << ",";
OS << (int)Record.MaskedPolicyScheme << ",";
OS << "},\n";
return OS;
}
} // end namespace RISCV
} // end namespace clang