//===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the MLIR SPIR-V module to SPIR-V binary serializer. // //===----------------------------------------------------------------------===// #include "Serializer.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "spirv-serialization" using namespace mlir; /// Returns the merge block if the given `op` is a structured control flow op. /// Otherwise returns nullptr. static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { if (auto selectionOp = dyn_cast(op)) return selectionOp.getMergeBlock(); if (auto loopOp = dyn_cast(op)) return loopOp.getMergeBlock(); return nullptr; } /// Given a predecessor `block` for a block with arguments, returns the block /// that should be used as the parent block for SPIR-V OpPhi instructions /// corresponding to the block arguments. static Block *getPhiIncomingBlock(Block *block) { // If the predecessor block in question is the entry block for a // spv.mlir.loop, we jump to this spv.mlir.loop from its enclosing block. if (block->isEntryBlock()) { if (auto loopOp = dyn_cast(block->getParentOp())) { // Then the incoming parent block for OpPhi should be the merge block of // the structured control flow op before this loop. Operation *op = loopOp.getOperation(); while ((op = op->getPrevNode()) != nullptr) if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) return incomingBlock; // Or the enclosing block itself if no structured control flow ops // exists before this loop. return loopOp->getBlock(); } } // Otherwise, we jump from the given predecessor block. Try to see if there is // a structured control flow op inside it. for (Operation &op : llvm::reverse(block->getOperations())) { if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) return incomingBlock; } return block; } namespace mlir { namespace spirv { /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into /// the given `binary` vector. LogicalResult encodeInstructionInto(SmallVectorImpl &binary, spirv::Opcode op, ArrayRef operands) { uint32_t wordCount = 1 + operands.size(); binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); binary.append(operands.begin(), operands.end()); return success(); } Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) : module(module), mlirBuilder(module.getContext()), emitDebugInfo(emitDebugInfo) {} LogicalResult Serializer::serialize() { LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); if (failed(module.verify())) return failure(); // TODO: handle the other sections processCapability(); processExtension(); processMemoryModel(); processDebugInfo(); // Iterate over the module body to serialize it. Assumptions are that there is // only one basic block in the moduleOp for (auto &op : *module.getBody()) { if (failed(processOperation(&op))) { return failure(); } } LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); return success(); } void Serializer::collect(SmallVectorImpl &binary) { auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + extensions.size() + extendedSets.size() + memoryModel.size() + entryPoints.size() + executionModes.size() + decorations.size() + typesGlobalValues.size() + functions.size(); binary.clear(); binary.reserve(moduleSize); spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); binary.append(capabilities.begin(), capabilities.end()); binary.append(extensions.begin(), extensions.end()); binary.append(extendedSets.begin(), extendedSets.end()); binary.append(memoryModel.begin(), memoryModel.end()); binary.append(entryPoints.begin(), entryPoints.end()); binary.append(executionModes.begin(), executionModes.end()); binary.append(debug.begin(), debug.end()); binary.append(names.begin(), names.end()); binary.append(decorations.begin(), decorations.end()); binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); binary.append(functions.begin(), functions.end()); } #ifndef NDEBUG void Serializer::printValueIDMap(raw_ostream &os) { os << "\n= Value Map =\n\n"; for (auto valueIDPair : valueIDMap) { Value val = valueIDPair.first; os << " " << val << " " << "id = " << valueIDPair.second << ' '; if (auto *op = val.getDefiningOp()) { os << "from op '" << op->getName() << "'"; } else if (auto arg = val.dyn_cast()) { Block *block = arg.getOwner(); os << "from argument of block " << block << ' '; os << " in op '" << block->getParentOp()->getName() << "'"; } os << '\n'; } } #endif //===----------------------------------------------------------------------===// // Module structure //===----------------------------------------------------------------------===// uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { auto funcID = funcIDMap.lookup(fnName); if (!funcID) { funcID = getNextID(); funcIDMap[fnName] = funcID; } return funcID; } void Serializer::processCapability() { for (auto cap : module.vce_triple()->getCapabilities()) (void)encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, {static_cast(cap)}); } void Serializer::processDebugInfo() { if (!emitDebugInfo) return; auto fileLoc = module.getLoc().dyn_cast(); auto fileName = fileLoc ? fileLoc.getFilename().strref() : ""; fileID = getNextID(); SmallVector operands; operands.push_back(fileID); (void)spirv::encodeStringLiteralInto(operands, fileName); (void)encodeInstructionInto(debug, spirv::Opcode::OpString, operands); // TODO: Encode more debug instructions. } void Serializer::processExtension() { llvm::SmallVector extName; for (spirv::Extension ext : module.vce_triple()->getExtensions()) { extName.clear(); (void)spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); (void)encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); } } void Serializer::processMemoryModel() { uint32_t mm = module->getAttrOfType("memory_model").getInt(); uint32_t am = module->getAttrOfType("addressing_model").getInt(); (void)encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); } LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, NamedAttribute attr) { auto attrName = attr.getName().strref(); auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); auto decoration = spirv::symbolizeDecoration(decorationName); if (!decoration) { return emitError( loc, "non-argument attributes expected to have snake-case-ified " "decoration name, unhandled attribute with name : ") << attrName; } SmallVector args; switch (decoration.getValue()) { case spirv::Decoration::Binding: case spirv::Decoration::DescriptorSet: case spirv::Decoration::Location: if (auto intAttr = attr.getValue().dyn_cast()) { args.push_back(intAttr.getValue().getZExtValue()); break; } return emitError(loc, "expected integer attribute for ") << attrName; case spirv::Decoration::BuiltIn: if (auto strAttr = attr.getValue().dyn_cast()) { auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); if (enumVal) { args.push_back(static_cast(enumVal.getValue())); break; } return emitError(loc, "invalid ") << attrName << " attribute " << strAttr.getValue(); } return emitError(loc, "expected string attribute for ") << attrName; case spirv::Decoration::Aliased: case spirv::Decoration::Flat: case spirv::Decoration::NonReadable: case spirv::Decoration::NonWritable: case spirv::Decoration::NoPerspective: case spirv::Decoration::Restrict: case spirv::Decoration::RelaxedPrecision: // For unit attributes, the args list has no values so we do nothing if (auto unitAttr = attr.getValue().dyn_cast()) break; return emitError(loc, "expected unit attribute for ") << attrName; default: return emitError(loc, "unhandled decoration ") << decorationName; } return emitDecoration(resultID, decoration.getValue(), args); } LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { assert(!name.empty() && "unexpected empty string for OpName"); SmallVector nameOperands; nameOperands.push_back(resultID); if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { return failure(); } return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); } template <> LogicalResult Serializer::processTypeDecoration( Location loc, spirv::ArrayType type, uint32_t resultID) { if (unsigned stride = type.getArrayStride()) { // OpDecorate %arrayTypeSSA ArrayStride strideLiteral return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); } return success(); } template <> LogicalResult Serializer::processTypeDecoration( Location loc, spirv::RuntimeArrayType type, uint32_t resultID) { if (unsigned stride = type.getArrayStride()) { // OpDecorate %arrayTypeSSA ArrayStride strideLiteral return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); } return success(); } LogicalResult Serializer::processMemberDecoration( uint32_t structID, const spirv::StructType::MemberDecorationInfo &memberDecoration) { SmallVector args( {structID, memberDecoration.memberIndex, static_cast(memberDecoration.decoration)}); if (memberDecoration.hasValue) { args.push_back(memberDecoration.decorationValue); } return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args); } //===----------------------------------------------------------------------===// // Type //===----------------------------------------------------------------------===// // According to the SPIR-V spec "Validation Rules for Shader Capabilities": // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and // PushConstant Storage Classes must be explicitly laid out." bool Serializer::isInterfaceStructPtrType(Type type) const { if (auto ptrType = type.dyn_cast()) { switch (ptrType.getStorageClass()) { case spirv::StorageClass::PhysicalStorageBuffer: case spirv::StorageClass::PushConstant: case spirv::StorageClass::StorageBuffer: case spirv::StorageClass::Uniform: return ptrType.getPointeeType().isa(); default: break; } } return false; } LogicalResult Serializer::processType(Location loc, Type type, uint32_t &typeID) { // Maintains a set of names for nested identified struct types. This is used // to properly serialize recursive references. SetVector serializationCtx; return processTypeImpl(loc, type, typeID, serializationCtx); } LogicalResult Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, SetVector &serializationCtx) { typeID = getTypeID(type); if (typeID) { return success(); } typeID = getNextID(); SmallVector operands; operands.push_back(typeID); auto typeEnum = spirv::Opcode::OpTypeVoid; bool deferSerialization = false; if ((type.isa() && succeeded(prepareFunctionType(loc, type.cast(), typeEnum, operands))) || succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, deferSerialization, serializationCtx))) { if (deferSerialization) return success(); typeIDMap[type] = typeID; if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) return failure(); if (recursiveStructInfos.count(type) != 0) { // This recursive struct type is emitted already, now the OpTypePointer // instructions referring to recursive references are emitted as well. for (auto &ptrInfo : recursiveStructInfos[type]) { // TODO: This might not work if more than 1 recursive reference is // present in the struct. SmallVector ptrOperands; ptrOperands.push_back(ptrInfo.pointerTypeID); ptrOperands.push_back(static_cast(ptrInfo.storageClass)); ptrOperands.push_back(typeIDMap[type]); if (failed(encodeInstructionInto( typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) return failure(); } recursiveStructInfos[type].clear(); } return success(); } return failure(); } LogicalResult Serializer::prepareBasicType( Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, SmallVectorImpl &operands, bool &deferSerialization, SetVector &serializationCtx) { deferSerialization = false; if (isVoidType(type)) { typeEnum = spirv::Opcode::OpTypeVoid; return success(); } if (auto intType = type.dyn_cast()) { if (intType.getWidth() == 1) { typeEnum = spirv::Opcode::OpTypeBool; return success(); } typeEnum = spirv::Opcode::OpTypeInt; operands.push_back(intType.getWidth()); // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics // to preserve or validate. // 0 indicates unsigned, or no signedness semantics // 1 indicates signed semantics." operands.push_back(intType.isSigned() ? 1 : 0); return success(); } if (auto floatType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeFloat; operands.push_back(floatType.getWidth()); return success(); } if (auto vectorType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeVector; operands.push_back(elementTypeID); operands.push_back(vectorType.getNumElements()); return success(); } if (auto imageType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeImage; uint32_t sampledTypeID = 0; if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) return failure(); operands.push_back(sampledTypeID); operands.push_back(static_cast(imageType.getDim())); operands.push_back(static_cast(imageType.getDepthInfo())); operands.push_back(static_cast(imageType.getArrayedInfo())); operands.push_back(static_cast(imageType.getSamplingInfo())); operands.push_back(static_cast(imageType.getSamplerUseInfo())); operands.push_back(static_cast(imageType.getImageFormat())); return success(); } if (auto arrayType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeArray; uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, serializationCtx))) { return failure(); } operands.push_back(elementTypeID); if (auto elementCountID = prepareConstantInt( loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { operands.push_back(elementCountID); } return processTypeDecoration(loc, arrayType, resultID); } if (auto ptrType = type.dyn_cast()) { uint32_t pointeeTypeID = 0; spirv::StructType pointeeStruct = ptrType.getPointeeType().dyn_cast(); if (pointeeStruct && pointeeStruct.isIdentified() && serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { // A recursive reference to an enclosing struct is found. // // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage // class as operands. SmallVector forwardPtrOperands; forwardPtrOperands.push_back(resultID); forwardPtrOperands.push_back( static_cast(ptrType.getStorageClass())); (void)encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypeForwardPointer, forwardPtrOperands); // 2. Find the pointee (enclosing) struct. auto structType = spirv::StructType::getIdentified( module.getContext(), pointeeStruct.getIdentifier()); if (!structType) return failure(); // 3. Mark the OpTypePointer that is supposed to be emitted by this call // as deferred. deferSerialization = true; // 4. Record the info needed to emit the deferred OpTypePointer // instruction when the enclosing struct is completely serialized. recursiveStructInfos[structType].push_back( {resultID, ptrType.getStorageClass()}); } else { if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, serializationCtx))) return failure(); } typeEnum = spirv::Opcode::OpTypePointer; operands.push_back(static_cast(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); return success(); } if (auto runtimeArrayType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeRuntimeArray; operands.push_back(elementTypeID); return processTypeDecoration(loc, runtimeArrayType, resultID); } if (auto sampledImageType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeSampledImage; uint32_t imageTypeID = 0; if (failed( processType(loc, sampledImageType.getImageType(), imageTypeID))) { return failure(); } operands.push_back(imageTypeID); return success(); } if (auto structType = type.dyn_cast()) { if (structType.isIdentified()) { (void)processName(resultID, structType.getIdentifier()); serializationCtx.insert(structType.getIdentifier()); } bool hasOffset = structType.hasOffset(); for (auto elementIndex : llvm::seq(0, structType.getNumElements())) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), elementTypeID, serializationCtx))) { return failure(); } operands.push_back(elementTypeID); if (hasOffset) { // Decorate each struct member with an offset spirv::StructType::MemberDecorationInfo offsetDecoration{ elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, static_cast(structType.getMemberOffset(elementIndex))}; if (failed(processMemberDecoration(resultID, offsetDecoration))) { return emitError(loc, "cannot decorate ") << elementIndex << "-th member of " << structType << " with its offset"; } } } SmallVector memberDecorations; structType.getMemberDecorations(memberDecorations); for (auto &memberDecoration : memberDecorations) { if (failed(processMemberDecoration(resultID, memberDecoration))) { return emitError(loc, "cannot decorate ") << static_cast(memberDecoration.memberIndex) << "-th member of " << structType << " with " << stringifyDecoration(memberDecoration.decoration); } } typeEnum = spirv::Opcode::OpTypeStruct; if (structType.isIdentified()) serializationCtx.remove(structType.getIdentifier()); return success(); } if (auto cooperativeMatrixType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; auto getConstantOp = [&](uint32_t id) { auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); return prepareConstantInt(loc, attr); }; operands.push_back(elementTypeID); operands.push_back( getConstantOp(static_cast(cooperativeMatrixType.getScope()))); operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); return success(); } if (auto matrixType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeMatrix; operands.push_back(elementTypeID); operands.push_back(matrixType.getNumColumns()); return success(); } // TODO: Handle other types. return emitError(loc, "unhandled type in serialization: ") << type; } LogicalResult Serializer::prepareFunctionType(Location loc, FunctionType type, spirv::Opcode &typeEnum, SmallVectorImpl &operands) { typeEnum = spirv::Opcode::OpTypeFunction; assert(type.getNumResults() <= 1 && "serialization supports only a single return value"); uint32_t resultID = 0; if (failed(processType( loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), resultID))) { return failure(); } operands.push_back(resultID); for (auto &res : type.getInputs()) { uint32_t argTypeID = 0; if (failed(processType(loc, res, argTypeID))) { return failure(); } operands.push_back(argTypeID); } return success(); } //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// uint32_t Serializer::prepareConstant(Location loc, Type constType, Attribute valueAttr) { if (auto id = prepareConstantScalar(loc, valueAttr)) { return id; } // This is a composite literal. We need to handle each component separately // and then emit an OpConstantComposite for the whole. if (auto id = getConstantID(valueAttr)) { return id; } uint32_t typeID = 0; if (failed(processType(loc, constType, typeID))) { return 0; } uint32_t resultID = 0; if (auto attr = valueAttr.dyn_cast()) { int rank = attr.getType().dyn_cast().getRank(); SmallVector index(rank); resultID = prepareDenseElementsConstant(loc, constType, attr, /*dim=*/0, index); } else if (auto arrayAttr = valueAttr.dyn_cast()) { resultID = prepareArrayConstant(loc, constType, arrayAttr); } if (resultID == 0) { emitError(loc, "cannot serialize attribute: ") << valueAttr; return 0; } constIDMap[valueAttr] = resultID; return resultID; } uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, ArrayAttr attr) { uint32_t typeID = 0; if (failed(processType(loc, constType, typeID))) { return 0; } uint32_t resultID = getNextID(); SmallVector operands = {typeID, resultID}; operands.reserve(attr.size() + 2); auto elementType = constType.cast().getElementType(); for (Attribute elementAttr : attr) { if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { operands.push_back(elementID); } else { return 0; } } spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; (void)encodeInstructionInto(typesGlobalValues, opcode, operands); return resultID; } // TODO: Turn the below function into iterative function, instead of // recursive function. uint32_t Serializer::prepareDenseElementsConstant(Location loc, Type constType, DenseElementsAttr valueAttr, int dim, MutableArrayRef index) { auto shapedType = valueAttr.getType().dyn_cast(); assert(dim <= shapedType.getRank()); if (shapedType.getRank() == dim) { if (auto attr = valueAttr.dyn_cast()) { return attr.getType().getElementType().isInteger(1) ? prepareConstantBool(loc, attr.getValues()[index]) : prepareConstantInt(loc, attr.getValues()[index]); } if (auto attr = valueAttr.dyn_cast()) { return prepareConstantFp(loc, attr.getValues()[index]); } return 0; } uint32_t typeID = 0; if (failed(processType(loc, constType, typeID))) { return 0; } uint32_t resultID = getNextID(); SmallVector operands = {typeID, resultID}; operands.reserve(shapedType.getDimSize(dim) + 2); auto elementType = constType.cast().getElementType(0); for (int i = 0; i < shapedType.getDimSize(dim); ++i) { index[dim] = i; if (auto elementID = prepareDenseElementsConstant( loc, elementType, valueAttr, dim + 1, index)) { operands.push_back(elementID); } else { return 0; } } spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; (void)encodeInstructionInto(typesGlobalValues, opcode, operands); return resultID; } uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, bool isSpec) { if (auto floatAttr = valueAttr.dyn_cast()) { return prepareConstantFp(loc, floatAttr, isSpec); } if (auto boolAttr = valueAttr.dyn_cast()) { return prepareConstantBool(loc, boolAttr, isSpec); } if (auto intAttr = valueAttr.dyn_cast()) { return prepareConstantInt(loc, intAttr, isSpec); } return 0; } uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec) { if (!isSpec) { // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(boolAttr)) { return id; } } // Process the type for this bool literal uint32_t typeID = 0; if (failed(processType(loc, boolAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); auto opcode = boolAttr.getValue() ? (isSpec ? spirv::Opcode::OpSpecConstantTrue : spirv::Opcode::OpConstantTrue) : (isSpec ? spirv::Opcode::OpSpecConstantFalse : spirv::Opcode::OpConstantFalse); (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); if (!isSpec) { constIDMap[boolAttr] = resultID; } return resultID; } uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec) { if (!isSpec) { // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(intAttr)) { return id; } } // Process the type for this integer literal uint32_t typeID = 0; if (failed(processType(loc, intAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); APInt value = intAttr.getValue(); unsigned bitwidth = value.getBitWidth(); bool isSigned = value.isSignedIntN(bitwidth); auto opcode = isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; switch (bitwidth) { // According to SPIR-V spec, "When the type's bit width is less than // 32-bits, the literal's value appears in the low-order bits of the word, // and the high-order bits must be 0 for a floating-point type, or 0 for an // integer type with Signedness of 0, or sign extended when Signedness // is 1." case 32: case 16: case 8: { uint32_t word = 0; if (isSigned) { word = static_cast(value.getSExtValue()); } else { word = static_cast(value.getZExtValue()); } (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } break; // According to SPIR-V spec: "When the type's bit width is larger than one // word, the literal’s low-order words appear first." case 64: { struct DoubleWord { uint32_t word1; uint32_t word2; } words; if (isSigned) { words = llvm::bit_cast(value.getSExtValue()); } else { words = llvm::bit_cast(value.getZExtValue()); } (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); } break; default: { std::string valueStr; llvm::raw_string_ostream rss(valueStr); value.print(rss, /*isSigned=*/false); emitError(loc, "cannot serialize ") << bitwidth << "-bit integer literal: " << rss.str(); return 0; } } if (!isSpec) { constIDMap[intAttr] = resultID; } return resultID; } uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec) { if (!isSpec) { // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(floatAttr)) { return id; } } // Process the type for this float literal uint32_t typeID = 0; if (failed(processType(loc, floatAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); APFloat value = floatAttr.getValue(); APInt intValue = value.bitcastToAPInt(); auto opcode = isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; if (&value.getSemantics() == &APFloat::IEEEsingle()) { uint32_t word = llvm::bit_cast(value.convertToFloat()); (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { struct DoubleWord { uint32_t word1; uint32_t word2; } words = llvm::bit_cast(value.convertToDouble()); (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { uint32_t word = static_cast(value.bitcastToAPInt().getZExtValue()); (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } else { std::string valueStr; llvm::raw_string_ostream rss(valueStr); value.print(rss); emitError(loc, "cannot serialize ") << floatAttr.getType() << "-typed float literal: " << rss.str(); return 0; } if (!isSpec) { constIDMap[floatAttr] = resultID; } return resultID; } //===----------------------------------------------------------------------===// // Control flow //===----------------------------------------------------------------------===// uint32_t Serializer::getOrCreateBlockID(Block *block) { if (uint32_t id = getBlockID(block)) return id; return blockIDMap[block] = getNextID(); } LogicalResult Serializer::processBlock(Block *block, bool omitLabel, function_ref actionBeforeTerminator) { LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); LLVM_DEBUG(block->print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); if (!omitLabel) { uint32_t blockID = getOrCreateBlockID(block); LLVM_DEBUG(llvm::dbgs() << "[block] " << block << " (id = " << blockID << ")\n"); // Emit OpLabel for this block. (void)encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); } // Emit OpPhi instructions for block arguments, if any. if (failed(emitPhiForBlockArguments(block))) return failure(); // Process each op in this block except the terminator. for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { if (failed(processOperation(&op))) return failure(); } // Process the terminator. if (actionBeforeTerminator) actionBeforeTerminator(); if (failed(processOperation(&block->back()))) return failure(); return success(); } LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { // Nothing to do if this block has no arguments or it's the entry block, which // always has the same arguments as the function signature. if (block->args_empty() || block->isEntryBlock()) return success(); // If the block has arguments, we need to create SPIR-V OpPhi instructions. // A SPIR-V OpPhi instruction is of the syntax: // OpPhi | result type | result | (value , parent block ) pair // So we need to collect all predecessor blocks and the arguments they send // to this block. SmallVector, 4> predecessors; for (Block *predecessor : block->getPredecessors()) { auto *terminator = predecessor->getTerminator(); // The predecessor here is the immediate one according to MLIR's IR // structure. It does not directly map to the incoming parent block for the // OpPhi instructions at SPIR-V binary level. This is because structured // control flow ops are serialized to multiple SPIR-V blocks. If there is a // spv.mlir.selection/spv.mlir.loop op in the MLIR predecessor block, the // branch op jumping to the OpPhi's block then resides in the previous // structured control flow op's merge block. predecessor = getPhiIncomingBlock(predecessor); if (auto branchOp = dyn_cast(terminator)) { predecessors.emplace_back(predecessor, branchOp.getOperands()); } else if (auto branchCondOp = dyn_cast(terminator)) { Optional blockOperands; for (auto successorIdx : llvm::seq(0, predecessor->getNumSuccessors())) if (predecessor->getSuccessors()[successorIdx] == block) { blockOperands = branchCondOp.getSuccessorOperands(successorIdx); break; } assert(blockOperands && !blockOperands->empty() && "expected non-empty block operand range"); predecessors.emplace_back(predecessor, *blockOperands); } else { return terminator->emitError("unimplemented terminator for Phi creation"); } } // Then create OpPhi instruction for each of the block argument. for (auto argIndex : llvm::seq(0, block->getNumArguments())) { BlockArgument arg = block->getArgument(argIndex); // Get the type and result for this OpPhi instruction. uint32_t phiTypeID = 0; if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) return failure(); uint32_t phiID = getNextID(); LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' << arg << " (id = " << phiID << ")\n"); // Prepare the (value , parent block ) pairs. SmallVector phiArgs; phiArgs.push_back(phiTypeID); phiArgs.push_back(phiID); for (auto predIndex : llvm::seq(0, predecessors.size())) { Value value = predecessors[predIndex].second[argIndex]; uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId << ") value " << value << ' '); // Each pair is a value ... uint32_t valueId = getValueID(value); if (valueId == 0) { // The op generating this value hasn't been visited yet so we don't have // an assigned yet. Record this to fix up later. LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); deferredPhiValues[value].push_back(functionBody.size() + 1 + phiArgs.size()); } else { LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); } phiArgs.push_back(valueId); // ... and a parent block . phiArgs.push_back(predBlockId); } (void)encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); valueIDMap[arg] = phiID; } return success(); } //===----------------------------------------------------------------------===// // Operation //===----------------------------------------------------------------------===// LogicalResult Serializer::encodeExtensionInstruction( Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, ArrayRef operands) { // Check if the extension has been imported. auto &setID = extendedInstSetIDMap[extensionSetName]; if (!setID) { setID = getNextID(); SmallVector importOperands; importOperands.push_back(setID); if (failed( spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || failed(encodeInstructionInto( extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { return failure(); } } // The first two operands are the result type and result . The set // and the opcode need to be insert after this. if (operands.size() < 2) { return op->emitError("extended instructions must have a result encoding"); } SmallVector extInstOperands; extInstOperands.reserve(operands.size() + 2); extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); extInstOperands.push_back(setID); extInstOperands.push_back(extensionOpcode); extInstOperands.append(std::next(operands.begin(), 2), operands.end()); return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, extInstOperands); } LogicalResult Serializer::processOperation(Operation *opInst) { LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); // First dispatch the ops that do not directly mirror an instruction from // the SPIR-V spec. return TypeSwitch(opInst) .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) .Case([&](spirv::BranchConditionalOp op) { return processBranchConditionalOp(op); }) .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) .Case([&](spirv::GlobalVariableOp op) { return processGlobalVariableOp(op); }) .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) .Case([&](spirv::SpecConstantCompositeOp op) { return processSpecConstantCompositeOp(op); }) .Case([&](spirv::SpecConstantOperationOp op) { return processSpecConstantOperationOp(op); }) .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) // Then handle all the ops that directly mirror SPIR-V instructions with // auto-generated methods. .Default( [&](Operation *op) { return dispatchToAutogenSerialization(op); }); } LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet, uint32_t opcode) { SmallVector operands; Location loc = op->getLoc(); uint32_t resultID = 0; if (op->getNumResults() != 0) { uint32_t resultTypeID = 0; if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) return failure(); operands.push_back(resultTypeID); resultID = getNextID(); operands.push_back(resultID); valueIDMap[op->getResult(0)] = resultID; }; for (Value operand : op->getOperands()) operands.push_back(getValueID(operand)); (void)emitDebugLine(functionBody, loc); if (extInstSet.empty()) { (void)encodeInstructionInto(functionBody, static_cast(opcode), operands); } else { (void)encodeExtensionInstruction(op, extInstSet, opcode, operands); } if (op->getNumResults() != 0) { for (auto attr : op->getAttrs()) { if (failed(processDecoration(loc, resultID, attr))) return failure(); } } return success(); } LogicalResult Serializer::emitDecoration(uint32_t target, spirv::Decoration decoration, ArrayRef params) { uint32_t wordCount = 3 + params.size(); decorations.push_back( spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); decorations.push_back(target); decorations.push_back(static_cast(decoration)); decorations.append(params.begin(), params.end()); return success(); } LogicalResult Serializer::emitDebugLine(SmallVectorImpl &binary, Location loc) { if (!emitDebugInfo) return success(); if (lastProcessedWasMergeInst) { lastProcessedWasMergeInst = false; return success(); } auto fileLoc = loc.dyn_cast(); if (fileLoc) (void)encodeInstructionInto( binary, spirv::Opcode::OpLine, {fileID, fileLoc.getLine(), fileLoc.getColumn()}); return success(); } } // namespace spirv } // namespace mlir