2019-05-26 05:43:20 -07:00
|
|
|
//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
|
|
|
|
//
|
2020-01-26 03:58:30 +00:00
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
2019-12-23 09:35:36 -08:00
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2019-05-26 05:43:20 -07:00
|
|
|
//
|
2019-12-23 09:35:36 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-05-26 05:43:20 -07:00
|
|
|
//
|
|
|
|
// This file defines the operations in the SPIR-V dialect.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-07-16 05:06:57 -07:00
|
|
|
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
2019-05-26 05:43:20 -07:00
|
|
|
|
2020-03-11 16:04:25 -04:00
|
|
|
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
|
2019-08-20 13:33:41 -07:00
|
|
|
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
2019-07-16 05:06:57 -07:00
|
|
|
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
2019-05-29 10:47:16 -07:00
|
|
|
#include "mlir/IR/Builders.h"
|
2019-07-11 11:41:04 -07:00
|
|
|
#include "mlir/IR/Function.h"
|
2020-02-07 11:30:19 -05:00
|
|
|
#include "mlir/IR/FunctionImplementation.h"
|
2019-05-29 10:47:16 -07:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
2019-05-26 05:43:20 -07:00
|
|
|
#include "mlir/IR/StandardTypes.h"
|
2020-03-10 12:20:24 -07:00
|
|
|
#include "mlir/Interfaces/CallInterfaces.h"
|
2019-07-30 14:14:28 -07:00
|
|
|
#include "mlir/Support/StringExtras.h"
|
2019-09-24 19:24:33 -07:00
|
|
|
#include "llvm/ADT/bit.h"
|
2019-05-26 05:43:20 -07:00
|
|
|
|
2019-05-29 10:47:16 -07:00
|
|
|
using namespace mlir;
|
|
|
|
|
2019-07-02 06:02:20 -07:00
|
|
|
// TODO(antiagainst): generate these strings using ODS.
|
2019-07-03 18:12:52 -07:00
|
|
|
static constexpr const char kAlignmentAttrName[] = "alignment";
|
2019-08-30 12:17:21 -07:00
|
|
|
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
|
2019-09-16 15:39:16 -07:00
|
|
|
static constexpr const char kCallee[] = "callee";
|
2020-01-26 10:19:24 -05:00
|
|
|
static constexpr const char kClusterSize[] = "cluster_size";
|
2019-08-20 13:33:41 -07:00
|
|
|
static constexpr const char kDefaultValueAttrName[] = "default_value";
|
2019-09-21 10:18:00 -07:00
|
|
|
static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
|
2019-12-05 10:05:54 -08:00
|
|
|
static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
|
2019-08-17 10:19:48 -07:00
|
|
|
static constexpr const char kFnNameAttrName[] = "fn";
|
2020-01-26 10:19:24 -05:00
|
|
|
static constexpr const char kGroupOperationAttrName[] = "group_operation";
|
2019-07-12 06:14:53 -07:00
|
|
|
static constexpr const char kIndicesAttrName[] = "indices";
|
2019-08-17 10:19:48 -07:00
|
|
|
static constexpr const char kInitializerAttrName[] = "initializer";
|
|
|
|
static constexpr const char kInterfaceAttrName[] = "interface";
|
2019-09-21 10:18:00 -07:00
|
|
|
static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
|
2019-12-16 15:05:21 -08:00
|
|
|
static constexpr const char kSemanticsAttrName[] = "semantics";
|
2019-10-15 14:53:01 -07:00
|
|
|
static constexpr const char kSpecIdAttrName[] = "spec_id";
|
2019-08-17 10:19:48 -07:00
|
|
|
static constexpr const char kTypeAttrName[] = "type";
|
2019-12-05 10:05:54 -08:00
|
|
|
static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
|
2019-06-17 14:47:22 -07:00
|
|
|
static constexpr const char kValueAttrName[] = "value";
|
2019-07-08 10:56:20 -07:00
|
|
|
static constexpr const char kValuesAttrName[] = "values";
|
2019-06-17 14:47:22 -07:00
|
|
|
|
2019-06-04 14:03:30 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Common utility functions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-01-25 09:16:29 -05:00
|
|
|
/// Returns true if the given op is a function-like op or nested in a
|
|
|
|
/// function-like op without a module-like op in the middle.
|
|
|
|
static bool isNestedInFunctionLikeOp(Operation *op) {
|
|
|
|
if (!op)
|
|
|
|
return false;
|
|
|
|
if (op->hasTrait<OpTrait::SymbolTable>())
|
|
|
|
return false;
|
|
|
|
if (op->hasTrait<OpTrait::FunctionLike>())
|
|
|
|
return true;
|
|
|
|
return isNestedInFunctionLikeOp(op->getParentOp());
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns true if the given op is an module-like op that maintains a symbol
|
|
|
|
/// table.
|
|
|
|
static bool isDirectInModuleLikeOp(Operation *op) {
|
|
|
|
return op && op->hasTrait<OpTrait::SymbolTable>();
|
|
|
|
}
|
|
|
|
|
2020-01-26 10:19:24 -05:00
|
|
|
static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
|
|
|
|
auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
|
2019-07-25 15:42:41 -07:00
|
|
|
if (!constOp) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
auto valueAttr = constOp.value();
|
|
|
|
auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
|
|
|
|
if (!integerValueAttr) {
|
|
|
|
return failure();
|
|
|
|
}
|
2020-01-26 10:19:24 -05:00
|
|
|
value = integerValueAttr.getInt();
|
2019-07-25 15:42:41 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-09 09:51:25 -08:00
|
|
|
template <typename Ty>
|
|
|
|
static ArrayAttr
|
|
|
|
getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
|
2019-12-18 09:28:48 -08:00
|
|
|
function_ref<StringRef(Ty)> stringifyFn) {
|
2019-12-09 09:51:25 -08:00
|
|
|
if (enumValues.empty()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
SmallVector<StringRef, 1> enumValStrs;
|
|
|
|
enumValStrs.reserve(enumValues.size());
|
|
|
|
for (auto val : enumValues) {
|
|
|
|
enumValStrs.emplace_back(stringifyFn(val));
|
|
|
|
}
|
|
|
|
return builder.getStrArrayAttr(enumValStrs);
|
|
|
|
}
|
|
|
|
|
2020-03-11 16:04:25 -04:00
|
|
|
/// Parses the next string attribute in `parser` as an enumerant of the given
|
|
|
|
/// `EnumClass`.
|
2019-07-03 18:12:52 -07:00
|
|
|
template <typename EnumClass>
|
2019-09-21 10:18:00 -07:00
|
|
|
static ParseResult
|
2020-03-11 16:04:25 -04:00
|
|
|
parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
|
|
|
|
StringRef attrName = spirv::attributeName<EnumClass>()) {
|
2019-07-03 18:12:52 -07:00
|
|
|
Attribute attrVal;
|
|
|
|
SmallVector<NamedAttribute, 1> attr;
|
2019-09-20 11:36:49 -07:00
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
|
2019-09-21 10:18:00 -07:00
|
|
|
attrName, attr)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
2019-07-03 18:12:52 -07:00
|
|
|
if (!attrVal.isa<StringAttr>()) {
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(loc, "expected ")
|
2019-09-21 10:18:00 -07:00
|
|
|
<< attrName << " attribute specified as string";
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
2019-07-03 18:12:52 -07:00
|
|
|
auto attrOptional =
|
|
|
|
spirv::symbolizeEnum<EnumClass>()(attrVal.cast<StringAttr>().getValue());
|
|
|
|
if (!attrOptional) {
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(loc, "invalid ")
|
2019-09-21 10:18:00 -07:00
|
|
|
<< attrName << " attribute specification: " << attrVal;
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
2019-07-03 18:12:52 -07:00
|
|
|
value = attrOptional.getValue();
|
2019-07-30 14:14:28 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-03-11 16:04:25 -04:00
|
|
|
/// Parses the next string attribute in `parser` as an enumerant of the given
|
|
|
|
/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
|
|
|
|
/// attribute with the enum class's name as attribute name.
|
2019-07-30 14:14:28 -07:00
|
|
|
template <typename EnumClass>
|
2019-09-21 10:18:00 -07:00
|
|
|
static ParseResult
|
2020-03-11 16:04:25 -04:00
|
|
|
parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
|
|
|
|
StringRef attrName = spirv::attributeName<EnumClass>()) {
|
|
|
|
if (parseEnumStrAttr(value, parser)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
|
|
|
|
llvm::bit_cast<int32_t>(value)));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Parses the next keyword in `parser` as an enumerant of the given
|
|
|
|
/// `EnumClass`.
|
|
|
|
template <typename EnumClass>
|
|
|
|
static ParseResult
|
|
|
|
parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
|
|
|
|
StringRef attrName = spirv::attributeName<EnumClass>()) {
|
|
|
|
StringRef keyword;
|
|
|
|
SmallVector<NamedAttribute, 1> attr;
|
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseKeyword(&keyword))
|
|
|
|
return failure();
|
|
|
|
if (Optional<EnumClass> attr = spirv::symbolizeEnum<EnumClass>()(keyword)) {
|
|
|
|
value = attr.getValue();
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return parser.emitError(loc, "invalid ")
|
|
|
|
<< attrName << " attribute specification: " << keyword;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
|
|
|
|
/// and inserts the enumerant into `state` as an 32-bit integer attribute with
|
|
|
|
/// the enum class's name as attribute name.
|
|
|
|
template <typename EnumClass>
|
|
|
|
static ParseResult
|
|
|
|
parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
|
|
|
|
OperationState &state,
|
|
|
|
StringRef attrName = spirv::attributeName<EnumClass>()) {
|
|
|
|
if (parseEnumKeywordAttr(value, parser)) {
|
2019-07-30 14:14:28 -07:00
|
|
|
return failure();
|
|
|
|
}
|
2019-09-21 10:18:00 -07:00
|
|
|
state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
|
2019-09-24 19:24:33 -07:00
|
|
|
llvm::bit_cast<int32_t>(value)));
|
2019-06-24 10:59:05 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &state) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// Parse an optional list of attributes staring with '['
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOptionalLSquare()) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// Nothing to do
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-07-03 18:12:52 -07:00
|
|
|
spirv::MemoryAccess memoryAccessAttr;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(memoryAccessAttr, parser, state)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-09-16 09:22:43 -07:00
|
|
|
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// Parse integer attribute for alignment.
|
|
|
|
Attribute alignmentAttr;
|
2019-09-20 11:36:49 -07:00
|
|
|
Type i32Type = parser.getBuilder().getIntegerType(32);
|
|
|
|
if (parser.parseComma() ||
|
|
|
|
parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
|
2019-09-20 19:47:05 -07:00
|
|
|
state.attributes)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.parseRSquare();
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename LoadStoreOpTy>
|
|
|
|
static void
|
2019-09-20 20:43:02 -07:00
|
|
|
printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer,
|
2019-06-24 10:59:05 -07:00
|
|
|
SmallVectorImpl<StringRef> &elidedAttrs) {
|
|
|
|
// Print optional memory access attribute.
|
2019-07-02 06:02:20 -07:00
|
|
|
if (auto memAccess = loadStoreOp.memory_access()) {
|
2019-07-03 18:12:52 -07:00
|
|
|
elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
|
2019-06-24 10:59:05 -07:00
|
|
|
|
|
|
|
// Print integer alignment attribute.
|
|
|
|
if (auto alignment = loadStoreOp.alignment()) {
|
2019-07-03 18:12:52 -07:00
|
|
|
elidedAttrs.push_back(kAlignmentAttrName);
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << ", " << alignment;
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << "]";
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
2019-07-03 18:12:52 -07:00
|
|
|
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
|
|
|
|
2019-10-30 14:41:26 -07:00
|
|
|
static LogicalResult verifyCastOp(Operation *op,
|
|
|
|
bool requireSameBitWidth = true) {
|
2020-01-11 08:54:04 -08:00
|
|
|
Type operandType = op->getOperand(0).getType();
|
|
|
|
Type resultType = op->getResult(0).getType();
|
2019-10-30 14:41:26 -07:00
|
|
|
|
|
|
|
// ODS checks that result type and operand type have the same shape.
|
|
|
|
if (auto vectorType = operandType.dyn_cast<VectorType>()) {
|
|
|
|
operandType = vectorType.getElementType();
|
|
|
|
resultType = resultType.cast<VectorType>().getElementType();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
|
|
|
|
auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
|
|
|
|
auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
|
|
|
|
|
|
|
|
if (requireSameBitWidth) {
|
|
|
|
if (!isSameBitWidth) {
|
|
|
|
return op->emitOpError(
|
|
|
|
"expected the same bit widths for operand type and result "
|
|
|
|
"type, but provided ")
|
|
|
|
<< operandType << " and " << resultType;
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (isSameBitWidth) {
|
|
|
|
return op->emitOpError(
|
|
|
|
"expected the different bit widths for operand type and result "
|
|
|
|
"type, but provided ")
|
|
|
|
<< operandType << " and " << resultType;
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-06-24 10:59:05 -07:00
|
|
|
template <typename LoadStoreOpTy>
|
|
|
|
static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) {
|
|
|
|
// ODS checks for attributes values. Just need to verify that if the
|
|
|
|
// memory-access attribute is Aligned, then the alignment attribute must be
|
|
|
|
// present.
|
|
|
|
auto *op = loadStoreOp.getOperation();
|
2019-07-03 18:12:52 -07:00
|
|
|
auto memAccessAttr = op->getAttr(spirv::attributeName<spirv::MemoryAccess>());
|
2019-07-02 06:02:20 -07:00
|
|
|
if (!memAccessAttr) {
|
|
|
|
// Alignment attribute shouldn't be present if memory access attribute is
|
|
|
|
// not present.
|
2019-07-03 18:12:52 -07:00
|
|
|
if (op->getAttr(kAlignmentAttrName)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return loadStoreOp.emitOpError(
|
|
|
|
"invalid alignment specification without aligned memory access "
|
|
|
|
"specification");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-07-02 06:02:20 -07:00
|
|
|
auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
|
|
|
|
auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
|
|
|
|
|
|
|
|
if (!memAccess) {
|
|
|
|
return loadStoreOp.emitOpError("invalid memory access specifier: ")
|
|
|
|
<< memAccessVal;
|
|
|
|
}
|
|
|
|
|
2019-09-16 09:22:43 -07:00
|
|
|
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
|
2019-07-03 18:12:52 -07:00
|
|
|
if (!op->getAttr(kAlignmentAttrName)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return loadStoreOp.emitOpError("missing alignment value");
|
|
|
|
}
|
|
|
|
} else {
|
2019-07-03 18:12:52 -07:00
|
|
|
if (op->getAttr(kAlignmentAttrName)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return loadStoreOp.emitOpError(
|
|
|
|
"invalid alignment specification with non-aligned memory access "
|
|
|
|
"specification");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-21 10:18:00 -07:00
|
|
|
template <typename BarrierOp>
|
|
|
|
static LogicalResult verifyMemorySemantics(BarrierOp op) {
|
|
|
|
// According to the SPIR-V specification:
|
|
|
|
// "Despite being a mask and allowing multiple bits to be combined, it is
|
|
|
|
// invalid for more than one of these four bits to be set: Acquire, Release,
|
|
|
|
// AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
|
|
|
|
// Release semantics is done by setting the AcquireRelease bit, not by setting
|
|
|
|
// two bits."
|
|
|
|
auto memorySemantics = op.memory_semantics();
|
|
|
|
auto atMostOneInSet = spirv::MemorySemantics::Acquire |
|
|
|
|
spirv::MemorySemantics::Release |
|
|
|
|
spirv::MemorySemantics::AcquireRelease |
|
|
|
|
spirv::MemorySemantics::SequentiallyConsistent;
|
|
|
|
|
|
|
|
auto bitCount = llvm::countPopulation(
|
|
|
|
static_cast<uint32_t>(memorySemantics & atMostOneInSet));
|
|
|
|
if (bitCount > 1) {
|
|
|
|
return op.emitError("expected at most one of these four memory constraints "
|
|
|
|
"to be set: `Acquire`, `Release`,"
|
|
|
|
"`AcquireRelease` or `SequentiallyConsistent`");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-06-24 10:59:05 -07:00
|
|
|
template <typename LoadStoreOpTy>
|
2019-12-23 14:45:01 -08:00
|
|
|
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
|
|
|
|
Value val) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// ODS already checks ptr is spirv::PointerType. Just check that the pointee
|
|
|
|
// type of the pointer and the type of the value are the same
|
|
|
|
//
|
|
|
|
// TODO(ravishankarm): Check that the value type satisfies restrictions of
|
|
|
|
// SPIR-V OpLoad/OpStore operations
|
2020-01-11 08:54:04 -08:00
|
|
|
if (val.getType() !=
|
|
|
|
ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return op.emitOpError("mismatch in result type and pointer type");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseVariableDecorations(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &state) {
|
2019-08-17 10:19:48 -07:00
|
|
|
auto builtInName =
|
|
|
|
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
|
2019-09-20 11:36:49 -07:00
|
|
|
if (succeeded(parser.parseOptionalKeyword("bind"))) {
|
2019-08-17 10:19:48 -07:00
|
|
|
Attribute set, binding;
|
|
|
|
// Parse optional descriptor binding
|
|
|
|
auto descriptorSetName = convertToSnakeCase(
|
|
|
|
stringifyDecoration(spirv::Decoration::DescriptorSet));
|
|
|
|
auto bindingName =
|
|
|
|
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
|
2019-09-20 11:36:49 -07:00
|
|
|
Type i32Type = parser.getBuilder().getIntegerType(32);
|
|
|
|
if (parser.parseLParen() ||
|
|
|
|
parser.parseAttribute(set, i32Type, descriptorSetName,
|
2019-09-20 19:47:05 -07:00
|
|
|
state.attributes) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseComma() ||
|
|
|
|
parser.parseAttribute(binding, i32Type, bindingName,
|
2019-09-20 19:47:05 -07:00
|
|
|
state.attributes) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseRParen()) {
|
2019-08-17 10:19:48 -07:00
|
|
|
return failure();
|
|
|
|
}
|
2019-09-20 11:36:49 -07:00
|
|
|
} else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
|
2019-08-17 10:19:48 -07:00
|
|
|
StringAttr builtIn;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseLParen() ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.parseAttribute(builtIn, builtInName, state.attributes) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseRParen()) {
|
2019-08-17 10:19:48 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parse other attributes
|
2019-11-05 13:32:07 -08:00
|
|
|
if (parser.parseOptionalAttrDict(state.attributes))
|
2019-08-17 10:19:48 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
|
2019-08-17 10:19:48 -07:00
|
|
|
SmallVectorImpl<StringRef> &elidedAttrs) {
|
|
|
|
// Print optional descriptor binding
|
|
|
|
auto descriptorSetName =
|
|
|
|
convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
|
|
|
|
auto bindingName =
|
|
|
|
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
|
|
|
|
auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
|
|
|
|
auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
|
|
|
|
if (descriptorSet && binding) {
|
|
|
|
elidedAttrs.push_back(descriptorSetName);
|
|
|
|
elidedAttrs.push_back(bindingName);
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
|
|
|
|
<< ")";
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// Print BuiltIn attribute if present
|
|
|
|
auto builtInName =
|
|
|
|
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
|
|
|
|
if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
|
2019-08-17 10:19:48 -07:00
|
|
|
elidedAttrs.push_back(builtInName);
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
|
2019-09-25 19:01:18 -07:00
|
|
|
// Get bit width of types.
|
|
|
|
static unsigned getBitWidth(Type type) {
|
|
|
|
if (type.isa<spirv::PointerType>()) {
|
|
|
|
// Just return 64 bits for pointer types for now.
|
|
|
|
// TODO: Make sure not caller relies on the actual pointer width value.
|
|
|
|
return 64;
|
|
|
|
}
|
2020-03-04 15:12:33 -05:00
|
|
|
|
|
|
|
if (type.isIntOrFloat())
|
2019-09-25 19:01:18 -07:00
|
|
|
return type.getIntOrFloatBitWidth();
|
2020-03-04 15:12:33 -05:00
|
|
|
|
2019-09-25 19:01:18 -07:00
|
|
|
if (auto vectorType = type.dyn_cast<VectorType>()) {
|
2020-03-04 15:12:33 -05:00
|
|
|
assert(vectorType.getElementType().isIntOrFloat());
|
2019-09-25 19:01:18 -07:00
|
|
|
return vectorType.getNumElements() *
|
|
|
|
vectorType.getElementType().getIntOrFloatBitWidth();
|
|
|
|
}
|
|
|
|
llvm_unreachable("unhandled bit width computation for type");
|
|
|
|
}
|
|
|
|
|
2019-12-05 13:10:10 -08:00
|
|
|
/// Walks the given type hierarchy with the given indices, potentially down
|
|
|
|
/// to component granularity, to select an element type. Returns null type and
|
|
|
|
/// emits errors with the given loc on failure.
|
2019-12-10 10:11:19 -08:00
|
|
|
static Type
|
|
|
|
getElementType(Type type, ArrayRef<int32_t> indices,
|
2019-12-18 09:28:48 -08:00
|
|
|
function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
|
2019-12-10 10:11:19 -08:00
|
|
|
if (indices.empty()) {
|
|
|
|
emitErrorFn("expected at least one index for spv.CompositeExtract");
|
2019-12-05 13:10:10 -08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2019-12-10 10:11:19 -08:00
|
|
|
for (auto index : indices) {
|
2019-12-05 13:10:10 -08:00
|
|
|
if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
|
|
|
|
if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
|
2019-12-10 10:11:19 -08:00
|
|
|
emitErrorFn("index ") << index << " out of bounds for " << type;
|
2019-12-05 13:10:10 -08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
type = cType.getElementType(index);
|
|
|
|
} else {
|
2019-12-10 10:11:19 -08:00
|
|
|
emitErrorFn("cannot extract from non-composite type ")
|
2019-12-05 13:10:10 -08:00
|
|
|
<< type << " with index " << index;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return type;
|
|
|
|
}
|
|
|
|
|
2019-12-10 10:11:19 -08:00
|
|
|
static Type
|
|
|
|
getElementType(Type type, Attribute indices,
|
2019-12-18 09:28:48 -08:00
|
|
|
function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
|
2019-12-10 10:11:19 -08:00
|
|
|
auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
|
|
|
|
if (!indicesArrayAttr) {
|
|
|
|
emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
if (!indicesArrayAttr.size()) {
|
|
|
|
emitErrorFn("expected at least one index for spv.CompositeExtract");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int32_t, 2> indexVals;
|
|
|
|
for (auto indexAttr : indicesArrayAttr) {
|
|
|
|
auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
|
|
|
|
if (!indexIntAttr) {
|
|
|
|
emitErrorFn("expected an 32-bit integer for index, but found '")
|
|
|
|
<< indexAttr << "'";
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
indexVals.push_back(indexIntAttr.getInt());
|
|
|
|
}
|
|
|
|
return getElementType(type, indexVals, emitErrorFn);
|
|
|
|
}
|
|
|
|
|
|
|
|
static Type getElementType(Type type, Attribute indices, Location loc) {
|
|
|
|
auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
|
|
|
|
return ::mlir::emitError(loc, err);
|
|
|
|
};
|
|
|
|
return getElementType(type, indices, errorFn);
|
|
|
|
}
|
|
|
|
|
|
|
|
static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
|
|
|
|
llvm::SMLoc loc) {
|
|
|
|
auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
|
|
|
|
return parser.emitError(loc, err);
|
|
|
|
};
|
|
|
|
return getElementType(type, indices, errorFn);
|
|
|
|
}
|
|
|
|
|
2019-10-02 11:00:50 -07:00
|
|
|
/// Returns true if the given `block` only contains one `spv._merge` op.
|
|
|
|
static inline bool isMergeBlock(Block &block) {
|
|
|
|
return !block.empty() && std::next(block.begin()) == block.end() &&
|
|
|
|
isa<spirv::MergeOp>(block.front());
|
|
|
|
}
|
|
|
|
|
2019-09-25 16:34:37 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Common parsers and printers
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-12-16 15:05:21 -08:00
|
|
|
// Parses an atomic update op. If the update op does not take a value (like
|
|
|
|
// AtomicIIncrement) `hasValue` must be false.
|
|
|
|
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
|
|
|
|
OperationState &state, bool hasValue) {
|
|
|
|
spirv::Scope scope;
|
|
|
|
spirv::MemorySemantics memoryScope;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> operandInfo;
|
|
|
|
OpAsmParser::OperandType ptrInfo, valueInfo;
|
|
|
|
Type type;
|
|
|
|
llvm::SMLoc loc;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
|
|
|
|
parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
|
2019-12-16 15:05:21 -08:00
|
|
|
parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
|
|
|
|
parser.getCurrentLocation(&loc) || parser.parseColonType(type))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
|
|
|
if (!ptrType)
|
|
|
|
return parser.emitError(loc, "expected pointer type");
|
|
|
|
|
|
|
|
SmallVector<Type, 2> operandTypes;
|
|
|
|
operandTypes.push_back(ptrType);
|
|
|
|
if (hasValue)
|
|
|
|
operandTypes.push_back(ptrType.getPointeeType());
|
|
|
|
if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
|
|
|
|
state.operands))
|
|
|
|
return failure();
|
|
|
|
return parser.addTypeToList(ptrType.getPointeeType(), state.types);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Prints an atomic update op.
|
|
|
|
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
|
|
|
|
printer << op->getName() << " \"";
|
|
|
|
auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
|
|
|
|
printer << spirv::stringifyScope(
|
|
|
|
static_cast<spirv::Scope>(scopeAttr.getInt()))
|
|
|
|
<< "\" \"";
|
|
|
|
auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
|
|
|
|
printer << spirv::stringifyMemorySemantics(
|
|
|
|
static_cast<spirv::MemorySemantics>(
|
|
|
|
memorySemanticsAttr.getInt()))
|
2020-01-11 08:54:04 -08:00
|
|
|
<< "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
|
2019-12-16 15:05:21 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Verifies an atomic update op.
|
|
|
|
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
|
2020-01-11 08:54:04 -08:00
|
|
|
auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
|
2019-12-16 15:05:21 -08:00
|
|
|
auto elementType = ptrType.getPointeeType();
|
2020-03-04 15:12:33 -05:00
|
|
|
if (!elementType.isa<IntegerType>())
|
2019-12-16 15:05:21 -08:00
|
|
|
return op->emitOpError(
|
|
|
|
"pointer operand must point to an integer value, found ")
|
|
|
|
<< elementType;
|
|
|
|
|
|
|
|
if (op->getNumOperands() > 1) {
|
2020-01-11 08:54:04 -08:00
|
|
|
auto valueType = op->getOperand(1).getType();
|
2019-12-16 15:05:21 -08:00
|
|
|
if (valueType != elementType)
|
|
|
|
return op->emitOpError("expected value to have the same type as the "
|
|
|
|
"pointer operand's pointee type ")
|
|
|
|
<< elementType << ", but found " << valueType;
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-01-28 09:36:01 -05:00
|
|
|
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
|
|
|
spirv::Scope executionScope;
|
|
|
|
spirv::GroupOperation groupOperation;
|
|
|
|
OpAsmParser::OperandType valueInfo;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(executionScope, parser, state,
|
|
|
|
kExecutionScopeAttrName) ||
|
|
|
|
parseEnumStrAttr(groupOperation, parser, state,
|
|
|
|
kGroupOperationAttrName) ||
|
2020-01-28 09:36:01 -05:00
|
|
|
parser.parseOperand(valueInfo))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
Optional<OpAsmParser::OperandType> clusterSizeInfo;
|
|
|
|
if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
|
|
|
|
clusterSizeInfo = OpAsmParser::OperandType();
|
|
|
|
if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
|
|
|
|
parser.parseRParen())
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
Type resultType;
|
|
|
|
if (parser.parseColonType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
if (parser.resolveOperand(valueInfo, resultType, state.operands))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
if (clusterSizeInfo.hasValue()) {
|
|
|
|
Type i32Type = parser.getBuilder().getIntegerType(32);
|
|
|
|
if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
return parser.addTypeToList(resultType, state.types);
|
|
|
|
}
|
|
|
|
|
|
|
|
static void printGroupNonUniformArithmeticOp(Operation *groupOp,
|
|
|
|
OpAsmPrinter &printer) {
|
|
|
|
printer << groupOp->getName() << " \""
|
|
|
|
<< stringifyScope(static_cast<spirv::Scope>(
|
|
|
|
groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
|
|
|
|
.getInt()))
|
|
|
|
<< "\" \""
|
|
|
|
<< stringifyGroupOperation(static_cast<spirv::GroupOperation>(
|
|
|
|
groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
|
|
|
|
.getInt()))
|
|
|
|
<< "\" " << groupOp->getOperand(0);
|
|
|
|
|
|
|
|
if (groupOp->getNumOperands() > 1)
|
|
|
|
printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
|
|
|
|
printer << " : " << groupOp->getResult(0).getType();
|
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
|
|
|
|
spirv::Scope scope = static_cast<spirv::Scope>(
|
|
|
|
groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
|
|
|
|
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
|
|
|
return groupOp->emitOpError(
|
|
|
|
"execution scope must be 'Workgroup' or 'Subgroup'");
|
|
|
|
|
|
|
|
spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
|
|
|
|
groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
|
|
|
|
if (operation == spirv::GroupOperation::ClusteredReduce &&
|
|
|
|
groupOp->getNumOperands() == 1)
|
|
|
|
return groupOp->emitOpError("cluster size operand must be provided for "
|
|
|
|
"'ClusteredReduce' group operation");
|
|
|
|
if (groupOp->getNumOperands() > 1) {
|
|
|
|
Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
|
|
|
|
int32_t clusterSize = 0;
|
|
|
|
|
|
|
|
// TODO(antiagainst): support specialization constant here.
|
|
|
|
if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
|
|
|
|
return groupOp->emitOpError(
|
|
|
|
"cluster size operand must come from a constant op");
|
|
|
|
|
|
|
|
if (!llvm::isPowerOf2_32(clusterSize))
|
|
|
|
return groupOp->emitOpError(
|
|
|
|
"cluster size operand must be a power of two");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-25 16:34:37 -07:00
|
|
|
static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) {
|
|
|
|
OpAsmParser::OperandType operandInfo;
|
|
|
|
Type type;
|
|
|
|
if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
|
|
|
|
parser.resolveOperands(operandInfo, type, state.operands)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
state.addTypes(type);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
|
2020-01-11 08:54:04 -08:00
|
|
|
printer << unaryOp->getName() << ' ' << unaryOp->getOperand(0) << " : "
|
|
|
|
<< unaryOp->getOperand(0).getType();
|
2019-09-25 16:34:37 -07:00
|
|
|
}
|
|
|
|
|
2019-09-30 10:40:07 -07:00
|
|
|
/// Result of a logical op must be a scalar or vector of boolean type.
|
|
|
|
static Type getUnaryOpResultType(Builder &builder, Type operandType) {
|
|
|
|
Type resultType = builder.getIntegerType(1);
|
|
|
|
if (auto vecType = operandType.dyn_cast<VectorType>()) {
|
|
|
|
return VectorType::get(vecType.getNumElements(), resultType);
|
|
|
|
}
|
|
|
|
return resultType;
|
|
|
|
}
|
|
|
|
|
|
|
|
static ParseResult parseLogicalUnaryOp(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
|
|
|
OpAsmParser::OperandType operandInfo;
|
|
|
|
Type type;
|
|
|
|
if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
|
|
|
|
parser.resolveOperand(operandInfo, type, state.operands)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
state.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
|
2019-09-25 16:34:37 -07:00
|
|
|
OperationState &result) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> ops;
|
|
|
|
Type type;
|
|
|
|
if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) ||
|
|
|
|
parser.resolveOperands(ops, type, result.operands)) {
|
|
|
|
return failure();
|
|
|
|
}
|
2019-09-30 10:40:07 -07:00
|
|
|
result.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
|
2019-09-25 16:34:37 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-30 10:40:07 -07:00
|
|
|
static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
|
2019-12-12 15:31:39 -08:00
|
|
|
printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
|
2020-01-11 08:54:04 -08:00
|
|
|
<< logicalOp->getOperand(0).getType();
|
2019-09-25 16:34:37 -07:00
|
|
|
}
|
|
|
|
|
2019-11-08 11:05:32 -08:00
|
|
|
static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> operandInfo;
|
|
|
|
Type baseType;
|
|
|
|
Type shiftType;
|
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
|
|
|
|
if (parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
|
|
|
|
parser.parseType(baseType) || parser.parseComma() ||
|
|
|
|
parser.parseType(shiftType) ||
|
|
|
|
parser.resolveOperands(operandInfo, {baseType, shiftType}, loc,
|
|
|
|
state.operands)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
state.addTypes(baseType);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
|
2019-12-23 14:45:01 -08:00
|
|
|
Value base = op->getOperand(0);
|
|
|
|
Value shift = op->getOperand(1);
|
2020-01-11 08:54:04 -08:00
|
|
|
printer << op->getName() << ' ' << base << ", " << shift << " : "
|
|
|
|
<< base.getType() << ", " << shift.getType();
|
2019-11-08 11:05:32 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verifyShiftOp(Operation *op) {
|
2020-01-11 08:54:04 -08:00
|
|
|
if (op->getOperand(0).getType() != op->getResult(0).getType()) {
|
2019-11-08 11:05:32 -08:00
|
|
|
return op->emitError("expected the same type for the first operand and "
|
|
|
|
"result, but provided ")
|
2020-01-11 08:54:04 -08:00
|
|
|
<< op->getOperand(0).getType() << " and "
|
|
|
|
<< op->getResult(0).getType();
|
2019-11-08 11:05:32 -08:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-07-25 15:42:41 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AccessChainOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-12-07 10:35:01 -08:00
|
|
|
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
|
2019-08-27 10:49:53 -07:00
|
|
|
if (indices.empty()) {
|
2019-07-25 15:42:41 -07:00
|
|
|
emitError(baseLoc, "'spv.AccessChain' op expected at least "
|
|
|
|
"one index ");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
|
|
|
if (!ptrType) {
|
|
|
|
emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
|
|
|
|
"to composite type, but provided ")
|
|
|
|
<< type;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto resultType = ptrType.getPointeeType();
|
|
|
|
auto resultStorageClass = ptrType.getStorageClass();
|
|
|
|
int32_t index = 0;
|
|
|
|
|
|
|
|
for (auto indexSSA : indices) {
|
|
|
|
auto cType = resultType.dyn_cast<spirv::CompositeType>();
|
|
|
|
if (!cType) {
|
|
|
|
emitError(baseLoc,
|
|
|
|
"'spv.AccessChain' op cannot extract from non-composite type ")
|
|
|
|
<< resultType << " with index " << index;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
index = 0;
|
|
|
|
if (resultType.isa<spirv::StructType>()) {
|
2020-01-11 08:54:04 -08:00
|
|
|
Operation *op = indexSSA.getDefiningOp();
|
2019-07-25 15:42:41 -07:00
|
|
|
if (!op) {
|
|
|
|
emitError(baseLoc, "'spv.AccessChain' op index must be an "
|
|
|
|
"integer spv.constant to access "
|
|
|
|
"element of spv.struct");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO(denis0x0D): this should be relaxed to allow
|
|
|
|
// integer literals of other bitwidths.
|
|
|
|
if (failed(extractValueFromConstOp(op, index))) {
|
|
|
|
emitError(baseLoc,
|
|
|
|
"'spv.AccessChain' index must be an integer spv.constant to "
|
|
|
|
"access element of spv.struct, but provided ")
|
|
|
|
<< op->getName();
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
|
|
|
|
emitError(baseLoc, "'spv.AccessChain' op index ")
|
|
|
|
<< index << " out of bounds for " << resultType;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
resultType = cType.getElementType(index);
|
|
|
|
}
|
|
|
|
return spirv::PointerType::get(resultType, resultStorageClass);
|
|
|
|
}
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
void spirv::AccessChainOp::build(Builder *builder, OperationState &state,
|
2019-12-23 14:45:01 -08:00
|
|
|
Value basePtr, ValueRange indices) {
|
2020-01-11 08:54:04 -08:00
|
|
|
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
|
2019-08-27 10:49:53 -07:00
|
|
|
assert(type && "Unable to deduce return type based on basePtr and indices");
|
|
|
|
build(builder, state, type, basePtr, indices);
|
|
|
|
}
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseAccessChainOp(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &state) {
|
2019-07-25 15:42:41 -07:00
|
|
|
OpAsmParser::OperandType ptrInfo;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
|
|
|
|
Type type;
|
|
|
|
// TODO(denis0x0D): regarding to the spec an index must be any integer type,
|
|
|
|
// figure out how to use resolveOperand with a range of types and do not
|
|
|
|
// fail on first attempt.
|
2019-09-20 11:36:49 -07:00
|
|
|
Type indicesType = parser.getBuilder().getIntegerType(32);
|
2019-07-25 15:42:41 -07:00
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(ptrInfo) ||
|
|
|
|
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser.parseColonType(type) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperand(ptrInfo, type, state.operands) ||
|
|
|
|
parser.resolveOperands(indicesInfo, indicesType, state.operands)) {
|
2019-07-25 15:42:41 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto resultType = getElementPtrType(
|
2019-09-20 19:47:05 -07:00
|
|
|
type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
|
2019-07-25 15:42:41 -07:00
|
|
|
if (!resultType) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addTypes(resultType);
|
2019-07-25 15:42:41 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
|
2020-01-11 08:54:04 -08:00
|
|
|
printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
|
|
|
|
<< '[' << op.indices() << "] : " << op.base_ptr().getType();
|
2019-07-25 15:42:41 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
|
2019-12-23 14:45:01 -08:00
|
|
|
SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
|
|
|
|
accessChainOp.indices().end());
|
2020-01-11 08:54:04 -08:00
|
|
|
auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
|
2019-07-25 15:42:41 -07:00
|
|
|
indices, accessChainOp.getLoc());
|
|
|
|
if (!resultType) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto providedResultType =
|
|
|
|
accessChainOp.getType().dyn_cast<spirv::PointerType>();
|
|
|
|
if (!providedResultType) {
|
|
|
|
return accessChainOp.emitOpError(
|
|
|
|
"result type must be a pointer, but provided")
|
|
|
|
<< providedResultType;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (resultType != providedResultType) {
|
|
|
|
return accessChainOp.emitOpError("invalid result type: expected ")
|
|
|
|
<< resultType << ", but provided " << providedResultType;
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-08-17 10:19:48 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv._address_of
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-10-04 14:02:14 -07:00
|
|
|
void spirv::AddressOfOp::build(Builder *builder, OperationState &state,
|
|
|
|
spirv::GlobalVariableOp var) {
|
|
|
|
build(builder, state, var.type(), builder->getSymbolRefAttr(var));
|
|
|
|
}
|
|
|
|
|
2019-08-17 10:19:48 -07:00
|
|
|
static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
|
2020-01-25 09:16:29 -05:00
|
|
|
auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
|
|
|
|
SymbolTable::lookupNearestSymbolFrom(addressOfOp.getParentOp(),
|
|
|
|
addressOfOp.variable()));
|
2019-08-17 10:19:48 -07:00
|
|
|
if (!varOp) {
|
2019-08-20 13:33:41 -07:00
|
|
|
return addressOfOp.emitOpError("expected spv.globalVariable symbol");
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
2020-01-11 08:54:04 -08:00
|
|
|
if (addressOfOp.pointer().getType() != varOp.type()) {
|
2019-08-20 13:33:41 -07:00
|
|
|
return addressOfOp.emitOpError(
|
|
|
|
"result type mismatch with the referenced global variable's type");
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-05 10:05:54 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicCompareExchangeWeak
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
|
|
|
spirv::Scope memoryScope;
|
|
|
|
spirv::MemorySemantics equalSemantics, unequalSemantics;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 3> operandInfo;
|
|
|
|
Type type;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
|
|
|
|
parseEnumStrAttr(equalSemantics, parser, state,
|
|
|
|
kEqualSemanticsAttrName) ||
|
|
|
|
parseEnumStrAttr(unequalSemantics, parser, state,
|
|
|
|
kUnequalSemanticsAttrName) ||
|
2019-12-05 10:05:54 -08:00
|
|
|
parser.parseOperandList(operandInfo, 3))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseColonType(type))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
|
|
|
if (!ptrType)
|
|
|
|
return parser.emitError(loc, "expected pointer type");
|
|
|
|
|
|
|
|
if (parser.resolveOperands(
|
|
|
|
operandInfo,
|
|
|
|
{ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
|
|
|
|
parser.getNameLoc(), state.operands))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
return parser.addTypeToList(ptrType.getPointeeType(), state.types);
|
|
|
|
}
|
|
|
|
|
|
|
|
static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
|
|
|
|
OpAsmPrinter &printer) {
|
|
|
|
printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \""
|
|
|
|
<< stringifyScope(atomOp.memory_scope()) << "\" \""
|
|
|
|
<< stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
|
2019-12-12 15:31:39 -08:00
|
|
|
<< stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
|
2020-01-11 08:54:04 -08:00
|
|
|
<< atomOp.getOperands() << " : " << atomOp.pointer().getType();
|
2019-12-05 10:05:54 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
|
|
|
|
// According to the spec:
|
|
|
|
// "The type of Value must be the same as Result Type. The type of the value
|
|
|
|
// pointed to by Pointer must be the same as Result Type. This type must also
|
|
|
|
// match the type of Comparator."
|
2020-01-11 08:54:04 -08:00
|
|
|
if (atomOp.getType() != atomOp.value().getType())
|
2019-12-05 10:05:54 -08:00
|
|
|
return atomOp.emitOpError("value operand must have the same type as the op "
|
|
|
|
"result, but found ")
|
2020-01-11 08:54:04 -08:00
|
|
|
<< atomOp.value().getType() << " vs " << atomOp.getType();
|
2019-12-05 10:05:54 -08:00
|
|
|
|
2020-01-11 08:54:04 -08:00
|
|
|
if (atomOp.getType() != atomOp.comparator().getType())
|
2019-12-05 10:05:54 -08:00
|
|
|
return atomOp.emitOpError(
|
|
|
|
"comparator operand must have the same type as the op "
|
|
|
|
"result, but found ")
|
2020-01-11 08:54:04 -08:00
|
|
|
<< atomOp.comparator().getType() << " vs " << atomOp.getType();
|
2019-12-05 10:05:54 -08:00
|
|
|
|
|
|
|
Type pointeeType =
|
2020-01-11 08:54:04 -08:00
|
|
|
atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
|
2019-12-05 10:05:54 -08:00
|
|
|
if (atomOp.getType() != pointeeType)
|
|
|
|
return atomOp.emitOpError(
|
|
|
|
"pointer operand's pointee type must have the same "
|
|
|
|
"as the op result type, but found ")
|
|
|
|
<< pointeeType << " vs " << atomOp.getType();
|
|
|
|
|
|
|
|
// TODO(antiagainst): Unequal cannot be set to Release or Acquire and Release.
|
|
|
|
// In addition, Unequal cannot be set to a stronger memory-order then Equal.
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-25 19:01:18 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.BitcastOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::BitcastOp bitcastOp) {
|
|
|
|
// TODO: The SPIR-V spec validation rules are different for different
|
|
|
|
// versions.
|
2020-01-11 08:54:04 -08:00
|
|
|
auto operandType = bitcastOp.operand().getType();
|
|
|
|
auto resultType = bitcastOp.result().getType();
|
2019-09-25 19:01:18 -07:00
|
|
|
if (operandType == resultType) {
|
|
|
|
return bitcastOp.emitError(
|
|
|
|
"result type must be different from operand type");
|
|
|
|
}
|
|
|
|
if (operandType.isa<spirv::PointerType>() &&
|
|
|
|
!resultType.isa<spirv::PointerType>()) {
|
|
|
|
return bitcastOp.emitError(
|
|
|
|
"unhandled bit cast conversion from pointer type to non-pointer type");
|
|
|
|
}
|
|
|
|
if (!operandType.isa<spirv::PointerType>() &&
|
|
|
|
resultType.isa<spirv::PointerType>()) {
|
|
|
|
return bitcastOp.emitError(
|
|
|
|
"unhandled bit cast conversion from non-pointer type to pointer type");
|
|
|
|
}
|
|
|
|
auto operandBitWidth = getBitWidth(operandType);
|
|
|
|
auto resultBitWidth = getBitWidth(resultType);
|
|
|
|
if (operandBitWidth != resultBitWidth) {
|
|
|
|
return bitcastOp.emitOpError("mismatch in result type bitwidth ")
|
|
|
|
<< resultBitWidth << " and operand type bitwidth "
|
|
|
|
<< operandBitWidth;
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-03-05 12:40:23 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.BranchOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Optional<OperandRange> spirv::BranchOp::getSuccessorOperands(unsigned index) {
|
|
|
|
assert(index == 0 && "invalid successor index");
|
|
|
|
return getOperands();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool spirv::BranchOp::canEraseSuccessorOperand() { return true; }
|
|
|
|
|
2019-08-30 12:17:21 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.BranchConditionalOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-03-05 12:40:23 -08:00
|
|
|
Optional<OperandRange>
|
|
|
|
spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
|
|
|
|
assert(index < 2 && "invalid successor index");
|
|
|
|
return index == kTrueIndex ? getTrueBlockArguments()
|
|
|
|
: getFalseBlockArguments();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; }
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &state) {
|
2019-09-20 11:36:49 -07:00
|
|
|
auto &builder = parser.getBuilder();
|
2019-08-30 12:17:21 -07:00
|
|
|
OpAsmParser::OperandType condInfo;
|
|
|
|
Block *dest;
|
|
|
|
|
|
|
|
// Parse the condition.
|
|
|
|
Type boolTy = builder.getI1Type();
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(condInfo) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperand(condInfo, boolTy, state.operands))
|
2019-08-30 12:17:21 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Parse the optional branch weights.
|
2019-09-20 11:36:49 -07:00
|
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
2019-08-30 12:17:21 -07:00
|
|
|
IntegerAttr trueWeight, falseWeight;
|
|
|
|
SmallVector<NamedAttribute, 2> weights;
|
|
|
|
|
|
|
|
auto i32Type = builder.getIntegerType(32);
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
|
|
|
|
parser.parseComma() ||
|
|
|
|
parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
|
|
|
|
parser.parseRSquare())
|
2019-08-30 12:17:21 -07:00
|
|
|
return failure();
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addAttribute(kBranchWeightAttrName,
|
|
|
|
builder.getArrayAttr({trueWeight, falseWeight}));
|
2019-08-30 12:17:21 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// Parse the true branch.
|
2020-03-05 12:48:28 -08:00
|
|
|
SmallVector<Value, 4> trueOperands;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseComma() ||
|
2020-03-05 12:48:28 -08:00
|
|
|
parser.parseSuccessorAndUseList(dest, trueOperands))
|
2019-08-30 12:17:21 -07:00
|
|
|
return failure();
|
2020-03-05 12:48:28 -08:00
|
|
|
state.addSuccessors(dest);
|
|
|
|
state.addOperands(trueOperands);
|
2019-08-30 12:17:21 -07:00
|
|
|
|
|
|
|
// Parse the false branch.
|
2020-03-05 12:48:28 -08:00
|
|
|
SmallVector<Value, 4> falseOperands;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseComma() ||
|
2020-03-05 12:48:28 -08:00
|
|
|
parser.parseSuccessorAndUseList(dest, falseOperands))
|
2019-08-30 12:17:21 -07:00
|
|
|
return failure();
|
2020-03-05 12:48:28 -08:00
|
|
|
state.addSuccessors(dest);
|
|
|
|
state.addOperands(falseOperands);
|
|
|
|
state.addAttribute(
|
|
|
|
spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
|
|
|
|
builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
|
|
|
|
static_cast<int32_t>(falseOperands.size())}));
|
2019-08-30 12:17:21 -07:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
|
2019-12-12 15:31:39 -08:00
|
|
|
printer << spirv::BranchConditionalOp::getOperationName() << ' '
|
|
|
|
<< branchOp.condition();
|
2019-08-30 12:17:21 -07:00
|
|
|
|
|
|
|
if (auto weights = branchOp.branch_weights()) {
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << " [";
|
|
|
|
interleaveComma(weights->getValue(), printer, [&](Attribute a) {
|
|
|
|
printer << a.cast<IntegerAttr>().getInt();
|
2019-09-11 14:02:23 -07:00
|
|
|
});
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << "]";
|
2019-08-30 12:17:21 -07:00
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << ", ";
|
2020-03-05 12:48:28 -08:00
|
|
|
printer.printSuccessorAndUseList(branchOp.getTrueBlock(),
|
|
|
|
branchOp.getTrueBlockArguments());
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << ", ";
|
2020-03-05 12:48:28 -08:00
|
|
|
printer.printSuccessorAndUseList(branchOp.getFalseBlock(),
|
|
|
|
branchOp.getFalseBlockArguments());
|
2019-08-30 12:17:21 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
|
|
|
|
if (auto weights = branchOp.branch_weights()) {
|
|
|
|
if (weights->getValue().size() != 2) {
|
|
|
|
return branchOp.emitOpError("must have exactly two branch weights");
|
|
|
|
}
|
|
|
|
if (llvm::all_of(*weights, [](Attribute attr) {
|
|
|
|
return attr.cast<IntegerAttr>().getValue().isNullValue();
|
|
|
|
}))
|
|
|
|
return branchOp.emitOpError("branch weights cannot both be zero");
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-09 12:43:23 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.CompositeConstruct
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> operands;
|
|
|
|
Type type;
|
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
|
|
|
|
if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
auto cType = type.dyn_cast<spirv::CompositeType>();
|
|
|
|
if (!cType) {
|
|
|
|
return parser.emitError(
|
|
|
|
loc, "result type must be a composite type, but provided ")
|
|
|
|
<< type;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (operands.size() != cType.getNumElements()) {
|
|
|
|
return parser.emitError(loc, "has incorrect number of operands: expected ")
|
|
|
|
<< cType.getNumElements() << ", but provided " << operands.size();
|
|
|
|
}
|
|
|
|
// TODO: Add support for constructing a vector type from the vector operands.
|
|
|
|
// According to the spec: "for constructing a vector, the operands may
|
|
|
|
// also be vectors with the same component type as the Result Type component
|
|
|
|
// type".
|
|
|
|
SmallVector<Type, 4> elementTypes;
|
|
|
|
elementTypes.reserve(cType.getNumElements());
|
|
|
|
for (auto index : llvm::seq<uint32_t>(0, cType.getNumElements())) {
|
|
|
|
elementTypes.push_back(cType.getElementType(index));
|
|
|
|
}
|
|
|
|
state.addTypes(type);
|
|
|
|
return parser.resolveOperands(operands, elementTypes, loc, state.operands);
|
|
|
|
}
|
|
|
|
|
|
|
|
static void print(spirv::CompositeConstructOp compositeConstructOp,
|
|
|
|
OpAsmPrinter &printer) {
|
2019-12-12 15:31:39 -08:00
|
|
|
printer << spirv::CompositeConstructOp::getOperationName() << " "
|
|
|
|
<< compositeConstructOp.constituents() << " : "
|
2020-01-11 08:54:04 -08:00
|
|
|
<< compositeConstructOp.getResult().getType();
|
2019-12-09 12:43:23 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
|
|
|
|
auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();
|
|
|
|
|
2019-12-23 14:45:01 -08:00
|
|
|
SmallVector<Value, 4> constituents(compositeConstructOp.constituents());
|
2019-12-09 12:43:23 -08:00
|
|
|
if (constituents.size() != cType.getNumElements()) {
|
|
|
|
return compositeConstructOp.emitError(
|
|
|
|
"has incorrect number of operands: expected ")
|
|
|
|
<< cType.getNumElements() << ", but provided "
|
|
|
|
<< constituents.size();
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
2020-01-11 08:54:04 -08:00
|
|
|
if (constituents[index].getType() != cType.getElementType(index)) {
|
2019-12-09 12:43:23 -08:00
|
|
|
return compositeConstructOp.emitError(
|
|
|
|
"operand type mismatch: expected operand type ")
|
|
|
|
<< cType.getElementType(index) << ", but provided "
|
2020-01-11 08:54:04 -08:00
|
|
|
<< constituents[index].getType();
|
2019-12-09 12:43:23 -08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-07-12 06:14:53 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.CompositeExtractOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-12-10 10:11:19 -08:00
|
|
|
void spirv::CompositeExtractOp::build(Builder *builder, OperationState &state,
|
2019-12-23 14:45:01 -08:00
|
|
|
Value composite,
|
2019-12-10 10:11:19 -08:00
|
|
|
ArrayRef<int32_t> indices) {
|
|
|
|
auto indexAttr = builder->getI32ArrayAttr(indices);
|
|
|
|
auto elementType =
|
2020-01-11 08:54:04 -08:00
|
|
|
getElementType(composite.getType(), indexAttr, state.location);
|
2019-12-10 10:11:19 -08:00
|
|
|
if (!elementType) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
build(builder, state, elementType, composite, indexAttr);
|
|
|
|
}
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseCompositeExtractOp(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &state) {
|
2019-07-12 06:14:53 -07:00
|
|
|
OpAsmParser::OperandType compositeInfo;
|
|
|
|
Attribute indicesAttr;
|
|
|
|
Type compositeType;
|
|
|
|
llvm::SMLoc attrLocation;
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(compositeInfo) ||
|
|
|
|
parser.getCurrentLocation(&attrLocation) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseColonType(compositeType) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
|
2019-07-12 06:14:53 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-12-10 10:11:19 -08:00
|
|
|
Type resultType =
|
|
|
|
getElementType(compositeType, indicesAttr, parser, attrLocation);
|
|
|
|
if (!resultType) {
|
|
|
|
return failure();
|
2019-07-12 06:14:53 -07:00
|
|
|
}
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addTypes(resultType);
|
2019-07-12 06:14:53 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
static void print(spirv::CompositeExtractOp compositeExtractOp,
|
2019-09-20 20:43:02 -07:00
|
|
|
OpAsmPrinter &printer) {
|
|
|
|
printer << spirv::CompositeExtractOp::getOperationName() << ' '
|
2020-01-11 08:54:04 -08:00
|
|
|
<< compositeExtractOp.composite() << compositeExtractOp.indices()
|
|
|
|
<< " : " << compositeExtractOp.composite().getType();
|
2019-07-12 06:14:53 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
|
|
|
|
auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
|
2020-01-11 08:54:04 -08:00
|
|
|
auto resultType = getElementType(compExOp.composite().getType(),
|
2019-12-05 13:10:10 -08:00
|
|
|
indicesArrayAttr, compExOp.getLoc());
|
|
|
|
if (!resultType)
|
|
|
|
return failure();
|
2019-07-12 06:14:53 -07:00
|
|
|
|
|
|
|
if (resultType != compExOp.getType()) {
|
2019-07-12 13:53:57 -07:00
|
|
|
return compExOp.emitOpError("invalid result type: expected ")
|
|
|
|
<< resultType << " but provided " << compExOp.getType();
|
2019-07-12 06:14:53 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-05 13:10:10 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.CompositeInsert
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> operands;
|
|
|
|
Type objectType, compositeType;
|
|
|
|
Attribute indicesAttr;
|
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
|
|
|
|
return failure(
|
|
|
|
parser.parseOperandList(operands, 2) ||
|
|
|
|
parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
|
|
|
|
parser.parseColonType(objectType) ||
|
|
|
|
parser.parseKeywordType("into", compositeType) ||
|
|
|
|
parser.resolveOperands(operands, {objectType, compositeType}, loc,
|
|
|
|
state.operands) ||
|
|
|
|
parser.addTypesToList(compositeType, state.types));
|
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
|
|
|
|
auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>();
|
|
|
|
auto objectType =
|
2020-01-11 08:54:04 -08:00
|
|
|
getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr,
|
2019-12-05 13:10:10 -08:00
|
|
|
compositeInsertOp.getLoc());
|
|
|
|
if (!objectType)
|
|
|
|
return failure();
|
|
|
|
|
2020-01-11 08:54:04 -08:00
|
|
|
if (objectType != compositeInsertOp.object().getType()) {
|
2019-12-05 13:10:10 -08:00
|
|
|
return compositeInsertOp.emitOpError("object operand type should be ")
|
|
|
|
<< objectType << ", but found "
|
2020-01-11 08:54:04 -08:00
|
|
|
<< compositeInsertOp.object().getType();
|
2019-12-05 13:10:10 -08:00
|
|
|
}
|
|
|
|
|
2020-01-11 08:54:04 -08:00
|
|
|
if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) {
|
2019-12-05 13:10:10 -08:00
|
|
|
return compositeInsertOp.emitOpError("result type should be the same as "
|
|
|
|
"the composite type, but found ")
|
2020-01-11 08:54:04 -08:00
|
|
|
<< compositeInsertOp.composite().getType() << " vs "
|
2019-12-05 13:10:10 -08:00
|
|
|
<< compositeInsertOp.getType();
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
static void print(spirv::CompositeInsertOp compositeInsertOp,
|
|
|
|
OpAsmPrinter &printer) {
|
|
|
|
printer << spirv::CompositeInsertOp::getOperationName() << " "
|
2020-01-11 08:54:04 -08:00
|
|
|
<< compositeInsertOp.object() << ", " << compositeInsertOp.composite()
|
|
|
|
<< compositeInsertOp.indices() << " : "
|
|
|
|
<< compositeInsertOp.object().getType() << " into "
|
|
|
|
<< compositeInsertOp.composite().getType();
|
2019-12-05 13:10:10 -08:00
|
|
|
}
|
|
|
|
|
2019-06-17 14:47:22 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.constant
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
|
2019-06-17 14:47:22 -07:00
|
|
|
Attribute value;
|
2019-09-20 19:47:05 -07:00
|
|
|
if (parser.parseAttribute(value, kValueAttrName, state.attributes))
|
2019-06-17 14:47:22 -07:00
|
|
|
return failure();
|
|
|
|
|
2019-11-18 20:01:28 -08:00
|
|
|
Type type = value.getType();
|
|
|
|
if (type.isa<NoneType>() || type.isa<TensorType>()) {
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseColonType(type))
|
2019-06-17 14:47:22 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
return parser.addTypeToList(type, state.types);
|
2019-06-17 14:47:22 -07:00
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
|
|
|
|
printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
|
2019-12-12 15:31:39 -08:00
|
|
|
if (constOp.getType().isa<spirv::ArrayType>())
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << " : " << constOp.getType();
|
2019-06-17 14:47:22 -07:00
|
|
|
}
|
|
|
|
|
2019-06-18 11:15:55 -07:00
|
|
|
static LogicalResult verify(spirv::ConstantOp constOp) {
|
2019-06-17 14:47:22 -07:00
|
|
|
auto opType = constOp.getType();
|
|
|
|
auto value = constOp.value();
|
|
|
|
auto valueType = value.getType();
|
|
|
|
|
|
|
|
// ODS already generates checks to make sure the result type is valid. We just
|
|
|
|
// need to additionally check that the value's attribute type is consistent
|
|
|
|
// with the result type.
|
|
|
|
switch (value.getKind()) {
|
|
|
|
case StandardAttributes::Bool:
|
|
|
|
case StandardAttributes::Integer:
|
2019-11-18 20:01:28 -08:00
|
|
|
case StandardAttributes::Float: {
|
2019-06-17 14:47:22 -07:00
|
|
|
if (valueType != opType)
|
|
|
|
return constOp.emitOpError("result type (")
|
|
|
|
<< opType << ") does not match value type (" << valueType << ")";
|
|
|
|
return success();
|
|
|
|
} break;
|
2019-11-18 20:01:28 -08:00
|
|
|
case StandardAttributes::DenseElements:
|
|
|
|
case StandardAttributes::SparseElements: {
|
|
|
|
if (valueType == opType)
|
|
|
|
break;
|
|
|
|
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
|
|
|
|
auto shapedType = valueType.dyn_cast<ShapedType>();
|
|
|
|
if (!arrayType) {
|
|
|
|
return constOp.emitOpError(
|
|
|
|
"must have spv.array result type for array value");
|
|
|
|
}
|
|
|
|
|
|
|
|
int numElements = arrayType.getNumElements();
|
|
|
|
auto opElemType = arrayType.getElementType();
|
|
|
|
while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
|
|
|
|
numElements *= t.getNumElements();
|
|
|
|
opElemType = t.getElementType();
|
|
|
|
}
|
2020-03-04 15:12:33 -05:00
|
|
|
if (!opElemType.isIntOrFloat()) {
|
2019-11-18 20:01:28 -08:00
|
|
|
return constOp.emitOpError("only support nested array result type");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto valueElemType = shapedType.getElementType();
|
|
|
|
if (valueElemType != opElemType) {
|
|
|
|
return constOp.emitOpError("result element type (")
|
|
|
|
<< opElemType << ") does not match value element type ("
|
|
|
|
<< valueElemType << ")";
|
|
|
|
}
|
|
|
|
|
|
|
|
if (numElements != shapedType.getNumElements()) {
|
|
|
|
return constOp.emitOpError("result number of elements (")
|
|
|
|
<< numElements << ") does not match value number of elements ("
|
|
|
|
<< shapedType.getNumElements() << ")";
|
|
|
|
}
|
|
|
|
} break;
|
2019-06-17 14:47:22 -07:00
|
|
|
case StandardAttributes::Array: {
|
|
|
|
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
|
|
|
|
if (!arrayType)
|
|
|
|
return constOp.emitOpError(
|
|
|
|
"must have spv.array result type for array value");
|
|
|
|
auto elemType = arrayType.getElementType();
|
|
|
|
for (auto element : value.cast<ArrayAttr>().getValue()) {
|
|
|
|
if (element.getType() != elemType)
|
2019-09-03 12:09:07 -07:00
|
|
|
return constOp.emitOpError("has array element whose type (")
|
|
|
|
<< element.getType()
|
|
|
|
<< ") does not match the result element type (" << elemType
|
|
|
|
<< ')';
|
2019-06-17 14:47:22 -07:00
|
|
|
}
|
|
|
|
} break;
|
|
|
|
default:
|
|
|
|
return constOp.emitOpError("cannot have value of type ") << valueType;
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-03 12:09:07 -07:00
|
|
|
bool spirv::ConstantOp::isBuildableWith(Type type) {
|
|
|
|
// Must be valid SPIR-V type first.
|
|
|
|
if (!SPIRVDialect::isValidType(type))
|
|
|
|
return false;
|
|
|
|
|
|
|
|
if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
|
|
|
|
type.getKind() <= spirv::TypeKind::LAST_SPIRV_TYPE) {
|
2019-10-20 00:11:03 -07:00
|
|
|
// TODO(antiagainst): support constant struct
|
2019-09-03 12:09:07 -07:00
|
|
|
return type.isa<spirv::ArrayType>();
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2019-11-27 14:12:32 -08:00
|
|
|
spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
|
|
|
|
OpBuilder *builder) {
|
|
|
|
if (auto intType = type.dyn_cast<IntegerType>()) {
|
|
|
|
unsigned width = intType.getWidth();
|
|
|
|
if (width == 1)
|
|
|
|
return builder->create<spirv::ConstantOp>(loc, type,
|
|
|
|
builder->getBoolAttr(false));
|
|
|
|
return builder->create<spirv::ConstantOp>(
|
|
|
|
loc, type, builder->getIntegerAttr(type, APInt(width, 0)));
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm_unreachable("unimplemented types for ConstantOp::getZero()");
|
|
|
|
}
|
|
|
|
|
|
|
|
spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
|
|
|
|
OpBuilder *builder) {
|
|
|
|
if (auto intType = type.dyn_cast<IntegerType>()) {
|
|
|
|
unsigned width = intType.getWidth();
|
|
|
|
if (width == 1)
|
|
|
|
return builder->create<spirv::ConstantOp>(loc, type,
|
|
|
|
builder->getBoolAttr(true));
|
|
|
|
return builder->create<spirv::ConstantOp>(
|
|
|
|
loc, type, builder->getIntegerAttr(type, APInt(width, 1)));
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm_unreachable("unimplemented types for ConstantOp::getOne()");
|
|
|
|
}
|
|
|
|
|
2019-07-08 10:56:20 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.EntryPoint
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-10-04 14:02:14 -07:00
|
|
|
void spirv::EntryPointOp::build(Builder *builder, OperationState &state,
|
|
|
|
spirv::ExecutionModel executionModel,
|
2020-02-07 11:30:19 -05:00
|
|
|
spirv::FuncOp function,
|
2019-10-04 14:02:14 -07:00
|
|
|
ArrayRef<Attribute> interfaceVars) {
|
|
|
|
build(builder, state,
|
|
|
|
builder->getI32IntegerAttr(static_cast<int32_t>(executionModel)),
|
|
|
|
builder->getSymbolRefAttr(function),
|
|
|
|
builder->getArrayAttr(interfaceVars));
|
|
|
|
}
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseEntryPointOp(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &state) {
|
2019-07-08 10:56:20 -07:00
|
|
|
spirv::ExecutionModel execModel;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 0> identifiers;
|
|
|
|
SmallVector<Type, 0> idTypes;
|
2019-10-04 14:02:14 -07:00
|
|
|
SmallVector<Attribute, 4> interfaceVars;
|
2019-07-08 10:56:20 -07:00
|
|
|
|
2019-11-11 18:18:02 -08:00
|
|
|
FlatSymbolRefAttr fn;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(execModel, parser, state) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
|
2019-07-08 10:56:20 -07:00
|
|
|
return failure();
|
|
|
|
}
|
2019-08-17 10:19:48 -07:00
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
if (!parser.parseOptionalComma()) {
|
2019-08-17 10:19:48 -07:00
|
|
|
// Parse the interface variables
|
|
|
|
do {
|
|
|
|
// The name of the interface variable attribute isnt important
|
|
|
|
auto attrName = "var_symbol";
|
2019-11-11 18:18:02 -08:00
|
|
|
FlatSymbolRefAttr var;
|
2019-08-17 10:19:48 -07:00
|
|
|
SmallVector<NamedAttribute, 1> attrs;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseAttribute(var, Type(), attrName, attrs)) {
|
2019-08-17 10:19:48 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
interfaceVars.push_back(var);
|
2019-09-20 11:36:49 -07:00
|
|
|
} while (!parser.parseOptionalComma());
|
2019-07-08 10:56:20 -07:00
|
|
|
}
|
2019-10-04 14:02:14 -07:00
|
|
|
state.addAttribute(kInterfaceAttrName,
|
|
|
|
parser.getBuilder().getArrayAttr(interfaceVars));
|
2019-07-08 10:56:20 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) {
|
|
|
|
printer << spirv::EntryPointOp::getOperationName() << " \""
|
2019-10-08 17:44:39 -07:00
|
|
|
<< stringifyExecutionModel(entryPointOp.execution_model()) << "\" ";
|
|
|
|
printer.printSymbolName(entryPointOp.fn());
|
2019-10-04 14:02:14 -07:00
|
|
|
auto interfaceVars = entryPointOp.interface().getValue();
|
|
|
|
if (!interfaceVars.empty()) {
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << ", ";
|
2019-10-04 14:02:14 -07:00
|
|
|
interleaveComma(interfaceVars, printer);
|
2019-07-08 10:56:20 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
|
2019-08-17 10:19:48 -07:00
|
|
|
// Checks for fn and interface symbol reference are done in spirv::ModuleOp
|
|
|
|
// verification.
|
2019-07-08 10:56:20 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ExecutionMode
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-10-04 14:02:14 -07:00
|
|
|
void spirv::ExecutionModeOp::build(Builder *builder, OperationState &state,
|
2020-02-07 11:30:19 -05:00
|
|
|
spirv::FuncOp function,
|
2019-10-04 14:02:14 -07:00
|
|
|
spirv::ExecutionMode executionMode,
|
|
|
|
ArrayRef<int32_t> params) {
|
|
|
|
build(builder, state, builder->getSymbolRefAttr(function),
|
|
|
|
builder->getI32IntegerAttr(static_cast<int32_t>(executionMode)),
|
|
|
|
builder->getI32ArrayAttr(params));
|
|
|
|
}
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseExecutionModeOp(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &state) {
|
2019-07-08 10:56:20 -07:00
|
|
|
spirv::ExecutionMode execMode;
|
2019-07-19 07:30:15 -07:00
|
|
|
Attribute fn;
|
2019-09-20 19:47:05 -07:00
|
|
|
if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
|
2020-03-11 16:04:25 -04:00
|
|
|
parseEnumStrAttr(execMode, parser, state)) {
|
2019-07-08 10:56:20 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int32_t, 4> values;
|
2019-09-20 11:36:49 -07:00
|
|
|
Type i32Type = parser.getBuilder().getIntegerType(32);
|
|
|
|
while (!parser.parseOptionalComma()) {
|
2019-07-08 10:56:20 -07:00
|
|
|
SmallVector<NamedAttribute, 1> attr;
|
|
|
|
Attribute value;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseAttribute(value, i32Type, "value", attr)) {
|
2019-07-08 10:56:20 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
values.push_back(value.cast<IntegerAttr>().getInt());
|
|
|
|
}
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addAttribute(kValuesAttrName,
|
|
|
|
parser.getBuilder().getI32ArrayAttr(values));
|
2019-07-08 10:56:20 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
|
2020-02-07 11:30:19 -05:00
|
|
|
printer << spirv::ExecutionModeOp::getOperationName() << " ";
|
|
|
|
printer.printSymbolName(execModeOp.fn());
|
|
|
|
printer << " \"" << stringifyExecutionMode(execModeOp.execution_mode())
|
|
|
|
<< "\"";
|
2019-07-08 10:56:20 -07:00
|
|
|
auto values = execModeOp.values();
|
2019-12-12 15:31:39 -08:00
|
|
|
if (!values.size())
|
2019-07-08 10:56:20 -07:00
|
|
|
return;
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << ", ";
|
2019-10-04 14:02:14 -07:00
|
|
|
interleaveComma(values, printer, [&](Attribute a) {
|
|
|
|
printer << a.cast<IntegerAttr>().getInt();
|
|
|
|
});
|
2019-07-08 10:56:20 -07:00
|
|
|
}
|
|
|
|
|
2020-02-07 11:30:19 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.func
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> entryArgs;
|
|
|
|
SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
|
|
|
|
SmallVector<SmallVector<NamedAttribute, 2>, 4> resultAttrs;
|
|
|
|
SmallVector<Type, 4> argTypes;
|
|
|
|
SmallVector<Type, 4> resultTypes;
|
|
|
|
auto &builder = parser.getBuilder();
|
|
|
|
|
|
|
|
// Parse the name as a symbol.
|
|
|
|
StringAttr nameAttr;
|
|
|
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
|
|
|
state.attributes))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Parse the function signature.
|
|
|
|
bool isVariadic = false;
|
|
|
|
if (impl::parseFunctionSignature(parser, /*allowVariadic=*/false, entryArgs,
|
|
|
|
argTypes, argAttrs, isVariadic, resultTypes,
|
|
|
|
resultAttrs))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto fnType = builder.getFunctionType(argTypes, resultTypes);
|
|
|
|
state.addAttribute(impl::getTypeAttrName(), TypeAttr::get(fnType));
|
|
|
|
|
|
|
|
// Parse the optional function control keyword.
|
|
|
|
spirv::FunctionControl fnControl;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(fnControl, parser, state))
|
2020-02-07 11:30:19 -05:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
// If additional attributes are present, parse them.
|
|
|
|
if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Add the attributes to the function arguments.
|
|
|
|
assert(argAttrs.size() == argTypes.size());
|
|
|
|
assert(resultAttrs.size() == resultTypes.size());
|
|
|
|
impl::addArgAndResultAttrs(builder, state, argAttrs, resultAttrs);
|
|
|
|
|
|
|
|
// Parse the optional function body.
|
|
|
|
auto *body = state.addRegion();
|
|
|
|
return parser.parseOptionalRegion(
|
|
|
|
*body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
|
|
|
|
}
|
|
|
|
|
|
|
|
static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) {
|
|
|
|
// Print function name, signature, and control.
|
|
|
|
printer << spirv::FuncOp::getOperationName() << " ";
|
|
|
|
printer.printSymbolName(fnOp.sym_name());
|
|
|
|
auto fnType = fnOp.getType();
|
|
|
|
impl::printFunctionSignature(printer, fnOp, fnType.getInputs(),
|
|
|
|
/*isVariadic=*/false, fnType.getResults());
|
|
|
|
printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control())
|
|
|
|
<< "\"";
|
|
|
|
impl::printFunctionAttributes(
|
|
|
|
printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(),
|
|
|
|
{spirv::attributeName<spirv::FunctionControl>()});
|
|
|
|
|
|
|
|
// Print the body if this is not an external function.
|
|
|
|
Region &body = fnOp.body();
|
|
|
|
if (!body.empty())
|
|
|
|
printer.printRegion(body, /*printEntryBlockArgs=*/false,
|
|
|
|
/*printBlockTerminators=*/true);
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult spirv::FuncOp::verifyType() {
|
|
|
|
auto type = getTypeAttr().getValue();
|
|
|
|
if (!type.isa<FunctionType>())
|
|
|
|
return emitOpError("requires '" + getTypeAttrName() +
|
|
|
|
"' attribute of function type");
|
|
|
|
if (getType().getNumResults() > 1)
|
|
|
|
return emitOpError("cannot have more than one result");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult spirv::FuncOp::verifyBody() {
|
|
|
|
FunctionType fnType = getType();
|
|
|
|
|
|
|
|
auto walkResult = walk([fnType](Operation *op) -> WalkResult {
|
|
|
|
if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
|
|
|
|
if (fnType.getNumResults() != 0)
|
|
|
|
return retOp.emitOpError("cannot be used in functions returning value");
|
|
|
|
} else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
|
|
|
|
if (fnType.getNumResults() != 1)
|
|
|
|
return retOp.emitOpError(
|
|
|
|
"returns 1 value but enclosing function requires ")
|
|
|
|
<< fnType.getNumResults() << " results";
|
|
|
|
|
|
|
|
auto retOperandType = retOp.value().getType();
|
|
|
|
auto fnResultType = fnType.getResult(0);
|
|
|
|
if (retOperandType != fnResultType)
|
|
|
|
return retOp.emitOpError(" return value's type (")
|
|
|
|
<< retOperandType << ") mismatch with function's result type ("
|
|
|
|
<< fnResultType << ")";
|
|
|
|
}
|
|
|
|
return WalkResult::advance();
|
|
|
|
});
|
|
|
|
|
|
|
|
// TODO(antiagainst): verify other bits like linkage type.
|
|
|
|
|
|
|
|
return failure(walkResult.wasInterrupted());
|
|
|
|
}
|
|
|
|
|
|
|
|
void spirv::FuncOp::build(Builder *builder, OperationState &state,
|
|
|
|
StringRef name, FunctionType type,
|
|
|
|
spirv::FunctionControl control,
|
|
|
|
ArrayRef<NamedAttribute> attrs) {
|
|
|
|
state.addAttribute(SymbolTable::getSymbolAttrName(),
|
|
|
|
builder->getStringAttr(name));
|
|
|
|
state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
|
|
|
|
state.addAttribute(
|
|
|
|
spirv::attributeName<spirv::FunctionControl>(),
|
|
|
|
builder->getI32IntegerAttr(static_cast<uint32_t>(control)));
|
|
|
|
state.attributes.append(attrs.begin(), attrs.end());
|
|
|
|
state.addRegion();
|
|
|
|
}
|
|
|
|
|
|
|
|
// CallableOpInterface
|
|
|
|
Region *spirv::FuncOp::getCallableRegion() {
|
|
|
|
return isExternal() ? nullptr : &body();
|
|
|
|
}
|
|
|
|
|
|
|
|
// CallableOpInterface
|
|
|
|
ArrayRef<Type> spirv::FuncOp::getCallableResults() {
|
|
|
|
return getType().getResults();
|
|
|
|
}
|
|
|
|
|
2019-09-16 15:39:16 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-09-16 17:11:50 -07:00
|
|
|
// spv.FunctionCall
|
2019-09-16 15:39:16 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
|
|
|
|
auto fnName = functionCallOp.callee();
|
|
|
|
|
2020-02-07 11:30:19 -05:00
|
|
|
auto funcOp =
|
|
|
|
dyn_cast_or_null<spirv::FuncOp>(SymbolTable::lookupNearestSymbolFrom(
|
|
|
|
functionCallOp.getParentOp(), fnName));
|
2019-09-16 15:39:16 -07:00
|
|
|
if (!funcOp) {
|
|
|
|
return functionCallOp.emitOpError("callee function '")
|
2020-01-25 09:16:29 -05:00
|
|
|
<< fnName << "' not found in nearest symbol table";
|
2019-09-16 15:39:16 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
auto functionType = funcOp.getType();
|
|
|
|
|
|
|
|
if (functionCallOp.getNumResults() > 1) {
|
|
|
|
return functionCallOp.emitOpError(
|
|
|
|
"expected callee function to have 0 or 1 result, but provided ")
|
|
|
|
<< functionCallOp.getNumResults();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (functionType.getNumInputs() != functionCallOp.getNumOperands()) {
|
|
|
|
return functionCallOp.emitOpError(
|
|
|
|
"has incorrect number of operands for callee: expected ")
|
|
|
|
<< functionType.getNumInputs() << ", but provided "
|
|
|
|
<< functionCallOp.getNumOperands();
|
|
|
|
}
|
|
|
|
|
|
|
|
for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
|
2020-01-11 08:54:04 -08:00
|
|
|
if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) {
|
2019-09-16 15:39:16 -07:00
|
|
|
return functionCallOp.emitOpError(
|
|
|
|
"operand type mismatch: expected operand type ")
|
|
|
|
<< functionType.getInput(i) << ", but provided "
|
2020-01-11 08:54:04 -08:00
|
|
|
<< functionCallOp.getOperand(i).getType() << " for operand number "
|
|
|
|
<< i;
|
2019-09-16 15:39:16 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (functionType.getNumResults() != functionCallOp.getNumResults()) {
|
|
|
|
return functionCallOp.emitOpError(
|
|
|
|
"has incorrect number of results has for callee: expected ")
|
|
|
|
<< functionType.getNumResults() << ", but provided "
|
|
|
|
<< functionCallOp.getNumResults();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (functionCallOp.getNumResults() &&
|
2020-01-11 08:54:04 -08:00
|
|
|
(functionCallOp.getResult(0).getType() != functionType.getResult(0))) {
|
2019-09-16 15:39:16 -07:00
|
|
|
return functionCallOp.emitOpError("result type mismatch: expected ")
|
|
|
|
<< functionType.getResult(0) << ", but provided "
|
2020-01-11 08:54:04 -08:00
|
|
|
<< functionCallOp.getResult(0).getType();
|
2019-09-16 15:39:16 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-10-16 17:36:58 -07:00
|
|
|
CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
|
|
|
|
return getAttrOfType<SymbolRefAttr>(kCallee);
|
|
|
|
}
|
|
|
|
|
|
|
|
Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
|
|
|
|
return arguments();
|
|
|
|
}
|
|
|
|
|
2019-08-17 10:19:48 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.globalVariable
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-11-25 10:38:31 -08:00
|
|
|
void spirv::GlobalVariableOp::build(Builder *builder, OperationState &state,
|
|
|
|
Type type, StringRef name,
|
|
|
|
unsigned descriptorSet, unsigned binding) {
|
|
|
|
build(builder, state, TypeAttr::get(type), builder->getStringAttr(name),
|
|
|
|
nullptr);
|
|
|
|
state.addAttribute(
|
|
|
|
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
|
|
|
|
builder->getI32IntegerAttr(descriptorSet));
|
|
|
|
state.addAttribute(
|
|
|
|
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
|
|
|
|
builder->getI32IntegerAttr(binding));
|
|
|
|
}
|
|
|
|
|
2019-12-10 10:11:19 -08:00
|
|
|
void spirv::GlobalVariableOp::build(Builder *builder, OperationState &state,
|
|
|
|
Type type, StringRef name,
|
|
|
|
spirv::BuiltIn builtin) {
|
|
|
|
build(builder, state, TypeAttr::get(type), builder->getStringAttr(name),
|
|
|
|
nullptr);
|
|
|
|
state.addAttribute(
|
|
|
|
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
|
|
|
|
builder->getStringAttr(spirv::stringifyBuiltIn(builtin)));
|
|
|
|
}
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &state) {
|
2019-08-17 10:19:48 -07:00
|
|
|
// Parse variable name.
|
|
|
|
StringAttr nameAttr;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
2019-09-20 19:47:05 -07:00
|
|
|
state.attributes)) {
|
2019-08-17 10:19:48 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parse optional initializer
|
2019-09-20 11:36:49 -07:00
|
|
|
if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
|
2019-11-11 18:18:02 -08:00
|
|
|
FlatSymbolRefAttr initSymbol;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseLParen() ||
|
|
|
|
parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
|
2019-09-20 19:47:05 -07:00
|
|
|
state.attributes) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseRParen())
|
2019-08-17 10:19:48 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (parseVariableDecorations(parser, state)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-08-19 11:38:53 -07:00
|
|
|
Type type;
|
2019-09-20 11:36:49 -07:00
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseColonType(type)) {
|
2019-08-19 11:38:53 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
if (!type.isa<spirv::PointerType>()) {
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(loc, "expected spv.ptr type");
|
2019-08-19 11:38:53 -07:00
|
|
|
}
|
2019-10-17 20:08:01 -07:00
|
|
|
state.addAttribute(kTypeAttrName, TypeAttr::get(type));
|
2019-08-19 11:38:53 -07:00
|
|
|
|
2019-08-17 10:19:48 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) {
|
2019-08-17 10:19:48 -07:00
|
|
|
auto *op = varOp.getOperation();
|
|
|
|
SmallVector<StringRef, 4> elidedAttrs{
|
|
|
|
spirv::attributeName<spirv::StorageClass>()};
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << spirv::GlobalVariableOp::getOperationName();
|
2019-08-17 10:19:48 -07:00
|
|
|
|
|
|
|
// Print variable name.
|
2019-10-08 17:44:39 -07:00
|
|
|
printer << ' ';
|
|
|
|
printer.printSymbolName(varOp.sym_name());
|
2019-08-17 10:19:48 -07:00
|
|
|
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
|
|
|
|
|
|
|
|
// Print optional initializer
|
|
|
|
if (auto initializer = varOp.initializer()) {
|
2019-10-08 17:44:39 -07:00
|
|
|
printer << " " << kInitializerAttrName << '(';
|
|
|
|
printer.printSymbolName(initializer.getValue());
|
|
|
|
printer << ')';
|
2019-08-17 10:19:48 -07:00
|
|
|
elidedAttrs.push_back(kInitializerAttrName);
|
|
|
|
}
|
2019-08-19 11:38:53 -07:00
|
|
|
|
|
|
|
elidedAttrs.push_back(kTypeAttrName);
|
2019-08-17 10:19:48 -07:00
|
|
|
printVariableDecorations(op, printer, elidedAttrs);
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << " : " << varOp.type();
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::GlobalVariableOp varOp) {
|
|
|
|
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
|
|
|
|
// object. It cannot be Generic. It must be the same as the Storage Class
|
|
|
|
// operand of the Result Type."
|
|
|
|
if (varOp.storageClass() == spirv::StorageClass::Generic)
|
|
|
|
return varOp.emitOpError("storage class cannot be 'Generic'");
|
|
|
|
|
2019-11-11 18:18:02 -08:00
|
|
|
if (auto init =
|
|
|
|
varOp.getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
|
2020-01-25 09:16:29 -05:00
|
|
|
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
|
|
|
|
varOp.getParentOp(), init.getValue());
|
2019-08-20 13:33:41 -07:00
|
|
|
// TODO: Currently only variable initialization with specialization
|
|
|
|
// constants and other variables is supported. They could be normal
|
|
|
|
// constants in the module scope as well.
|
|
|
|
if (!initOp || !(isa<spirv::GlobalVariableOp>(initOp) ||
|
|
|
|
isa<spirv::SpecConstantOp>(initOp))) {
|
|
|
|
return varOp.emitOpError("initializer must be result of a "
|
|
|
|
"spv.specConstant or spv.globalVariable op");
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-03 16:43:40 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformBallotOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
|
|
|
|
spirv::Scope scope = ballotOp.execution_scope();
|
|
|
|
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
|
|
|
return ballotOp.emitOpError(
|
|
|
|
"execution scope must be 'Workgroup' or 'Subgroup'");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-01-26 10:19:24 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformElectOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void spirv::GroupNonUniformElectOp::build(Builder *builder,
|
|
|
|
OperationState &state,
|
|
|
|
spirv::Scope scope) {
|
|
|
|
build(builder, state, builder->getI1Type(), scope);
|
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
|
|
|
|
spirv::Scope scope = groupOp.execution_scope();
|
|
|
|
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
|
|
|
return groupOp.emitOpError(
|
|
|
|
"execution scope must be 'Workgroup' or 'Subgroup'");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-06-24 10:59:05 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.LoadOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-10-04 14:02:14 -07:00
|
|
|
void spirv::LoadOp::build(Builder *builder, OperationState &state,
|
2019-12-23 14:45:01 -08:00
|
|
|
Value basePtr, IntegerAttr memory_access,
|
2019-10-04 14:02:14 -07:00
|
|
|
IntegerAttr alignment) {
|
2020-01-11 08:54:04 -08:00
|
|
|
auto ptrType = basePtr.getType().cast<spirv::PointerType>();
|
2019-10-04 14:02:14 -07:00
|
|
|
build(builder, state, ptrType.getPointeeType(), basePtr, memory_access,
|
|
|
|
alignment);
|
|
|
|
}
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// Parse the storage class specification
|
|
|
|
spirv::StorageClass storageClass;
|
|
|
|
OpAsmParser::OperandType ptrInfo;
|
|
|
|
Type elementType;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
|
2019-07-08 10:56:20 -07:00
|
|
|
parseMemoryAccessAttributes(parser, state) ||
|
2019-11-05 13:32:07 -08:00
|
|
|
parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
|
|
|
|
parser.parseType(elementType)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto ptrType = spirv::PointerType::get(elementType, storageClass);
|
2019-09-20 19:47:05 -07:00
|
|
|
if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addTypes(elementType);
|
2019-06-24 10:59:05 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
|
2019-06-24 10:59:05 -07:00
|
|
|
auto *op = loadOp.getOperation();
|
|
|
|
SmallVector<StringRef, 4> elidedAttrs;
|
2019-07-02 06:02:20 -07:00
|
|
|
StringRef sc = stringifyStorageClass(
|
2020-01-11 08:54:04 -08:00
|
|
|
loadOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
|
2019-12-12 15:31:39 -08:00
|
|
|
printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
|
|
|
|
<< loadOp.ptr();
|
2019-06-24 10:59:05 -07:00
|
|
|
|
|
|
|
printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
|
|
|
printer << " : " << loadOp.getType();
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::LoadOp loadOp) {
|
|
|
|
// SPIR-V spec : "Result Type is the type of the loaded object. It must be a
|
|
|
|
// type with fixed size; i.e., it cannot be, nor include, any
|
|
|
|
// OpTypeRuntimeArray types."
|
|
|
|
if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(),
|
|
|
|
loadOp.value()))) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
return verifyMemoryAccessAttribute(loadOp);
|
|
|
|
}
|
|
|
|
|
2019-09-05 12:45:08 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.loop
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-11-12 11:59:34 -08:00
|
|
|
void spirv::LoopOp::build(Builder *builder, OperationState &state) {
|
|
|
|
state.addAttribute("loop_control",
|
|
|
|
builder->getI32IntegerAttr(
|
|
|
|
static_cast<uint32_t>(spirv::LoopControl::None)));
|
|
|
|
state.addRegion();
|
|
|
|
}
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
|
2019-09-05 12:45:08 -07:00
|
|
|
// TODO(antiagainst): support loop control properly
|
2019-09-20 11:36:49 -07:00
|
|
|
Builder builder = parser.getBuilder();
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addAttribute("loop_control",
|
|
|
|
builder.getI32IntegerAttr(
|
|
|
|
static_cast<uint32_t>(spirv::LoopControl::None)));
|
2019-09-05 12:45:08 -07:00
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
|
2019-09-20 11:36:49 -07:00
|
|
|
/*argTypes=*/{});
|
2019-09-05 12:45:08 -07:00
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
|
2019-09-05 12:45:08 -07:00
|
|
|
auto *op = loopOp.getOperation();
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << spirv::LoopOp::getOperationName();
|
|
|
|
printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
|
|
|
|
/*printBlockTerminators=*/true);
|
2019-09-05 12:45:08 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
|
|
|
|
/// given `dstBlock`.
|
|
|
|
static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
|
|
|
|
// Check that there is only one op in the `srcBlock`.
|
2019-10-30 11:13:52 -07:00
|
|
|
if (!has_single_element(srcBlock))
|
2019-09-05 12:45:08 -07:00
|
|
|
return false;
|
|
|
|
|
|
|
|
auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
|
2020-03-05 12:39:46 -08:00
|
|
|
return branchOp && branchOp.getSuccessor() == &dstBlock;
|
2019-09-05 12:45:08 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::LoopOp loopOp) {
|
|
|
|
auto *op = loopOp.getOperation();
|
|
|
|
|
|
|
|
// We need to verify that the blocks follow the following layout:
|
|
|
|
//
|
|
|
|
// +-------------+
|
|
|
|
// | entry block |
|
|
|
|
// +-------------+
|
|
|
|
// |
|
|
|
|
// v
|
|
|
|
// +-------------+
|
|
|
|
// | loop header | <-----+
|
|
|
|
// +-------------+ |
|
|
|
|
// |
|
|
|
|
// ... |
|
|
|
|
// \ | / |
|
|
|
|
// v |
|
|
|
|
// +---------------+ |
|
|
|
|
// | loop continue | -----+
|
|
|
|
// +---------------+
|
|
|
|
//
|
|
|
|
// ...
|
|
|
|
// \ | /
|
|
|
|
// v
|
|
|
|
// +-------------+
|
|
|
|
// | merge block |
|
|
|
|
// +-------------+
|
|
|
|
|
|
|
|
auto ®ion = op->getRegion(0);
|
|
|
|
// Allow empty region as a degenerated case, which can come from
|
|
|
|
// optimizations.
|
|
|
|
if (region.empty())
|
|
|
|
return success();
|
|
|
|
|
|
|
|
// The last block is the merge block.
|
|
|
|
Block &merge = region.back();
|
|
|
|
if (!isMergeBlock(merge))
|
|
|
|
return loopOp.emitOpError(
|
|
|
|
"last block must be the merge block with only one 'spv._merge' op");
|
|
|
|
|
|
|
|
if (std::next(region.begin()) == region.end())
|
|
|
|
return loopOp.emitOpError(
|
|
|
|
"must have an entry block branching to the loop header block");
|
|
|
|
// The first block is the entry block.
|
|
|
|
Block &entry = region.front();
|
|
|
|
|
|
|
|
if (std::next(region.begin(), 2) == region.end())
|
|
|
|
return loopOp.emitOpError(
|
|
|
|
"must have a loop header block branched from the entry block");
|
|
|
|
// The second block is the loop header block.
|
|
|
|
Block &header = *std::next(region.begin(), 1);
|
|
|
|
|
|
|
|
if (!hasOneBranchOpTo(entry, header))
|
|
|
|
return loopOp.emitOpError(
|
|
|
|
"entry block must only have one 'spv.Branch' op to the second block");
|
|
|
|
|
|
|
|
if (std::next(region.begin(), 3) == region.end())
|
|
|
|
return loopOp.emitOpError(
|
|
|
|
"requires a loop continue block branching to the loop header block");
|
|
|
|
// The second to last block is the loop continue block.
|
|
|
|
Block &cont = *std::prev(region.end(), 2);
|
|
|
|
|
|
|
|
// Make sure that we have a branch from the loop continue block to the loop
|
|
|
|
// header block.
|
|
|
|
if (llvm::none_of(
|
|
|
|
llvm::seq<unsigned>(0, cont.getNumSuccessors()),
|
|
|
|
[&](unsigned index) { return cont.getSuccessor(index) == &header; }))
|
|
|
|
return loopOp.emitOpError("second to last block must be the loop continue "
|
|
|
|
"block that branches to the loop header block");
|
|
|
|
|
|
|
|
// Make sure that no other blocks (except the entry and loop continue block)
|
|
|
|
// branches to the loop header block.
|
|
|
|
for (auto &block : llvm::make_range(std::next(region.begin(), 2),
|
|
|
|
std::prev(region.end(), 2))) {
|
|
|
|
for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
|
|
|
|
if (block.getSuccessor(i) == &header) {
|
|
|
|
return loopOp.emitOpError("can only have the entry and loop continue "
|
|
|
|
"block branching to the loop header block");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-11-12 11:59:34 -08:00
|
|
|
Block *spirv::LoopOp::getEntryBlock() {
|
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
|
|
|
return &body().front();
|
|
|
|
}
|
|
|
|
|
2019-09-11 14:02:23 -07:00
|
|
|
Block *spirv::LoopOp::getHeaderBlock() {
|
2019-10-02 11:00:50 -07:00
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
2019-09-11 14:02:23 -07:00
|
|
|
// The second block is the loop header block.
|
|
|
|
return &*std::next(body().begin());
|
|
|
|
}
|
|
|
|
|
|
|
|
Block *spirv::LoopOp::getContinueBlock() {
|
2019-10-02 11:00:50 -07:00
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
2019-09-11 14:02:23 -07:00
|
|
|
// The second to last block is the loop continue block.
|
|
|
|
return &*std::prev(body().end(), 2);
|
|
|
|
}
|
|
|
|
|
|
|
|
Block *spirv::LoopOp::getMergeBlock() {
|
2019-10-02 11:00:50 -07:00
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
2019-09-11 14:02:23 -07:00
|
|
|
// The last block is the loop merge block.
|
|
|
|
return &body().back();
|
|
|
|
}
|
|
|
|
|
|
|
|
void spirv::LoopOp::addEntryAndMergeBlock() {
|
|
|
|
assert(body().empty() && "entry and merge block already exist");
|
|
|
|
body().push_back(new Block());
|
|
|
|
auto *mergeBlock = new Block();
|
|
|
|
body().push_back(mergeBlock);
|
|
|
|
OpBuilder builder(mergeBlock);
|
|
|
|
|
|
|
|
// Add a spv._merge op into the merge block.
|
2019-10-02 11:00:50 -07:00
|
|
|
builder.create<spirv::MergeOp>(getLoc());
|
2019-09-11 14:02:23 -07:00
|
|
|
}
|
|
|
|
|
2019-09-05 12:45:08 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv._merge
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::MergeOp mergeOp) {
|
2019-10-02 11:00:50 -07:00
|
|
|
auto *parentOp = mergeOp.getParentOp();
|
|
|
|
if (!parentOp ||
|
|
|
|
(!isa<spirv::SelectionOp>(parentOp) && !isa<spirv::LoopOp>(parentOp)))
|
|
|
|
return mergeOp.emitOpError(
|
|
|
|
"expected parent op to be 'spv.selection' or 'spv.loop'");
|
|
|
|
|
2019-09-05 12:45:08 -07:00
|
|
|
Block &parentLastBlock = mergeOp.getParentRegion()->back();
|
|
|
|
if (mergeOp.getOperation() != parentLastBlock.getTerminator())
|
|
|
|
return mergeOp.emitOpError(
|
2019-10-02 11:00:50 -07:00
|
|
|
"can only be used in the last block of 'spv.selection' or 'spv.loop'");
|
2019-09-05 12:45:08 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-05-29 10:47:16 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.module
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
void spirv::ModuleOp::build(Builder *builder, OperationState &state) {
|
|
|
|
ensureTerminator(*state.addRegion(), *builder, state.location);
|
2019-06-21 14:51:58 -07:00
|
|
|
}
|
|
|
|
|
2019-12-09 09:51:25 -08:00
|
|
|
void spirv::ModuleOp::build(Builder *builder, OperationState &state,
|
|
|
|
spirv::AddressingModel addressing_model,
|
2020-03-11 16:04:25 -04:00
|
|
|
spirv::MemoryModel memory_model) {
|
2019-12-09 09:51:25 -08:00
|
|
|
state.addAttribute(
|
|
|
|
"addressing_model",
|
|
|
|
builder->getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
|
|
|
|
state.addAttribute("memory_model", builder->getI32IntegerAttr(
|
|
|
|
static_cast<int32_t>(memory_model)));
|
2020-03-11 16:04:25 -04:00
|
|
|
ensureTerminator(*state.addRegion(), *builder, state.location);
|
2019-07-30 11:29:48 -07:00
|
|
|
}
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
|
|
|
|
Region *body = state.addRegion();
|
2019-05-29 10:47:16 -07:00
|
|
|
|
2019-07-03 18:12:52 -07:00
|
|
|
// Parse attributes
|
|
|
|
spirv::AddressingModel addrModel;
|
|
|
|
spirv::MemoryModel memoryModel;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumKeywordAttr(addrModel, parser, state) ||
|
|
|
|
parseEnumKeywordAttr(memoryModel, parser, state))
|
2019-07-02 06:02:20 -07:00
|
|
|
return failure();
|
2020-03-11 16:04:25 -04:00
|
|
|
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("requires"))) {
|
|
|
|
spirv::VerCapExtAttr vceTriple;
|
|
|
|
if (parser.parseAttribute(vceTriple,
|
|
|
|
spirv::ModuleOp::getVCETripleAttrName(),
|
|
|
|
state.attributes))
|
|
|
|
return failure();
|
2019-07-02 06:02:20 -07:00
|
|
|
}
|
|
|
|
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
|
2019-07-02 06:02:20 -07:00
|
|
|
return failure();
|
|
|
|
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
|
2019-11-05 17:58:16 -08:00
|
|
|
return failure();
|
2019-05-29 10:47:16 -07:00
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
spirv::ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location);
|
2019-05-29 10:47:16 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
|
2019-12-12 15:31:39 -08:00
|
|
|
printer << spirv::ModuleOp::getOperationName();
|
2019-07-02 06:02:20 -07:00
|
|
|
|
|
|
|
SmallVector<StringRef, 2> elidedAttrs;
|
2020-03-11 16:04:25 -04:00
|
|
|
|
|
|
|
printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model())
|
|
|
|
<< " " << spirv::stringifyMemoryModel(moduleOp.memory_model());
|
2019-07-03 18:12:52 -07:00
|
|
|
auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
|
|
|
|
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
|
2020-03-11 16:04:25 -04:00
|
|
|
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
|
|
|
|
|
|
|
|
if (Optional<spirv::VerCapExtAttr> triple = moduleOp.vce_triple()) {
|
|
|
|
printer << " requires " << *triple;
|
|
|
|
elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
|
2019-07-02 06:02:20 -07:00
|
|
|
}
|
|
|
|
|
2020-03-11 16:04:25 -04:00
|
|
|
printer.printOptionalAttrDictWithKeyword(moduleOp.getAttrs(), elidedAttrs);
|
2019-12-12 15:31:39 -08:00
|
|
|
printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
|
2019-09-20 20:43:02 -07:00
|
|
|
/*printBlockTerminators=*/false);
|
2019-05-29 10:47:16 -07:00
|
|
|
}
|
|
|
|
|
2019-06-18 11:15:55 -07:00
|
|
|
static LogicalResult verify(spirv::ModuleOp moduleOp) {
|
2019-06-04 14:03:30 -07:00
|
|
|
auto &op = *moduleOp.getOperation();
|
|
|
|
auto *dialect = op.getDialect();
|
2020-02-07 11:30:19 -05:00
|
|
|
DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
|
2019-07-08 10:56:20 -07:00
|
|
|
entryPoints;
|
2019-08-08 14:40:03 -07:00
|
|
|
SymbolTable table(moduleOp);
|
2019-06-04 14:03:30 -07:00
|
|
|
|
2020-03-11 16:04:25 -04:00
|
|
|
for (auto &op : moduleOp.getBlock()) {
|
2020-02-07 11:30:19 -05:00
|
|
|
if (op.getDialect() != dialect)
|
|
|
|
return op.emitError("'spv.module' can only contain spv.* ops");
|
|
|
|
|
|
|
|
// For EntryPoint op, check that the function and execution model is not
|
|
|
|
// duplicated in EntryPointOps. Also verify that the interface specified
|
|
|
|
// comes from globalVariables here to make this check cheaper.
|
|
|
|
if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
|
|
|
|
auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
|
|
|
|
if (!funcOp) {
|
|
|
|
return entryPointOp.emitError("function '")
|
|
|
|
<< entryPointOp.fn() << "' not found in 'spv.module'";
|
|
|
|
}
|
|
|
|
if (auto interface = entryPointOp.interface()) {
|
|
|
|
for (Attribute varRef : interface) {
|
|
|
|
auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
|
|
|
|
if (!varSymRef) {
|
|
|
|
return entryPointOp.emitError(
|
|
|
|
"expected symbol reference for interface "
|
|
|
|
"specification instead of '")
|
|
|
|
<< varRef;
|
|
|
|
}
|
|
|
|
auto variableOp =
|
|
|
|
table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
|
|
|
|
if (!variableOp) {
|
|
|
|
return entryPointOp.emitError("expected spv.globalVariable "
|
|
|
|
"symbol reference instead of'")
|
|
|
|
<< varSymRef << "'";
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
}
|
2019-07-08 10:56:20 -07:00
|
|
|
}
|
2019-06-04 14:03:30 -07:00
|
|
|
|
2020-02-07 11:30:19 -05:00
|
|
|
auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
|
|
|
|
funcOp, entryPointOp.execution_model());
|
|
|
|
auto entryPtIt = entryPoints.find(key);
|
|
|
|
if (entryPtIt != entryPoints.end()) {
|
|
|
|
return entryPointOp.emitError("duplicate of a previous EntryPointOp");
|
2019-06-04 14:03:30 -07:00
|
|
|
}
|
2020-02-07 11:30:19 -05:00
|
|
|
entryPoints[key] = entryPointOp;
|
|
|
|
} else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
|
|
|
|
if (funcOp.isExternal())
|
|
|
|
return op.emitError("'spv.module' cannot contain external functions");
|
|
|
|
|
|
|
|
// TODO(antiagainst): move this check to spv.func.
|
|
|
|
for (auto &block : funcOp)
|
|
|
|
for (auto &op : block) {
|
|
|
|
if (op.getDialect() != dialect)
|
|
|
|
return op.emitError(
|
|
|
|
"functions in 'spv.module' can only contain spv.* ops");
|
|
|
|
}
|
|
|
|
}
|
2019-06-04 14:03:30 -07:00
|
|
|
}
|
2019-08-22 11:15:05 -07:00
|
|
|
|
2019-06-04 14:03:30 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-08-20 13:33:41 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv._reference_of
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
|
2020-01-25 09:16:29 -05:00
|
|
|
auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(
|
|
|
|
SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(),
|
|
|
|
referenceOfOp.spec_const()));
|
2019-08-20 13:33:41 -07:00
|
|
|
if (!specConstOp) {
|
|
|
|
return referenceOfOp.emitOpError("expected spv.specConstant symbol");
|
|
|
|
}
|
2020-01-11 08:54:04 -08:00
|
|
|
if (referenceOfOp.reference().getType() !=
|
2019-08-20 13:33:41 -07:00
|
|
|
specConstOp.default_value().getType()) {
|
|
|
|
return referenceOfOp.emitOpError("result type mismatch with the referenced "
|
|
|
|
"specialization constant's type");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
2019-08-30 12:17:21 -07:00
|
|
|
|
2019-06-04 14:03:30 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.Return
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-08-23 11:07:13 -07:00
|
|
|
static LogicalResult verify(spirv::ReturnOp returnOp) {
|
2020-02-07 11:30:19 -05:00
|
|
|
// Verification is performed in spv.func op.
|
2019-06-04 14:03:30 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-08-19 10:57:43 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ReturnValue
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::ReturnValueOp retValOp) {
|
2020-02-07 11:30:19 -05:00
|
|
|
// Verification is performed in spv.func op.
|
2019-08-19 10:57:43 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-02 21:06:35 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.Select
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-12-23 14:45:01 -08:00
|
|
|
void spirv::SelectOp::build(Builder *builder, OperationState &state, Value cond,
|
|
|
|
Value trueValue, Value falseValue) {
|
2020-01-11 08:54:04 -08:00
|
|
|
build(builder, state, trueValue.getType(), cond, trueValue, falseValue);
|
2019-10-28 23:03:54 -07:00
|
|
|
}
|
|
|
|
|
2019-09-02 21:06:35 -07:00
|
|
|
static LogicalResult verify(spirv::SelectOp op) {
|
2020-01-11 08:54:04 -08:00
|
|
|
if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
|
2020-02-03 21:52:38 -08:00
|
|
|
auto resultVectorTy = op.result().getType().dyn_cast<VectorType>();
|
2019-09-02 21:06:35 -07:00
|
|
|
if (!resultVectorTy) {
|
|
|
|
return op.emitOpError("result expected to be of vector type when "
|
|
|
|
"condition is of vector type");
|
|
|
|
}
|
|
|
|
if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
|
|
|
|
return op.emitOpError("result should have the same number of elements as "
|
|
|
|
"the condition when condition is of vector type");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-10-02 11:00:50 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.selection
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static ParseResult parseSelectionOp(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
|
|
|
// TODO(antiagainst): support selection control properly
|
|
|
|
Builder builder = parser.getBuilder();
|
|
|
|
state.addAttribute("selection_control",
|
|
|
|
builder.getI32IntegerAttr(
|
|
|
|
static_cast<uint32_t>(spirv::SelectionControl::None)));
|
|
|
|
|
|
|
|
return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
|
|
|
|
/*argTypes=*/{});
|
|
|
|
}
|
|
|
|
|
|
|
|
static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
|
|
|
|
auto *op = selectionOp.getOperation();
|
|
|
|
|
|
|
|
printer << spirv::SelectionOp::getOperationName();
|
|
|
|
printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
|
|
|
|
/*printBlockTerminators=*/true);
|
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::SelectionOp selectionOp) {
|
|
|
|
auto *op = selectionOp.getOperation();
|
|
|
|
|
|
|
|
// We need to verify that the blocks follow the following layout:
|
|
|
|
//
|
|
|
|
// +--------------+
|
|
|
|
// | header block |
|
|
|
|
// +--------------+
|
|
|
|
// / | \
|
|
|
|
// ...
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// +---------+ +---------+ +---------+
|
|
|
|
// | case #0 | | case #1 | | case #2 | ...
|
|
|
|
// +---------+ +---------+ +---------+
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// ...
|
|
|
|
// \ | /
|
|
|
|
// v
|
|
|
|
// +-------------+
|
|
|
|
// | merge block |
|
|
|
|
// +-------------+
|
|
|
|
|
|
|
|
auto ®ion = op->getRegion(0);
|
|
|
|
// Allow empty region as a degenerated case, which can come from
|
|
|
|
// optimizations.
|
|
|
|
if (region.empty())
|
|
|
|
return success();
|
|
|
|
|
|
|
|
// The last block is the merge block.
|
|
|
|
if (!isMergeBlock(region.back()))
|
|
|
|
return selectionOp.emitOpError(
|
|
|
|
"last block must be the merge block with only one 'spv._merge' op");
|
|
|
|
|
|
|
|
if (std::next(region.begin()) == region.end())
|
|
|
|
return selectionOp.emitOpError("must have a selection header block");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
Block *spirv::SelectionOp::getHeaderBlock() {
|
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
|
|
|
// The first block is the loop header block.
|
|
|
|
return &body().front();
|
|
|
|
}
|
|
|
|
|
|
|
|
Block *spirv::SelectionOp::getMergeBlock() {
|
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
|
|
|
// The last block is the loop merge block.
|
|
|
|
return &body().back();
|
|
|
|
}
|
|
|
|
|
|
|
|
void spirv::SelectionOp::addMergeBlock() {
|
|
|
|
assert(body().empty() && "entry and merge block already exist");
|
|
|
|
auto *mergeBlock = new Block();
|
|
|
|
body().push_back(mergeBlock);
|
|
|
|
OpBuilder builder(mergeBlock);
|
|
|
|
|
|
|
|
// Add a spv._merge op into the merge block.
|
|
|
|
builder.create<spirv::MergeOp>(getLoc());
|
|
|
|
}
|
|
|
|
|
2020-01-26 11:10:29 -05:00
|
|
|
spirv::SelectionOp spirv::SelectionOp::createIfThen(
|
|
|
|
Location loc, Value condition,
|
|
|
|
function_ref<void(OpBuilder *builder)> thenBody, OpBuilder *builder) {
|
|
|
|
auto selectionControl = builder->getI32IntegerAttr(
|
|
|
|
static_cast<uint32_t>(spirv::SelectionControl::None));
|
|
|
|
auto selectionOp = builder->create<spirv::SelectionOp>(loc, selectionControl);
|
|
|
|
|
|
|
|
selectionOp.addMergeBlock();
|
|
|
|
Block *mergeBlock = selectionOp.getMergeBlock();
|
|
|
|
Block *thenBlock = nullptr;
|
|
|
|
|
|
|
|
// Build the "then" block.
|
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard guard(*builder);
|
|
|
|
thenBlock = builder->createBlock(mergeBlock);
|
|
|
|
thenBody(builder);
|
|
|
|
builder->create<spirv::BranchOp>(loc, mergeBlock);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Build the header block.
|
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard guard(*builder);
|
|
|
|
builder->createBlock(thenBlock);
|
|
|
|
builder->create<spirv::BranchConditionalOp>(
|
|
|
|
loc, condition, thenBlock,
|
|
|
|
/*trueArguments=*/ArrayRef<Value>(), mergeBlock,
|
|
|
|
/*falseArguments=*/ArrayRef<Value>());
|
|
|
|
}
|
|
|
|
|
|
|
|
return selectionOp;
|
|
|
|
}
|
|
|
|
|
2019-08-20 13:33:41 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.specConstant
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseSpecConstantOp(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &state) {
|
2019-08-20 13:33:41 -07:00
|
|
|
StringAttr nameAttr;
|
|
|
|
Attribute valueAttr;
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
2019-10-15 14:53:01 -07:00
|
|
|
state.attributes))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Parse optional spec_id.
|
|
|
|
if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
|
|
|
|
IntegerAttr specIdAttr;
|
|
|
|
if (parser.parseLParen() ||
|
|
|
|
parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
|
|
|
|
parser.parseRParen())
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (parser.parseEqual() ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
|
2019-08-20 13:33:41 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
|
2019-10-08 17:44:39 -07:00
|
|
|
printer << spirv::SpecConstantOp::getOperationName() << ' ';
|
|
|
|
printer.printSymbolName(constOp.sym_name());
|
2019-10-15 14:53:01 -07:00
|
|
|
if (auto specID = constOp.getAttrOfType<IntegerAttr>(kSpecIdAttrName))
|
|
|
|
printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
|
2019-12-12 15:31:39 -08:00
|
|
|
printer << " = " << constOp.default_value();
|
2019-08-20 13:33:41 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::SpecConstantOp constOp) {
|
2019-10-15 14:53:01 -07:00
|
|
|
if (auto specID = constOp.getAttrOfType<IntegerAttr>(kSpecIdAttrName))
|
|
|
|
if (specID.getValue().isNegative())
|
|
|
|
return constOp.emitOpError("SpecId cannot be negative");
|
|
|
|
|
2019-08-20 13:33:41 -07:00
|
|
|
auto value = constOp.default_value();
|
|
|
|
|
|
|
|
switch (value.getKind()) {
|
|
|
|
case StandardAttributes::Bool:
|
|
|
|
case StandardAttributes::Integer:
|
|
|
|
case StandardAttributes::Float: {
|
|
|
|
// Make sure bitwidth is allowed.
|
2019-09-03 12:09:07 -07:00
|
|
|
if (!spirv::SPIRVDialect::isValidType(value.getType()))
|
2019-08-20 13:33:41 -07:00
|
|
|
return constOp.emitOpError("default value bitwidth disallowed");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
return constOp.emitOpError(
|
|
|
|
"default value can only be a bool, integer, or float scalar");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-06-24 10:59:05 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.StoreOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &state) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// Parse the storage class specification
|
|
|
|
spirv::StorageClass storageClass;
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> operandInfo;
|
2019-09-20 11:36:49 -07:00
|
|
|
auto loc = parser.getCurrentLocation();
|
2019-06-24 10:59:05 -07:00
|
|
|
Type elementType;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(storageClass, parser) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseOperandList(operandInfo, 2) ||
|
|
|
|
parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
|
|
|
|
parser.parseType(elementType)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto ptrType = spirv::PointerType::get(elementType, storageClass);
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
|
2019-09-20 19:47:05 -07:00
|
|
|
state.operands)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
|
2019-06-24 10:59:05 -07:00
|
|
|
auto *op = storeOp.getOperation();
|
|
|
|
SmallVector<StringRef, 4> elidedAttrs;
|
2019-07-02 06:02:20 -07:00
|
|
|
StringRef sc = stringifyStorageClass(
|
2020-01-11 08:54:04 -08:00
|
|
|
storeOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
|
2019-12-12 15:31:39 -08:00
|
|
|
printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
|
|
|
|
<< storeOp.ptr() << ", " << storeOp.value();
|
2019-06-24 10:59:05 -07:00
|
|
|
|
|
|
|
printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
|
|
|
|
|
2020-01-11 08:54:04 -08:00
|
|
|
printer << " : " << storeOp.value().getType();
|
2019-09-20 20:43:02 -07:00
|
|
|
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::StoreOp storeOp) {
|
|
|
|
// SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
|
|
|
|
// OpTypePointer whose Type operand is the same as the type of Object."
|
|
|
|
if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(),
|
|
|
|
storeOp.value()))) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
return verifyMemoryAccessAttribute(storeOp);
|
|
|
|
}
|
|
|
|
|
2019-10-30 05:40:47 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.Unreachable
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::UnreachableOp unreachableOp) {
|
|
|
|
auto *op = unreachableOp.getOperation();
|
|
|
|
auto *block = op->getBlock();
|
|
|
|
// Fast track: if this is in entry block, its invalid. Otherwise, if no
|
|
|
|
// predecessors, it's valid.
|
|
|
|
if (block->isEntryBlock())
|
|
|
|
return unreachableOp.emitOpError("cannot be used in reachable block");
|
|
|
|
if (block->hasNoPredecessors())
|
|
|
|
return success();
|
|
|
|
|
|
|
|
// TODO(antiagainst): further verification needs to analyze reachablility from
|
|
|
|
// the entry block.
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-06-18 11:15:55 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.Variable
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
|
2019-06-18 11:15:55 -07:00
|
|
|
// Parse optional initializer
|
|
|
|
Optional<OpAsmParser::OperandType> initInfo;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (succeeded(parser.parseOptionalKeyword("init"))) {
|
2019-06-18 11:15:55 -07:00
|
|
|
initInfo = OpAsmParser::OperandType();
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
|
|
|
|
parser.parseRParen())
|
2019-06-18 11:15:55 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-08-17 10:19:48 -07:00
|
|
|
if (parseVariableDecorations(parser, state)) {
|
2019-06-18 11:15:55 -07:00
|
|
|
return failure();
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
2019-06-18 11:15:55 -07:00
|
|
|
|
|
|
|
// Parse result pointer type
|
|
|
|
Type type;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseColon())
|
2019-06-18 11:15:55 -07:00
|
|
|
return failure();
|
2019-09-20 11:36:49 -07:00
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseType(type))
|
2019-06-18 11:15:55 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
|
|
|
if (!ptrType)
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(loc, "expected spv.ptr type");
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addTypes(ptrType);
|
2019-06-18 11:15:55 -07:00
|
|
|
|
|
|
|
// Resolve the initializer operand
|
|
|
|
if (initInfo) {
|
2019-12-03 04:49:20 -08:00
|
|
|
if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
|
|
|
|
state.operands))
|
2019-06-18 11:15:55 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
auto attr = parser.getBuilder().getI32IntegerAttr(
|
2019-09-24 19:24:33 -07:00
|
|
|
llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
|
2019-06-18 11:15:55 -07:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
|
2019-07-03 18:12:52 -07:00
|
|
|
SmallVector<StringRef, 4> elidedAttrs{
|
|
|
|
spirv::attributeName<spirv::StorageClass>()};
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << spirv::VariableOp::getOperationName();
|
2019-06-18 11:15:55 -07:00
|
|
|
|
|
|
|
// Print optional initializer
|
2019-12-12 15:31:39 -08:00
|
|
|
if (varOp.getNumOperands() != 0)
|
|
|
|
printer << " init(" << varOp.initializer() << ")";
|
2019-08-15 10:52:24 -07:00
|
|
|
|
2019-12-12 15:31:39 -08:00
|
|
|
printVariableDecorations(varOp, printer, elidedAttrs);
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << " : " << varOp.getType();
|
2019-06-18 11:15:55 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verify(spirv::VariableOp varOp) {
|
|
|
|
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
|
|
|
|
// object. It cannot be Generic. It must be the same as the Storage Class
|
|
|
|
// operand of the Result Type."
|
2019-08-17 10:19:48 -07:00
|
|
|
if (varOp.storage_class() != spirv::StorageClass::Function) {
|
|
|
|
return varOp.emitOpError(
|
|
|
|
"can only be used to model function-level variables. Use "
|
|
|
|
"spv.globalVariable for module-level variables.");
|
|
|
|
}
|
2019-06-18 11:15:55 -07:00
|
|
|
|
2020-01-11 08:54:04 -08:00
|
|
|
auto pointerType = varOp.pointer().getType().cast<spirv::PointerType>();
|
2019-07-02 06:02:20 -07:00
|
|
|
if (varOp.storage_class() != pointerType.getStorageClass())
|
2019-06-18 11:15:55 -07:00
|
|
|
return varOp.emitOpError(
|
|
|
|
"storage class must match result pointer's storage class");
|
|
|
|
|
|
|
|
if (varOp.getNumOperands() != 0) {
|
|
|
|
// SPIR-V spec: "Initializer must be an <id> from a constant instruction or
|
|
|
|
// a global (module scope) OpVariable instruction".
|
2020-01-11 08:54:04 -08:00
|
|
|
auto *initOp = varOp.getOperand(0).getDefiningOp();
|
2019-08-20 13:33:41 -07:00
|
|
|
if (!initOp || !(isa<spirv::ConstantOp>(initOp) || // for normal constant
|
|
|
|
isa<spirv::ReferenceOfOp>(initOp) || // for spec constant
|
|
|
|
isa<spirv::AddressOfOp>(initOp)))
|
2019-06-18 11:15:55 -07:00
|
|
|
return varOp.emitOpError("initializer must be the result of a "
|
2019-08-20 13:33:41 -07:00
|
|
|
"constant or spv.globalVariable op");
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO(antiagainst): generate these strings using ODS.
|
|
|
|
auto *op = varOp.getOperation();
|
|
|
|
auto descriptorSetName =
|
|
|
|
convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
|
|
|
|
auto bindingName =
|
|
|
|
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
|
|
|
|
auto builtInName =
|
|
|
|
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
|
|
|
|
|
|
|
|
for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
|
|
|
|
if (op->getAttr(attr))
|
|
|
|
return varOp.emitOpError("cannot have '")
|
|
|
|
<< attr << "' attribute (only allowed in spv.globalVariable)";
|
2019-06-18 11:15:55 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-05-26 05:43:20 -07:00
|
|
|
namespace mlir {
|
|
|
|
namespace spirv {
|
|
|
|
|
2019-12-27 16:24:33 -05:00
|
|
|
// TableGen'erated operation interfaces for querying versions, extensions, and
|
|
|
|
// capabilities.
|
|
|
|
#include "mlir/Dialect/SPIRV/SPIRVAvailability.cpp.inc"
|
|
|
|
|
|
|
|
// TablenGen'erated operation definitions.
|
2019-05-26 05:43:20 -07:00
|
|
|
#define GET_OP_CLASSES
|
2019-07-16 05:06:57 -07:00
|
|
|
#include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
|
2019-05-26 05:43:20 -07:00
|
|
|
|
2019-12-27 16:24:33 -05:00
|
|
|
// TableGen'erated operation availability interface implementations.
|
|
|
|
#include "mlir/Dialect/SPIRV/SPIRVOpAvailabilityImpl.inc"
|
|
|
|
|
2019-05-26 05:43:20 -07:00
|
|
|
} // namespace spirv
|
|
|
|
} // namespace mlir
|