Lei Zhang a29fffc475 [mlir][spirv] Migrate to use specalized enum attributes
Previously we are using IntegerAttr to back all SPIR-V enum
attributes. Therefore we all such attributes are showed like
IntegerAttr in IRs, which is barely readable and breaks
roundtripability of the IR. This commit changes to use
`EnumAttr` as the base directly so that we can have separate
attribute definitions and better IR printing.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D131311
2022-08-09 14:14:54 -04:00

1234 lines
44 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
//
//===----------------------------------------------------------------------===//
#include "Serializer.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/Debug.h"
#include <cstdint>
#define DEBUG_TYPE "spirv-serialization"
using namespace mlir;
/// Returns the merge block if the given `op` is a structured control flow op.
/// Otherwise returns nullptr.
static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
return selectionOp.getMergeBlock();
if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
return loopOp.getMergeBlock();
return nullptr;
}
/// Given a predecessor `block` for a block with arguments, returns the block
/// that should be used as the parent block for SPIR-V OpPhi instructions
/// corresponding to the block arguments.
static Block *getPhiIncomingBlock(Block *block) {
// If the predecessor block in question is the entry block for a
// spv.mlir.loop, we jump to this spv.mlir.loop from its enclosing block.
if (block->isEntryBlock()) {
if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
// Then the incoming parent block for OpPhi should be the merge block of
// the structured control flow op before this loop.
Operation *op = loopOp.getOperation();
while ((op = op->getPrevNode()) != nullptr)
if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
return incomingBlock;
// Or the enclosing block itself if no structured control flow ops
// exists before this loop.
return loopOp->getBlock();
}
}
// Otherwise, we jump from the given predecessor block. Try to see if there is
// a structured control flow op inside it.
for (Operation &op : llvm::reverse(block->getOperations())) {
if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
return incomingBlock;
}
return block;
}
namespace mlir {
namespace spirv {
/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
/// the given `binary` vector.
void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
ArrayRef<uint32_t> operands) {
uint32_t wordCount = 1 + operands.size();
binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
binary.append(operands.begin(), operands.end());
}
Serializer::Serializer(spirv::ModuleOp module,
const SerializationOptions &options)
: module(module), mlirBuilder(module.getContext()), options(options) {}
LogicalResult Serializer::serialize() {
LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
if (failed(module.verifyInvariants()))
return failure();
// TODO: handle the other sections
processCapability();
processExtension();
processMemoryModel();
processDebugInfo();
// Iterate over the module body to serialize it. Assumptions are that there is
// only one basic block in the moduleOp
for (auto &op : *module.getBody()) {
if (failed(processOperation(&op))) {
return failure();
}
}
LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
return success();
}
void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
extensions.size() + extendedSets.size() +
memoryModel.size() + entryPoints.size() +
executionModes.size() + decorations.size() +
typesGlobalValues.size() + functions.size();
binary.clear();
binary.reserve(moduleSize);
spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID);
binary.append(capabilities.begin(), capabilities.end());
binary.append(extensions.begin(), extensions.end());
binary.append(extendedSets.begin(), extendedSets.end());
binary.append(memoryModel.begin(), memoryModel.end());
binary.append(entryPoints.begin(), entryPoints.end());
binary.append(executionModes.begin(), executionModes.end());
binary.append(debug.begin(), debug.end());
binary.append(names.begin(), names.end());
binary.append(decorations.begin(), decorations.end());
binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
binary.append(functions.begin(), functions.end());
}
#ifndef NDEBUG
void Serializer::printValueIDMap(raw_ostream &os) {
os << "\n= Value <id> Map =\n\n";
for (auto valueIDPair : valueIDMap) {
Value val = valueIDPair.first;
os << " " << val << " "
<< "id = " << valueIDPair.second << ' ';
if (auto *op = val.getDefiningOp()) {
os << "from op '" << op->getName() << "'";
} else if (auto arg = val.dyn_cast<BlockArgument>()) {
Block *block = arg.getOwner();
os << "from argument of block " << block << ' ';
os << " in op '" << block->getParentOp()->getName() << "'";
}
os << '\n';
}
}
#endif
//===----------------------------------------------------------------------===//
// Module structure
//===----------------------------------------------------------------------===//
uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
auto funcID = funcIDMap.lookup(fnName);
if (!funcID) {
funcID = getNextID();
funcIDMap[fnName] = funcID;
}
return funcID;
}
void Serializer::processCapability() {
for (auto cap : module.vce_triple()->getCapabilities())
encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
{static_cast<uint32_t>(cap)});
}
void Serializer::processDebugInfo() {
if (!options.emitDebugInfo)
return;
auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
fileID = getNextID();
SmallVector<uint32_t, 16> operands;
operands.push_back(fileID);
spirv::encodeStringLiteralInto(operands, fileName);
encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
// TODO: Encode more debug instructions.
}
void Serializer::processExtension() {
llvm::SmallVector<uint32_t, 16> extName;
for (spirv::Extension ext : module.vce_triple()->getExtensions()) {
extName.clear();
spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
}
}
void Serializer::processMemoryModel() {
auto mm = static_cast<uint32_t>(
module->getAttrOfType<spirv::MemoryModelAttr>("memory_model").getValue());
auto am = static_cast<uint32_t>(
module->getAttrOfType<spirv::AddressingModelAttr>("addressing_model")
.getValue());
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
}
LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr) {
auto attrName = attr.getName().strref();
auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true);
auto decoration = spirv::symbolizeDecoration(decorationName);
if (!decoration) {
return emitError(
loc, "non-argument attributes expected to have snake-case-ified "
"decoration name, unhandled attribute with name : ")
<< attrName;
}
SmallVector<uint32_t, 1> args;
switch (*decoration) {
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
if (auto intAttr = attr.getValue().dyn_cast<IntegerAttr>()) {
args.push_back(intAttr.getValue().getZExtValue());
break;
}
return emitError(loc, "expected integer attribute for ") << attrName;
case spirv::Decoration::BuiltIn:
if (auto strAttr = attr.getValue().dyn_cast<StringAttr>()) {
auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
if (enumVal) {
args.push_back(static_cast<uint32_t>(*enumVal));
break;
}
return emitError(loc, "invalid ")
<< attrName << " attribute " << strAttr.getValue();
}
return emitError(loc, "expected string attribute for ") << attrName;
case spirv::Decoration::Aliased:
case spirv::Decoration::Flat:
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
case spirv::Decoration::Restrict:
case spirv::Decoration::RelaxedPrecision:
// For unit attributes, the args list has no values so we do nothing
if (auto unitAttr = attr.getValue().dyn_cast<UnitAttr>())
break;
return emitError(loc, "expected unit attribute for ") << attrName;
default:
return emitError(loc, "unhandled decoration ") << decorationName;
}
return emitDecoration(resultID, *decoration, args);
}
LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
assert(!name.empty() && "unexpected empty string for OpName");
if (!options.emitSymbolName)
return success();
SmallVector<uint32_t, 4> nameOperands;
nameOperands.push_back(resultID);
spirv::encodeStringLiteralInto(nameOperands, name);
encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
return success();
}
template <>
LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
Location loc, spirv::ArrayType type, uint32_t resultID) {
if (unsigned stride = type.getArrayStride()) {
// OpDecorate %arrayTypeSSA ArrayStride strideLiteral
return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
}
return success();
}
template <>
LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
if (unsigned stride = type.getArrayStride()) {
// OpDecorate %arrayTypeSSA ArrayStride strideLiteral
return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
}
return success();
}
LogicalResult Serializer::processMemberDecoration(
uint32_t structID,
const spirv::StructType::MemberDecorationInfo &memberDecoration) {
SmallVector<uint32_t, 4> args(
{structID, memberDecoration.memberIndex,
static_cast<uint32_t>(memberDecoration.decoration)});
if (memberDecoration.hasValue) {
args.push_back(memberDecoration.decorationValue);
}
encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
return success();
}
//===----------------------------------------------------------------------===//
// Type
//===----------------------------------------------------------------------===//
// According to the SPIR-V spec "Validation Rules for Shader Capabilities":
// "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
// PushConstant Storage Classes must be explicitly laid out."
bool Serializer::isInterfaceStructPtrType(Type type) const {
if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
switch (ptrType.getStorageClass()) {
case spirv::StorageClass::PhysicalStorageBuffer:
case spirv::StorageClass::PushConstant:
case spirv::StorageClass::StorageBuffer:
case spirv::StorageClass::Uniform:
return ptrType.getPointeeType().isa<spirv::StructType>();
default:
break;
}
}
return false;
}
LogicalResult Serializer::processType(Location loc, Type type,
uint32_t &typeID) {
// Maintains a set of names for nested identified struct types. This is used
// to properly serialize recursive references.
SetVector<StringRef> serializationCtx;
return processTypeImpl(loc, type, typeID, serializationCtx);
}
LogicalResult
Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
SetVector<StringRef> &serializationCtx) {
typeID = getTypeID(type);
if (typeID)
return success();
typeID = getNextID();
SmallVector<uint32_t, 4> operands;
operands.push_back(typeID);
auto typeEnum = spirv::Opcode::OpTypeVoid;
bool deferSerialization = false;
if ((type.isa<FunctionType>() &&
succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
operands))) ||
succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
deferSerialization, serializationCtx))) {
if (deferSerialization)
return success();
typeIDMap[type] = typeID;
encodeInstructionInto(typesGlobalValues, typeEnum, operands);
if (recursiveStructInfos.count(type) != 0) {
// This recursive struct type is emitted already, now the OpTypePointer
// instructions referring to recursive references are emitted as well.
for (auto &ptrInfo : recursiveStructInfos[type]) {
// TODO: This might not work if more than 1 recursive reference is
// present in the struct.
SmallVector<uint32_t, 4> ptrOperands;
ptrOperands.push_back(ptrInfo.pointerTypeID);
ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
ptrOperands.push_back(typeIDMap[type]);
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
ptrOperands);
}
recursiveStructInfos[type].clear();
}
return success();
}
return failure();
}
LogicalResult Serializer::prepareBasicType(
Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
SetVector<StringRef> &serializationCtx) {
deferSerialization = false;
if (isVoidType(type)) {
typeEnum = spirv::Opcode::OpTypeVoid;
return success();
}
if (auto intType = type.dyn_cast<IntegerType>()) {
if (intType.getWidth() == 1) {
typeEnum = spirv::Opcode::OpTypeBool;
return success();
}
typeEnum = spirv::Opcode::OpTypeInt;
operands.push_back(intType.getWidth());
// SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
// to preserve or validate.
// 0 indicates unsigned, or no signedness semantics
// 1 indicates signed semantics."
operands.push_back(intType.isSigned() ? 1 : 0);
return success();
}
if (auto floatType = type.dyn_cast<FloatType>()) {
typeEnum = spirv::Opcode::OpTypeFloat;
operands.push_back(floatType.getWidth());
return success();
}
if (auto vectorType = type.dyn_cast<VectorType>()) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeVector;
operands.push_back(elementTypeID);
operands.push_back(vectorType.getNumElements());
return success();
}
if (auto imageType = type.dyn_cast<spirv::ImageType>()) {
typeEnum = spirv::Opcode::OpTypeImage;
uint32_t sampledTypeID = 0;
if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
return failure();
operands.push_back(sampledTypeID);
operands.push_back(static_cast<uint32_t>(imageType.getDim()));
operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo()));
operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo()));
operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo()));
operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo()));
operands.push_back(static_cast<uint32_t>(imageType.getImageFormat()));
return success();
}
if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
typeEnum = spirv::Opcode::OpTypeArray;
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
serializationCtx))) {
return failure();
}
operands.push_back(elementTypeID);
if (auto elementCountID = prepareConstantInt(
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
operands.push_back(elementCountID);
}
return processTypeDecoration(loc, arrayType, resultID);
}
if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
uint32_t pointeeTypeID = 0;
spirv::StructType pointeeStruct =
ptrType.getPointeeType().dyn_cast<spirv::StructType>();
if (pointeeStruct && pointeeStruct.isIdentified() &&
serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
// A recursive reference to an enclosing struct is found.
//
// 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
// class as operands.
SmallVector<uint32_t, 2> forwardPtrOperands;
forwardPtrOperands.push_back(resultID);
forwardPtrOperands.push_back(
static_cast<uint32_t>(ptrType.getStorageClass()));
encodeInstructionInto(typesGlobalValues,
spirv::Opcode::OpTypeForwardPointer,
forwardPtrOperands);
// 2. Find the pointee (enclosing) struct.
auto structType = spirv::StructType::getIdentified(
module.getContext(), pointeeStruct.getIdentifier());
if (!structType)
return failure();
// 3. Mark the OpTypePointer that is supposed to be emitted by this call
// as deferred.
deferSerialization = true;
// 4. Record the info needed to emit the deferred OpTypePointer
// instruction when the enclosing struct is completely serialized.
recursiveStructInfos[structType].push_back(
{resultID, ptrType.getStorageClass()});
} else {
if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
serializationCtx)))
return failure();
}
typeEnum = spirv::Opcode::OpTypePointer;
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
operands.push_back(pointeeTypeID);
if (isInterfaceStructPtrType(ptrType)) {
if (failed(emitDecoration(getTypeID(pointeeStruct),
spirv::Decoration::Block)))
return emitError(loc, "cannot decorate ")
<< pointeeStruct << " with Block decoration";
}
return success();
}
if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
elementTypeID, serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeRuntimeArray;
operands.push_back(elementTypeID);
return processTypeDecoration(loc, runtimeArrayType, resultID);
}
if (auto sampledImageType = type.dyn_cast<spirv::SampledImageType>()) {
typeEnum = spirv::Opcode::OpTypeSampledImage;
uint32_t imageTypeID = 0;
if (failed(
processType(loc, sampledImageType.getImageType(), imageTypeID))) {
return failure();
}
operands.push_back(imageTypeID);
return success();
}
if (auto structType = type.dyn_cast<spirv::StructType>()) {
if (structType.isIdentified()) {
if (failed(processName(resultID, structType.getIdentifier())))
return failure();
serializationCtx.insert(structType.getIdentifier());
}
bool hasOffset = structType.hasOffset();
for (auto elementIndex :
llvm::seq<uint32_t>(0, structType.getNumElements())) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
elementTypeID, serializationCtx))) {
return failure();
}
operands.push_back(elementTypeID);
if (hasOffset) {
// Decorate each struct member with an offset
spirv::StructType::MemberDecorationInfo offsetDecoration{
elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
if (failed(processMemberDecoration(resultID, offsetDecoration))) {
return emitError(loc, "cannot decorate ")
<< elementIndex << "-th member of " << structType
<< " with its offset";
}
}
}
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
structType.getMemberDecorations(memberDecorations);
for (auto &memberDecoration : memberDecorations) {
if (failed(processMemberDecoration(resultID, memberDecoration))) {
return emitError(loc, "cannot decorate ")
<< static_cast<uint32_t>(memberDecoration.memberIndex)
<< "-th member of " << structType << " with "
<< stringifyDecoration(memberDecoration.decoration);
}
}
typeEnum = spirv::Opcode::OpTypeStruct;
if (structType.isIdentified())
serializationCtx.remove(structType.getIdentifier());
return success();
}
if (auto cooperativeMatrixType =
type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
elementTypeID, serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
auto getConstantOp = [&](uint32_t id) {
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
return prepareConstantInt(loc, attr);
};
operands.push_back(elementTypeID);
operands.push_back(
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
return success();
}
if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeMatrix;
operands.push_back(elementTypeID);
operands.push_back(matrixType.getNumColumns());
return success();
}
// TODO: Handle other types.
return emitError(loc, "unhandled type in serialization: ") << type;
}
LogicalResult
Serializer::prepareFunctionType(Location loc, FunctionType type,
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands) {
typeEnum = spirv::Opcode::OpTypeFunction;
assert(type.getNumResults() <= 1 &&
"serialization supports only a single return value");
uint32_t resultID = 0;
if (failed(processType(
loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
resultID))) {
return failure();
}
operands.push_back(resultID);
for (auto &res : type.getInputs()) {
uint32_t argTypeID = 0;
if (failed(processType(loc, res, argTypeID))) {
return failure();
}
operands.push_back(argTypeID);
}
return success();
}
//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//
uint32_t Serializer::prepareConstant(Location loc, Type constType,
Attribute valueAttr) {
if (auto id = prepareConstantScalar(loc, valueAttr)) {
return id;
}
// This is a composite literal. We need to handle each component separately
// and then emit an OpConstantComposite for the whole.
if (auto id = getConstantID(valueAttr)) {
return id;
}
uint32_t typeID = 0;
if (failed(processType(loc, constType, typeID))) {
return 0;
}
uint32_t resultID = 0;
if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
int rank = attr.getType().dyn_cast<ShapedType>().getRank();
SmallVector<uint64_t, 4> index(rank);
resultID = prepareDenseElementsConstant(loc, constType, attr,
/*dim=*/0, index);
} else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
resultID = prepareArrayConstant(loc, constType, arrayAttr);
}
if (resultID == 0) {
emitError(loc, "cannot serialize attribute: ") << valueAttr;
return 0;
}
constIDMap[valueAttr] = resultID;
return resultID;
}
uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
ArrayAttr attr) {
uint32_t typeID = 0;
if (failed(processType(loc, constType, typeID))) {
return 0;
}
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
operands.reserve(attr.size() + 2);
auto elementType = constType.cast<spirv::ArrayType>().getElementType();
for (Attribute elementAttr : attr) {
if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
operands.push_back(elementID);
} else {
return 0;
}
}
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
encodeInstructionInto(typesGlobalValues, opcode, operands);
return resultID;
}
// TODO: Turn the below function into iterative function, instead of
// recursive function.
uint32_t
Serializer::prepareDenseElementsConstant(Location loc, Type constType,
DenseElementsAttr valueAttr, int dim,
MutableArrayRef<uint64_t> index) {
auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
assert(dim <= shapedType.getRank());
if (shapedType.getRank() == dim) {
if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
return attr.getType().getElementType().isInteger(1)
? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
: prepareConstantInt(loc,
attr.getValues<IntegerAttr>()[index]);
}
if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
}
return 0;
}
uint32_t typeID = 0;
if (failed(processType(loc, constType, typeID))) {
return 0;
}
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
operands.reserve(shapedType.getDimSize(dim) + 2);
auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
index[dim] = i;
if (auto elementID = prepareDenseElementsConstant(
loc, elementType, valueAttr, dim + 1, index)) {
operands.push_back(elementID);
} else {
return 0;
}
}
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
encodeInstructionInto(typesGlobalValues, opcode, operands);
return resultID;
}
uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
bool isSpec) {
if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
return prepareConstantFp(loc, floatAttr, isSpec);
}
if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
return prepareConstantBool(loc, boolAttr, isSpec);
}
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
return prepareConstantInt(loc, intAttr, isSpec);
}
return 0;
}
uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
bool isSpec) {
if (!isSpec) {
// We can de-duplicate normal constants, but not specialization constants.
if (auto id = getConstantID(boolAttr)) {
return id;
}
}
// Process the type for this bool literal
uint32_t typeID = 0;
if (failed(
processType(loc, boolAttr.cast<IntegerAttr>().getType(), typeID))) {
return 0;
}
auto resultID = getNextID();
auto opcode = boolAttr.getValue()
? (isSpec ? spirv::Opcode::OpSpecConstantTrue
: spirv::Opcode::OpConstantTrue)
: (isSpec ? spirv::Opcode::OpSpecConstantFalse
: spirv::Opcode::OpConstantFalse);
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
if (!isSpec) {
constIDMap[boolAttr] = resultID;
}
return resultID;
}
uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
bool isSpec) {
if (!isSpec) {
// We can de-duplicate normal constants, but not specialization constants.
if (auto id = getConstantID(intAttr)) {
return id;
}
}
// Process the type for this integer literal
uint32_t typeID = 0;
if (failed(processType(loc, intAttr.getType(), typeID))) {
return 0;
}
auto resultID = getNextID();
APInt value = intAttr.getValue();
unsigned bitwidth = value.getBitWidth();
bool isSigned = value.isSignedIntN(bitwidth);
auto opcode =
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
switch (bitwidth) {
// According to SPIR-V spec, "When the type's bit width is less than
// 32-bits, the literal's value appears in the low-order bits of the word,
// and the high-order bits must be 0 for a floating-point type, or 0 for an
// integer type with Signedness of 0, or sign extended when Signedness
// is 1."
case 32:
case 16:
case 8: {
uint32_t word = 0;
if (isSigned) {
word = static_cast<int32_t>(value.getSExtValue());
} else {
word = static_cast<uint32_t>(value.getZExtValue());
}
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} break;
// According to SPIR-V spec: "When the type's bit width is larger than one
// word, the literals low-order words appear first."
case 64: {
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words;
if (isSigned) {
words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
} else {
words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
}
encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
} break;
default: {
std::string valueStr;
llvm::raw_string_ostream rss(valueStr);
value.print(rss, /*isSigned=*/false);
emitError(loc, "cannot serialize ")
<< bitwidth << "-bit integer literal: " << rss.str();
return 0;
}
}
if (!isSpec) {
constIDMap[intAttr] = resultID;
}
return resultID;
}
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
bool isSpec) {
if (!isSpec) {
// We can de-duplicate normal constants, but not specialization constants.
if (auto id = getConstantID(floatAttr)) {
return id;
}
}
// Process the type for this float literal
uint32_t typeID = 0;
if (failed(processType(loc, floatAttr.getType(), typeID))) {
return 0;
}
auto resultID = getNextID();
APFloat value = floatAttr.getValue();
APInt intValue = value.bitcastToAPInt();
auto opcode =
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
if (&value.getSemantics() == &APFloat::IEEEsingle()) {
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
} else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
uint32_t word =
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} else {
std::string valueStr;
llvm::raw_string_ostream rss(valueStr);
value.print(rss);
emitError(loc, "cannot serialize ")
<< floatAttr.getType() << "-typed float literal: " << rss.str();
return 0;
}
if (!isSpec) {
constIDMap[floatAttr] = resultID;
}
return resultID;
}
//===----------------------------------------------------------------------===//
// Control flow
//===----------------------------------------------------------------------===//
uint32_t Serializer::getOrCreateBlockID(Block *block) {
if (uint32_t id = getBlockID(block))
return id;
return blockIDMap[block] = getNextID();
}
#ifndef NDEBUG
void Serializer::printBlock(Block *block, raw_ostream &os) {
os << "block " << block << " (id = ";
if (uint32_t id = getBlockID(block))
os << id;
else
os << "unknown";
os << ")\n";
}
#endif
LogicalResult
Serializer::processBlock(Block *block, bool omitLabel,
function_ref<LogicalResult()> emitMerge) {
LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
LLVM_DEBUG(block->print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
if (!omitLabel) {
uint32_t blockID = getOrCreateBlockID(block);
LLVM_DEBUG(printBlock(block, llvm::dbgs()));
// Emit OpLabel for this block.
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
}
// Emit OpPhi instructions for block arguments, if any.
if (failed(emitPhiForBlockArguments(block)))
return failure();
// If we need to emit merge instructions, it must happen in this block. Check
// whether we have other structured control flow ops, which will be expanded
// into multiple basic blocks. If that's the case, we need to emit the merge
// right now and then create new blocks for further serialization of the ops
// in this block.
if (emitMerge && llvm::any_of(block->getOperations(), [](Operation &op) {
return isa<spirv::LoopOp, spirv::SelectionOp>(op);
})) {
if (failed(emitMerge()))
return failure();
emitMerge = nullptr;
// Start a new block for further serialization.
uint32_t blockID = getNextID();
encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
}
// Process each op in this block except the terminator.
for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
if (failed(processOperation(&op)))
return failure();
}
// Process the terminator.
if (emitMerge)
if (failed(emitMerge()))
return failure();
if (failed(processOperation(&block->back())))
return failure();
return success();
}
LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
// Nothing to do if this block has no arguments or it's the entry block, which
// always has the same arguments as the function signature.
if (block->args_empty() || block->isEntryBlock())
return success();
LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
// If the block has arguments, we need to create SPIR-V OpPhi instructions.
// A SPIR-V OpPhi instruction is of the syntax:
// OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
// So we need to collect all predecessor blocks and the arguments they send
// to this block.
SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
for (Block *mlirPredecessor : block->getPredecessors()) {
auto *terminator = mlirPredecessor->getTerminator();
LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
// The predecessor here is the immediate one according to MLIR's IR
// structure. It does not directly map to the incoming parent block for the
// OpPhi instructions at SPIR-V binary level. This is because structured
// control flow ops are serialized to multiple SPIR-V blocks. If there is a
// spv.mlir.selection/spv.mlir.loop op in the MLIR predecessor block, the
// branch op jumping to the OpPhi's block then resides in the previous
// structured control flow op's merge block.
Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
} else if (auto branchCondOp =
dyn_cast<spirv::BranchConditionalOp>(terminator)) {
Optional<OperandRange> blockOperands;
if (branchCondOp.trueTarget() == block) {
blockOperands = branchCondOp.trueTargetOperands();
} else {
assert(branchCondOp.falseTarget() == block);
blockOperands = branchCondOp.falseTargetOperands();
}
assert(!blockOperands->empty() &&
"expected non-empty block operand range");
predecessors.emplace_back(spirvPredecessor, *blockOperands);
} else {
return terminator->emitError("unimplemented terminator for Phi creation");
}
LLVM_DEBUG({
llvm::dbgs() << " block arguments:\n";
for (Value v : predecessors.back().second)
llvm::dbgs() << " " << v << "\n";
});
}
// Then create OpPhi instruction for each of the block argument.
for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
BlockArgument arg = block->getArgument(argIndex);
// Get the type <id> and result <id> for this OpPhi instruction.
uint32_t phiTypeID = 0;
if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
return failure();
uint32_t phiID = getNextID();
LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
<< arg << " (id = " << phiID << ")\n");
// Prepare the (value <id>, parent block <id>) pairs.
SmallVector<uint32_t, 8> phiArgs;
phiArgs.push_back(phiTypeID);
phiArgs.push_back(phiID);
for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
Value value = predecessors[predIndex].second[argIndex];
uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
<< ") value " << value << ' ');
// Each pair is a value <id> ...
uint32_t valueId = getValueID(value);
if (valueId == 0) {
// The op generating this value hasn't been visited yet so we don't have
// an <id> assigned yet. Record this to fix up later.
LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
deferredPhiValues[value].push_back(functionBody.size() + 1 +
phiArgs.size());
} else {
LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
}
phiArgs.push_back(valueId);
// ... and a parent block <id>.
phiArgs.push_back(predBlockId);
}
encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
valueIDMap[arg] = phiID;
}
return success();
}
//===----------------------------------------------------------------------===//
// Operation
//===----------------------------------------------------------------------===//
LogicalResult Serializer::encodeExtensionInstruction(
Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
ArrayRef<uint32_t> operands) {
// Check if the extension has been imported.
auto &setID = extendedInstSetIDMap[extensionSetName];
if (!setID) {
setID = getNextID();
SmallVector<uint32_t, 16> importOperands;
importOperands.push_back(setID);
spirv::encodeStringLiteralInto(importOperands, extensionSetName);
encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
importOperands);
}
// The first two operands are the result type <id> and result <id>. The set
// <id> and the opcode need to be insert after this.
if (operands.size() < 2) {
return op->emitError("extended instructions must have a result encoding");
}
SmallVector<uint32_t, 8> extInstOperands;
extInstOperands.reserve(operands.size() + 2);
extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
extInstOperands.push_back(setID);
extInstOperands.push_back(extensionOpcode);
extInstOperands.append(std::next(operands.begin(), 2), operands.end());
encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
extInstOperands);
return success();
}
LogicalResult Serializer::processOperation(Operation *opInst) {
LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
// First dispatch the ops that do not directly mirror an instruction from
// the SPIR-V spec.
return TypeSwitch<Operation *, LogicalResult>(opInst)
.Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
.Case([&](spirv::BranchOp op) { return processBranchOp(op); })
.Case([&](spirv::BranchConditionalOp op) {
return processBranchConditionalOp(op);
})
.Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
.Case([&](spirv::FuncOp op) { return processFuncOp(op); })
.Case([&](spirv::GlobalVariableOp op) {
return processGlobalVariableOp(op);
})
.Case([&](spirv::LoopOp op) { return processLoopOp(op); })
.Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
.Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
.Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
.Case([&](spirv::SpecConstantCompositeOp op) {
return processSpecConstantCompositeOp(op);
})
.Case([&](spirv::SpecConstantOperationOp op) {
return processSpecConstantOperationOp(op);
})
.Case([&](spirv::UndefOp op) { return processUndefOp(op); })
.Case([&](spirv::VariableOp op) { return processVariableOp(op); })
// Then handle all the ops that directly mirror SPIR-V instructions with
// auto-generated methods.
.Default(
[&](Operation *op) { return dispatchToAutogenSerialization(op); });
}
LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
StringRef extInstSet,
uint32_t opcode) {
SmallVector<uint32_t, 4> operands;
Location loc = op->getLoc();
uint32_t resultID = 0;
if (op->getNumResults() != 0) {
uint32_t resultTypeID = 0;
if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
return failure();
operands.push_back(resultTypeID);
resultID = getNextID();
operands.push_back(resultID);
valueIDMap[op->getResult(0)] = resultID;
};
for (Value operand : op->getOperands())
operands.push_back(getValueID(operand));
if (failed(emitDebugLine(functionBody, loc)))
return failure();
if (extInstSet.empty()) {
encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
operands);
} else {
if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
return failure();
}
if (op->getNumResults() != 0) {
for (auto attr : op->getAttrs()) {
if (failed(processDecoration(loc, resultID, attr)))
return failure();
}
}
return success();
}
LogicalResult Serializer::emitDecoration(uint32_t target,
spirv::Decoration decoration,
ArrayRef<uint32_t> params) {
uint32_t wordCount = 3 + params.size();
decorations.push_back(
spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
decorations.push_back(target);
decorations.push_back(static_cast<uint32_t>(decoration));
decorations.append(params.begin(), params.end());
return success();
}
LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
Location loc) {
if (!options.emitDebugInfo)
return success();
if (lastProcessedWasMergeInst) {
lastProcessedWasMergeInst = false;
return success();
}
auto fileLoc = loc.dyn_cast<FileLineColLoc>();
if (fileLoc)
encodeInstructionInto(binary, spirv::Opcode::OpLine,
{fileID, fileLoc.getLine(), fileLoc.getColumn()});
return success();
}
} // namespace spirv
} // namespace mlir