mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 04:56:07 +00:00
[MLIR][SPIRV] Add (de-)serialization support for SpecConstantOpeation.
This commit adds support for (de-)serializing SpecConstantOpeation. One thing worth noting is that during deserialization, we assign a fake ID to enclosed ops inside SpecConstantOpeation. We need to do this in order for deserialization logic to properly update ID to value map and to later reference the created value from the sibling 'spv::YieldOp'. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D93591
This commit is contained in:
parent
b8d2842088
commit
a40767ec88
@ -3170,6 +3170,7 @@ def SPV_OC_OpSpecConstantTrue : I32EnumAttrCase<"OpSpecConstantTrue", 4
|
||||
def SPV_OC_OpSpecConstantFalse : I32EnumAttrCase<"OpSpecConstantFalse", 49>;
|
||||
def SPV_OC_OpSpecConstant : I32EnumAttrCase<"OpSpecConstant", 50>;
|
||||
def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite", 51>;
|
||||
def SPV_OC_OpSpecConstantOperation : I32EnumAttrCase<"OpSpecConstantOperation", 52>;
|
||||
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
|
||||
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
|
||||
def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>;
|
||||
@ -3314,7 +3315,8 @@ def SPV_OpcodeAttr :
|
||||
SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,
|
||||
SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue,
|
||||
SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
|
||||
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
|
||||
SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOperation,
|
||||
SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
|
||||
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
|
||||
SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
|
||||
SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic,
|
||||
|
@ -3445,9 +3445,8 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
|
||||
return constOp.emitOpError("invalid enclosed op");
|
||||
|
||||
for (auto operand : enclosedOp.getOperands())
|
||||
if (!isa<spirv::ConstantOp, spirv::SpecConstantOp,
|
||||
spirv::SpecConstantCompositeOp, spirv::SpecConstantOperationOp>(
|
||||
operand.getDefiningOp()))
|
||||
if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
|
||||
spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
|
||||
return constOp.emitOpError(
|
||||
"invalid operand, must be defined by a constant operation");
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include "mlir/Target/SPIRV/Deserialization.h"
|
||||
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
||||
@ -28,6 +29,7 @@
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/bit.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/SaveAndRestore.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
@ -132,6 +134,14 @@ struct DeferredStructTypeInfo {
|
||||
SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
|
||||
};
|
||||
|
||||
/// A struct that collects the info needed to materialize/emit a
|
||||
/// SpecConstantOperation op.
|
||||
struct SpecConstOperationMaterializationInfo {
|
||||
spirv::Opcode enclodesOpcode;
|
||||
uint32_t resultTypeID;
|
||||
SmallVector<uint32_t> enclosedOpOperands;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Deserializer Declaration
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -216,9 +226,14 @@ private:
|
||||
/// Gets the constant's attribute and type associated with the given <id>.
|
||||
Optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
|
||||
|
||||
/// Gets the constant's integer attribute with the given <id>. Returns a null
|
||||
/// IntegerAttr if the given is not registered or does not correspond to an
|
||||
/// integer constant.
|
||||
/// Gets the info needed to materialize the spec constant operation op
|
||||
/// associated with the given <id>.
|
||||
Optional<SpecConstOperationMaterializationInfo>
|
||||
getSpecConstantOperation(uint32_t id);
|
||||
|
||||
/// Gets the constant's integer attribute with the given <id>. Returns a
|
||||
/// null IntegerAttr if the given is not registered or does not correspond
|
||||
/// to an integer constant.
|
||||
IntegerAttr getConstantInt(uint32_t id);
|
||||
|
||||
/// Returns a symbol to be used for the function name with the given
|
||||
@ -305,8 +320,20 @@ private:
|
||||
/// `operands`.
|
||||
LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Processes a SPIR-V OpSpecConstantComposite instruction with the given
|
||||
/// `operands`.
|
||||
LogicalResult processSpecConstantComposite(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Processes a SPIR-V OpSpecConstantOperation instruction with the given
|
||||
/// `operands`.
|
||||
LogicalResult processSpecConstantOperation(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Materializes/emits an OpSpecConstantOperation instruction.
|
||||
Value materializeSpecConstantOperation(uint32_t resultID,
|
||||
spirv::Opcode enclosedOpcode,
|
||||
uint32_t resultTypeID,
|
||||
ArrayRef<uint32_t> enclosedOpOperands);
|
||||
|
||||
/// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
|
||||
LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
|
||||
|
||||
@ -534,6 +561,11 @@ private:
|
||||
// Result <id> to composite spec constant mapping.
|
||||
DenseMap<uint32_t, spirv::SpecConstantCompositeOp> specConstCompositeMap;
|
||||
|
||||
/// Result <id> to info needed to materialize an OpSpecConstantOperation
|
||||
/// mapping.
|
||||
DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
|
||||
specConstOperationMap;
|
||||
|
||||
// Result <id> to variable mapping.
|
||||
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
|
||||
|
||||
@ -1036,6 +1068,14 @@ Optional<std::pair<Attribute, Type>> Deserializer::getConstant(uint32_t id) {
|
||||
return constIt->getSecond();
|
||||
}
|
||||
|
||||
Optional<SpecConstOperationMaterializationInfo>
|
||||
Deserializer::getSpecConstantOperation(uint32_t id) {
|
||||
auto constIt = specConstOperationMap.find(id);
|
||||
if (constIt == specConstOperationMap.end())
|
||||
return llvm::None;
|
||||
return constIt->getSecond();
|
||||
}
|
||||
|
||||
std::string Deserializer::getFunctionSymbol(uint32_t id) {
|
||||
auto funcName = nameMap.lookup(id).str();
|
||||
if (funcName.empty()) {
|
||||
@ -1745,6 +1785,91 @@ Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() < 3)
|
||||
return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
|
||||
"result <id>, and operand opcode");
|
||||
|
||||
uint32_t resultTypeID = operands[0];
|
||||
|
||||
if (!getType(resultTypeID))
|
||||
return emitError(unknownLoc, "undefined result type from <id> ")
|
||||
<< resultTypeID;
|
||||
|
||||
uint32_t resultID = operands[1];
|
||||
spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
|
||||
auto emplaceResult = specConstOperationMap.try_emplace(
|
||||
resultID,
|
||||
SpecConstOperationMaterializationInfo{
|
||||
enclosedOpcode, resultTypeID,
|
||||
SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
|
||||
|
||||
if (!emplaceResult.second)
|
||||
return emitError(unknownLoc, "value with <id>: ")
|
||||
<< resultID << " is probably defined before.";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Value Deserializer::materializeSpecConstantOperation(
|
||||
uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
|
||||
ArrayRef<uint32_t> enclosedOpOperands) {
|
||||
|
||||
Type resultType = getType(resultTypeID);
|
||||
|
||||
// Instructions wrapped by OpSpecConstantOp need an ID for their
|
||||
// Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
|
||||
// dialect wrapped op. For that purpose, a new value map is created and "fake"
|
||||
// ID in that map is assigned to the result of the enclosed instruction. Note
|
||||
// that there is no need to update this fake ID since we only need to
|
||||
// reference the created Value for the enclosed op from the spv::YieldOp
|
||||
// created later in this method (both of which are the only values in their
|
||||
// region: the SpecConstantOperation's region). If we encounter another
|
||||
// SpecConstantOperation in the module, we simply re-use the fake ID since the
|
||||
// previous Value assigned to it isn't visible in the current scope anyway.
|
||||
DenseMap<uint32_t, Value> newValueMap;
|
||||
llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap,
|
||||
newValueMap);
|
||||
constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
|
||||
|
||||
SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
|
||||
enclosedOpResultTypeAndOperands.push_back(resultTypeID);
|
||||
enclosedOpResultTypeAndOperands.push_back(fakeID);
|
||||
enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
|
||||
enclosedOpOperands.end());
|
||||
|
||||
// Process enclosed instruction before creating the enclosing
|
||||
// specConstantOperation (and its region). This way, references to constants,
|
||||
// global variables, and spec constants will be materialized outside the new
|
||||
// op's region. For more info, see Deserializer::getValue's implementation.
|
||||
if (failed(
|
||||
processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
|
||||
return Value();
|
||||
|
||||
// Since the enclosed op is emitted in the current block, split it in a
|
||||
// separate new block.
|
||||
Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
|
||||
|
||||
auto loc = createFileLineColLoc(opBuilder);
|
||||
auto specConstOperationOp =
|
||||
opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
|
||||
|
||||
Region &body = specConstOperationOp.body();
|
||||
// Move the new block into SpecConstantOperation's body.
|
||||
body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
|
||||
Region::iterator(enclosedBlock));
|
||||
Block &block = body.back();
|
||||
|
||||
// RAII guard to reset the insertion point to the module's region after
|
||||
// deserializing the body of the specConstantOperation.
|
||||
OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
|
||||
opBuilder.setInsertionPointToEnd(&block);
|
||||
|
||||
opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
|
||||
return specConstOperationOp.getResult();
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() != 2) {
|
||||
return emitError(unknownLoc,
|
||||
@ -2378,6 +2503,12 @@ Value Deserializer::getValue(uint32_t id) {
|
||||
opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
|
||||
return referenceOfOp.reference();
|
||||
}
|
||||
if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
|
||||
return materializeSpecConstantOperation(
|
||||
id, specConstOperationInfo->enclodesOpcode,
|
||||
specConstOperationInfo->resultTypeID,
|
||||
specConstOperationInfo->enclosedOpOperands);
|
||||
}
|
||||
if (auto undef = getUndefType(id)) {
|
||||
return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
|
||||
}
|
||||
@ -2483,6 +2614,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
|
||||
return processConstantComposite(operands);
|
||||
case spirv::Opcode::OpSpecConstantComposite:
|
||||
return processSpecConstantComposite(operands);
|
||||
case spirv::Opcode::OpSpecConstantOperation:
|
||||
return processSpecConstantOperation(operands);
|
||||
case spirv::Opcode::OpConstantTrue:
|
||||
return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
|
||||
case spirv::Opcode::OpSpecConstantTrue:
|
||||
|
@ -204,6 +204,9 @@ private:
|
||||
LogicalResult
|
||||
processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
|
||||
|
||||
LogicalResult
|
||||
processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
|
||||
|
||||
/// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA
|
||||
/// value to use with other operations. The SPIR-V spec recommends that
|
||||
/// OpUndef be generated at module level. The serialization generates an
|
||||
@ -711,6 +714,49 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
|
||||
return processName(resultID, op.sym_name());
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
|
||||
uint32_t typeID = 0;
|
||||
if (failed(processType(op.getLoc(), op.getType(), typeID))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto resultID = getNextID();
|
||||
|
||||
SmallVector<uint32_t, 8> operands;
|
||||
operands.push_back(typeID);
|
||||
operands.push_back(resultID);
|
||||
|
||||
Block &block = op.getRegion().getBlocks().front();
|
||||
Operation &enclosedOp = block.getOperations().front();
|
||||
|
||||
std::string enclosedOpName;
|
||||
llvm::raw_string_ostream rss(enclosedOpName);
|
||||
rss << "Op" << enclosedOp.getName().stripDialect();
|
||||
auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
|
||||
|
||||
if (!enclosedOpcode) {
|
||||
op.emitError("Couldn't find op code for op ")
|
||||
<< enclosedOp.getName().getStringRef();
|
||||
return failure();
|
||||
}
|
||||
|
||||
operands.push_back(static_cast<uint32_t>(enclosedOpcode.getValue()));
|
||||
|
||||
// Append operands to the enclosed op to the list of operands.
|
||||
for (Value operand : enclosedOp.getOperands()) {
|
||||
uint32_t id = getValueID(operand);
|
||||
assert(id && "use before def!");
|
||||
operands.push_back(id);
|
||||
}
|
||||
|
||||
encodeInstructionInto(typesGlobalValues,
|
||||
spirv::Opcode::OpSpecConstantOperation, operands);
|
||||
valueIDMap[op.getResult()] = resultID;
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
|
||||
auto undefType = op.getType();
|
||||
auto &id = undefValIDMap[undefType];
|
||||
@ -1929,6 +1975,9 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
|
||||
.Case([&](spirv::SpecConstantCompositeOp op) {
|
||||
return processSpecConstantCompositeOp(op);
|
||||
})
|
||||
.Case([&](spirv::SpecConstantOperationOp op) {
|
||||
return processSpecConstantOperationOp(op);
|
||||
})
|
||||
.Case([&](spirv::UndefOp op) { return processUndefOp(op); })
|
||||
.Case([&](spirv::VariableOp op) { return processVariableOp(op); })
|
||||
|
||||
|
@ -780,6 +780,20 @@ spv.module Logical GLSL450 {
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @sc = 42 : i32
|
||||
|
||||
spv.func @foo() -> i32 "None" {
|
||||
// CHECK: [[SC:%.*]] = spv.mlir.referenceof @sc
|
||||
%0 = spv.mlir.referenceof @sc : i32
|
||||
// CHECK: spv.SpecConstantOperation wraps "spv.ISub"([[SC]], [[SC]]) : (i32, i32) -> i32
|
||||
%1 = spv.SpecConstantOperation wraps "spv.ISub"(%0, %0) : (i32, i32) -> i32
|
||||
spv.ReturnValue %1 : i32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo() -> i32 "None" {
|
||||
%0 = spv.constant 1: i32
|
||||
|
@ -85,3 +85,34 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
// CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32>
|
||||
spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
|
||||
spv.specConstant @sc_i32_1 = 1 : i32
|
||||
|
||||
spv.func @use_composite() -> (i32) "None" {
|
||||
// CHECK: [[USE1:%.*]] = spv.mlir.referenceof @sc_i32_1 : i32
|
||||
// CHECK: [[USE2:%.*]] = spv.constant 0 : i32
|
||||
|
||||
// CHECK: [[RES1:%.*]] = spv.SpecConstantOperation wraps "spv.ISub"([[USE1]], [[USE2]]) : (i32, i32) -> i32
|
||||
|
||||
// CHECK: [[USE3:%.*]] = spv.mlir.referenceof @sc_i32_1 : i32
|
||||
// CHECK: [[USE4:%.*]] = spv.constant 0 : i32
|
||||
|
||||
// CHECK: [[RES2:%.*]] = spv.SpecConstantOperation wraps "spv.ISub"([[USE3]], [[USE4]]) : (i32, i32) -> i32
|
||||
|
||||
%0 = spv.mlir.referenceof @sc_i32_1 : i32
|
||||
%1 = spv.constant 0 : i32
|
||||
%2 = spv.SpecConstantOperation wraps "spv.ISub"(%0, %1) : (i32, i32) -> i32
|
||||
|
||||
// CHECK: [[RES3:%.*]] = spv.SpecConstantOperation wraps "spv.IMul"([[RES1]], [[RES2]]) : (i32, i32) -> i32
|
||||
%3 = spv.SpecConstantOperation wraps "spv.IMul"(%2, %2) : (i32, i32) -> i32
|
||||
|
||||
// Make sure deserialization continues from the right place after creating
|
||||
// the previous op.
|
||||
// CHECK: spv.ReturnValue [[RES3]]
|
||||
spv.ReturnValue %3 : i32
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user