2019-05-26 05:43:20 -07:00
|
|
|
//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
|
|
|
|
//
|
2020-01-26 03:58:30 +00:00
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
2019-12-23 09:35:36 -08:00
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2019-05-26 05:43:20 -07:00
|
|
|
//
|
2019-12-23 09:35:36 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-05-26 05:43:20 -07:00
|
|
|
//
|
|
|
|
// This file defines the operations in the SPIR-V dialect.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-12-17 10:55:45 -05:00
|
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
2019-05-26 05:43:20 -07:00
|
|
|
|
2020-12-17 10:55:45 -05:00
|
|
|
#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
|
|
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
|
|
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
2021-01-08 14:48:48 +01:00
|
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
|
2020-12-17 10:55:45 -05:00
|
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
|
|
|
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
|
2019-05-29 10:47:16 -07:00
|
|
|
#include "mlir/IR/Builders.h"
|
2020-11-19 10:43:12 -08:00
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
2020-12-03 17:22:29 -08:00
|
|
|
#include "mlir/IR/BuiltinTypes.h"
|
2020-02-07 11:30:19 -05:00
|
|
|
#include "mlir/IR/FunctionImplementation.h"
|
2021-01-08 14:48:48 +01:00
|
|
|
#include "mlir/IR/OpDefinition.h"
|
2019-05-29 10:47:16 -07:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
2021-02-02 11:08:39 -05:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
2020-03-10 12:20:24 -07:00
|
|
|
#include "mlir/Interfaces/CallInterfaces.h"
|
2021-01-21 22:20:18 -08:00
|
|
|
#include "llvm/ADT/APFloat.h"
|
|
|
|
#include "llvm/ADT/APInt.h"
|
2022-06-15 20:38:47 -04:00
|
|
|
#include "llvm/ADT/STLExtras.h"
|
2020-04-14 18:54:23 -07:00
|
|
|
#include "llvm/ADT/StringExtras.h"
|
2019-09-24 19:24:33 -07:00
|
|
|
#include "llvm/ADT/bit.h"
|
2022-07-27 19:16:56 -04:00
|
|
|
#include <numeric>
|
2019-05-26 05:43:20 -07:00
|
|
|
|
2019-05-29 10:47:16 -07:00
|
|
|
using namespace mlir;
|
|
|
|
|
2020-07-07 01:35:23 -07:00
|
|
|
// TODO: generate these strings using ODS.
|
2020-07-09 19:08:51 -04:00
|
|
|
static constexpr const char kMemoryAccessAttrName[] = "memory_access";
|
|
|
|
static constexpr const char kSourceMemoryAccessAttrName[] =
|
|
|
|
"source_memory_access";
|
2019-07-03 18:12:52 -07:00
|
|
|
static constexpr const char kAlignmentAttrName[] = "alignment";
|
2020-07-09 19:08:51 -04:00
|
|
|
static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
|
2019-08-30 12:17:21 -07:00
|
|
|
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
|
2019-09-16 15:39:16 -07:00
|
|
|
static constexpr const char kCallee[] = "callee";
|
2020-01-26 10:19:24 -05:00
|
|
|
static constexpr const char kClusterSize[] = "cluster_size";
|
2020-08-03 09:31:08 +03:00
|
|
|
static constexpr const char kControl[] = "control";
|
2019-08-20 13:33:41 -07:00
|
|
|
static constexpr const char kDefaultValueAttrName[] = "default_value";
|
2019-09-21 10:18:00 -07:00
|
|
|
static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
|
2019-12-05 10:05:54 -08:00
|
|
|
static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
|
2019-08-17 10:19:48 -07:00
|
|
|
static constexpr const char kFnNameAttrName[] = "fn";
|
2020-01-26 10:19:24 -05:00
|
|
|
static constexpr const char kGroupOperationAttrName[] = "group_operation";
|
2019-07-12 06:14:53 -07:00
|
|
|
static constexpr const char kIndicesAttrName[] = "indices";
|
2019-08-17 10:19:48 -07:00
|
|
|
static constexpr const char kInitializerAttrName[] = "initializer";
|
|
|
|
static constexpr const char kInterfaceAttrName[] = "interface";
|
2019-09-21 10:18:00 -07:00
|
|
|
static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
|
2019-12-16 15:05:21 -08:00
|
|
|
static constexpr const char kSemanticsAttrName[] = "semantics";
|
2019-10-15 14:53:01 -07:00
|
|
|
static constexpr const char kSpecIdAttrName[] = "spec_id";
|
2019-08-17 10:19:48 -07:00
|
|
|
static constexpr const char kTypeAttrName[] = "type";
|
2019-12-05 10:05:54 -08:00
|
|
|
static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
|
2019-06-17 14:47:22 -07:00
|
|
|
static constexpr const char kValueAttrName[] = "value";
|
2019-07-08 10:56:20 -07:00
|
|
|
static constexpr const char kValuesAttrName[] = "values";
|
2020-10-02 14:56:17 -04:00
|
|
|
static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
|
2019-06-17 14:47:22 -07:00
|
|
|
|
2019-06-04 14:03:30 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Common utility functions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-06 12:33:08 -08:00
|
|
|
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
|
2022-02-06 12:33:08 -08:00
|
|
|
Type type;
|
|
|
|
// If the operand list is in-between parentheses, then we have a generic form.
|
|
|
|
// (see the fallback in `printOneResultOp`).
|
|
|
|
SMLoc loc = parser.getCurrentLocation();
|
|
|
|
if (!parser.parseOptionalLParen()) {
|
|
|
|
if (parser.parseOperandList(ops) || parser.parseRParen() ||
|
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
|
|
parser.parseColon() || parser.parseType(type))
|
|
|
|
return failure();
|
|
|
|
auto fnType = type.dyn_cast<FunctionType>();
|
|
|
|
if (!fnType) {
|
|
|
|
parser.emitError(loc, "expected function type");
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
|
|
|
|
return failure();
|
|
|
|
result.addTypes(fnType.getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return failure(parser.parseOperandList(ops) ||
|
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
|
|
parser.parseColonType(type) ||
|
|
|
|
parser.resolveOperands(ops, type, result.operands) ||
|
|
|
|
parser.addTypeToList(type, result.types));
|
|
|
|
}
|
|
|
|
|
|
|
|
static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
|
|
|
|
assert(op->getNumResults() == 1 && "op should have one result");
|
|
|
|
|
|
|
|
// If not all the operand and result types are the same, just use the
|
|
|
|
// generic assembly form to avoid omitting information in printing.
|
|
|
|
auto resultType = op->getResult(0).getType();
|
|
|
|
if (llvm::any_of(op->getOperandTypes(),
|
|
|
|
[&](Type type) { return type != resultType; })) {
|
|
|
|
p.printGenericOp(op, /*printOpName=*/false);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
p << ' ';
|
|
|
|
p.printOperands(op->getOperands());
|
|
|
|
p.printOptionalAttrDict(op->getAttrs());
|
|
|
|
// Now we can output only one type for all operands and the result.
|
|
|
|
p << " : " << resultType;
|
|
|
|
}
|
|
|
|
|
2020-01-25 09:16:29 -05:00
|
|
|
/// Returns true if the given op is a function-like op or nested in a
|
|
|
|
/// function-like op without a module-like op in the middle.
|
2022-01-13 20:51:38 -08:00
|
|
|
static bool isNestedInFunctionOpInterface(Operation *op) {
|
2020-01-25 09:16:29 -05:00
|
|
|
if (!op)
|
|
|
|
return false;
|
|
|
|
if (op->hasTrait<OpTrait::SymbolTable>())
|
|
|
|
return false;
|
2022-01-13 20:51:38 -08:00
|
|
|
if (isa<FunctionOpInterface>(op))
|
2020-01-25 09:16:29 -05:00
|
|
|
return true;
|
2022-01-13 20:51:38 -08:00
|
|
|
return isNestedInFunctionOpInterface(op->getParentOp());
|
2020-01-25 09:16:29 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns true if the given op is an module-like op that maintains a symbol
|
|
|
|
/// table.
|
|
|
|
static bool isDirectInModuleLikeOp(Operation *op) {
|
|
|
|
return op && op->hasTrait<OpTrait::SymbolTable>();
|
|
|
|
}
|
|
|
|
|
2020-01-26 10:19:24 -05:00
|
|
|
static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
|
|
|
|
auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
|
2019-07-25 15:42:41 -07:00
|
|
|
if (!constOp) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
auto valueAttr = constOp.value();
|
|
|
|
auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
|
|
|
|
if (!integerValueAttr) {
|
|
|
|
return failure();
|
|
|
|
}
|
2021-10-05 00:04:33 +08:00
|
|
|
|
|
|
|
if (integerValueAttr.getType().isSignlessInteger())
|
|
|
|
value = integerValueAttr.getInt();
|
|
|
|
else
|
|
|
|
value = integerValueAttr.getSInt();
|
|
|
|
|
2019-07-25 15:42:41 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-09 09:51:25 -08:00
|
|
|
template <typename Ty>
|
|
|
|
static ArrayAttr
|
|
|
|
getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
|
2019-12-18 09:28:48 -08:00
|
|
|
function_ref<StringRef(Ty)> stringifyFn) {
|
2019-12-09 09:51:25 -08:00
|
|
|
if (enumValues.empty()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
SmallVector<StringRef, 1> enumValStrs;
|
|
|
|
enumValStrs.reserve(enumValues.size());
|
|
|
|
for (auto val : enumValues) {
|
|
|
|
enumValStrs.emplace_back(stringifyFn(val));
|
|
|
|
}
|
|
|
|
return builder.getStrArrayAttr(enumValStrs);
|
|
|
|
}
|
|
|
|
|
2020-03-11 16:04:25 -04:00
|
|
|
/// Parses the next string attribute in `parser` as an enumerant of the given
|
|
|
|
/// `EnumClass`.
|
2019-07-03 18:12:52 -07:00
|
|
|
template <typename EnumClass>
|
2019-09-21 10:18:00 -07:00
|
|
|
static ParseResult
|
2020-03-11 16:04:25 -04:00
|
|
|
parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
|
|
|
|
StringRef attrName = spirv::attributeName<EnumClass>()) {
|
2019-07-03 18:12:52 -07:00
|
|
|
Attribute attrVal;
|
2020-05-06 13:48:36 -07:00
|
|
|
NamedAttrList attr;
|
2019-09-20 11:36:49 -07:00
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
|
2019-09-21 10:18:00 -07:00
|
|
|
attrName, attr)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
2019-07-03 18:12:52 -07:00
|
|
|
if (!attrVal.isa<StringAttr>()) {
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(loc, "expected ")
|
2019-09-21 10:18:00 -07:00
|
|
|
<< attrName << " attribute specified as string";
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
2019-07-03 18:12:52 -07:00
|
|
|
auto attrOptional =
|
2020-04-12 19:03:33 -07:00
|
|
|
spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
|
2019-07-03 18:12:52 -07:00
|
|
|
if (!attrOptional) {
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(loc, "invalid ")
|
2019-09-21 10:18:00 -07:00
|
|
|
<< attrName << " attribute specification: " << attrVal;
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
2022-06-20 23:20:25 -07:00
|
|
|
value = *attrOptional;
|
2019-07-30 14:14:28 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-03-11 16:04:25 -04:00
|
|
|
/// Parses the next string attribute in `parser` as an enumerant of the given
|
|
|
|
/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
|
|
|
|
/// attribute with the enum class's name as attribute name.
|
2019-07-30 14:14:28 -07:00
|
|
|
template <typename EnumClass>
|
2019-09-21 10:18:00 -07:00
|
|
|
static ParseResult
|
2020-03-11 16:04:25 -04:00
|
|
|
parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
|
|
|
|
StringRef attrName = spirv::attributeName<EnumClass>()) {
|
|
|
|
if (parseEnumStrAttr(value, parser)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
|
|
|
|
llvm::bit_cast<int32_t>(value)));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
|
|
|
|
/// and inserts the enumerant into `state` as an 32-bit integer attribute with
|
|
|
|
/// the enum class's name as attribute name.
|
|
|
|
template <typename EnumClass>
|
|
|
|
static ParseResult
|
|
|
|
parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
|
|
|
|
OperationState &state,
|
|
|
|
StringRef attrName = spirv::attributeName<EnumClass>()) {
|
|
|
|
if (parseEnumKeywordAttr(value, parser)) {
|
2019-07-30 14:14:28 -07:00
|
|
|
return failure();
|
|
|
|
}
|
2019-09-21 10:18:00 -07:00
|
|
|
state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
|
2019-09-24 19:24:33 -07:00
|
|
|
llvm::bit_cast<int32_t>(value)));
|
2019-06-24 10:59:05 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-08-03 09:31:08 +03:00
|
|
|
/// Parses Function, Selection and Loop control attributes. If no control is
|
|
|
|
/// specified, "None" is used as a default.
|
|
|
|
template <typename EnumClass>
|
|
|
|
static ParseResult
|
|
|
|
parseControlAttribute(OpAsmParser &parser, OperationState &state,
|
|
|
|
StringRef attrName = spirv::attributeName<EnumClass>()) {
|
|
|
|
if (succeeded(parser.parseOptionalKeyword(kControl))) {
|
|
|
|
EnumClass control;
|
|
|
|
if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
|
|
|
|
parser.parseRParen())
|
|
|
|
return failure();
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
// Set control to "None" otherwise.
|
|
|
|
Builder builder = parser.getBuilder();
|
|
|
|
state.addAttribute(attrName, builder.getI32IntegerAttr(0));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-07-09 19:08:51 -04:00
|
|
|
/// Parses optional memory access attributes attached to a memory access
|
|
|
|
/// operand/pointer. Specifically, parses the following syntax:
|
|
|
|
/// (`[` memory-access `]`)?
|
|
|
|
/// where:
|
|
|
|
/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
|
|
|
|
/// integer-literal | `"NonTemporal"`
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &state) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// Parse an optional list of attributes staring with '['
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOptionalLSquare()) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// Nothing to do
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-07-03 18:12:52 -07:00
|
|
|
spirv::MemoryAccess memoryAccessAttr;
|
2020-07-09 19:08:51 -04:00
|
|
|
if (parseEnumStrAttr(memoryAccessAttr, parser, state,
|
|
|
|
kMemoryAccessAttrName)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-09-16 09:22:43 -07:00
|
|
|
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// Parse integer attribute for alignment.
|
|
|
|
Attribute alignmentAttr;
|
2019-09-20 11:36:49 -07:00
|
|
|
Type i32Type = parser.getBuilder().getIntegerType(32);
|
|
|
|
if (parser.parseComma() ||
|
2020-07-02 15:42:10 -04:00
|
|
|
parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
|
2019-09-20 19:47:05 -07:00
|
|
|
state.attributes)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.parseRSquare();
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
|
|
|
|
2020-07-09 19:08:51 -04:00
|
|
|
// TODO Make sure to merge this and the previous function into one template
|
2020-10-29 04:03:15 +09:00
|
|
|
// parameterized by memory access attribute name and alignment. Doing so now
|
2020-07-09 19:08:51 -04:00
|
|
|
// results in VS2017 in producing an internal error (at the call site) that's
|
2020-10-29 04:03:15 +09:00
|
|
|
// not detailed enough to understand what is happening.
|
2020-07-09 19:08:51 -04:00
|
|
|
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
|
|
|
// Parse an optional list of attributes staring with '['
|
|
|
|
if (parser.parseOptionalLSquare()) {
|
|
|
|
// Nothing to do
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
spirv::MemoryAccess memoryAccessAttr;
|
|
|
|
if (parseEnumStrAttr(memoryAccessAttr, parser, state,
|
|
|
|
kSourceMemoryAccessAttrName)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
|
|
|
|
// Parse integer attribute for alignment.
|
|
|
|
Attribute alignmentAttr;
|
|
|
|
Type i32Type = parser.getBuilder().getIntegerType(32);
|
|
|
|
if (parser.parseComma() ||
|
|
|
|
parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
|
|
|
|
state.attributes)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return parser.parseRSquare();
|
|
|
|
}
|
|
|
|
|
2020-07-02 15:42:10 -04:00
|
|
|
template <typename MemoryOpTy>
|
2020-07-09 19:08:51 -04:00
|
|
|
static void printMemoryAccessAttribute(
|
|
|
|
MemoryOpTy memoryOp, OpAsmPrinter &printer,
|
|
|
|
SmallVectorImpl<StringRef> &elidedAttrs,
|
|
|
|
Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
|
2020-09-01 13:32:14 -07:00
|
|
|
Optional<uint32_t> alignmentAttrValue = None) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// Print optional memory access attribute.
|
2020-07-09 19:08:51 -04:00
|
|
|
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
|
|
|
|
: memoryOp.memory_access())) {
|
|
|
|
elidedAttrs.push_back(kMemoryAccessAttrName);
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
|
2019-06-24 10:59:05 -07:00
|
|
|
|
2020-07-09 19:08:51 -04:00
|
|
|
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
|
|
|
|
// Print integer alignment attribute.
|
|
|
|
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
|
|
|
|
: memoryOp.alignment())) {
|
|
|
|
elidedAttrs.push_back(kAlignmentAttrName);
|
|
|
|
printer << ", " << alignment;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
printer << "]";
|
|
|
|
}
|
|
|
|
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO Make sure to merge this and the previous function into one template
|
2020-10-29 04:03:15 +09:00
|
|
|
// parameterized by memory access attribute name and alignment. Doing so now
|
2020-07-09 19:08:51 -04:00
|
|
|
// results in VS2017 in producing an internal error (at the call site) that's
|
2020-10-29 04:03:15 +09:00
|
|
|
// not detailed enough to understand what is happening.
|
2020-07-09 19:08:51 -04:00
|
|
|
template <typename MemoryOpTy>
|
|
|
|
static void printSourceMemoryAccessAttribute(
|
|
|
|
MemoryOpTy memoryOp, OpAsmPrinter &printer,
|
|
|
|
SmallVectorImpl<StringRef> &elidedAttrs,
|
|
|
|
Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
|
2020-09-01 13:32:14 -07:00
|
|
|
Optional<uint32_t> alignmentAttrValue = None) {
|
2020-07-09 19:08:51 -04:00
|
|
|
|
|
|
|
printer << ", ";
|
|
|
|
|
|
|
|
// Print optional memory access attribute.
|
|
|
|
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
|
|
|
|
: memoryOp.memory_access())) {
|
|
|
|
elidedAttrs.push_back(kSourceMemoryAccessAttrName);
|
|
|
|
|
|
|
|
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
|
|
|
|
|
|
|
|
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
|
|
|
|
// Print integer alignment attribute.
|
|
|
|
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
|
|
|
|
: memoryOp.alignment())) {
|
|
|
|
elidedAttrs.push_back(kSourceAlignmentAttrName);
|
|
|
|
printer << ", " << alignment;
|
|
|
|
}
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << "]";
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
2019-07-03 18:12:52 -07:00
|
|
|
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
|
|
|
|
2021-09-02 02:39:05 +08:00
|
|
|
static ParseResult parseImageOperands(OpAsmParser &parser,
|
|
|
|
spirv::ImageOperandsAttr &attr) {
|
|
|
|
// Expect image operands
|
|
|
|
if (parser.parseOptionalLSquare())
|
|
|
|
return success();
|
|
|
|
|
|
|
|
spirv::ImageOperands imageOperands;
|
|
|
|
if (parseEnumStrAttr(imageOperands, parser))
|
|
|
|
return failure();
|
|
|
|
|
2021-09-29 17:47:08 -07:00
|
|
|
attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
|
2021-09-02 02:39:05 +08:00
|
|
|
|
|
|
|
return parser.parseRSquare();
|
|
|
|
}
|
|
|
|
|
|
|
|
static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
|
|
|
|
spirv::ImageOperandsAttr attr) {
|
|
|
|
if (attr) {
|
|
|
|
auto strImageOperands = stringifyImageOperands(attr.getValue());
|
|
|
|
printer << "[\"" << strImageOperands << "\"]";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename Op>
|
|
|
|
static LogicalResult verifyImageOperands(Op imageOp,
|
|
|
|
spirv::ImageOperandsAttr attr,
|
|
|
|
Operation::operand_range operands) {
|
|
|
|
if (!attr) {
|
|
|
|
if (operands.empty())
|
|
|
|
return success();
|
|
|
|
|
|
|
|
return imageOp.emitError("the Image Operands should encode what operands "
|
|
|
|
"follow, as per Image Operands");
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: Add the validation rules for the following Image Operands.
|
|
|
|
spirv::ImageOperands noSupportOperands =
|
|
|
|
spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
|
|
|
|
spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
|
|
|
|
spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
|
|
|
|
spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
|
|
|
|
spirv::ImageOperands::MakeTexelAvailable |
|
|
|
|
spirv::ImageOperands::MakeTexelVisible |
|
|
|
|
spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
|
|
|
|
|
|
|
|
if (spirv::bitEnumContains(attr.getValue(), noSupportOperands))
|
|
|
|
llvm_unreachable("unimplemented operands of Image Operands");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-10-30 14:41:26 -07:00
|
|
|
static LogicalResult verifyCastOp(Operation *op,
|
2020-09-11 10:39:00 -07:00
|
|
|
bool requireSameBitWidth = true,
|
|
|
|
bool skipBitWidthCheck = false) {
|
|
|
|
// Some CastOps have no limit on bit widths for result and operand type.
|
|
|
|
if (skipBitWidthCheck)
|
|
|
|
return success();
|
|
|
|
|
2020-01-11 08:54:04 -08:00
|
|
|
Type operandType = op->getOperand(0).getType();
|
|
|
|
Type resultType = op->getResult(0).getType();
|
2019-10-30 14:41:26 -07:00
|
|
|
|
|
|
|
// ODS checks that result type and operand type have the same shape.
|
|
|
|
if (auto vectorType = operandType.dyn_cast<VectorType>()) {
|
|
|
|
operandType = vectorType.getElementType();
|
|
|
|
resultType = resultType.cast<VectorType>().getElementType();
|
|
|
|
}
|
|
|
|
|
2020-05-21 11:35:32 -07:00
|
|
|
if (auto coopMatrixType =
|
|
|
|
operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
|
|
|
|
operandType = coopMatrixType.getElementType();
|
|
|
|
resultType =
|
|
|
|
resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
|
|
|
|
}
|
|
|
|
|
2019-10-30 14:41:26 -07:00
|
|
|
auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
|
|
|
|
auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
|
|
|
|
auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
|
|
|
|
|
|
|
|
if (requireSameBitWidth) {
|
|
|
|
if (!isSameBitWidth) {
|
|
|
|
return op->emitOpError(
|
|
|
|
"expected the same bit widths for operand type and result "
|
|
|
|
"type, but provided ")
|
|
|
|
<< operandType << " and " << resultType;
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (isSameBitWidth) {
|
|
|
|
return op->emitOpError(
|
|
|
|
"expected the different bit widths for operand type and result "
|
|
|
|
"type, but provided ")
|
|
|
|
<< operandType << " and " << resultType;
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-07-02 15:42:10 -04:00
|
|
|
template <typename MemoryOpTy>
|
2020-06-26 09:37:30 -04:00
|
|
|
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// ODS checks for attributes values. Just need to verify that if the
|
|
|
|
// memory-access attribute is Aligned, then the alignment attribute must be
|
|
|
|
// present.
|
2020-06-26 09:37:30 -04:00
|
|
|
auto *op = memoryOp.getOperation();
|
2020-07-09 19:08:51 -04:00
|
|
|
auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
|
2019-07-02 06:02:20 -07:00
|
|
|
if (!memAccessAttr) {
|
|
|
|
// Alignment attribute shouldn't be present if memory access attribute is
|
|
|
|
// not present.
|
2020-07-02 15:42:10 -04:00
|
|
|
if (op->getAttr(kAlignmentAttrName)) {
|
2020-06-26 09:37:30 -04:00
|
|
|
return memoryOp.emitOpError(
|
2019-06-24 10:59:05 -07:00
|
|
|
"invalid alignment specification without aligned memory access "
|
|
|
|
"specification");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-07-02 06:02:20 -07:00
|
|
|
auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
|
|
|
|
auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
|
|
|
|
|
|
|
|
if (!memAccess) {
|
2020-06-26 09:37:30 -04:00
|
|
|
return memoryOp.emitOpError("invalid memory access specifier: ")
|
2019-07-02 06:02:20 -07:00
|
|
|
<< memAccessVal;
|
|
|
|
}
|
|
|
|
|
2019-09-16 09:22:43 -07:00
|
|
|
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
|
2020-07-02 15:42:10 -04:00
|
|
|
if (!op->getAttr(kAlignmentAttrName)) {
|
2020-06-26 09:37:30 -04:00
|
|
|
return memoryOp.emitOpError("missing alignment value");
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
|
|
|
} else {
|
2020-07-02 15:42:10 -04:00
|
|
|
if (op->getAttr(kAlignmentAttrName)) {
|
2020-06-26 09:37:30 -04:00
|
|
|
return memoryOp.emitOpError(
|
2019-06-24 10:59:05 -07:00
|
|
|
"invalid alignment specification with non-aligned memory access "
|
|
|
|
"specification");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-07-09 19:08:51 -04:00
|
|
|
// TODO Make sure to merge this and the previous function into one template
|
2020-10-29 04:03:15 +09:00
|
|
|
// parameterized by memory access attribute name and alignment. Doing so now
|
2020-07-09 19:08:51 -04:00
|
|
|
// results in VS2017 in producing an internal error (at the call site) that's
|
2020-10-29 04:03:15 +09:00
|
|
|
// not detailed enough to understand what is happening.
|
2020-07-09 19:08:51 -04:00
|
|
|
template <typename MemoryOpTy>
|
|
|
|
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
|
|
|
|
// ODS checks for attributes values. Just need to verify that if the
|
|
|
|
// memory-access attribute is Aligned, then the alignment attribute must be
|
|
|
|
// present.
|
|
|
|
auto *op = memoryOp.getOperation();
|
|
|
|
auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
|
|
|
|
if (!memAccessAttr) {
|
|
|
|
// Alignment attribute shouldn't be present if memory access attribute is
|
|
|
|
// not present.
|
|
|
|
if (op->getAttr(kSourceAlignmentAttrName)) {
|
|
|
|
return memoryOp.emitOpError(
|
|
|
|
"invalid alignment specification without aligned memory access "
|
|
|
|
"specification");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
|
|
|
|
auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
|
|
|
|
|
|
|
|
if (!memAccess) {
|
|
|
|
return memoryOp.emitOpError("invalid memory access specifier: ")
|
|
|
|
<< memAccessVal;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
|
|
|
|
if (!op->getAttr(kSourceAlignmentAttrName)) {
|
|
|
|
return memoryOp.emitOpError("missing alignment value");
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if (op->getAttr(kSourceAlignmentAttrName)) {
|
|
|
|
return memoryOp.emitOpError(
|
|
|
|
"invalid alignment specification with non-aligned memory access "
|
|
|
|
"specification");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-10-10 22:40:20 +08:00
|
|
|
static LogicalResult
|
|
|
|
verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
|
2019-09-21 10:18:00 -07:00
|
|
|
// According to the SPIR-V specification:
|
|
|
|
// "Despite being a mask and allowing multiple bits to be combined, it is
|
|
|
|
// invalid for more than one of these four bits to be set: Acquire, Release,
|
|
|
|
// AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
|
|
|
|
// Release semantics is done by setting the AcquireRelease bit, not by setting
|
|
|
|
// two bits."
|
|
|
|
auto atMostOneInSet = spirv::MemorySemantics::Acquire |
|
|
|
|
spirv::MemorySemantics::Release |
|
|
|
|
spirv::MemorySemantics::AcquireRelease |
|
|
|
|
spirv::MemorySemantics::SequentiallyConsistent;
|
|
|
|
|
|
|
|
auto bitCount = llvm::countPopulation(
|
|
|
|
static_cast<uint32_t>(memorySemantics & atMostOneInSet));
|
|
|
|
if (bitCount > 1) {
|
2021-10-10 22:40:20 +08:00
|
|
|
return op->emitError(
|
|
|
|
"expected at most one of these four memory constraints "
|
|
|
|
"to be set: `Acquire`, `Release`,"
|
|
|
|
"`AcquireRelease` or `SequentiallyConsistent`");
|
2019-09-21 10:18:00 -07:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-06-24 10:59:05 -07:00
|
|
|
template <typename LoadStoreOpTy>
|
2019-12-23 14:45:01 -08:00
|
|
|
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
|
|
|
|
Value val) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// ODS already checks ptr is spirv::PointerType. Just check that the pointee
|
|
|
|
// type of the pointer and the type of the value are the same
|
|
|
|
//
|
2020-07-07 01:35:23 -07:00
|
|
|
// TODO: Check that the value type satisfies restrictions of
|
2019-06-24 10:59:05 -07:00
|
|
|
// SPIR-V OpLoad/OpStore operations
|
2020-01-11 08:54:04 -08:00
|
|
|
if (val.getType() !=
|
|
|
|
ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return op.emitOpError("mismatch in result type and pointer type");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-09-02 19:52:29 -07:00
|
|
|
template <typename BlockReadWriteOpTy>
|
|
|
|
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
|
|
|
|
Value ptr, Value val) {
|
|
|
|
auto valType = val.getType();
|
|
|
|
if (auto valVecTy = valType.dyn_cast<VectorType>())
|
|
|
|
valType = valVecTy.getElementType();
|
|
|
|
|
|
|
|
if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
|
|
|
|
return op.emitOpError("mismatch in result type and pointer type");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
static ParseResult parseVariableDecorations(OpAsmParser &parser,
|
2019-09-20 19:47:05 -07:00
|
|
|
OperationState &state) {
|
2020-04-14 18:54:23 -07:00
|
|
|
auto builtInName = llvm::convertToSnakeFromCamelCase(
|
|
|
|
stringifyDecoration(spirv::Decoration::BuiltIn));
|
2019-09-20 11:36:49 -07:00
|
|
|
if (succeeded(parser.parseOptionalKeyword("bind"))) {
|
2019-08-17 10:19:48 -07:00
|
|
|
Attribute set, binding;
|
|
|
|
// Parse optional descriptor binding
|
2020-04-14 18:54:23 -07:00
|
|
|
auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
|
2019-08-17 10:19:48 -07:00
|
|
|
stringifyDecoration(spirv::Decoration::DescriptorSet));
|
2020-04-14 18:54:23 -07:00
|
|
|
auto bindingName = llvm::convertToSnakeFromCamelCase(
|
|
|
|
stringifyDecoration(spirv::Decoration::Binding));
|
2019-09-20 11:36:49 -07:00
|
|
|
Type i32Type = parser.getBuilder().getIntegerType(32);
|
|
|
|
if (parser.parseLParen() ||
|
|
|
|
parser.parseAttribute(set, i32Type, descriptorSetName,
|
2019-09-20 19:47:05 -07:00
|
|
|
state.attributes) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseComma() ||
|
|
|
|
parser.parseAttribute(binding, i32Type, bindingName,
|
2019-09-20 19:47:05 -07:00
|
|
|
state.attributes) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseRParen()) {
|
2019-08-17 10:19:48 -07:00
|
|
|
return failure();
|
|
|
|
}
|
2019-09-20 11:36:49 -07:00
|
|
|
} else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
|
2019-08-17 10:19:48 -07:00
|
|
|
StringAttr builtIn;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseLParen() ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.parseAttribute(builtIn, builtInName, state.attributes) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseRParen()) {
|
2019-08-17 10:19:48 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parse other attributes
|
2019-11-05 13:32:07 -08:00
|
|
|
if (parser.parseOptionalAttrDict(state.attributes))
|
2019-08-17 10:19:48 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
|
2019-08-17 10:19:48 -07:00
|
|
|
SmallVectorImpl<StringRef> &elidedAttrs) {
|
|
|
|
// Print optional descriptor binding
|
2020-04-14 18:54:23 -07:00
|
|
|
auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
|
|
|
|
stringifyDecoration(spirv::Decoration::DescriptorSet));
|
|
|
|
auto bindingName = llvm::convertToSnakeFromCamelCase(
|
|
|
|
stringifyDecoration(spirv::Decoration::Binding));
|
2019-08-17 10:19:48 -07:00
|
|
|
auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
|
|
|
|
auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
|
|
|
|
if (descriptorSet && binding) {
|
|
|
|
elidedAttrs.push_back(descriptorSetName);
|
|
|
|
elidedAttrs.push_back(bindingName);
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
|
|
|
|
<< ")";
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// Print BuiltIn attribute if present
|
2020-04-14 18:54:23 -07:00
|
|
|
auto builtInName = llvm::convertToSnakeFromCamelCase(
|
|
|
|
stringifyDecoration(spirv::Decoration::BuiltIn));
|
2019-08-17 10:19:48 -07:00
|
|
|
if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
|
2019-08-17 10:19:48 -07:00
|
|
|
elidedAttrs.push_back(builtInName);
|
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
|
2019-09-25 19:01:18 -07:00
|
|
|
// Get bit width of types.
|
|
|
|
static unsigned getBitWidth(Type type) {
|
|
|
|
if (type.isa<spirv::PointerType>()) {
|
|
|
|
// Just return 64 bits for pointer types for now.
|
|
|
|
// TODO: Make sure not caller relies on the actual pointer width value.
|
|
|
|
return 64;
|
|
|
|
}
|
2020-03-04 15:12:33 -05:00
|
|
|
|
|
|
|
if (type.isIntOrFloat())
|
2019-09-25 19:01:18 -07:00
|
|
|
return type.getIntOrFloatBitWidth();
|
2020-03-04 15:12:33 -05:00
|
|
|
|
2019-09-25 19:01:18 -07:00
|
|
|
if (auto vectorType = type.dyn_cast<VectorType>()) {
|
2020-03-04 15:12:33 -05:00
|
|
|
assert(vectorType.getElementType().isIntOrFloat());
|
2019-09-25 19:01:18 -07:00
|
|
|
return vectorType.getNumElements() *
|
|
|
|
vectorType.getElementType().getIntOrFloatBitWidth();
|
|
|
|
}
|
|
|
|
llvm_unreachable("unhandled bit width computation for type");
|
|
|
|
}
|
|
|
|
|
2019-12-05 13:10:10 -08:00
|
|
|
/// Walks the given type hierarchy with the given indices, potentially down
|
|
|
|
/// to component granularity, to select an element type. Returns null type and
|
|
|
|
/// emits errors with the given loc on failure.
|
2019-12-10 10:11:19 -08:00
|
|
|
static Type
|
|
|
|
getElementType(Type type, ArrayRef<int32_t> indices,
|
2019-12-18 09:28:48 -08:00
|
|
|
function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
|
2019-12-10 10:11:19 -08:00
|
|
|
if (indices.empty()) {
|
|
|
|
emitErrorFn("expected at least one index for spv.CompositeExtract");
|
2019-12-05 13:10:10 -08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2019-12-10 10:11:19 -08:00
|
|
|
for (auto index : indices) {
|
2019-12-05 13:10:10 -08:00
|
|
|
if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
|
2020-05-21 11:59:31 -07:00
|
|
|
if (cType.hasCompileTimeKnownNumElements() &&
|
|
|
|
(index < 0 ||
|
|
|
|
static_cast<uint64_t>(index) >= cType.getNumElements())) {
|
2019-12-10 10:11:19 -08:00
|
|
|
emitErrorFn("index ") << index << " out of bounds for " << type;
|
2019-12-05 13:10:10 -08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
type = cType.getElementType(index);
|
|
|
|
} else {
|
2019-12-10 10:11:19 -08:00
|
|
|
emitErrorFn("cannot extract from non-composite type ")
|
2019-12-05 13:10:10 -08:00
|
|
|
<< type << " with index " << index;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return type;
|
|
|
|
}
|
|
|
|
|
2019-12-10 10:11:19 -08:00
|
|
|
static Type
|
|
|
|
getElementType(Type type, Attribute indices,
|
2019-12-18 09:28:48 -08:00
|
|
|
function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
|
2019-12-10 10:11:19 -08:00
|
|
|
auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
|
|
|
|
if (!indicesArrayAttr) {
|
|
|
|
emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
|
|
|
|
return nullptr;
|
|
|
|
}
|
2022-01-02 01:55:30 +00:00
|
|
|
if (indicesArrayAttr.empty()) {
|
2019-12-10 10:11:19 -08:00
|
|
|
emitErrorFn("expected at least one index for spv.CompositeExtract");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int32_t, 2> indexVals;
|
|
|
|
for (auto indexAttr : indicesArrayAttr) {
|
|
|
|
auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
|
|
|
|
if (!indexIntAttr) {
|
|
|
|
emitErrorFn("expected an 32-bit integer for index, but found '")
|
|
|
|
<< indexAttr << "'";
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
indexVals.push_back(indexIntAttr.getInt());
|
|
|
|
}
|
|
|
|
return getElementType(type, indexVals, emitErrorFn);
|
|
|
|
}
|
|
|
|
|
|
|
|
static Type getElementType(Type type, Attribute indices, Location loc) {
|
|
|
|
auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
|
|
|
|
return ::mlir::emitError(loc, err);
|
|
|
|
};
|
|
|
|
return getElementType(type, indices, errorFn);
|
|
|
|
}
|
|
|
|
|
|
|
|
static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
|
2022-01-26 15:49:53 -08:00
|
|
|
SMLoc loc) {
|
2019-12-10 10:11:19 -08:00
|
|
|
auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
|
|
|
|
return parser.emitError(loc, err);
|
|
|
|
};
|
|
|
|
return getElementType(type, indices, errorFn);
|
|
|
|
}
|
|
|
|
|
2020-11-19 09:48:58 -05:00
|
|
|
/// Returns true if the given `block` only contains one `spv.mlir.merge` op.
|
2019-10-02 11:00:50 -07:00
|
|
|
static inline bool isMergeBlock(Block &block) {
|
|
|
|
return !block.empty() && std::next(block.begin()) == block.end() &&
|
|
|
|
isa<spirv::MergeOp>(block.front());
|
|
|
|
}
|
|
|
|
|
2019-09-25 16:34:37 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Common parsers and printers
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-12-16 15:05:21 -08:00
|
|
|
// Parses an atomic update op. If the update op does not take a value (like
|
|
|
|
// AtomicIIncrement) `hasValue` must be false.
|
|
|
|
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
|
|
|
|
OperationState &state, bool hasValue) {
|
|
|
|
spirv::Scope scope;
|
|
|
|
spirv::MemorySemantics memoryScope;
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
|
|
|
|
OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
|
2019-12-16 15:05:21 -08:00
|
|
|
Type type;
|
2022-01-26 15:49:53 -08:00
|
|
|
SMLoc loc;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
|
|
|
|
parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
|
2019-12-16 15:05:21 -08:00
|
|
|
parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
|
|
|
|
parser.getCurrentLocation(&loc) || parser.parseColonType(type))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
|
|
|
if (!ptrType)
|
|
|
|
return parser.emitError(loc, "expected pointer type");
|
|
|
|
|
|
|
|
SmallVector<Type, 2> operandTypes;
|
|
|
|
operandTypes.push_back(ptrType);
|
|
|
|
if (hasValue)
|
|
|
|
operandTypes.push_back(ptrType.getPointeeType());
|
|
|
|
if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
|
|
|
|
state.operands))
|
|
|
|
return failure();
|
|
|
|
return parser.addTypeToList(ptrType.getPointeeType(), state.types);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Prints an atomic update op.
|
|
|
|
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
|
2021-08-28 03:03:15 +00:00
|
|
|
printer << " \"";
|
2019-12-16 15:05:21 -08:00
|
|
|
auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
|
|
|
|
printer << spirv::stringifyScope(
|
|
|
|
static_cast<spirv::Scope>(scopeAttr.getInt()))
|
|
|
|
<< "\" \"";
|
|
|
|
auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
|
|
|
|
printer << spirv::stringifyMemorySemantics(
|
|
|
|
static_cast<spirv::MemorySemantics>(
|
|
|
|
memorySemanticsAttr.getInt()))
|
2020-01-11 08:54:04 -08:00
|
|
|
<< "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
|
2019-12-16 15:05:21 -08:00
|
|
|
}
|
|
|
|
|
2021-10-28 19:04:35 +03:00
|
|
|
template <typename T>
|
|
|
|
static StringRef stringifyTypeName();
|
|
|
|
|
|
|
|
template <>
|
|
|
|
StringRef stringifyTypeName<IntegerType>() {
|
|
|
|
return "integer";
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
StringRef stringifyTypeName<FloatType>() {
|
|
|
|
return "float";
|
|
|
|
}
|
|
|
|
|
2019-12-16 15:05:21 -08:00
|
|
|
// Verifies an atomic update op.
|
2021-10-28 19:04:35 +03:00
|
|
|
template <typename ExpectedElementType>
|
2019-12-16 15:05:21 -08:00
|
|
|
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
|
2020-01-11 08:54:04 -08:00
|
|
|
auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
|
2019-12-16 15:05:21 -08:00
|
|
|
auto elementType = ptrType.getPointeeType();
|
2021-10-28 19:04:35 +03:00
|
|
|
if (!elementType.isa<ExpectedElementType>())
|
|
|
|
return op->emitOpError() << "pointer operand must point to an "
|
|
|
|
<< stringifyTypeName<ExpectedElementType>()
|
|
|
|
<< " value, found " << elementType;
|
2019-12-16 15:05:21 -08:00
|
|
|
|
|
|
|
if (op->getNumOperands() > 1) {
|
2020-01-11 08:54:04 -08:00
|
|
|
auto valueType = op->getOperand(1).getType();
|
2019-12-16 15:05:21 -08:00
|
|
|
if (valueType != elementType)
|
|
|
|
return op->emitOpError("expected value to have the same type as the "
|
|
|
|
"pointer operand's pointee type ")
|
|
|
|
<< elementType << ", but found " << valueType;
|
|
|
|
}
|
2021-10-10 22:40:20 +08:00
|
|
|
auto memorySemantics = static_cast<spirv::MemorySemantics>(
|
|
|
|
op->getAttrOfType<IntegerAttr>(kSemanticsAttrName).getInt());
|
|
|
|
if (failed(verifyMemorySemantics(op, memorySemantics))) {
|
|
|
|
return failure();
|
|
|
|
}
|
2019-12-16 15:05:21 -08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-01-28 09:36:01 -05:00
|
|
|
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
|
|
|
spirv::Scope executionScope;
|
|
|
|
spirv::GroupOperation groupOperation;
|
2022-03-21 21:42:13 +01:00
|
|
|
OpAsmParser::UnresolvedOperand valueInfo;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(executionScope, parser, state,
|
|
|
|
kExecutionScopeAttrName) ||
|
|
|
|
parseEnumStrAttr(groupOperation, parser, state,
|
|
|
|
kGroupOperationAttrName) ||
|
2020-01-28 09:36:01 -05:00
|
|
|
parser.parseOperand(valueInfo))
|
|
|
|
return failure();
|
|
|
|
|
2022-03-21 21:42:13 +01:00
|
|
|
Optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
|
2020-01-28 09:36:01 -05:00
|
|
|
if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
|
2022-03-21 21:42:13 +01:00
|
|
|
clusterSizeInfo = OpAsmParser::UnresolvedOperand();
|
2020-01-28 09:36:01 -05:00
|
|
|
if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
|
|
|
|
parser.parseRParen())
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
Type resultType;
|
|
|
|
if (parser.parseColonType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
if (parser.resolveOperand(valueInfo, resultType, state.operands))
|
|
|
|
return failure();
|
|
|
|
|
2022-06-20 11:22:37 -07:00
|
|
|
if (clusterSizeInfo) {
|
2020-01-28 09:36:01 -05:00
|
|
|
Type i32Type = parser.getBuilder().getIntegerType(32);
|
|
|
|
if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
return parser.addTypeToList(resultType, state.types);
|
|
|
|
}
|
|
|
|
|
|
|
|
static void printGroupNonUniformArithmeticOp(Operation *groupOp,
|
|
|
|
OpAsmPrinter &printer) {
|
2021-08-28 03:03:15 +00:00
|
|
|
printer << " \""
|
2020-01-28 09:36:01 -05:00
|
|
|
<< stringifyScope(static_cast<spirv::Scope>(
|
|
|
|
groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
|
|
|
|
.getInt()))
|
|
|
|
<< "\" \""
|
|
|
|
<< stringifyGroupOperation(static_cast<spirv::GroupOperation>(
|
|
|
|
groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
|
|
|
|
.getInt()))
|
|
|
|
<< "\" " << groupOp->getOperand(0);
|
|
|
|
|
|
|
|
if (groupOp->getNumOperands() > 1)
|
|
|
|
printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
|
|
|
|
printer << " : " << groupOp->getResult(0).getType();
|
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
|
|
|
|
spirv::Scope scope = static_cast<spirv::Scope>(
|
|
|
|
groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
|
|
|
|
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
|
|
|
return groupOp->emitOpError(
|
|
|
|
"execution scope must be 'Workgroup' or 'Subgroup'");
|
|
|
|
|
|
|
|
spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
|
|
|
|
groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
|
|
|
|
if (operation == spirv::GroupOperation::ClusteredReduce &&
|
|
|
|
groupOp->getNumOperands() == 1)
|
|
|
|
return groupOp->emitOpError("cluster size operand must be provided for "
|
|
|
|
"'ClusteredReduce' group operation");
|
|
|
|
if (groupOp->getNumOperands() > 1) {
|
|
|
|
Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
|
|
|
|
int32_t clusterSize = 0;
|
|
|
|
|
2020-07-07 01:35:23 -07:00
|
|
|
// TODO: support specialization constant here.
|
2020-01-28 09:36:01 -05:00
|
|
|
if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
|
|
|
|
return groupOp->emitOpError(
|
|
|
|
"cluster size operand must come from a constant op");
|
|
|
|
|
|
|
|
if (!llvm::isPowerOf2_32(clusterSize))
|
|
|
|
return groupOp->emitOpError(
|
|
|
|
"cluster size operand must be a power of two");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-30 10:40:07 -07:00
|
|
|
/// Result of a logical op must be a scalar or vector of boolean type.
|
2022-02-07 17:54:04 -08:00
|
|
|
static Type getUnaryOpResultType(Type operandType) {
|
|
|
|
Builder builder(operandType.getContext());
|
2019-09-30 10:40:07 -07:00
|
|
|
Type resultType = builder.getIntegerType(1);
|
2022-02-07 17:54:04 -08:00
|
|
|
if (auto vecType = operandType.dyn_cast<VectorType>())
|
2019-09-30 10:40:07 -07:00
|
|
|
return VectorType::get(vecType.getNumElements(), resultType);
|
|
|
|
return resultType;
|
|
|
|
}
|
|
|
|
|
2019-11-08 11:05:32 -08:00
|
|
|
static LogicalResult verifyShiftOp(Operation *op) {
|
2020-01-11 08:54:04 -08:00
|
|
|
if (op->getOperand(0).getType() != op->getResult(0).getType()) {
|
2019-11-08 11:05:32 -08:00
|
|
|
return op->emitError("expected the same type for the first operand and "
|
|
|
|
"result, but provided ")
|
2020-01-11 08:54:04 -08:00
|
|
|
<< op->getOperand(0).getType() << " and "
|
|
|
|
<< op->getResult(0).getType();
|
2019-11-08 11:05:32 -08:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-07-12 13:03:23 -04:00
|
|
|
static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
|
|
|
|
Value lhs, Value rhs) {
|
|
|
|
assert(lhs.getType() == rhs.getType());
|
|
|
|
|
|
|
|
Type boolType = builder.getI1Type();
|
|
|
|
if (auto vecType = lhs.getType().dyn_cast<VectorType>())
|
|
|
|
boolType = VectorType::get(vecType.getShape(), boolType);
|
|
|
|
state.addTypes(boolType);
|
|
|
|
|
|
|
|
state.addOperands({lhs, rhs});
|
|
|
|
}
|
|
|
|
|
2021-01-22 13:08:00 -05:00
|
|
|
static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
|
|
|
|
Value value) {
|
|
|
|
Type boolType = builder.getI1Type();
|
|
|
|
if (auto vecType = value.getType().dyn_cast<VectorType>())
|
|
|
|
boolType = VectorType::get(vecType.getShape(), boolType);
|
|
|
|
state.addTypes(boolType);
|
|
|
|
|
|
|
|
state.addOperands(value);
|
|
|
|
}
|
|
|
|
|
2019-07-25 15:42:41 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AccessChainOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-12-07 10:35:01 -08:00
|
|
|
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
|
2019-07-25 15:42:41 -07:00
|
|
|
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
|
|
|
if (!ptrType) {
|
|
|
|
emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
|
|
|
|
"to composite type, but provided ")
|
|
|
|
<< type;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto resultType = ptrType.getPointeeType();
|
|
|
|
auto resultStorageClass = ptrType.getStorageClass();
|
|
|
|
int32_t index = 0;
|
|
|
|
|
|
|
|
for (auto indexSSA : indices) {
|
|
|
|
auto cType = resultType.dyn_cast<spirv::CompositeType>();
|
|
|
|
if (!cType) {
|
|
|
|
emitError(baseLoc,
|
|
|
|
"'spv.AccessChain' op cannot extract from non-composite type ")
|
|
|
|
<< resultType << " with index " << index;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
index = 0;
|
|
|
|
if (resultType.isa<spirv::StructType>()) {
|
2020-01-11 08:54:04 -08:00
|
|
|
Operation *op = indexSSA.getDefiningOp();
|
2019-07-25 15:42:41 -07:00
|
|
|
if (!op) {
|
|
|
|
emitError(baseLoc, "'spv.AccessChain' op index must be an "
|
2021-03-04 16:15:46 -05:00
|
|
|
"integer spv.Constant to access "
|
2019-07-25 15:42:41 -07:00
|
|
|
"element of spv.struct");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2020-07-07 01:35:23 -07:00
|
|
|
// TODO: this should be relaxed to allow
|
2019-07-25 15:42:41 -07:00
|
|
|
// integer literals of other bitwidths.
|
|
|
|
if (failed(extractValueFromConstOp(op, index))) {
|
|
|
|
emitError(baseLoc,
|
2021-03-04 16:15:46 -05:00
|
|
|
"'spv.AccessChain' index must be an integer spv.Constant to "
|
2019-07-25 15:42:41 -07:00
|
|
|
"access element of spv.struct, but provided ")
|
|
|
|
<< op->getName();
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
|
|
|
|
emitError(baseLoc, "'spv.AccessChain' op index ")
|
|
|
|
<< index << " out of bounds for " << resultType;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
resultType = cType.getElementType(index);
|
|
|
|
}
|
|
|
|
return spirv::PointerType::get(resultType, resultStorageClass);
|
|
|
|
}
|
|
|
|
|
2020-04-23 16:02:46 +02:00
|
|
|
void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
|
2019-12-23 14:45:01 -08:00
|
|
|
Value basePtr, ValueRange indices) {
|
2020-01-11 08:54:04 -08:00
|
|
|
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
|
2019-08-27 10:49:53 -07:00
|
|
|
assert(type && "Unable to deduce return type based on basePtr and indices");
|
|
|
|
build(builder, state, type, basePtr, indices);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2022-03-21 21:42:13 +01:00
|
|
|
OpAsmParser::UnresolvedOperand ptrInfo;
|
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
|
2019-07-25 15:42:41 -07:00
|
|
|
Type type;
|
2020-06-22 10:04:53 -04:00
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
SmallVector<Type, 4> indicesTypes;
|
2019-07-25 15:42:41 -07:00
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(ptrInfo) ||
|
|
|
|
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser.parseColonType(type) ||
|
2020-06-22 10:04:53 -04:00
|
|
|
parser.resolveOperand(ptrInfo, type, state.operands)) {
|
2019-07-25 15:42:41 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2020-06-22 10:04:53 -04:00
|
|
|
// Check that the provided indices list is not empty before parsing their
|
|
|
|
// type list.
|
|
|
|
if (indicesInfo.empty()) {
|
2022-02-07 17:54:04 -08:00
|
|
|
return mlir::emitError(state.location, "'spv.AccessChain' op expected at "
|
|
|
|
"least one index ");
|
2020-06-22 10:04:53 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
if (parser.parseComma() || parser.parseTypeList(indicesTypes))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Check that the indices types list is not empty and that it has a one-to-one
|
|
|
|
// mapping to the provided indices.
|
|
|
|
if (indicesTypes.size() != indicesInfo.size()) {
|
2022-02-07 17:54:04 -08:00
|
|
|
return mlir::emitError(state.location,
|
|
|
|
"'spv.AccessChain' op indices types' count must be "
|
|
|
|
"equal to indices info count");
|
2020-06-22 10:04:53 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
|
|
|
|
return failure();
|
|
|
|
|
2019-07-25 15:42:41 -07:00
|
|
|
auto resultType = getElementPtrType(
|
2019-09-20 19:47:05 -07:00
|
|
|
type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
|
2019-07-25 15:42:41 -07:00
|
|
|
if (!resultType) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addTypes(resultType);
|
2019-07-25 15:42:41 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-08-14 11:57:02 +03:00
|
|
|
template <typename Op>
|
|
|
|
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
|
2021-08-28 03:03:15 +00:00
|
|
|
printer << ' ' << op.base_ptr() << '[' << indices
|
2021-08-14 11:57:02 +03:00
|
|
|
<< "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
|
|
|
|
printAccessChain(*this, indices(), printer);
|
2019-07-25 15:42:41 -07:00
|
|
|
}
|
|
|
|
|
2021-08-14 11:57:02 +03:00
|
|
|
template <typename Op>
|
|
|
|
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
|
2020-01-11 08:54:04 -08:00
|
|
|
auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
|
2019-07-25 15:42:41 -07:00
|
|
|
indices, accessChainOp.getLoc());
|
2021-08-14 11:57:02 +03:00
|
|
|
if (!resultType)
|
2019-07-25 15:42:41 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto providedResultType =
|
2021-08-14 11:57:02 +03:00
|
|
|
accessChainOp.getType().template dyn_cast<spirv::PointerType>();
|
|
|
|
if (!providedResultType)
|
2019-07-25 15:42:41 -07:00
|
|
|
return accessChainOp.emitOpError(
|
|
|
|
"result type must be a pointer, but provided")
|
|
|
|
<< providedResultType;
|
|
|
|
|
2021-08-14 11:57:02 +03:00
|
|
|
if (resultType != providedResultType)
|
2019-07-25 15:42:41 -07:00
|
|
|
return accessChainOp.emitOpError("invalid result type: expected ")
|
|
|
|
<< resultType << ", but provided " << providedResultType;
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::AccessChainOp::verify() {
|
|
|
|
return verifyAccessChain(*this, indices());
|
2021-08-14 11:57:02 +03:00
|
|
|
}
|
|
|
|
|
2019-08-17 10:19:48 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-11-17 11:45:32 -05:00
|
|
|
// spv.mlir.addressof
|
2019-08-17 10:19:48 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-04-23 16:02:46 +02:00
|
|
|
void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
|
2019-10-04 14:02:14 -07:00
|
|
|
spirv::GlobalVariableOp var) {
|
2021-08-30 09:31:48 -07:00
|
|
|
build(builder, state, var.type(), SymbolRefAttr::get(var));
|
2019-10-04 14:02:14 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::AddressOfOp::verify() {
|
2020-01-25 09:16:29 -05:00
|
|
|
auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
|
2022-02-02 10:06:30 -08:00
|
|
|
SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
|
|
|
|
variableAttr()));
|
2019-08-17 10:19:48 -07:00
|
|
|
if (!varOp) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("expected spv.GlobalVariable symbol");
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
2022-02-02 10:06:30 -08:00
|
|
|
if (pointer().getType() != varOp.type()) {
|
|
|
|
return emitOpError(
|
2019-08-20 13:33:41 -07:00
|
|
|
"result type mismatch with the referenced global variable's type");
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-10-28 19:04:35 +03:00
|
|
|
template <typename T>
|
|
|
|
static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
|
|
|
|
printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
|
|
|
|
<< stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
|
|
|
|
<< stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
|
|
|
|
<< atomOp.getOperands() << " : " << atomOp.pointer().getType();
|
|
|
|
}
|
2019-12-05 10:05:54 -08:00
|
|
|
|
2021-10-28 19:04:35 +03:00
|
|
|
static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2019-12-05 10:05:54 -08:00
|
|
|
spirv::Scope memoryScope;
|
|
|
|
spirv::MemorySemantics equalSemantics, unequalSemantics;
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
|
2019-12-05 10:05:54 -08:00
|
|
|
Type type;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
|
|
|
|
parseEnumStrAttr(equalSemantics, parser, state,
|
|
|
|
kEqualSemanticsAttrName) ||
|
|
|
|
parseEnumStrAttr(unequalSemantics, parser, state,
|
|
|
|
kUnequalSemanticsAttrName) ||
|
2019-12-05 10:05:54 -08:00
|
|
|
parser.parseOperandList(operandInfo, 3))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseColonType(type))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
|
|
|
if (!ptrType)
|
|
|
|
return parser.emitError(loc, "expected pointer type");
|
|
|
|
|
|
|
|
if (parser.resolveOperands(
|
|
|
|
operandInfo,
|
|
|
|
{ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
|
|
|
|
parser.getNameLoc(), state.operands))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
return parser.addTypeToList(ptrType.getPointeeType(), state.types);
|
|
|
|
}
|
|
|
|
|
2021-10-28 19:04:35 +03:00
|
|
|
template <typename T>
|
|
|
|
static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
|
2019-12-05 10:05:54 -08:00
|
|
|
// According to the spec:
|
|
|
|
// "The type of Value must be the same as Result Type. The type of the value
|
|
|
|
// pointed to by Pointer must be the same as Result Type. This type must also
|
|
|
|
// match the type of Comparator."
|
2020-01-11 08:54:04 -08:00
|
|
|
if (atomOp.getType() != atomOp.value().getType())
|
2019-12-05 10:05:54 -08:00
|
|
|
return atomOp.emitOpError("value operand must have the same type as the op "
|
|
|
|
"result, but found ")
|
2020-01-11 08:54:04 -08:00
|
|
|
<< atomOp.value().getType() << " vs " << atomOp.getType();
|
2019-12-05 10:05:54 -08:00
|
|
|
|
2020-01-11 08:54:04 -08:00
|
|
|
if (atomOp.getType() != atomOp.comparator().getType())
|
2019-12-05 10:05:54 -08:00
|
|
|
return atomOp.emitOpError(
|
|
|
|
"comparator operand must have the same type as the op "
|
|
|
|
"result, but found ")
|
2020-01-11 08:54:04 -08:00
|
|
|
<< atomOp.comparator().getType() << " vs " << atomOp.getType();
|
2019-12-05 10:05:54 -08:00
|
|
|
|
2021-10-28 19:04:35 +03:00
|
|
|
Type pointeeType = atomOp.pointer()
|
|
|
|
.getType()
|
|
|
|
.template cast<spirv::PointerType>()
|
|
|
|
.getPointeeType();
|
2019-12-05 10:05:54 -08:00
|
|
|
if (atomOp.getType() != pointeeType)
|
|
|
|
return atomOp.emitOpError(
|
|
|
|
"pointer operand's pointee type must have the same "
|
|
|
|
"as the op result type, but found ")
|
|
|
|
<< pointeeType << " vs " << atomOp.getType();
|
|
|
|
|
2020-07-07 01:35:23 -07:00
|
|
|
// TODO: Unequal cannot be set to Release or Acquire and Release.
|
2019-12-05 10:05:54 -08:00
|
|
|
// In addition, Unequal cannot be set to a stronger memory-order then Equal.
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicAndOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicAndOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, true);
|
|
|
|
}
|
|
|
|
void spirv::AtomicAndOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicCompareExchangeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicCompareExchangeOp::verify() {
|
|
|
|
return ::verifyAtomicCompareExchangeImpl(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicCompareExchangeImpl(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicCompareExchangeImpl(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicCompareExchangeWeakOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() {
|
|
|
|
return ::verifyAtomicCompareExchangeImpl(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicCompareExchangeImpl(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicCompareExchangeImpl(*this, p);
|
|
|
|
}
|
|
|
|
|
2021-10-28 19:04:35 +03:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicExchange
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) {
|
|
|
|
printer << " \"" << stringifyScope(memory_scope()) << "\" \""
|
|
|
|
<< stringifyMemorySemantics(semantics()) << "\" " << getOperands()
|
|
|
|
<< " : " << pointer().getType();
|
2021-10-28 19:04:35 +03:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2021-10-28 19:04:35 +03:00
|
|
|
spirv::Scope memoryScope;
|
|
|
|
spirv::MemorySemantics semantics;
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
|
2021-10-28 19:04:35 +03:00
|
|
|
Type type;
|
|
|
|
if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
|
|
|
|
parseEnumStrAttr(semantics, parser, state, kSemanticsAttrName) ||
|
|
|
|
parser.parseOperandList(operandInfo, 2))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseColonType(type))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
|
|
|
if (!ptrType)
|
|
|
|
return parser.emitError(loc, "expected pointer type");
|
|
|
|
|
|
|
|
if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
|
|
|
|
parser.getNameLoc(), state.operands))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
return parser.addTypeToList(ptrType.getPointeeType(), state.types);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::AtomicExchangeOp::verify() {
|
|
|
|
if (getType() != value().getType())
|
|
|
|
return emitOpError("value operand must have the same type as the op "
|
|
|
|
"result, but found ")
|
|
|
|
<< value().getType() << " vs " << getType();
|
2021-10-28 19:04:35 +03:00
|
|
|
|
|
|
|
Type pointeeType =
|
2022-02-02 10:06:30 -08:00
|
|
|
pointer().getType().cast<spirv::PointerType>().getPointeeType();
|
|
|
|
if (getType() != pointeeType)
|
|
|
|
return emitOpError("pointer operand's pointee type must have the same "
|
|
|
|
"as the op result type, but found ")
|
|
|
|
<< pointeeType << " vs " << getType();
|
2021-10-28 19:04:35 +03:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2022-02-07 17:54:04 -08:00
|
|
|
// spv.AtomicIAddOp
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicIAddOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, true);
|
|
|
|
}
|
|
|
|
void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicFAddEXTOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicFAddEXTOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<FloatType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, true);
|
|
|
|
}
|
|
|
|
void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicIDecrementOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicIDecrementOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, false);
|
|
|
|
}
|
|
|
|
void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicIIncrementOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicIIncrementOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, false);
|
|
|
|
}
|
|
|
|
void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicISubOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicISubOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, true);
|
|
|
|
}
|
|
|
|
void spirv::AtomicISubOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicOrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicOrOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, true);
|
|
|
|
}
|
|
|
|
void spirv::AtomicOrOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicSMaxOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicSMaxOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, true);
|
|
|
|
}
|
|
|
|
void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicSMinOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicSMinOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, true);
|
|
|
|
}
|
|
|
|
void spirv::AtomicSMinOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicUMaxOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicUMaxOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, true);
|
|
|
|
}
|
|
|
|
void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicUMinOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicUMinOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, true);
|
|
|
|
}
|
|
|
|
void spirv::AtomicUMinOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.AtomicXorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::AtomicXorOp::verify() {
|
|
|
|
return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return ::parseAtomicUpdateOp(parser, result, true);
|
|
|
|
}
|
|
|
|
void spirv::AtomicXorOp::print(OpAsmPrinter &p) {
|
|
|
|
::printAtomicUpdateOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2019-09-25 19:01:18 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.BitcastOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::BitcastOp::verify() {
|
2019-09-25 19:01:18 -07:00
|
|
|
// TODO: The SPIR-V spec validation rules are different for different
|
|
|
|
// versions.
|
2022-02-02 10:06:30 -08:00
|
|
|
auto operandType = operand().getType();
|
|
|
|
auto resultType = result().getType();
|
2019-09-25 19:01:18 -07:00
|
|
|
if (operandType == resultType) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("result type must be different from operand type");
|
2019-09-25 19:01:18 -07:00
|
|
|
}
|
|
|
|
if (operandType.isa<spirv::PointerType>() &&
|
|
|
|
!resultType.isa<spirv::PointerType>()) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError(
|
2019-09-25 19:01:18 -07:00
|
|
|
"unhandled bit cast conversion from pointer type to non-pointer type");
|
|
|
|
}
|
|
|
|
if (!operandType.isa<spirv::PointerType>() &&
|
|
|
|
resultType.isa<spirv::PointerType>()) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError(
|
2019-09-25 19:01:18 -07:00
|
|
|
"unhandled bit cast conversion from non-pointer type to pointer type");
|
|
|
|
}
|
|
|
|
auto operandBitWidth = getBitWidth(operandType);
|
|
|
|
auto resultBitWidth = getBitWidth(resultType);
|
|
|
|
if (operandBitWidth != resultBitWidth) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("mismatch in result type bitwidth ")
|
2019-09-25 19:01:18 -07:00
|
|
|
<< resultBitWidth << " and operand type bitwidth "
|
|
|
|
<< operandBitWidth;
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-03-05 12:40:23 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.BranchOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-04-08 08:17:36 +02:00
|
|
|
SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
|
2020-03-05 12:40:23 -08:00
|
|
|
assert(index == 0 && "invalid successor index");
|
2022-04-08 08:17:36 +02:00
|
|
|
return SuccessorOperands(0, targetOperandsMutable());
|
2020-03-05 12:40:23 -08:00
|
|
|
}
|
|
|
|
|
2019-08-30 12:17:21 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.BranchConditionalOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-04-08 08:17:36 +02:00
|
|
|
SuccessorOperands
|
|
|
|
spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
|
2020-03-05 12:40:23 -08:00
|
|
|
assert(index < 2 && "invalid successor index");
|
2022-04-08 08:17:36 +02:00
|
|
|
return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable()
|
|
|
|
: falseTargetOperandsMutable());
|
2020-03-05 12:40:23 -08:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2019-09-20 11:36:49 -07:00
|
|
|
auto &builder = parser.getBuilder();
|
2022-03-21 21:42:13 +01:00
|
|
|
OpAsmParser::UnresolvedOperand condInfo;
|
2019-08-30 12:17:21 -07:00
|
|
|
Block *dest;
|
|
|
|
|
|
|
|
// Parse the condition.
|
|
|
|
Type boolTy = builder.getI1Type();
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(condInfo) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperand(condInfo, boolTy, state.operands))
|
2019-08-30 12:17:21 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Parse the optional branch weights.
|
2019-09-20 11:36:49 -07:00
|
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
2019-08-30 12:17:21 -07:00
|
|
|
IntegerAttr trueWeight, falseWeight;
|
2020-05-06 13:48:36 -07:00
|
|
|
NamedAttrList weights;
|
2019-08-30 12:17:21 -07:00
|
|
|
|
|
|
|
auto i32Type = builder.getIntegerType(32);
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
|
|
|
|
parser.parseComma() ||
|
|
|
|
parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
|
|
|
|
parser.parseRSquare())
|
2019-08-30 12:17:21 -07:00
|
|
|
return failure();
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addAttribute(kBranchWeightAttrName,
|
|
|
|
builder.getArrayAttr({trueWeight, falseWeight}));
|
2019-08-30 12:17:21 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// Parse the true branch.
|
2020-03-05 12:48:28 -08:00
|
|
|
SmallVector<Value, 4> trueOperands;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseComma() ||
|
2020-03-05 12:48:28 -08:00
|
|
|
parser.parseSuccessorAndUseList(dest, trueOperands))
|
2019-08-30 12:17:21 -07:00
|
|
|
return failure();
|
2020-03-05 12:48:28 -08:00
|
|
|
state.addSuccessors(dest);
|
|
|
|
state.addOperands(trueOperands);
|
2019-08-30 12:17:21 -07:00
|
|
|
|
|
|
|
// Parse the false branch.
|
2020-03-05 12:48:28 -08:00
|
|
|
SmallVector<Value, 4> falseOperands;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseComma() ||
|
2020-03-05 12:48:28 -08:00
|
|
|
parser.parseSuccessorAndUseList(dest, falseOperands))
|
2019-08-30 12:17:21 -07:00
|
|
|
return failure();
|
2020-03-05 12:48:28 -08:00
|
|
|
state.addSuccessors(dest);
|
|
|
|
state.addOperands(falseOperands);
|
|
|
|
state.addAttribute(
|
|
|
|
spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
|
|
|
|
builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
|
|
|
|
static_cast<int32_t>(falseOperands.size())}));
|
2019-08-30 12:17:21 -07:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
|
|
|
|
printer << ' ' << condition();
|
2019-08-30 12:17:21 -07:00
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
if (auto weights = branch_weights()) {
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << " [";
|
2020-04-14 14:53:28 -07:00
|
|
|
llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << a.cast<IntegerAttr>().getInt();
|
2019-09-11 14:02:23 -07:00
|
|
|
});
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << "]";
|
2019-08-30 12:17:21 -07:00
|
|
|
}
|
|
|
|
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << ", ";
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << ", ";
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
|
2019-08-30 12:17:21 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::BranchConditionalOp::verify() {
|
|
|
|
if (auto weights = branch_weights()) {
|
2019-08-30 12:17:21 -07:00
|
|
|
if (weights->getValue().size() != 2) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("must have exactly two branch weights");
|
2019-08-30 12:17:21 -07:00
|
|
|
}
|
|
|
|
if (llvm::all_of(*weights, [](Attribute attr) {
|
|
|
|
return attr.cast<IntegerAttr>().getValue().isNullValue();
|
|
|
|
}))
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("branch weights cannot both be zero");
|
2019-08-30 12:17:21 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-09 12:43:23 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.CompositeConstruct
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::CompositeConstructOp::verify() {
|
|
|
|
auto cType = getType().cast<spirv::CompositeType>();
|
|
|
|
operand_range constituents = this->constituents();
|
2020-05-21 11:59:31 -07:00
|
|
|
|
2022-07-27 19:16:56 -04:00
|
|
|
if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
|
2020-05-21 11:59:31 -07:00
|
|
|
if (constituents.size() != 1)
|
2022-07-27 19:16:56 -04:00
|
|
|
return emitOpError("has incorrect number of operands: expected ")
|
2020-05-21 11:59:31 -07:00
|
|
|
<< "1, but provided " << constituents.size();
|
2022-07-27 19:16:56 -04:00
|
|
|
if (coopType.getElementType() != constituents.front().getType())
|
|
|
|
return emitOpError("operand type mismatch: expected operand type ")
|
|
|
|
<< coopType.getElementType() << ", but provided "
|
|
|
|
<< constituents.front().getType();
|
|
|
|
return success();
|
2019-12-09 12:43:23 -08:00
|
|
|
}
|
|
|
|
|
2022-07-27 19:16:56 -04:00
|
|
|
if (constituents.size() == cType.getNumElements()) {
|
|
|
|
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
|
|
|
if (constituents[index].getType() != cType.getElementType(index)) {
|
|
|
|
return emitOpError("operand type mismatch: expected operand type ")
|
|
|
|
<< cType.getElementType(index) << ", but provided "
|
|
|
|
<< constituents[index].getType();
|
|
|
|
}
|
2019-12-09 12:43:23 -08:00
|
|
|
}
|
2022-07-27 19:16:56 -04:00
|
|
|
return success();
|
2019-12-09 12:43:23 -08:00
|
|
|
}
|
|
|
|
|
2022-07-27 19:16:56 -04:00
|
|
|
// If not constructing a cooperative matrix type, then we must be constructing
|
|
|
|
// a vector type.
|
|
|
|
auto resultType = cType.dyn_cast<VectorType>();
|
|
|
|
if (!resultType)
|
|
|
|
return emitOpError(
|
|
|
|
"expected to return a vector or cooperative matrix when the number of "
|
|
|
|
"constituents is less than what the result needs");
|
|
|
|
|
|
|
|
SmallVector<unsigned> sizes;
|
|
|
|
for (Value component : constituents) {
|
|
|
|
if (!component.getType().isa<VectorType>() &&
|
|
|
|
!component.getType().isIntOrFloat())
|
|
|
|
return emitOpError("operand type mismatch: expected operand to have "
|
|
|
|
"a scalar or vector type, but provided ")
|
|
|
|
<< component.getType();
|
|
|
|
|
|
|
|
Type elementType = component.getType();
|
|
|
|
if (auto vectorType = component.getType().dyn_cast<VectorType>()) {
|
|
|
|
sizes.push_back(vectorType.getNumElements());
|
|
|
|
elementType = vectorType.getElementType();
|
|
|
|
} else {
|
|
|
|
sizes.push_back(1);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (elementType != resultType.getElementType())
|
|
|
|
return emitOpError("operand element type mismatch: expected to be ")
|
|
|
|
<< resultType.getElementType() << ", but provided " << elementType;
|
|
|
|
}
|
|
|
|
unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
|
|
|
|
if (totalCount != cType.getNumElements())
|
|
|
|
return emitOpError("has incorrect number of operands: expected ")
|
|
|
|
<< cType.getNumElements() << ", but provided " << totalCount;
|
2019-12-09 12:43:23 -08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-07-12 06:14:53 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.CompositeExtractOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-04-23 16:02:46 +02:00
|
|
|
void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
|
2019-12-23 14:45:01 -08:00
|
|
|
Value composite,
|
2019-12-10 10:11:19 -08:00
|
|
|
ArrayRef<int32_t> indices) {
|
2020-04-23 16:02:46 +02:00
|
|
|
auto indexAttr = builder.getI32ArrayAttr(indices);
|
2019-12-10 10:11:19 -08:00
|
|
|
auto elementType =
|
2020-01-11 08:54:04 -08:00
|
|
|
getElementType(composite.getType(), indexAttr, state.location);
|
2019-12-10 10:11:19 -08:00
|
|
|
if (!elementType) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
build(builder, state, elementType, composite, indexAttr);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2022-03-21 21:42:13 +01:00
|
|
|
OpAsmParser::UnresolvedOperand compositeInfo;
|
2019-07-12 06:14:53 -07:00
|
|
|
Attribute indicesAttr;
|
|
|
|
Type compositeType;
|
2022-01-26 15:49:53 -08:00
|
|
|
SMLoc attrLocation;
|
2019-07-12 06:14:53 -07:00
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(compositeInfo) ||
|
|
|
|
parser.getCurrentLocation(&attrLocation) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseColonType(compositeType) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
|
2019-07-12 06:14:53 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-12-10 10:11:19 -08:00
|
|
|
Type resultType =
|
|
|
|
getElementType(compositeType, indicesAttr, parser, attrLocation);
|
|
|
|
if (!resultType) {
|
|
|
|
return failure();
|
2019-07-12 06:14:53 -07:00
|
|
|
}
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addTypes(resultType);
|
2019-07-12 06:14:53 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
|
|
|
|
printer << ' ' << composite() << indices() << " : " << composite().getType();
|
2019-07-12 06:14:53 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::CompositeExtractOp::verify() {
|
|
|
|
auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
|
|
|
|
auto resultType =
|
|
|
|
getElementType(composite().getType(), indicesArrayAttr, getLoc());
|
2019-12-05 13:10:10 -08:00
|
|
|
if (!resultType)
|
|
|
|
return failure();
|
2019-07-12 06:14:53 -07:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (resultType != getType()) {
|
|
|
|
return emitOpError("invalid result type: expected ")
|
|
|
|
<< resultType << " but provided " << getType();
|
2019-07-12 06:14:53 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-05 13:10:10 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.CompositeInsert
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-10-06 11:35:14 -07:00
|
|
|
void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
|
|
|
|
Value object, Value composite,
|
|
|
|
ArrayRef<int32_t> indices) {
|
|
|
|
auto indexAttr = builder.getI32ArrayAttr(indices);
|
|
|
|
build(builder, state, composite.getType(), object, composite, indexAttr);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
|
2019-12-05 13:10:10 -08:00
|
|
|
Type objectType, compositeType;
|
|
|
|
Attribute indicesAttr;
|
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
|
|
|
|
return failure(
|
|
|
|
parser.parseOperandList(operands, 2) ||
|
|
|
|
parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
|
|
|
|
parser.parseColonType(objectType) ||
|
|
|
|
parser.parseKeywordType("into", compositeType) ||
|
|
|
|
parser.resolveOperands(operands, {objectType, compositeType}, loc,
|
|
|
|
state.operands) ||
|
|
|
|
parser.addTypesToList(compositeType, state.types));
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::CompositeInsertOp::verify() {
|
|
|
|
auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
|
2019-12-05 13:10:10 -08:00
|
|
|
auto objectType =
|
2022-02-02 10:06:30 -08:00
|
|
|
getElementType(composite().getType(), indicesArrayAttr, getLoc());
|
2019-12-05 13:10:10 -08:00
|
|
|
if (!objectType)
|
|
|
|
return failure();
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (objectType != object().getType()) {
|
|
|
|
return emitOpError("object operand type should be ")
|
|
|
|
<< objectType << ", but found " << object().getType();
|
2019-12-05 13:10:10 -08:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (composite().getType() != getType()) {
|
|
|
|
return emitOpError("result type should be the same as "
|
|
|
|
"the composite type, but found ")
|
|
|
|
<< composite().getType() << " vs " << getType();
|
2019-12-05 13:10:10 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
|
|
|
|
printer << " " << object() << ", " << composite() << indices() << " : "
|
|
|
|
<< object().getType() << " into " << composite().getType();
|
2019-12-05 13:10:10 -08:00
|
|
|
}
|
|
|
|
|
2019-06-17 14:47:22 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-03-04 16:15:46 -05:00
|
|
|
// spv.Constant
|
2019-06-17 14:47:22 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2019-06-17 14:47:22 -07:00
|
|
|
Attribute value;
|
2019-09-20 19:47:05 -07:00
|
|
|
if (parser.parseAttribute(value, kValueAttrName, state.attributes))
|
2019-06-17 14:47:22 -07:00
|
|
|
return failure();
|
|
|
|
|
2019-11-18 20:01:28 -08:00
|
|
|
Type type = value.getType();
|
2020-06-29 07:31:48 -07:00
|
|
|
if (type.isa<NoneType, TensorType>()) {
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseColonType(type))
|
2019-06-17 14:47:22 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
return parser.addTypeToList(type, state.types);
|
2019-06-17 14:47:22 -07:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::ConstantOp::print(OpAsmPrinter &printer) {
|
|
|
|
printer << ' ' << value();
|
|
|
|
if (getType().isa<spirv::ArrayType>())
|
|
|
|
printer << " : " << getType();
|
2019-06-17 14:47:22 -07:00
|
|
|
}
|
|
|
|
|
2022-02-07 12:42:23 -08:00
|
|
|
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
|
|
|
|
Type opType) {
|
2019-06-17 14:47:22 -07:00
|
|
|
auto valueType = value.getType();
|
|
|
|
|
2020-08-07 13:30:29 -07:00
|
|
|
if (value.isa<IntegerAttr, FloatAttr>()) {
|
2019-06-17 14:47:22 -07:00
|
|
|
if (valueType != opType)
|
2022-02-07 12:42:23 -08:00
|
|
|
return op.emitOpError("result type (")
|
2019-06-17 14:47:22 -07:00
|
|
|
<< opType << ") does not match value type (" << valueType << ")";
|
|
|
|
return success();
|
2020-08-07 13:30:29 -07:00
|
|
|
}
|
|
|
|
if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
|
2019-11-18 20:01:28 -08:00
|
|
|
if (valueType == opType)
|
2020-08-07 13:30:29 -07:00
|
|
|
return success();
|
2019-11-18 20:01:28 -08:00
|
|
|
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
|
|
|
|
auto shapedType = valueType.dyn_cast<ShapedType>();
|
2022-02-02 10:06:30 -08:00
|
|
|
if (!arrayType)
|
2022-02-07 12:42:23 -08:00
|
|
|
return op.emitOpError("result or element type (")
|
|
|
|
<< opType << ") does not match value type (" << valueType
|
|
|
|
<< "), must be the same or spv.array";
|
2019-11-18 20:01:28 -08:00
|
|
|
|
|
|
|
int numElements = arrayType.getNumElements();
|
|
|
|
auto opElemType = arrayType.getElementType();
|
|
|
|
while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
|
|
|
|
numElements *= t.getNumElements();
|
|
|
|
opElemType = t.getElementType();
|
|
|
|
}
|
2020-08-07 13:30:29 -07:00
|
|
|
if (!opElemType.isIntOrFloat())
|
2022-02-07 12:42:23 -08:00
|
|
|
return op.emitOpError("only support nested array result type");
|
2019-11-18 20:01:28 -08:00
|
|
|
|
|
|
|
auto valueElemType = shapedType.getElementType();
|
|
|
|
if (valueElemType != opElemType) {
|
2022-02-07 12:42:23 -08:00
|
|
|
return op.emitOpError("result element type (")
|
2019-11-18 20:01:28 -08:00
|
|
|
<< opElemType << ") does not match value element type ("
|
|
|
|
<< valueElemType << ")";
|
|
|
|
}
|
|
|
|
|
|
|
|
if (numElements != shapedType.getNumElements()) {
|
2022-02-07 12:42:23 -08:00
|
|
|
return op.emitOpError("result number of elements (")
|
2019-11-18 20:01:28 -08:00
|
|
|
<< numElements << ") does not match value number of elements ("
|
|
|
|
<< shapedType.getNumElements() << ")";
|
|
|
|
}
|
2020-08-07 13:30:29 -07:00
|
|
|
return success();
|
|
|
|
}
|
2022-02-07 12:42:23 -08:00
|
|
|
if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
|
2019-06-17 14:47:22 -07:00
|
|
|
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
|
|
|
|
if (!arrayType)
|
2022-02-07 12:42:23 -08:00
|
|
|
return op.emitOpError("must have spv.array result type for array value");
|
2020-08-07 13:30:29 -07:00
|
|
|
Type elemType = arrayType.getElementType();
|
2022-02-07 12:42:23 -08:00
|
|
|
for (Attribute element : arrayAttr.getValue()) {
|
|
|
|
// Verify array elements recursively.
|
|
|
|
if (failed(verifyConstantType(op, element, elemType)))
|
|
|
|
return failure();
|
2019-06-17 14:47:22 -07:00
|
|
|
}
|
2020-08-07 13:30:29 -07:00
|
|
|
return success();
|
2019-06-17 14:47:22 -07:00
|
|
|
}
|
2022-02-07 12:42:23 -08:00
|
|
|
return op.emitOpError("cannot have value of type ") << valueType;
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult spirv::ConstantOp::verify() {
|
|
|
|
// ODS already generates checks to make sure the result type is valid. We just
|
|
|
|
// need to additionally check that the value's attribute type is consistent
|
|
|
|
// with the result type.
|
|
|
|
return verifyConstantType(*this, valueAttr(), getType());
|
2019-06-17 14:47:22 -07:00
|
|
|
}
|
|
|
|
|
2019-09-03 12:09:07 -07:00
|
|
|
bool spirv::ConstantOp::isBuildableWith(Type type) {
|
|
|
|
// Must be valid SPIR-V type first.
|
2020-03-18 09:55:53 -04:00
|
|
|
if (!type.isa<spirv::SPIRVType>())
|
2019-09-03 12:09:07 -07:00
|
|
|
return false;
|
|
|
|
|
2020-08-07 13:30:43 -07:00
|
|
|
if (isa<SPIRVDialect>(type.getDialect())) {
|
2020-07-07 01:35:23 -07:00
|
|
|
// TODO: support constant struct
|
2019-09-03 12:09:07 -07:00
|
|
|
return type.isa<spirv::ArrayType>();
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2019-11-27 14:12:32 -08:00
|
|
|
spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
|
2020-04-23 16:02:46 +02:00
|
|
|
OpBuilder &builder) {
|
2019-11-27 14:12:32 -08:00
|
|
|
if (auto intType = type.dyn_cast<IntegerType>()) {
|
|
|
|
unsigned width = intType.getWidth();
|
|
|
|
if (width == 1)
|
2020-04-23 16:02:46 +02:00
|
|
|
return builder.create<spirv::ConstantOp>(loc, type,
|
|
|
|
builder.getBoolAttr(false));
|
|
|
|
return builder.create<spirv::ConstantOp>(
|
|
|
|
loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
|
2019-11-27 14:12:32 -08:00
|
|
|
}
|
2021-01-21 22:20:18 -08:00
|
|
|
if (auto floatType = type.dyn_cast<FloatType>()) {
|
|
|
|
return builder.create<spirv::ConstantOp>(
|
|
|
|
loc, type, builder.getFloatAttr(floatType, 0.0));
|
|
|
|
}
|
|
|
|
if (auto vectorType = type.dyn_cast<VectorType>()) {
|
|
|
|
Type elemType = vectorType.getElementType();
|
|
|
|
if (elemType.isa<IntegerType>()) {
|
|
|
|
return builder.create<spirv::ConstantOp>(
|
|
|
|
loc, type,
|
|
|
|
DenseElementsAttr::get(vectorType,
|
|
|
|
IntegerAttr::get(elemType, 0.0).getValue()));
|
|
|
|
}
|
|
|
|
if (elemType.isa<FloatType>()) {
|
|
|
|
return builder.create<spirv::ConstantOp>(
|
|
|
|
loc, type,
|
|
|
|
DenseFPElementsAttr::get(vectorType,
|
|
|
|
FloatAttr::get(elemType, 0.0).getValue()));
|
|
|
|
}
|
|
|
|
}
|
2019-11-27 14:12:32 -08:00
|
|
|
|
|
|
|
llvm_unreachable("unimplemented types for ConstantOp::getZero()");
|
|
|
|
}
|
|
|
|
|
|
|
|
spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
|
2020-04-23 16:02:46 +02:00
|
|
|
OpBuilder &builder) {
|
2019-11-27 14:12:32 -08:00
|
|
|
if (auto intType = type.dyn_cast<IntegerType>()) {
|
|
|
|
unsigned width = intType.getWidth();
|
|
|
|
if (width == 1)
|
2020-04-23 16:02:46 +02:00
|
|
|
return builder.create<spirv::ConstantOp>(loc, type,
|
|
|
|
builder.getBoolAttr(true));
|
|
|
|
return builder.create<spirv::ConstantOp>(
|
|
|
|
loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
|
2019-11-27 14:12:32 -08:00
|
|
|
}
|
2021-01-21 22:20:18 -08:00
|
|
|
if (auto floatType = type.dyn_cast<FloatType>()) {
|
|
|
|
return builder.create<spirv::ConstantOp>(
|
|
|
|
loc, type, builder.getFloatAttr(floatType, 1.0));
|
|
|
|
}
|
|
|
|
if (auto vectorType = type.dyn_cast<VectorType>()) {
|
|
|
|
Type elemType = vectorType.getElementType();
|
|
|
|
if (elemType.isa<IntegerType>()) {
|
|
|
|
return builder.create<spirv::ConstantOp>(
|
|
|
|
loc, type,
|
|
|
|
DenseElementsAttr::get(vectorType,
|
|
|
|
IntegerAttr::get(elemType, 1.0).getValue()));
|
|
|
|
}
|
|
|
|
if (elemType.isa<FloatType>()) {
|
|
|
|
return builder.create<spirv::ConstantOp>(
|
|
|
|
loc, type,
|
|
|
|
DenseFPElementsAttr::get(vectorType,
|
|
|
|
FloatAttr::get(elemType, 1.0).getValue()));
|
|
|
|
}
|
|
|
|
}
|
2019-11-27 14:12:32 -08:00
|
|
|
|
|
|
|
llvm_unreachable("unimplemented types for ConstantOp::getOne()");
|
|
|
|
}
|
|
|
|
|
2021-05-28 08:49:45 +02:00
|
|
|
void mlir::spirv::ConstantOp::getAsmResultNames(
|
|
|
|
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
|
|
|
|
Type type = getType();
|
|
|
|
|
|
|
|
SmallString<32> specialNameBuffer;
|
|
|
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
|
|
|
specialName << "cst";
|
|
|
|
|
|
|
|
IntegerType intTy = type.dyn_cast<IntegerType>();
|
|
|
|
|
|
|
|
if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
|
|
|
|
if (intTy && intTy.getWidth() == 1) {
|
|
|
|
return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
|
|
|
|
}
|
|
|
|
|
|
|
|
if (intTy.isSignless()) {
|
|
|
|
specialName << intCst.getInt();
|
|
|
|
} else {
|
|
|
|
specialName << intCst.getSInt();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (intTy || type.isa<FloatType>()) {
|
|
|
|
specialName << '_' << type;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto vecType = type.dyn_cast<VectorType>()) {
|
|
|
|
specialName << "_vec_";
|
|
|
|
specialName << vecType.getDimSize(0);
|
|
|
|
|
|
|
|
Type elementType = vecType.getElementType();
|
|
|
|
|
|
|
|
if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
|
|
|
|
specialName << "x" << elementType;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
setNameFn(getResult(), specialName.str());
|
|
|
|
}
|
|
|
|
|
2021-06-07 13:19:39 +02:00
|
|
|
void mlir::spirv::AddressOfOp::getAsmResultNames(
|
|
|
|
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
|
|
|
|
SmallString<32> specialNameBuffer;
|
|
|
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
|
|
|
specialName << variable() << "_addr";
|
|
|
|
setNameFn(getResult(), specialName.str());
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ControlBarrierOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::ControlBarrierOp::verify() {
|
|
|
|
return verifyMemorySemantics(getOperation(), memory_semantics());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ConvertFToSOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::ConvertFToSOp::verify() {
|
|
|
|
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
|
|
|
|
/*skipBitWidthCheck=*/true);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ConvertFToUOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::ConvertFToUOp::verify() {
|
|
|
|
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
|
|
|
|
/*skipBitWidthCheck=*/true);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ConvertSToFOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::ConvertSToFOp::verify() {
|
|
|
|
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
|
|
|
|
/*skipBitWidthCheck=*/true);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ConvertUToFOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::ConvertUToFOp::verify() {
|
|
|
|
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
|
|
|
|
/*skipBitWidthCheck=*/true);
|
|
|
|
}
|
|
|
|
|
2019-07-08 10:56:20 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.EntryPoint
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-04-23 16:02:46 +02:00
|
|
|
void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
|
2019-10-04 14:02:14 -07:00
|
|
|
spirv::ExecutionModel executionModel,
|
2020-02-07 11:30:19 -05:00
|
|
|
spirv::FuncOp function,
|
2019-10-04 14:02:14 -07:00
|
|
|
ArrayRef<Attribute> interfaceVars) {
|
|
|
|
build(builder, state,
|
2021-02-27 15:21:00 +03:00
|
|
|
spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
|
2021-08-30 09:31:48 -07:00
|
|
|
SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
|
2019-10-04 14:02:14 -07:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2019-07-08 10:56:20 -07:00
|
|
|
spirv::ExecutionModel execModel;
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers;
|
2019-07-08 10:56:20 -07:00
|
|
|
SmallVector<Type, 0> idTypes;
|
2019-10-04 14:02:14 -07:00
|
|
|
SmallVector<Attribute, 4> interfaceVars;
|
2019-07-08 10:56:20 -07:00
|
|
|
|
2019-11-11 18:18:02 -08:00
|
|
|
FlatSymbolRefAttr fn;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(execModel, parser, state) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
|
2019-07-08 10:56:20 -07:00
|
|
|
return failure();
|
|
|
|
}
|
2019-08-17 10:19:48 -07:00
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
if (!parser.parseOptionalComma()) {
|
2019-08-17 10:19:48 -07:00
|
|
|
// Parse the interface variables
|
2021-09-20 18:27:40 -07:00
|
|
|
if (parser.parseCommaSeparatedList([&]() -> ParseResult {
|
|
|
|
// The name of the interface variable attribute isnt important
|
|
|
|
FlatSymbolRefAttr var;
|
|
|
|
NamedAttrList attrs;
|
|
|
|
if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
|
|
|
|
return failure();
|
|
|
|
interfaceVars.push_back(var);
|
|
|
|
return success();
|
|
|
|
}))
|
|
|
|
return failure();
|
2019-07-08 10:56:20 -07:00
|
|
|
}
|
2019-10-04 14:02:14 -07:00
|
|
|
state.addAttribute(kInterfaceAttrName,
|
|
|
|
parser.getBuilder().getArrayAttr(interfaceVars));
|
2019-07-08 10:56:20 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
|
|
|
|
printer << " \"" << stringifyExecutionModel(execution_model()) << "\" ";
|
|
|
|
printer.printSymbolName(fn());
|
|
|
|
auto interfaceVars = interface().getValue();
|
2019-10-04 14:02:14 -07:00
|
|
|
if (!interfaceVars.empty()) {
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << ", ";
|
2020-04-14 14:53:28 -07:00
|
|
|
llvm::interleaveComma(interfaceVars, printer);
|
2019-07-08 10:56:20 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::EntryPointOp::verify() {
|
2019-08-17 10:19:48 -07:00
|
|
|
// Checks for fn and interface symbol reference are done in spirv::ModuleOp
|
|
|
|
// verification.
|
2019-07-08 10:56:20 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ExecutionMode
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-04-23 16:02:46 +02:00
|
|
|
void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
|
2020-02-07 11:30:19 -05:00
|
|
|
spirv::FuncOp function,
|
2019-10-04 14:02:14 -07:00
|
|
|
spirv::ExecutionMode executionMode,
|
|
|
|
ArrayRef<int32_t> params) {
|
2021-08-30 09:31:48 -07:00
|
|
|
build(builder, state, SymbolRefAttr::get(function),
|
2021-02-27 15:21:00 +03:00
|
|
|
spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
|
2020-04-23 16:02:46 +02:00
|
|
|
builder.getI32ArrayAttr(params));
|
2019-10-04 14:02:14 -07:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2019-07-08 10:56:20 -07:00
|
|
|
spirv::ExecutionMode execMode;
|
2019-07-19 07:30:15 -07:00
|
|
|
Attribute fn;
|
2019-09-20 19:47:05 -07:00
|
|
|
if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
|
2020-03-11 16:04:25 -04:00
|
|
|
parseEnumStrAttr(execMode, parser, state)) {
|
2019-07-08 10:56:20 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int32_t, 4> values;
|
2019-09-20 11:36:49 -07:00
|
|
|
Type i32Type = parser.getBuilder().getIntegerType(32);
|
|
|
|
while (!parser.parseOptionalComma()) {
|
2020-05-06 13:48:36 -07:00
|
|
|
NamedAttrList attr;
|
2019-07-08 10:56:20 -07:00
|
|
|
Attribute value;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseAttribute(value, i32Type, "value", attr)) {
|
2019-07-08 10:56:20 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
values.push_back(value.cast<IntegerAttr>().getInt());
|
|
|
|
}
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addAttribute(kValuesAttrName,
|
|
|
|
parser.getBuilder().getI32ArrayAttr(values));
|
2019-07-08 10:56:20 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
|
2021-08-28 03:03:15 +00:00
|
|
|
printer << " ";
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printSymbolName(fn());
|
|
|
|
printer << " \"" << stringifyExecutionMode(execution_mode()) << "\"";
|
|
|
|
auto values = this->values();
|
2022-01-02 01:55:30 +00:00
|
|
|
if (values.empty())
|
2019-07-08 10:56:20 -07:00
|
|
|
return;
|
2019-09-20 20:43:02 -07:00
|
|
|
printer << ", ";
|
2020-04-14 14:53:28 -07:00
|
|
|
llvm::interleaveComma(values, printer, [&](Attribute a) {
|
2019-10-04 14:02:14 -07:00
|
|
|
printer << a.cast<IntegerAttr>().getInt();
|
|
|
|
});
|
2019-07-08 10:56:20 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.FConvertOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::FConvertOp::verify() {
|
|
|
|
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.SConvertOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::SConvertOp::verify() {
|
|
|
|
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.UConvertOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::UConvertOp::verify() {
|
|
|
|
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
|
|
|
|
}
|
|
|
|
|
2020-02-07 11:30:19 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.func
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
|
2022-04-28 17:26:43 -07:00
|
|
|
SmallVector<OpAsmParser::Argument> entryArgs;
|
|
|
|
SmallVector<DictionaryAttr> resultAttrs;
|
2022-01-19 23:44:43 +00:00
|
|
|
SmallVector<Type> resultTypes;
|
2020-02-07 11:30:19 -05:00
|
|
|
auto &builder = parser.getBuilder();
|
|
|
|
|
|
|
|
// Parse the name as a symbol.
|
|
|
|
StringAttr nameAttr;
|
|
|
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
|
|
|
state.attributes))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Parse the function signature.
|
|
|
|
bool isVariadic = false;
|
2022-01-13 20:51:38 -08:00
|
|
|
if (function_interface_impl::parseFunctionSignature(
|
2022-04-28 17:26:43 -07:00
|
|
|
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
|
|
|
|
resultAttrs))
|
2020-02-07 11:30:19 -05:00
|
|
|
return failure();
|
|
|
|
|
2022-04-28 17:26:43 -07:00
|
|
|
SmallVector<Type> argTypes;
|
|
|
|
for (auto &arg : entryArgs)
|
|
|
|
argTypes.push_back(arg.type);
|
2020-02-07 11:30:19 -05:00
|
|
|
auto fnType = builder.getFunctionType(argTypes, resultTypes);
|
2022-01-13 20:51:38 -08:00
|
|
|
state.addAttribute(FunctionOpInterface::getTypeAttrName(),
|
2021-05-07 19:30:25 -07:00
|
|
|
TypeAttr::get(fnType));
|
2020-02-07 11:30:19 -05:00
|
|
|
|
|
|
|
// Parse the optional function control keyword.
|
|
|
|
spirv::FunctionControl fnControl;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(fnControl, parser, state))
|
2020-02-07 11:30:19 -05:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
// If additional attributes are present, parse them.
|
|
|
|
if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Add the attributes to the function arguments.
|
|
|
|
assert(resultAttrs.size() == resultTypes.size());
|
2022-04-28 17:26:43 -07:00
|
|
|
function_interface_impl::addArgAndResultAttrs(builder, state, entryArgs,
|
2022-01-13 20:51:38 -08:00
|
|
|
resultAttrs);
|
2020-02-07 11:30:19 -05:00
|
|
|
|
|
|
|
// Parse the optional function body.
|
|
|
|
auto *body = state.addRegion();
|
2022-04-28 17:26:43 -07:00
|
|
|
OptionalParseResult result = parser.parseOptionalRegion(*body, entryArgs);
|
2022-06-25 11:56:50 -07:00
|
|
|
return failure(result.hasValue() && failed(*result));
|
2020-02-07 11:30:19 -05:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::FuncOp::print(OpAsmPrinter &printer) {
|
2020-02-07 11:30:19 -05:00
|
|
|
// Print function name, signature, and control.
|
2021-08-28 03:03:15 +00:00
|
|
|
printer << " ";
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printSymbolName(sym_name());
|
2022-03-15 17:36:15 -07:00
|
|
|
auto fnType = getFunctionType();
|
2022-01-13 20:51:38 -08:00
|
|
|
function_interface_impl::printFunctionSignature(
|
2022-02-07 17:54:04 -08:00
|
|
|
printer, *this, fnType.getInputs(),
|
2022-01-13 20:51:38 -08:00
|
|
|
/*isVariadic=*/false, fnType.getResults());
|
2022-02-07 17:54:04 -08:00
|
|
|
printer << " \"" << spirv::stringifyFunctionControl(function_control())
|
2020-02-07 11:30:19 -05:00
|
|
|
<< "\"";
|
2022-01-13 20:51:38 -08:00
|
|
|
function_interface_impl::printFunctionAttributes(
|
2022-02-07 17:54:04 -08:00
|
|
|
printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
|
2020-02-07 11:30:19 -05:00
|
|
|
{spirv::attributeName<spirv::FunctionControl>()});
|
|
|
|
|
|
|
|
// Print the body if this is not an external function.
|
2022-02-07 17:54:04 -08:00
|
|
|
Region &body = this->body();
|
2022-01-18 07:47:25 +00:00
|
|
|
if (!body.empty()) {
|
|
|
|
printer << ' ';
|
2020-02-07 11:30:19 -05:00
|
|
|
printer.printRegion(body, /*printEntryBlockArgs=*/false,
|
|
|
|
/*printBlockTerminators=*/true);
|
2022-01-18 07:47:25 +00:00
|
|
|
}
|
2020-02-07 11:30:19 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult spirv::FuncOp::verifyType() {
|
2022-03-15 17:36:15 -07:00
|
|
|
auto type = getFunctionTypeAttr().getValue();
|
2020-02-07 11:30:19 -05:00
|
|
|
if (!type.isa<FunctionType>())
|
|
|
|
return emitOpError("requires '" + getTypeAttrName() +
|
|
|
|
"' attribute of function type");
|
2022-03-15 17:36:15 -07:00
|
|
|
if (getFunctionType().getNumResults() > 1)
|
2020-02-07 11:30:19 -05:00
|
|
|
return emitOpError("cannot have more than one result");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult spirv::FuncOp::verifyBody() {
|
2022-03-15 17:36:15 -07:00
|
|
|
FunctionType fnType = getFunctionType();
|
2020-02-07 11:30:19 -05:00
|
|
|
|
|
|
|
auto walkResult = walk([fnType](Operation *op) -> WalkResult {
|
|
|
|
if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
|
|
|
|
if (fnType.getNumResults() != 0)
|
|
|
|
return retOp.emitOpError("cannot be used in functions returning value");
|
|
|
|
} else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
|
|
|
|
if (fnType.getNumResults() != 1)
|
|
|
|
return retOp.emitOpError(
|
|
|
|
"returns 1 value but enclosing function requires ")
|
|
|
|
<< fnType.getNumResults() << " results";
|
|
|
|
|
|
|
|
auto retOperandType = retOp.value().getType();
|
|
|
|
auto fnResultType = fnType.getResult(0);
|
|
|
|
if (retOperandType != fnResultType)
|
|
|
|
return retOp.emitOpError(" return value's type (")
|
|
|
|
<< retOperandType << ") mismatch with function's result type ("
|
|
|
|
<< fnResultType << ")";
|
|
|
|
}
|
|
|
|
return WalkResult::advance();
|
|
|
|
});
|
|
|
|
|
2020-07-07 01:35:23 -07:00
|
|
|
// TODO: verify other bits like linkage type.
|
2020-02-07 11:30:19 -05:00
|
|
|
|
|
|
|
return failure(walkResult.wasInterrupted());
|
|
|
|
}
|
|
|
|
|
2020-04-23 16:02:46 +02:00
|
|
|
void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
|
2020-02-07 11:30:19 -05:00
|
|
|
StringRef name, FunctionType type,
|
|
|
|
spirv::FunctionControl control,
|
|
|
|
ArrayRef<NamedAttribute> attrs) {
|
|
|
|
state.addAttribute(SymbolTable::getSymbolAttrName(),
|
2020-04-23 16:02:46 +02:00
|
|
|
builder.getStringAttr(name));
|
2020-02-07 11:30:19 -05:00
|
|
|
state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
|
2020-04-23 16:02:46 +02:00
|
|
|
state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
|
|
|
|
builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
|
2020-02-07 11:30:19 -05:00
|
|
|
state.attributes.append(attrs.begin(), attrs.end());
|
|
|
|
state.addRegion();
|
|
|
|
}
|
|
|
|
|
|
|
|
// CallableOpInterface
|
|
|
|
Region *spirv::FuncOp::getCallableRegion() {
|
|
|
|
return isExternal() ? nullptr : &body();
|
|
|
|
}
|
|
|
|
|
|
|
|
// CallableOpInterface
|
|
|
|
ArrayRef<Type> spirv::FuncOp::getCallableResults() {
|
2022-03-15 17:36:15 -07:00
|
|
|
return getFunctionType().getResults();
|
2020-02-07 11:30:19 -05:00
|
|
|
}
|
|
|
|
|
2019-09-16 15:39:16 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-09-16 17:11:50 -07:00
|
|
|
// spv.FunctionCall
|
2019-09-16 15:39:16 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::FunctionCallOp::verify() {
|
|
|
|
auto fnName = calleeAttr();
|
2019-09-16 15:39:16 -07:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
|
|
|
|
SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
|
2019-09-16 15:39:16 -07:00
|
|
|
if (!funcOp) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("callee function '")
|
2021-08-29 14:22:24 -07:00
|
|
|
<< fnName.getValue() << "' not found in nearest symbol table";
|
2019-09-16 15:39:16 -07:00
|
|
|
}
|
|
|
|
|
2022-03-15 17:36:15 -07:00
|
|
|
auto functionType = funcOp.getFunctionType();
|
2019-09-16 15:39:16 -07:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (getNumResults() > 1) {
|
|
|
|
return emitOpError(
|
2019-09-16 15:39:16 -07:00
|
|
|
"expected callee function to have 0 or 1 result, but provided ")
|
2022-02-02 10:06:30 -08:00
|
|
|
<< getNumResults();
|
2019-09-16 15:39:16 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (functionType.getNumInputs() != getNumOperands()) {
|
|
|
|
return emitOpError("has incorrect number of operands for callee: expected ")
|
2019-09-16 15:39:16 -07:00
|
|
|
<< functionType.getNumInputs() << ", but provided "
|
2022-02-02 10:06:30 -08:00
|
|
|
<< getNumOperands();
|
2019-09-16 15:39:16 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
|
2022-02-02 10:06:30 -08:00
|
|
|
if (getOperand(i).getType() != functionType.getInput(i)) {
|
|
|
|
return emitOpError("operand type mismatch: expected operand type ")
|
2019-09-16 15:39:16 -07:00
|
|
|
<< functionType.getInput(i) << ", but provided "
|
2022-02-02 10:06:30 -08:00
|
|
|
<< getOperand(i).getType() << " for operand number " << i;
|
2019-09-16 15:39:16 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (functionType.getNumResults() != getNumResults()) {
|
|
|
|
return emitOpError(
|
2019-09-16 15:39:16 -07:00
|
|
|
"has incorrect number of results has for callee: expected ")
|
|
|
|
<< functionType.getNumResults() << ", but provided "
|
2022-02-02 10:06:30 -08:00
|
|
|
<< getNumResults();
|
2019-09-16 15:39:16 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (getNumResults() &&
|
|
|
|
(getResult(0).getType() != functionType.getResult(0))) {
|
|
|
|
return emitOpError("result type mismatch: expected ")
|
2019-09-16 15:39:16 -07:00
|
|
|
<< functionType.getResult(0) << ", but provided "
|
2022-02-02 10:06:30 -08:00
|
|
|
<< getResult(0).getType();
|
2019-09-16 15:39:16 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-10-16 17:36:58 -07:00
|
|
|
CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
|
2020-12-09 11:50:18 +01:00
|
|
|
return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
|
2019-10-16 17:36:58 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
|
|
|
|
return arguments();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2022-07-21 13:02:45 -04:00
|
|
|
// spv.GLFClampOp
|
2022-02-07 17:54:04 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-07-21 13:02:45 -04:00
|
|
|
ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
2022-02-07 17:54:04 -08:00
|
|
|
return parseOneResultSameOperandTypeOp(parser, result);
|
|
|
|
}
|
2022-07-21 13:02:45 -04:00
|
|
|
void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
|
2022-02-07 17:54:04 -08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-07-21 13:02:45 -04:00
|
|
|
// spv.GLUClampOp
|
2022-02-07 17:54:04 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-07-21 13:02:45 -04:00
|
|
|
ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
2022-02-07 17:54:04 -08:00
|
|
|
return parseOneResultSameOperandTypeOp(parser, result);
|
|
|
|
}
|
2022-07-21 13:02:45 -04:00
|
|
|
void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
|
2022-02-07 17:54:04 -08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-07-21 13:02:45 -04:00
|
|
|
// spv.GLSClampOp
|
2022-02-07 17:54:04 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-07-21 13:02:45 -04:00
|
|
|
ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
2022-02-07 17:54:04 -08:00
|
|
|
return parseOneResultSameOperandTypeOp(parser, result);
|
|
|
|
}
|
2022-07-21 13:02:45 -04:00
|
|
|
void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
|
2022-02-07 17:54:04 -08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-07-21 13:02:45 -04:00
|
|
|
// spv.GLFmaOp
|
2022-02-07 17:54:04 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-07-21 13:02:45 -04:00
|
|
|
ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
|
2022-02-07 17:54:04 -08:00
|
|
|
return parseOneResultSameOperandTypeOp(parser, result);
|
|
|
|
}
|
2022-07-21 13:02:45 -04:00
|
|
|
void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
|
2022-02-07 17:54:04 -08:00
|
|
|
|
2019-08-17 10:19:48 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-03-04 16:17:12 -05:00
|
|
|
// spv.GlobalVariable
|
2019-08-17 10:19:48 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-04-23 16:02:46 +02:00
|
|
|
void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
|
2019-11-25 10:38:31 -08:00
|
|
|
Type type, StringRef name,
|
|
|
|
unsigned descriptorSet, unsigned binding) {
|
2021-10-05 00:04:33 +08:00
|
|
|
build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
|
2019-11-25 10:38:31 -08:00
|
|
|
state.addAttribute(
|
|
|
|
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
|
2020-04-23 16:02:46 +02:00
|
|
|
builder.getI32IntegerAttr(descriptorSet));
|
2019-11-25 10:38:31 -08:00
|
|
|
state.addAttribute(
|
|
|
|
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
|
2020-04-23 16:02:46 +02:00
|
|
|
builder.getI32IntegerAttr(binding));
|
2019-11-25 10:38:31 -08:00
|
|
|
}
|
|
|
|
|
2020-04-23 16:02:46 +02:00
|
|
|
void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
|
2019-12-10 10:11:19 -08:00
|
|
|
Type type, StringRef name,
|
|
|
|
spirv::BuiltIn builtin) {
|
2021-10-05 00:04:33 +08:00
|
|
|
build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
|
2019-12-10 10:11:19 -08:00
|
|
|
state.addAttribute(
|
|
|
|
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
|
2020-04-23 16:02:46 +02:00
|
|
|
builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
|
2019-12-10 10:11:19 -08:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2019-08-17 10:19:48 -07:00
|
|
|
// Parse variable name.
|
|
|
|
StringAttr nameAttr;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
2019-09-20 19:47:05 -07:00
|
|
|
state.attributes)) {
|
2019-08-17 10:19:48 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parse optional initializer
|
2019-09-20 11:36:49 -07:00
|
|
|
if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
|
2019-11-11 18:18:02 -08:00
|
|
|
FlatSymbolRefAttr initSymbol;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseLParen() ||
|
|
|
|
parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
|
2019-09-20 19:47:05 -07:00
|
|
|
state.attributes) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseRParen())
|
2019-08-17 10:19:48 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (parseVariableDecorations(parser, state)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-08-19 11:38:53 -07:00
|
|
|
Type type;
|
2019-09-20 11:36:49 -07:00
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseColonType(type)) {
|
2019-08-19 11:38:53 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
if (!type.isa<spirv::PointerType>()) {
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(loc, "expected spv.ptr type");
|
2019-08-19 11:38:53 -07:00
|
|
|
}
|
2019-10-17 20:08:01 -07:00
|
|
|
state.addAttribute(kTypeAttrName, TypeAttr::get(type));
|
2019-08-19 11:38:53 -07:00
|
|
|
|
2019-08-17 10:19:48 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
|
2019-08-17 10:19:48 -07:00
|
|
|
SmallVector<StringRef, 4> elidedAttrs{
|
|
|
|
spirv::attributeName<spirv::StorageClass>()};
|
|
|
|
|
|
|
|
// Print variable name.
|
2019-10-08 17:44:39 -07:00
|
|
|
printer << ' ';
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printSymbolName(sym_name());
|
2019-08-17 10:19:48 -07:00
|
|
|
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
|
|
|
|
|
|
|
|
// Print optional initializer
|
2022-02-07 17:54:04 -08:00
|
|
|
if (auto initializer = this->initializer()) {
|
2019-10-08 17:44:39 -07:00
|
|
|
printer << " " << kInitializerAttrName << '(';
|
2022-06-20 23:20:25 -07:00
|
|
|
printer.printSymbolName(*initializer);
|
2019-10-08 17:44:39 -07:00
|
|
|
printer << ')';
|
2019-08-17 10:19:48 -07:00
|
|
|
elidedAttrs.push_back(kInitializerAttrName);
|
|
|
|
}
|
2019-08-19 11:38:53 -07:00
|
|
|
|
|
|
|
elidedAttrs.push_back(kTypeAttrName);
|
2022-02-07 17:54:04 -08:00
|
|
|
printVariableDecorations(*this, printer, elidedAttrs);
|
|
|
|
printer << " : " << type();
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::GlobalVariableOp::verify() {
|
2019-08-17 10:19:48 -07:00
|
|
|
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
|
|
|
|
// object. It cannot be Generic. It must be the same as the Storage Class
|
|
|
|
// operand of the Result Type."
|
2020-07-29 08:47:22 +03:00
|
|
|
// Also, Function storage class is reserved by spv.Variable.
|
2022-02-02 10:06:30 -08:00
|
|
|
auto storageClass = this->storageClass();
|
2020-07-29 08:47:22 +03:00
|
|
|
if (storageClass == spirv::StorageClass::Generic ||
|
|
|
|
storageClass == spirv::StorageClass::Function) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("storage class cannot be '")
|
2020-07-29 08:47:22 +03:00
|
|
|
<< stringifyStorageClass(storageClass) << "'";
|
|
|
|
}
|
2019-08-17 10:19:48 -07:00
|
|
|
|
2019-11-11 18:18:02 -08:00
|
|
|
if (auto init =
|
2022-02-02 10:06:30 -08:00
|
|
|
(*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
|
2020-01-25 09:16:29 -05:00
|
|
|
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
|
2022-02-02 10:06:30 -08:00
|
|
|
(*this)->getParentOp(), init.getAttr());
|
2019-08-20 13:33:41 -07:00
|
|
|
// TODO: Currently only variable initialization with specialization
|
|
|
|
// constants and other variables is supported. They could be normal
|
|
|
|
// constants in the module scope as well.
|
2020-06-29 07:31:48 -07:00
|
|
|
if (!initOp ||
|
|
|
|
!isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("initializer must be result of a "
|
|
|
|
"spv.SpecConstant or spv.GlobalVariable op");
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-08-10 09:39:27 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupBroadcast
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::GroupBroadcastOp::verify() {
|
|
|
|
spirv::Scope scope = execution_scope();
|
2020-08-10 09:39:27 -07:00
|
|
|
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
|
2020-08-10 09:39:27 -07:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (auto localIdTy = localid().getType().dyn_cast<VectorType>())
|
2022-06-12 17:56:43 +00:00
|
|
|
if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("localid is a vector and can be with only "
|
|
|
|
" 2 or 3 components, actual number is ")
|
2020-08-10 09:39:27 -07:00
|
|
|
<< localIdTy.getNumElements();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-03 16:43:40 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformBallotOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::GroupNonUniformBallotOp::verify() {
|
|
|
|
spirv::Scope scope = execution_scope();
|
2019-12-03 16:43:40 -08:00
|
|
|
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
|
2019-12-03 16:43:40 -08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-09-16 22:53:52 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformBroadcast
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::GroupNonUniformBroadcastOp::verify() {
|
|
|
|
spirv::Scope scope = execution_scope();
|
2020-09-16 22:53:52 -07:00
|
|
|
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
|
2020-09-16 22:53:52 -07:00
|
|
|
|
|
|
|
// SPIR-V spec: "Before version 1.5, Id must come from a
|
|
|
|
// constant instruction.
|
2022-02-02 10:06:30 -08:00
|
|
|
auto targetEnv = spirv::getDefaultTargetEnv(getContext());
|
|
|
|
if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
|
2020-09-16 22:53:52 -07:00
|
|
|
targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
|
|
|
|
|
|
|
|
if (targetEnv.getVersion() < spirv::Version::V_1_5) {
|
2022-02-02 10:06:30 -08:00
|
|
|
auto *idOp = id().getDefiningOp();
|
2020-09-16 22:53:52 -07:00
|
|
|
if (!idOp || !isa<spirv::ConstantOp, // for normal constant
|
|
|
|
spirv::ReferenceOfOp>(idOp)) // for spec constant
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("id must be the result of a constant op");
|
2020-09-16 22:53:52 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-09-02 19:52:29 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.SubgroupBlockReadINTEL
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2020-09-02 19:52:29 -07:00
|
|
|
// Parse the storage class specification
|
|
|
|
spirv::StorageClass storageClass;
|
2022-03-21 21:42:13 +01:00
|
|
|
OpAsmParser::UnresolvedOperand ptrInfo;
|
2020-09-02 19:52:29 -07:00
|
|
|
Type elementType;
|
|
|
|
if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
|
|
|
|
parser.parseColon() || parser.parseType(elementType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto ptrType = spirv::PointerType::get(elementType, storageClass);
|
|
|
|
if (auto valVecTy = elementType.dyn_cast<VectorType>())
|
|
|
|
ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
|
|
|
|
|
|
|
|
if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
state.addTypes(elementType);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) {
|
|
|
|
printer << " " << ptr() << " : " << getType();
|
2020-09-02 19:52:29 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
|
|
|
|
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
|
2020-09-02 19:52:29 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.SubgroupBlockWriteINTEL
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2020-09-02 19:52:29 -07:00
|
|
|
// Parse the storage class specification
|
|
|
|
spirv::StorageClass storageClass;
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
|
2020-09-02 19:52:29 -07:00
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
Type elementType;
|
|
|
|
if (parseEnumStrAttr(storageClass, parser) ||
|
|
|
|
parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
|
|
|
|
parser.parseType(elementType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto ptrType = spirv::PointerType::get(elementType, storageClass);
|
|
|
|
if (auto valVecTy = elementType.dyn_cast<VectorType>())
|
|
|
|
ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
|
|
|
|
|
|
|
|
if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
|
|
|
|
state.operands)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) {
|
|
|
|
printer << " " << ptr() << ", " << value() << " : " << value().getType();
|
2020-09-02 19:52:29 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() {
|
|
|
|
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
|
2020-09-02 19:52:29 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-01-26 10:19:24 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformElectOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::GroupNonUniformElectOp::verify() {
|
|
|
|
spirv::Scope scope = execution_scope();
|
2020-01-26 10:19:24 -05:00
|
|
|
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
|
2020-01-26 10:19:24 -05:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformFAddOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::GroupNonUniformFAddOp::verify() {
|
|
|
|
return verifyGroupNonUniformArithmeticOp(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return parseGroupNonUniformArithmeticOp(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
|
|
|
|
printGroupNonUniformArithmeticOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformFMaxOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::GroupNonUniformFMaxOp::verify() {
|
|
|
|
return verifyGroupNonUniformArithmeticOp(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return parseGroupNonUniformArithmeticOp(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
|
|
|
|
printGroupNonUniformArithmeticOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformFMinOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::GroupNonUniformFMinOp::verify() {
|
|
|
|
return verifyGroupNonUniformArithmeticOp(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return parseGroupNonUniformArithmeticOp(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
|
|
|
|
printGroupNonUniformArithmeticOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformFMulOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::GroupNonUniformFMulOp::verify() {
|
|
|
|
return verifyGroupNonUniformArithmeticOp(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return parseGroupNonUniformArithmeticOp(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
|
|
|
|
printGroupNonUniformArithmeticOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformIAddOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::GroupNonUniformIAddOp::verify() {
|
|
|
|
return verifyGroupNonUniformArithmeticOp(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return parseGroupNonUniformArithmeticOp(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
|
|
|
|
printGroupNonUniformArithmeticOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformIMulOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::GroupNonUniformIMulOp::verify() {
|
|
|
|
return verifyGroupNonUniformArithmeticOp(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return parseGroupNonUniformArithmeticOp(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
|
|
|
|
printGroupNonUniformArithmeticOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformSMaxOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::GroupNonUniformSMaxOp::verify() {
|
|
|
|
return verifyGroupNonUniformArithmeticOp(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return parseGroupNonUniformArithmeticOp(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
|
|
|
|
printGroupNonUniformArithmeticOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformSMinOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::GroupNonUniformSMinOp::verify() {
|
|
|
|
return verifyGroupNonUniformArithmeticOp(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return parseGroupNonUniformArithmeticOp(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
|
|
|
|
printGroupNonUniformArithmeticOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformUMaxOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::GroupNonUniformUMaxOp::verify() {
|
|
|
|
return verifyGroupNonUniformArithmeticOp(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return parseGroupNonUniformArithmeticOp(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
|
|
|
|
printGroupNonUniformArithmeticOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.GroupNonUniformUMinOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::GroupNonUniformUMinOp::verify() {
|
|
|
|
return verifyGroupNonUniformArithmeticOp(*this);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
return parseGroupNonUniformArithmeticOp(parser, result);
|
|
|
|
}
|
|
|
|
void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
|
|
|
|
printGroupNonUniformArithmeticOp(*this, p);
|
|
|
|
}
|
|
|
|
|
2022-06-15 20:38:47 -04:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ISubBorrowOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::ISubBorrowOp::verify() {
|
|
|
|
auto resultType = getType().cast<spirv::StructType>();
|
|
|
|
if (resultType.getNumElements() != 2)
|
|
|
|
return emitOpError("expected result struct type containing two members");
|
|
|
|
|
|
|
|
SmallVector<Type, 4> types;
|
|
|
|
types.push_back(operand1().getType());
|
|
|
|
types.push_back(operand2().getType());
|
|
|
|
types.push_back(resultType.getElementType(0));
|
|
|
|
types.push_back(resultType.getElementType(1));
|
|
|
|
if (!llvm::is_splat(types))
|
|
|
|
return emitOpError(
|
|
|
|
"expected all operand types and struct member types are the same");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
|
|
|
|
if (parser.parseOptionalAttrDict(state.attributes) ||
|
|
|
|
parser.parseOperandList(operands) || parser.parseColon())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
Type resultType;
|
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto structType = resultType.dyn_cast<spirv::StructType>();
|
|
|
|
if (!structType || structType.getNumElements() != 2)
|
|
|
|
return parser.emitError(loc, "expected spv.struct type with two members");
|
|
|
|
|
|
|
|
SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
|
|
|
|
if (parser.resolveOperands(operands, operandTypes, loc, state.operands))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
state.addTypes(resultType);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
|
|
|
|
printer << ' ';
|
|
|
|
printer.printOptionalAttrDict((*this)->getAttrs());
|
|
|
|
printer.printOperands((*this)->getOperands());
|
|
|
|
printer << " : " << getType();
|
|
|
|
}
|
|
|
|
|
2019-06-24 10:59:05 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.LoadOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-04-23 16:02:46 +02:00
|
|
|
void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
|
2021-02-27 15:21:00 +03:00
|
|
|
Value basePtr, MemoryAccessAttr memoryAccess,
|
2019-10-04 14:02:14 -07:00
|
|
|
IntegerAttr alignment) {
|
2020-01-11 08:54:04 -08:00
|
|
|
auto ptrType = basePtr.getType().cast<spirv::PointerType>();
|
2021-02-27 15:21:00 +03:00
|
|
|
build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
|
2019-10-04 14:02:14 -07:00
|
|
|
alignment);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &state) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// Parse the storage class specification
|
|
|
|
spirv::StorageClass storageClass;
|
2022-03-21 21:42:13 +01:00
|
|
|
OpAsmParser::UnresolvedOperand ptrInfo;
|
2019-06-24 10:59:05 -07:00
|
|
|
Type elementType;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
|
2019-07-08 10:56:20 -07:00
|
|
|
parseMemoryAccessAttributes(parser, state) ||
|
2019-11-05 13:32:07 -08:00
|
|
|
parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
|
|
|
|
parser.parseType(elementType)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto ptrType = spirv::PointerType::get(elementType, storageClass);
|
2019-09-20 19:47:05 -07:00
|
|
|
if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addTypes(elementType);
|
2019-06-24 10:59:05 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::LoadOp::print(OpAsmPrinter &printer) {
|
2019-06-24 10:59:05 -07:00
|
|
|
SmallVector<StringRef, 4> elidedAttrs;
|
2019-07-02 06:02:20 -07:00
|
|
|
StringRef sc = stringifyStorageClass(
|
2022-02-07 17:54:04 -08:00
|
|
|
ptr().getType().cast<spirv::PointerType>().getStorageClass());
|
|
|
|
printer << " \"" << sc << "\" " << ptr();
|
2019-06-24 10:59:05 -07:00
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
printMemoryAccessAttribute(*this, printer, elidedAttrs);
|
2019-06-24 10:59:05 -07:00
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
|
|
|
|
printer << " : " << getType();
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::LoadOp::verify() {
|
2019-06-24 10:59:05 -07:00
|
|
|
// SPIR-V spec : "Result Type is the type of the loaded object. It must be a
|
|
|
|
// type with fixed size; i.e., it cannot be, nor include, any
|
|
|
|
// OpTypeRuntimeArray types."
|
2022-02-02 10:06:30 -08:00
|
|
|
if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
2022-02-02 10:06:30 -08:00
|
|
|
return verifyMemoryAccessAttribute(*this);
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
|
|
|
|
2019-09-05 12:45:08 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-03-05 15:35:35 -05:00
|
|
|
// spv.mlir.loop
|
2019-09-05 12:45:08 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-04-23 16:02:46 +02:00
|
|
|
void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
|
2019-11-12 11:59:34 -08:00
|
|
|
state.addAttribute("loop_control",
|
2020-04-23 16:02:46 +02:00
|
|
|
builder.getI32IntegerAttr(
|
2019-11-12 11:59:34 -08:00
|
|
|
static_cast<uint32_t>(spirv::LoopControl::None)));
|
|
|
|
state.addRegion();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &state) {
|
2020-08-03 09:31:08 +03:00
|
|
|
if (parseControlAttribute<spirv::LoopControl>(parser, state))
|
|
|
|
return failure();
|
2019-09-20 19:47:05 -07:00
|
|
|
return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
|
2019-09-20 11:36:49 -07:00
|
|
|
/*argTypes=*/{});
|
2019-09-05 12:45:08 -07:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::LoopOp::print(OpAsmPrinter &printer) {
|
|
|
|
auto control = loop_control();
|
2020-08-03 09:31:08 +03:00
|
|
|
if (control != spirv::LoopControl::None)
|
|
|
|
printer << " control(" << spirv::stringifyLoopControl(control) << ")";
|
2022-01-18 07:47:25 +00:00
|
|
|
printer << ' ';
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
|
2019-09-20 20:43:02 -07:00
|
|
|
/*printBlockTerminators=*/true);
|
2019-09-05 12:45:08 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
|
|
|
|
/// given `dstBlock`.
|
|
|
|
static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
|
|
|
|
// Check that there is only one op in the `srcBlock`.
|
2020-04-14 14:53:07 -07:00
|
|
|
if (!llvm::hasSingleElement(srcBlock))
|
2019-09-05 12:45:08 -07:00
|
|
|
return false;
|
|
|
|
|
|
|
|
auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
|
2020-03-05 12:39:46 -08:00
|
|
|
return branchOp && branchOp.getSuccessor() == &dstBlock;
|
2019-09-05 12:45:08 -07:00
|
|
|
}
|
|
|
|
|
2022-03-10 22:10:45 +00:00
|
|
|
LogicalResult spirv::LoopOp::verifyRegions() {
|
2022-02-02 10:06:30 -08:00
|
|
|
auto *op = getOperation();
|
2019-09-05 12:45:08 -07:00
|
|
|
|
|
|
|
// We need to verify that the blocks follow the following layout:
|
|
|
|
//
|
|
|
|
// +-------------+
|
|
|
|
// | entry block |
|
|
|
|
// +-------------+
|
|
|
|
// |
|
|
|
|
// v
|
|
|
|
// +-------------+
|
|
|
|
// | loop header | <-----+
|
|
|
|
// +-------------+ |
|
|
|
|
// |
|
|
|
|
// ... |
|
|
|
|
// \ | / |
|
|
|
|
// v |
|
|
|
|
// +---------------+ |
|
|
|
|
// | loop continue | -----+
|
|
|
|
// +---------------+
|
|
|
|
//
|
|
|
|
// ...
|
|
|
|
// \ | /
|
|
|
|
// v
|
|
|
|
// +-------------+
|
|
|
|
// | merge block |
|
|
|
|
// +-------------+
|
|
|
|
|
|
|
|
auto ®ion = op->getRegion(0);
|
|
|
|
// Allow empty region as a degenerated case, which can come from
|
|
|
|
// optimizations.
|
|
|
|
if (region.empty())
|
|
|
|
return success();
|
|
|
|
|
|
|
|
// The last block is the merge block.
|
|
|
|
Block &merge = region.back();
|
|
|
|
if (!isMergeBlock(merge))
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2020-11-19 09:48:58 -05:00
|
|
|
"last block must be the merge block with only one 'spv.mlir.merge' op");
|
2019-09-05 12:45:08 -07:00
|
|
|
|
|
|
|
if (std::next(region.begin()) == region.end())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2019-09-05 12:45:08 -07:00
|
|
|
"must have an entry block branching to the loop header block");
|
|
|
|
// The first block is the entry block.
|
|
|
|
Block &entry = region.front();
|
|
|
|
|
|
|
|
if (std::next(region.begin(), 2) == region.end())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2019-09-05 12:45:08 -07:00
|
|
|
"must have a loop header block branched from the entry block");
|
|
|
|
// The second block is the loop header block.
|
|
|
|
Block &header = *std::next(region.begin(), 1);
|
|
|
|
|
|
|
|
if (!hasOneBranchOpTo(entry, header))
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2019-09-05 12:45:08 -07:00
|
|
|
"entry block must only have one 'spv.Branch' op to the second block");
|
|
|
|
|
|
|
|
if (std::next(region.begin(), 3) == region.end())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2019-09-05 12:45:08 -07:00
|
|
|
"requires a loop continue block branching to the loop header block");
|
|
|
|
// The second to last block is the loop continue block.
|
|
|
|
Block &cont = *std::prev(region.end(), 2);
|
|
|
|
|
|
|
|
// Make sure that we have a branch from the loop continue block to the loop
|
|
|
|
// header block.
|
|
|
|
if (llvm::none_of(
|
|
|
|
llvm::seq<unsigned>(0, cont.getNumSuccessors()),
|
|
|
|
[&](unsigned index) { return cont.getSuccessor(index) == &header; }))
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("second to last block must be the loop continue "
|
|
|
|
"block that branches to the loop header block");
|
2019-09-05 12:45:08 -07:00
|
|
|
|
|
|
|
// Make sure that no other blocks (except the entry and loop continue block)
|
|
|
|
// branches to the loop header block.
|
|
|
|
for (auto &block : llvm::make_range(std::next(region.begin(), 2),
|
|
|
|
std::prev(region.end(), 2))) {
|
|
|
|
for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
|
|
|
|
if (block.getSuccessor(i) == &header) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("can only have the entry and loop continue "
|
|
|
|
"block branching to the loop header block");
|
2019-09-05 12:45:08 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-11-12 11:59:34 -08:00
|
|
|
Block *spirv::LoopOp::getEntryBlock() {
|
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
|
|
|
return &body().front();
|
|
|
|
}
|
|
|
|
|
2019-09-11 14:02:23 -07:00
|
|
|
Block *spirv::LoopOp::getHeaderBlock() {
|
2019-10-02 11:00:50 -07:00
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
2019-09-11 14:02:23 -07:00
|
|
|
// The second block is the loop header block.
|
|
|
|
return &*std::next(body().begin());
|
|
|
|
}
|
|
|
|
|
|
|
|
Block *spirv::LoopOp::getContinueBlock() {
|
2019-10-02 11:00:50 -07:00
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
2019-09-11 14:02:23 -07:00
|
|
|
// The second to last block is the loop continue block.
|
|
|
|
return &*std::prev(body().end(), 2);
|
|
|
|
}
|
|
|
|
|
|
|
|
Block *spirv::LoopOp::getMergeBlock() {
|
2019-10-02 11:00:50 -07:00
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
2019-09-11 14:02:23 -07:00
|
|
|
// The last block is the loop merge block.
|
|
|
|
return &body().back();
|
|
|
|
}
|
|
|
|
|
|
|
|
void spirv::LoopOp::addEntryAndMergeBlock() {
|
|
|
|
assert(body().empty() && "entry and merge block already exist");
|
|
|
|
body().push_back(new Block());
|
|
|
|
auto *mergeBlock = new Block();
|
|
|
|
body().push_back(mergeBlock);
|
2020-03-30 16:52:59 +02:00
|
|
|
OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
|
2019-09-11 14:02:23 -07:00
|
|
|
|
2020-11-19 09:48:58 -05:00
|
|
|
// Add a spv.mlir.merge op into the merge block.
|
2019-10-02 11:00:50 -07:00
|
|
|
builder.create<spirv::MergeOp>(getLoc());
|
2019-09-11 14:02:23 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.MemoryBarrierOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::MemoryBarrierOp::verify() {
|
|
|
|
return verifyMemorySemantics(getOperation(), memory_semantics());
|
|
|
|
}
|
|
|
|
|
2019-09-05 12:45:08 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-11-19 09:48:58 -05:00
|
|
|
// spv.mlir.merge
|
2019-09-05 12:45:08 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::MergeOp::verify() {
|
|
|
|
auto *parentOp = (*this)->getParentOp();
|
2020-06-29 07:31:48 -07:00
|
|
|
if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2021-03-06 15:40:14 +01:00
|
|
|
"expected parent op to be 'spv.mlir.selection' or 'spv.mlir.loop'");
|
2019-10-02 11:00:50 -07:00
|
|
|
|
2022-03-10 22:10:45 +00:00
|
|
|
// TODO: This check should be done in `verifyRegions` of parent op.
|
2022-02-02 10:06:30 -08:00
|
|
|
Block &parentLastBlock = (*this)->getParentRegion()->back();
|
|
|
|
if (getOperation() != parentLastBlock.getTerminator())
|
|
|
|
return emitOpError("can only be used in the last block of "
|
|
|
|
"'spv.mlir.selection' or 'spv.mlir.loop'");
|
2019-09-05 12:45:08 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-05-29 10:47:16 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.module
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-08-27 07:10:14 +03:00
|
|
|
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
|
|
|
|
Optional<StringRef> name) {
|
2021-06-09 13:58:13 -04:00
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
|
|
builder.createBlock(state.addRegion());
|
2020-08-27 07:10:14 +03:00
|
|
|
if (name) {
|
|
|
|
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
|
|
|
|
builder.getStringAttr(*name));
|
|
|
|
}
|
2019-06-21 14:51:58 -07:00
|
|
|
}
|
|
|
|
|
2020-04-23 16:02:46 +02:00
|
|
|
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
|
2020-08-27 07:10:14 +03:00
|
|
|
spirv::AddressingModel addressingModel,
|
|
|
|
spirv::MemoryModel memoryModel,
|
2021-07-28 10:30:54 -04:00
|
|
|
Optional<VerCapExtAttr> vceTriple,
|
2020-08-27 07:10:14 +03:00
|
|
|
Optional<StringRef> name) {
|
2019-12-09 09:51:25 -08:00
|
|
|
state.addAttribute(
|
|
|
|
"addressing_model",
|
2020-08-27 07:10:14 +03:00
|
|
|
builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
|
2020-04-23 16:02:46 +02:00
|
|
|
state.addAttribute("memory_model", builder.getI32IntegerAttr(
|
2020-08-27 07:10:14 +03:00
|
|
|
static_cast<int32_t>(memoryModel)));
|
2021-06-09 13:58:13 -04:00
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
|
|
builder.createBlock(state.addRegion());
|
2021-07-28 10:30:54 -04:00
|
|
|
if (vceTriple)
|
|
|
|
state.addAttribute(getVCETripleAttrName(), *vceTriple);
|
|
|
|
if (name)
|
|
|
|
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
|
|
|
|
builder.getStringAttr(*name));
|
2019-07-30 11:29:48 -07:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::ModuleOp::parse(OpAsmParser &parser, OperationState &state) {
|
2019-09-20 19:47:05 -07:00
|
|
|
Region *body = state.addRegion();
|
2019-05-29 10:47:16 -07:00
|
|
|
|
2020-08-27 07:10:14 +03:00
|
|
|
// If the name is present, parse it.
|
2022-05-13 23:29:21 +01:00
|
|
|
StringAttr nameAttr;
|
2022-05-13 15:38:50 +01:00
|
|
|
(void)parser.parseOptionalSymbolName(
|
2022-02-07 17:54:04 -08:00
|
|
|
nameAttr, mlir::SymbolTable::getSymbolAttrName(), state.attributes);
|
2020-08-27 07:10:14 +03:00
|
|
|
|
2019-07-03 18:12:52 -07:00
|
|
|
// Parse attributes
|
2022-05-13 23:29:21 +01:00
|
|
|
spirv::AddressingModel addrModel;
|
|
|
|
spirv::MemoryModel memoryModel;
|
2022-02-07 17:54:04 -08:00
|
|
|
if (::parseEnumKeywordAttr(addrModel, parser, state) ||
|
|
|
|
::parseEnumKeywordAttr(memoryModel, parser, state))
|
2019-07-02 06:02:20 -07:00
|
|
|
return failure();
|
2020-03-11 16:04:25 -04:00
|
|
|
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("requires"))) {
|
|
|
|
spirv::VerCapExtAttr vceTriple;
|
|
|
|
if (parser.parseAttribute(vceTriple,
|
|
|
|
spirv::ModuleOp::getVCETripleAttrName(),
|
|
|
|
state.attributes))
|
|
|
|
return failure();
|
2019-07-02 06:02:20 -07:00
|
|
|
}
|
|
|
|
|
2022-05-13 15:38:50 +01:00
|
|
|
if (parser.parseOptionalAttrDictWithKeyword(state.attributes) ||
|
|
|
|
parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
|
2019-11-05 17:58:16 -08:00
|
|
|
return failure();
|
2019-05-29 10:47:16 -07:00
|
|
|
|
2021-06-09 13:58:13 -04:00
|
|
|
// Make sure we have at least one block.
|
|
|
|
if (body->empty())
|
|
|
|
body->push_back(new Block());
|
|
|
|
|
2019-05-29 10:47:16 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::ModuleOp::print(OpAsmPrinter &printer) {
|
|
|
|
if (Optional<StringRef> name = getName()) {
|
2020-08-27 07:10:14 +03:00
|
|
|
printer << ' ';
|
|
|
|
printer.printSymbolName(*name);
|
|
|
|
}
|
|
|
|
|
2019-07-02 06:02:20 -07:00
|
|
|
SmallVector<StringRef, 2> elidedAttrs;
|
2020-03-11 16:04:25 -04:00
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
printer << " " << spirv::stringifyAddressingModel(addressing_model()) << " "
|
|
|
|
<< spirv::stringifyMemoryModel(memory_model());
|
2019-07-03 18:12:52 -07:00
|
|
|
auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
|
|
|
|
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
|
2020-08-27 07:10:14 +03:00
|
|
|
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
|
2022-02-07 17:54:04 -08:00
|
|
|
mlir::SymbolTable::getSymbolAttrName()});
|
2020-03-11 16:04:25 -04:00
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
if (Optional<spirv::VerCapExtAttr> triple = vce_triple()) {
|
2020-03-11 16:04:25 -04:00
|
|
|
printer << " requires " << *triple;
|
|
|
|
elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
|
2019-07-02 06:02:20 -07:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
|
2022-01-18 07:47:25 +00:00
|
|
|
printer << ' ';
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printRegion(getRegion());
|
2019-05-29 10:47:16 -07:00
|
|
|
}
|
|
|
|
|
2022-03-10 22:10:45 +00:00
|
|
|
LogicalResult spirv::ModuleOp::verifyRegions() {
|
2022-02-02 10:06:30 -08:00
|
|
|
Dialect *dialect = (*this)->getDialect();
|
2020-02-07 11:30:19 -05:00
|
|
|
DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
|
2019-07-08 10:56:20 -07:00
|
|
|
entryPoints;
|
2022-02-02 10:06:30 -08:00
|
|
|
mlir::SymbolTable table(*this);
|
2019-06-04 14:03:30 -07:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
for (auto &op : *getBody()) {
|
2020-02-07 11:30:19 -05:00
|
|
|
if (op.getDialect() != dialect)
|
|
|
|
return op.emitError("'spv.module' can only contain spv.* ops");
|
|
|
|
|
|
|
|
// For EntryPoint op, check that the function and execution model is not
|
|
|
|
// duplicated in EntryPointOps. Also verify that the interface specified
|
|
|
|
// comes from globalVariables here to make this check cheaper.
|
|
|
|
if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
|
|
|
|
auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
|
|
|
|
if (!funcOp) {
|
|
|
|
return entryPointOp.emitError("function '")
|
|
|
|
<< entryPointOp.fn() << "' not found in 'spv.module'";
|
|
|
|
}
|
|
|
|
if (auto interface = entryPointOp.interface()) {
|
|
|
|
for (Attribute varRef : interface) {
|
|
|
|
auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
|
|
|
|
if (!varSymRef) {
|
|
|
|
return entryPointOp.emitError(
|
|
|
|
"expected symbol reference for interface "
|
|
|
|
"specification instead of '")
|
|
|
|
<< varRef;
|
|
|
|
}
|
|
|
|
auto variableOp =
|
|
|
|
table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
|
|
|
|
if (!variableOp) {
|
2021-03-04 16:17:12 -05:00
|
|
|
return entryPointOp.emitError("expected spv.GlobalVariable "
|
2020-02-07 11:30:19 -05:00
|
|
|
"symbol reference instead of'")
|
|
|
|
<< varSymRef << "'";
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
|
|
|
}
|
2019-07-08 10:56:20 -07:00
|
|
|
}
|
2019-06-04 14:03:30 -07:00
|
|
|
|
2020-02-07 11:30:19 -05:00
|
|
|
auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
|
|
|
|
funcOp, entryPointOp.execution_model());
|
|
|
|
auto entryPtIt = entryPoints.find(key);
|
|
|
|
if (entryPtIt != entryPoints.end()) {
|
|
|
|
return entryPointOp.emitError("duplicate of a previous EntryPointOp");
|
2019-06-04 14:03:30 -07:00
|
|
|
}
|
2020-02-07 11:30:19 -05:00
|
|
|
entryPoints[key] = entryPointOp;
|
|
|
|
} else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
|
|
|
|
if (funcOp.isExternal())
|
|
|
|
return op.emitError("'spv.module' cannot contain external functions");
|
|
|
|
|
2020-07-07 01:35:23 -07:00
|
|
|
// TODO: move this check to spv.func.
|
2020-02-07 11:30:19 -05:00
|
|
|
for (auto &block : funcOp)
|
|
|
|
for (auto &op : block) {
|
|
|
|
if (op.getDialect() != dialect)
|
|
|
|
return op.emitError(
|
|
|
|
"functions in 'spv.module' can only contain spv.* ops");
|
|
|
|
}
|
|
|
|
}
|
2019-06-04 14:03:30 -07:00
|
|
|
}
|
2019-08-22 11:15:05 -07:00
|
|
|
|
2019-06-04 14:03:30 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-08-20 13:33:41 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-11-18 12:43:06 -05:00
|
|
|
// spv.mlir.referenceof
|
2019-08-20 13:33:41 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::ReferenceOfOp::verify() {
|
2020-10-05 16:39:39 -04:00
|
|
|
auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
|
2022-02-02 10:06:30 -08:00
|
|
|
(*this)->getParentOp(), spec_constAttr());
|
2020-10-05 16:39:39 -04:00
|
|
|
Type constType;
|
|
|
|
|
|
|
|
auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
|
|
|
|
if (specConstOp)
|
|
|
|
constType = specConstOp.default_value().getType();
|
|
|
|
|
|
|
|
auto specConstCompositeOp =
|
|
|
|
dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
|
|
|
|
if (specConstCompositeOp)
|
|
|
|
constType = specConstCompositeOp.type();
|
|
|
|
|
|
|
|
if (!specConstOp && !specConstCompositeOp)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2021-03-04 16:01:28 -05:00
|
|
|
"expected spv.SpecConstant or spv.SpecConstantComposite symbol");
|
2020-10-05 16:39:39 -04:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (reference().getType() != constType)
|
|
|
|
return emitOpError("result type mismatch with the referenced "
|
|
|
|
"specialization constant's type");
|
2020-10-05 16:39:39 -04:00
|
|
|
|
2019-08-20 13:33:41 -07:00
|
|
|
return success();
|
|
|
|
}
|
2019-08-30 12:17:21 -07:00
|
|
|
|
2019-06-04 14:03:30 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.Return
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::ReturnOp::verify() {
|
2020-02-07 11:30:19 -05:00
|
|
|
// Verification is performed in spv.func op.
|
2019-06-04 14:03:30 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-08-19 10:57:43 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ReturnValue
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::ReturnValueOp::verify() {
|
2020-02-07 11:30:19 -05:00
|
|
|
// Verification is performed in spv.func op.
|
2019-08-19 10:57:43 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-09-02 21:06:35 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.Select
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::SelectOp::verify() {
|
|
|
|
if (auto conditionTy = condition().getType().dyn_cast<VectorType>()) {
|
|
|
|
auto resultVectorTy = result().getType().dyn_cast<VectorType>();
|
2019-09-02 21:06:35 -07:00
|
|
|
if (!resultVectorTy) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("result expected to be of vector type when "
|
|
|
|
"condition is of vector type");
|
2019-09-02 21:06:35 -07:00
|
|
|
}
|
|
|
|
if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("result should have the same number of elements as "
|
|
|
|
"the condition when condition is of vector type");
|
2019-09-02 21:06:35 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-10-02 11:00:50 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-03-06 15:40:14 +01:00
|
|
|
// spv.mlir.selection
|
2019-10-02 11:00:50 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2020-08-03 09:31:08 +03:00
|
|
|
if (parseControlAttribute<spirv::SelectionControl>(parser, state))
|
|
|
|
return failure();
|
2019-10-02 11:00:50 -07:00
|
|
|
return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
|
|
|
|
/*argTypes=*/{});
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::SelectionOp::print(OpAsmPrinter &printer) {
|
|
|
|
auto control = selection_control();
|
2020-08-03 09:31:08 +03:00
|
|
|
if (control != spirv::SelectionControl::None)
|
|
|
|
printer << " control(" << spirv::stringifySelectionControl(control) << ")";
|
2022-01-18 07:47:25 +00:00
|
|
|
printer << ' ';
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
|
2019-10-02 11:00:50 -07:00
|
|
|
/*printBlockTerminators=*/true);
|
|
|
|
}
|
|
|
|
|
2022-03-10 22:10:45 +00:00
|
|
|
LogicalResult spirv::SelectionOp::verifyRegions() {
|
2022-02-02 10:06:30 -08:00
|
|
|
auto *op = getOperation();
|
2019-10-02 11:00:50 -07:00
|
|
|
|
|
|
|
// We need to verify that the blocks follow the following layout:
|
|
|
|
//
|
|
|
|
// +--------------+
|
|
|
|
// | header block |
|
|
|
|
// +--------------+
|
|
|
|
// / | \
|
|
|
|
// ...
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// +---------+ +---------+ +---------+
|
|
|
|
// | case #0 | | case #1 | | case #2 | ...
|
|
|
|
// +---------+ +---------+ +---------+
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// ...
|
|
|
|
// \ | /
|
|
|
|
// v
|
|
|
|
// +-------------+
|
|
|
|
// | merge block |
|
|
|
|
// +-------------+
|
|
|
|
|
|
|
|
auto ®ion = op->getRegion(0);
|
|
|
|
// Allow empty region as a degenerated case, which can come from
|
|
|
|
// optimizations.
|
|
|
|
if (region.empty())
|
|
|
|
return success();
|
|
|
|
|
|
|
|
// The last block is the merge block.
|
|
|
|
if (!isMergeBlock(region.back()))
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2020-11-19 09:48:58 -05:00
|
|
|
"last block must be the merge block with only one 'spv.mlir.merge' op");
|
2019-10-02 11:00:50 -07:00
|
|
|
|
|
|
|
if (std::next(region.begin()) == region.end())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("must have a selection header block");
|
2019-10-02 11:00:50 -07:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
Block *spirv::SelectionOp::getHeaderBlock() {
|
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
|
|
|
// The first block is the loop header block.
|
|
|
|
return &body().front();
|
|
|
|
}
|
|
|
|
|
|
|
|
Block *spirv::SelectionOp::getMergeBlock() {
|
|
|
|
assert(!body().empty() && "op region should not be empty!");
|
|
|
|
// The last block is the loop merge block.
|
|
|
|
return &body().back();
|
|
|
|
}
|
|
|
|
|
|
|
|
void spirv::SelectionOp::addMergeBlock() {
|
|
|
|
assert(body().empty() && "entry and merge block already exist");
|
|
|
|
auto *mergeBlock = new Block();
|
|
|
|
body().push_back(mergeBlock);
|
2020-03-30 16:52:59 +02:00
|
|
|
OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
|
2019-10-02 11:00:50 -07:00
|
|
|
|
2020-11-19 09:48:58 -05:00
|
|
|
// Add a spv.mlir.merge op into the merge block.
|
2019-10-02 11:00:50 -07:00
|
|
|
builder.create<spirv::MergeOp>(getLoc());
|
|
|
|
}
|
|
|
|
|
2020-01-26 11:10:29 -05:00
|
|
|
spirv::SelectionOp spirv::SelectionOp::createIfThen(
|
|
|
|
Location loc, Value condition,
|
2020-04-23 16:02:46 +02:00
|
|
|
function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
|
2021-02-27 15:21:00 +03:00
|
|
|
auto selectionOp =
|
|
|
|
builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
|
2020-01-26 11:10:29 -05:00
|
|
|
|
|
|
|
selectionOp.addMergeBlock();
|
|
|
|
Block *mergeBlock = selectionOp.getMergeBlock();
|
|
|
|
Block *thenBlock = nullptr;
|
|
|
|
|
|
|
|
// Build the "then" block.
|
|
|
|
{
|
2020-04-23 16:02:46 +02:00
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
|
|
thenBlock = builder.createBlock(mergeBlock);
|
2020-01-26 11:10:29 -05:00
|
|
|
thenBody(builder);
|
2020-04-23 16:02:46 +02:00
|
|
|
builder.create<spirv::BranchOp>(loc, mergeBlock);
|
2020-01-26 11:10:29 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
// Build the header block.
|
|
|
|
{
|
2020-04-23 16:02:46 +02:00
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
|
|
builder.createBlock(thenBlock);
|
|
|
|
builder.create<spirv::BranchConditionalOp>(
|
2020-01-26 11:10:29 -05:00
|
|
|
loc, condition, thenBlock,
|
|
|
|
/*trueArguments=*/ArrayRef<Value>(), mergeBlock,
|
|
|
|
/*falseArguments=*/ArrayRef<Value>());
|
|
|
|
}
|
|
|
|
|
|
|
|
return selectionOp;
|
|
|
|
}
|
|
|
|
|
2019-08-20 13:33:41 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-03-04 16:01:28 -05:00
|
|
|
// spv.SpecConstant
|
2019-08-20 13:33:41 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2019-08-20 13:33:41 -07:00
|
|
|
StringAttr nameAttr;
|
|
|
|
Attribute valueAttr;
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
2019-10-15 14:53:01 -07:00
|
|
|
state.attributes))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Parse optional spec_id.
|
|
|
|
if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
|
|
|
|
IntegerAttr specIdAttr;
|
|
|
|
if (parser.parseLParen() ||
|
|
|
|
parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
|
|
|
|
parser.parseRParen())
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (parser.parseEqual() ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
|
2019-08-20 13:33:41 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
|
2021-08-28 03:03:15 +00:00
|
|
|
printer << ' ';
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printSymbolName(sym_name());
|
|
|
|
if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
|
2019-10-15 14:53:01 -07:00
|
|
|
printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
|
2022-02-07 17:54:04 -08:00
|
|
|
printer << " = " << default_value();
|
2019-08-20 13:33:41 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::SpecConstantOp::verify() {
|
|
|
|
if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
|
2019-10-15 14:53:01 -07:00
|
|
|
if (specID.getValue().isNegative())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("SpecId cannot be negative");
|
2019-10-15 14:53:01 -07:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
auto value = default_value();
|
2020-08-07 13:30:29 -07:00
|
|
|
if (value.isa<IntegerAttr, FloatAttr>()) {
|
2019-08-20 13:33:41 -07:00
|
|
|
// Make sure bitwidth is allowed.
|
2020-03-18 09:55:53 -04:00
|
|
|
if (!value.getType().isa<spirv::SPIRVType>())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("default value bitwidth disallowed");
|
2019-08-20 13:33:41 -07:00
|
|
|
return success();
|
|
|
|
}
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2020-08-07 13:30:29 -07:00
|
|
|
"default value can only be a bool, integer, or float scalar");
|
2019-08-20 13:33:41 -07:00
|
|
|
}
|
|
|
|
|
2019-06-24 10:59:05 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.StoreOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &state) {
|
2019-06-24 10:59:05 -07:00
|
|
|
// Parse the storage class specification
|
|
|
|
spirv::StorageClass storageClass;
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
|
2019-09-20 11:36:49 -07:00
|
|
|
auto loc = parser.getCurrentLocation();
|
2019-06-24 10:59:05 -07:00
|
|
|
Type elementType;
|
2020-03-11 16:04:25 -04:00
|
|
|
if (parseEnumStrAttr(storageClass, parser) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseOperandList(operandInfo, 2) ||
|
|
|
|
parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
|
|
|
|
parser.parseType(elementType)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto ptrType = spirv::PointerType::get(elementType, storageClass);
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
|
2019-09-20 19:47:05 -07:00
|
|
|
state.operands)) {
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::StoreOp::print(OpAsmPrinter &printer) {
|
2019-06-24 10:59:05 -07:00
|
|
|
SmallVector<StringRef, 4> elidedAttrs;
|
2019-07-02 06:02:20 -07:00
|
|
|
StringRef sc = stringifyStorageClass(
|
2022-02-07 17:54:04 -08:00
|
|
|
ptr().getType().cast<spirv::PointerType>().getStorageClass());
|
|
|
|
printer << " \"" << sc << "\" " << ptr() << ", " << value();
|
2019-06-24 10:59:05 -07:00
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
printMemoryAccessAttribute(*this, printer, elidedAttrs);
|
2019-06-24 10:59:05 -07:00
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
printer << " : " << value().getType();
|
|
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::StoreOp::verify() {
|
2019-06-24 10:59:05 -07:00
|
|
|
// SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
|
|
|
|
// OpTypePointer whose Type operand is the same as the type of Object."
|
2022-02-02 10:06:30 -08:00
|
|
|
if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value())))
|
2019-06-24 10:59:05 -07:00
|
|
|
return failure();
|
2022-02-02 10:06:30 -08:00
|
|
|
return verifyMemoryAccessAttribute(*this);
|
2019-06-24 10:59:05 -07:00
|
|
|
}
|
|
|
|
|
2019-10-30 05:40:47 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.Unreachable
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::UnreachableOp::verify() {
|
|
|
|
auto *block = (*this)->getBlock();
|
2019-10-30 05:40:47 -07:00
|
|
|
// Fast track: if this is in entry block, its invalid. Otherwise, if no
|
|
|
|
// predecessors, it's valid.
|
|
|
|
if (block->isEntryBlock())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("cannot be used in reachable block");
|
2019-10-30 05:40:47 -07:00
|
|
|
if (block->hasNoPredecessors())
|
|
|
|
return success();
|
|
|
|
|
2020-07-07 01:35:23 -07:00
|
|
|
// TODO: further verification needs to analyze reachability from
|
2019-10-30 05:40:47 -07:00
|
|
|
// the entry block.
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-06-18 11:15:55 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.Variable
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2019-06-18 11:15:55 -07:00
|
|
|
// Parse optional initializer
|
2022-03-21 21:42:13 +01:00
|
|
|
Optional<OpAsmParser::UnresolvedOperand> initInfo;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (succeeded(parser.parseOptionalKeyword("init"))) {
|
2022-03-21 21:42:13 +01:00
|
|
|
initInfo = OpAsmParser::UnresolvedOperand();
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
|
|
|
|
parser.parseRParen())
|
2019-06-18 11:15:55 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-08-17 10:19:48 -07:00
|
|
|
if (parseVariableDecorations(parser, state)) {
|
2019-06-18 11:15:55 -07:00
|
|
|
return failure();
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
2019-06-18 11:15:55 -07:00
|
|
|
|
|
|
|
// Parse result pointer type
|
|
|
|
Type type;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseColon())
|
2019-06-18 11:15:55 -07:00
|
|
|
return failure();
|
2019-09-20 11:36:49 -07:00
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseType(type))
|
2019-06-18 11:15:55 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
|
|
|
if (!ptrType)
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(loc, "expected spv.ptr type");
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addTypes(ptrType);
|
2019-06-18 11:15:55 -07:00
|
|
|
|
|
|
|
// Resolve the initializer operand
|
|
|
|
if (initInfo) {
|
2019-12-03 04:49:20 -08:00
|
|
|
if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
|
|
|
|
state.operands))
|
2019-06-18 11:15:55 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2019-09-20 11:36:49 -07:00
|
|
|
auto attr = parser.getBuilder().getI32IntegerAttr(
|
2019-09-24 19:24:33 -07:00
|
|
|
llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
|
2019-09-20 19:47:05 -07:00
|
|
|
state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
|
2019-06-18 11:15:55 -07:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::VariableOp::print(OpAsmPrinter &printer) {
|
2019-07-03 18:12:52 -07:00
|
|
|
SmallVector<StringRef, 4> elidedAttrs{
|
|
|
|
spirv::attributeName<spirv::StorageClass>()};
|
2019-06-18 11:15:55 -07:00
|
|
|
// Print optional initializer
|
2022-02-07 17:54:04 -08:00
|
|
|
if (getNumOperands() != 0)
|
|
|
|
printer << " init(" << initializer() << ")";
|
2019-08-15 10:52:24 -07:00
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
printVariableDecorations(*this, printer, elidedAttrs);
|
|
|
|
printer << " : " << getType();
|
2019-06-18 11:15:55 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::VariableOp::verify() {
|
2019-06-18 11:15:55 -07:00
|
|
|
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
|
|
|
|
// object. It cannot be Generic. It must be the same as the Storage Class
|
|
|
|
// operand of the Result Type."
|
2022-02-02 10:06:30 -08:00
|
|
|
if (storage_class() != spirv::StorageClass::Function) {
|
|
|
|
return emitOpError(
|
2019-08-17 10:19:48 -07:00
|
|
|
"can only be used to model function-level variables. Use "
|
2021-03-04 16:17:12 -05:00
|
|
|
"spv.GlobalVariable for module-level variables.");
|
2019-08-17 10:19:48 -07:00
|
|
|
}
|
2019-06-18 11:15:55 -07:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
auto pointerType = pointer().getType().cast<spirv::PointerType>();
|
|
|
|
if (storage_class() != pointerType.getStorageClass())
|
|
|
|
return emitOpError(
|
2019-06-18 11:15:55 -07:00
|
|
|
"storage class must match result pointer's storage class");
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (getNumOperands() != 0) {
|
2019-06-18 11:15:55 -07:00
|
|
|
// SPIR-V spec: "Initializer must be an <id> from a constant instruction or
|
|
|
|
// a global (module scope) OpVariable instruction".
|
2022-02-02 10:06:30 -08:00
|
|
|
auto *initOp = getOperand(0).getDefiningOp();
|
2020-06-29 07:31:48 -07:00
|
|
|
if (!initOp || !isa<spirv::ConstantOp, // for normal constant
|
|
|
|
spirv::ReferenceOfOp, // for spec constant
|
|
|
|
spirv::AddressOfOp>(initOp))
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("initializer must be the result of a "
|
|
|
|
"constant or spv.GlobalVariable op");
|
2019-08-20 13:33:41 -07:00
|
|
|
}
|
|
|
|
|
2020-07-07 01:35:23 -07:00
|
|
|
// TODO: generate these strings using ODS.
|
2022-02-02 10:06:30 -08:00
|
|
|
auto *op = getOperation();
|
2020-04-14 18:54:23 -07:00
|
|
|
auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
|
|
|
|
stringifyDecoration(spirv::Decoration::DescriptorSet));
|
|
|
|
auto bindingName = llvm::convertToSnakeFromCamelCase(
|
|
|
|
stringifyDecoration(spirv::Decoration::Binding));
|
|
|
|
auto builtInName = llvm::convertToSnakeFromCamelCase(
|
|
|
|
stringifyDecoration(spirv::Decoration::BuiltIn));
|
2019-08-20 13:33:41 -07:00
|
|
|
|
|
|
|
for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
|
|
|
|
if (op->getAttr(attr))
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("cannot have '")
|
2021-03-04 16:17:12 -05:00
|
|
|
<< attr << "' attribute (only allowed in spv.GlobalVariable)";
|
2019-06-18 11:15:55 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-02-02 11:08:39 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.VectorShuffle
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::VectorShuffleOp::verify() {
|
|
|
|
VectorType resultType = getType().cast<VectorType>();
|
2021-02-02 11:08:39 -05:00
|
|
|
|
|
|
|
size_t numResultElements = resultType.getNumElements();
|
2022-02-02 10:06:30 -08:00
|
|
|
if (numResultElements != components().size())
|
|
|
|
return emitOpError("result type element count (")
|
2021-02-02 11:08:39 -05:00
|
|
|
<< numResultElements
|
|
|
|
<< ") mismatch with the number of component selectors ("
|
2022-02-02 10:06:30 -08:00
|
|
|
<< components().size() << ")";
|
2021-02-02 11:08:39 -05:00
|
|
|
|
|
|
|
size_t totalSrcElements =
|
2022-02-02 10:06:30 -08:00
|
|
|
vector1().getType().cast<VectorType>().getNumElements() +
|
|
|
|
vector2().getType().cast<VectorType>().getNumElements();
|
2021-02-02 11:08:39 -05:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
for (const auto &selector : components().getAsValueRange<IntegerAttr>()) {
|
2021-02-02 11:08:39 -05:00
|
|
|
uint32_t index = selector.getZExtValue();
|
|
|
|
if (index >= totalSrcElements &&
|
|
|
|
index != std::numeric_limits<uint32_t>().max())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("component selector ")
|
2021-02-02 11:08:39 -05:00
|
|
|
<< index << " out of range: expected to be in [0, "
|
|
|
|
<< totalSrcElements << ") or 0xffffffff";
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-05-19 19:07:21 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.CooperativeMatrixLoadNV
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
|
2020-05-19 19:07:21 -07:00
|
|
|
Type strideType = parser.getBuilder().getIntegerType(32);
|
|
|
|
Type columnMajorType = parser.getBuilder().getIntegerType(1);
|
2020-07-31 08:02:21 -07:00
|
|
|
Type ptrType;
|
2020-05-19 19:07:21 -07:00
|
|
|
Type elementType;
|
2020-07-31 08:02:21 -07:00
|
|
|
if (parser.parseOperandList(operandInfo, 3) ||
|
2020-05-19 19:07:21 -07:00
|
|
|
parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
|
2020-07-31 08:02:21 -07:00
|
|
|
parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
|
2020-05-19 19:07:21 -07:00
|
|
|
return failure();
|
|
|
|
}
|
2020-08-28 21:14:20 +02:00
|
|
|
if (parser.resolveOperands(operandInfo,
|
|
|
|
{ptrType, strideType, columnMajorType},
|
|
|
|
parser.getNameLoc(), state.operands)) {
|
2020-05-19 19:07:21 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
state.addTypes(elementType);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) {
|
|
|
|
printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
|
2020-05-19 19:07:21 -07:00
|
|
|
// Print optional memory access attribute.
|
2022-02-07 17:54:04 -08:00
|
|
|
if (auto memAccess = memory_access())
|
2020-05-19 19:07:21 -07:00
|
|
|
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
|
2022-02-07 17:54:04 -08:00
|
|
|
printer << " : " << pointer().getType() << " as " << getType();
|
2020-05-19 19:07:21 -07:00
|
|
|
}
|
|
|
|
|
2020-05-21 11:35:32 -07:00
|
|
|
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
|
|
|
|
Type coopMatrix) {
|
2020-07-31 08:02:21 -07:00
|
|
|
Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
|
|
|
|
if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
|
2020-05-21 11:35:32 -07:00
|
|
|
return op->emitError(
|
2020-07-31 08:02:21 -07:00
|
|
|
"Pointer must point to a scalar or vector type but provided ")
|
|
|
|
<< pointeeType;
|
|
|
|
spirv::StorageClass storage =
|
|
|
|
pointer.cast<spirv::PointerType>().getStorageClass();
|
|
|
|
if (storage != spirv::StorageClass::Workgroup &&
|
|
|
|
storage != spirv::StorageClass::StorageBuffer &&
|
|
|
|
storage != spirv::StorageClass::PhysicalStorageBuffer)
|
|
|
|
return op->emitError(
|
|
|
|
"Pointer storage class must be Workgroup, StorageBuffer or "
|
|
|
|
"PhysicalStorageBufferEXT but provided ")
|
|
|
|
<< stringifyStorageClass(storage);
|
2020-05-21 11:35:32 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
|
|
|
|
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
|
|
|
|
result().getType());
|
|
|
|
}
|
|
|
|
|
2020-05-21 11:35:32 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.CooperativeMatrixStoreNV
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2022-03-21 21:42:13 +01:00
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
|
2020-05-21 11:35:32 -07:00
|
|
|
Type strideType = parser.getBuilder().getIntegerType(32);
|
|
|
|
Type columnMajorType = parser.getBuilder().getIntegerType(1);
|
2020-07-31 08:02:21 -07:00
|
|
|
Type ptrType;
|
2020-05-21 11:35:32 -07:00
|
|
|
Type elementType;
|
2020-07-31 08:02:21 -07:00
|
|
|
if (parser.parseOperandList(operandInfo, 4) ||
|
2020-05-21 11:35:32 -07:00
|
|
|
parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
|
2020-07-31 08:02:21 -07:00
|
|
|
parser.parseType(ptrType) || parser.parseComma() ||
|
2020-05-21 11:35:32 -07:00
|
|
|
parser.parseType(elementType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
2020-08-28 21:14:20 +02:00
|
|
|
if (parser.resolveOperands(
|
|
|
|
operandInfo, {ptrType, elementType, strideType, columnMajorType},
|
|
|
|
parser.getNameLoc(), state.operands)) {
|
2020-05-21 11:35:32 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
|
|
|
|
printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
|
|
|
|
<< columnmajor();
|
2020-05-21 11:35:32 -07:00
|
|
|
// Print optional memory access attribute.
|
2022-02-07 17:54:04 -08:00
|
|
|
if (auto memAccess = memory_access())
|
2020-05-21 11:35:32 -07:00
|
|
|
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
|
2022-02-07 17:54:04 -08:00
|
|
|
printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
|
2020-05-21 11:35:32 -07:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
|
|
|
|
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
|
|
|
|
object().getType());
|
|
|
|
}
|
|
|
|
|
2020-05-21 11:35:32 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.CooperativeMatrixMulAddNV
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static LogicalResult
|
|
|
|
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
|
|
|
|
if (op.c().getType() != op.result().getType())
|
2020-06-26 09:37:30 -04:00
|
|
|
return op.emitOpError("result and third operand must have the same type");
|
2020-05-21 11:35:32 -07:00
|
|
|
auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
|
|
|
|
auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
|
|
|
|
auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
|
|
|
|
auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
|
|
|
|
if (typeA.getRows() != typeR.getRows() ||
|
|
|
|
typeA.getColumns() != typeB.getRows() ||
|
|
|
|
typeB.getColumns() != typeR.getColumns())
|
|
|
|
return op.emitOpError("matrix size must match");
|
|
|
|
if (typeR.getScope() != typeA.getScope() ||
|
|
|
|
typeR.getScope() != typeB.getScope() ||
|
|
|
|
typeR.getScope() != typeC.getScope())
|
|
|
|
return op.emitOpError("matrix scope must match");
|
2020-06-18 13:05:09 -07:00
|
|
|
if (typeA.getElementType() != typeB.getElementType() ||
|
2020-05-21 11:35:32 -07:00
|
|
|
typeR.getElementType() != typeC.getElementType())
|
|
|
|
return op.emitOpError("matrix element type must match");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
|
|
|
|
return verifyCoopMatrixMulAdd(*this);
|
|
|
|
}
|
|
|
|
|
[mlir][spirv] Add MatrixTimesScalar operation
Summary:
- Define the MatrixTimesScalar operation and add roundtrip tests.
- Added a new base class for matrix-specific operations to avoid invalid operands type mismatch check.
- Created a separate Matrix arithmetic operations td file to add more operations in the future.
- Augmented the automatically generated verify method to print more fine-grained error messages.
- Made minor Updates to the matrix type tests.
Reviewers: antiagainst, rriddle, mravishankar
Reviewed By: antiagainst
Subscribers: mehdi_amini, jpienaar, shauheen, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, bader, grosul1, frgossen, Kayjukh, jurahul, msifontes
Tags: #mlir
Differential Revision: https://reviews.llvm.org/D81677
2020-06-15 21:50:18 -04:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.MatrixTimesScalar
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::MatrixTimesScalarOp::verify() {
|
[mlir][spirv] Add MatrixTimesScalar operation
Summary:
- Define the MatrixTimesScalar operation and add roundtrip tests.
- Added a new base class for matrix-specific operations to avoid invalid operands type mismatch check.
- Created a separate Matrix arithmetic operations td file to add more operations in the future.
- Augmented the automatically generated verify method to print more fine-grained error messages.
- Made minor Updates to the matrix type tests.
Reviewers: antiagainst, rriddle, mravishankar
Reviewed By: antiagainst
Subscribers: mehdi_amini, jpienaar, shauheen, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, bader, grosul1, frgossen, Kayjukh, jurahul, msifontes
Tags: #mlir
Differential Revision: https://reviews.llvm.org/D81677
2020-06-15 21:50:18 -04:00
|
|
|
// We already checked that result and matrix are both of matrix type in the
|
|
|
|
// auto-generated verify method.
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
|
|
|
|
auto resultMatrix = result().getType().cast<spirv::MatrixType>();
|
2020-07-07 21:32:39 -04:00
|
|
|
|
|
|
|
// Check that the scalar type is the same as the matrix element type.
|
2022-02-02 10:06:30 -08:00
|
|
|
if (scalar().getType() != inputMatrix.getElementType())
|
|
|
|
return emitError("input matrix components' type and scaling value must "
|
|
|
|
"have the same type");
|
2020-07-07 21:32:39 -04:00
|
|
|
|
|
|
|
// Note that the next three checks could be done using the AllTypesMatch
|
|
|
|
// trait in the Op definition file but it generates a vague error message.
|
|
|
|
|
|
|
|
// Check that the input and result matrices have the same columns' count
|
|
|
|
if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("input and result matrices must have the same "
|
|
|
|
"number of columns");
|
2020-07-07 21:32:39 -04:00
|
|
|
|
|
|
|
// Check that the input and result matrices' have the same rows count
|
|
|
|
if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("input and result matrices' columns must have "
|
|
|
|
"the same size");
|
2020-07-07 21:32:39 -04:00
|
|
|
|
|
|
|
// Check that the input and result matrices' have the same component type
|
|
|
|
if (inputMatrix.getElementType() != resultMatrix.getElementType())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("input and result matrices' columns must have "
|
|
|
|
"the same component type");
|
2020-06-26 09:37:30 -04:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.CopyMemory
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
|
2021-08-28 03:03:15 +00:00
|
|
|
printer << ' ';
|
2020-06-26 09:37:30 -04:00
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
StringRef targetStorageClass = stringifyStorageClass(
|
|
|
|
target().getType().cast<spirv::PointerType>().getStorageClass());
|
|
|
|
printer << " \"" << targetStorageClass << "\" " << target() << ", ";
|
|
|
|
|
|
|
|
StringRef sourceStorageClass = stringifyStorageClass(
|
|
|
|
source().getType().cast<spirv::PointerType>().getStorageClass());
|
|
|
|
printer << " \"" << sourceStorageClass << "\" " << source();
|
2020-06-26 09:37:30 -04:00
|
|
|
|
|
|
|
SmallVector<StringRef, 4> elidedAttrs;
|
2022-02-07 17:54:04 -08:00
|
|
|
printMemoryAccessAttribute(*this, printer, elidedAttrs);
|
|
|
|
printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
|
|
|
|
source_memory_access(), source_alignment());
|
2020-06-26 09:37:30 -04:00
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
|
2020-06-26 09:37:30 -04:00
|
|
|
|
|
|
|
Type pointeeType =
|
2022-02-07 17:54:04 -08:00
|
|
|
target().getType().cast<spirv::PointerType>().getPointeeType();
|
2020-06-26 09:37:30 -04:00
|
|
|
printer << " : " << pointeeType;
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2020-06-26 09:37:30 -04:00
|
|
|
spirv::StorageClass targetStorageClass;
|
2022-03-21 21:42:13 +01:00
|
|
|
OpAsmParser::UnresolvedOperand targetPtrInfo;
|
2020-06-26 09:37:30 -04:00
|
|
|
|
|
|
|
spirv::StorageClass sourceStorageClass;
|
2022-03-21 21:42:13 +01:00
|
|
|
OpAsmParser::UnresolvedOperand sourcePtrInfo;
|
2020-06-26 09:37:30 -04:00
|
|
|
|
|
|
|
Type elementType;
|
|
|
|
|
|
|
|
if (parseEnumStrAttr(targetStorageClass, parser) ||
|
|
|
|
parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
|
|
|
|
parseEnumStrAttr(sourceStorageClass, parser) ||
|
|
|
|
parser.parseOperand(sourcePtrInfo) ||
|
2020-07-09 19:08:51 -04:00
|
|
|
parseMemoryAccessAttributes(parser, state)) {
|
2020-06-26 09:37:30 -04:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2020-07-09 19:08:51 -04:00
|
|
|
if (!parser.parseOptionalComma()) {
|
|
|
|
// Parse 2nd memory access attributes.
|
|
|
|
if (parseSourceMemoryAccessAttributes(parser, state)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (parser.parseColon() || parser.parseType(elementType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
if (parser.parseOptionalAttrDict(state.attributes))
|
|
|
|
return failure();
|
|
|
|
|
2020-06-26 09:37:30 -04:00
|
|
|
auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
|
|
|
|
auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
|
|
|
|
|
|
|
|
if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
|
|
|
|
parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
[mlir][spirv] Add MatrixTimesScalar operation
Summary:
- Define the MatrixTimesScalar operation and add roundtrip tests.
- Added a new base class for matrix-specific operations to avoid invalid operands type mismatch check.
- Created a separate Matrix arithmetic operations td file to add more operations in the future.
- Augmented the automatically generated verify method to print more fine-grained error messages.
- Made minor Updates to the matrix type tests.
Reviewers: antiagainst, rriddle, mravishankar
Reviewed By: antiagainst
Subscribers: mehdi_amini, jpienaar, shauheen, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, bader, grosul1, frgossen, Kayjukh, jurahul, msifontes
Tags: #mlir
Differential Revision: https://reviews.llvm.org/D81677
2020-06-15 21:50:18 -04:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::CopyMemoryOp::verify() {
|
2020-06-26 09:37:30 -04:00
|
|
|
Type targetType =
|
2022-02-02 10:06:30 -08:00
|
|
|
target().getType().cast<spirv::PointerType>().getPointeeType();
|
2020-06-26 09:37:30 -04:00
|
|
|
|
|
|
|
Type sourceType =
|
2022-02-02 10:06:30 -08:00
|
|
|
source().getType().cast<spirv::PointerType>().getPointeeType();
|
2020-06-26 09:37:30 -04:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (targetType != sourceType)
|
|
|
|
return emitOpError("both operands must be pointers to the same type");
|
2020-06-26 09:37:30 -04:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
if (failed(verifyMemoryAccessAttribute(*this)))
|
2020-07-09 19:08:51 -04:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
// TODO - According to the spec:
|
|
|
|
//
|
|
|
|
// If two masks are present, the first applies to Target and cannot include
|
|
|
|
// MakePointerVisible, and the second applies to Source and cannot include
|
|
|
|
// MakePointerAvailable.
|
|
|
|
//
|
|
|
|
// Add such verification here.
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
return verifySourceMemoryAccessAttribute(*this);
|
2020-06-26 09:37:30 -04:00
|
|
|
}
|
|
|
|
|
2020-06-24 20:34:34 -04:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.Transpose
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::TransposeOp::verify() {
|
|
|
|
auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
|
|
|
|
auto resultMatrix = result().getType().cast<spirv::MatrixType>();
|
2020-06-24 20:34:34 -04:00
|
|
|
|
|
|
|
// Verify that the input and output matrices have correct shapes.
|
2020-07-07 21:32:39 -04:00
|
|
|
if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("input matrix rows count must be equal to "
|
|
|
|
"output matrix columns count");
|
2020-07-07 21:32:39 -04:00
|
|
|
|
|
|
|
if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("input matrix columns count must be equal to "
|
|
|
|
"output matrix rows count");
|
2020-07-07 21:32:39 -04:00
|
|
|
|
|
|
|
// Verify that the input and output matrices have the same component type
|
|
|
|
if (inputMatrix.getElementType() != resultMatrix.getElementType())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("input and output matrices must have the same "
|
|
|
|
"component type");
|
2020-07-07 21:32:39 -04:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.MatrixTimesMatrix
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::MatrixTimesMatrixOp::verify() {
|
|
|
|
auto leftMatrix = leftmatrix().getType().cast<spirv::MatrixType>();
|
|
|
|
auto rightMatrix = rightmatrix().getType().cast<spirv::MatrixType>();
|
|
|
|
auto resultMatrix = result().getType().cast<spirv::MatrixType>();
|
2020-07-07 21:32:39 -04:00
|
|
|
|
|
|
|
// left matrix columns' count and right matrix rows' count must be equal
|
|
|
|
if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("left matrix columns' count must be equal to "
|
|
|
|
"the right matrix rows' count");
|
2020-07-07 21:32:39 -04:00
|
|
|
|
|
|
|
// right and result matrices columns' count must be the same
|
|
|
|
if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError(
|
2020-07-07 21:32:39 -04:00
|
|
|
"right and result matrices must have equal columns' count");
|
|
|
|
|
|
|
|
// right and result matrices component type must be the same
|
|
|
|
if (rightMatrix.getElementType() != resultMatrix.getElementType())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("right and result matrices' component type must"
|
|
|
|
" be the same");
|
2020-07-07 21:32:39 -04:00
|
|
|
|
|
|
|
// left and result matrices component type must be the same
|
|
|
|
if (leftMatrix.getElementType() != resultMatrix.getElementType())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("left and result matrices' component type"
|
|
|
|
" must be the same");
|
2020-07-07 21:32:39 -04:00
|
|
|
|
|
|
|
// left and result matrices rows count must be the same
|
|
|
|
if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("left and result matrices must have equal rows' count");
|
2020-07-07 21:32:39 -04:00
|
|
|
|
2020-06-24 20:34:34 -04:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-10-02 14:56:17 -04:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-03-04 16:01:28 -05:00
|
|
|
// spv.SpecConstantComposite
|
2020-10-02 14:56:17 -04:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2020-10-02 14:56:17 -04:00
|
|
|
|
|
|
|
StringAttr compositeName;
|
|
|
|
if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
|
|
|
|
state.attributes))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
if (parser.parseLParen())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
SmallVector<Attribute, 4> constituents;
|
|
|
|
|
|
|
|
do {
|
|
|
|
// The name of the constituent attribute isn't important
|
|
|
|
const char *attrName = "spec_const";
|
|
|
|
FlatSymbolRefAttr specConstRef;
|
|
|
|
NamedAttrList attrs;
|
|
|
|
|
|
|
|
if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
constituents.push_back(specConstRef);
|
|
|
|
} while (!parser.parseOptionalComma());
|
|
|
|
|
|
|
|
if (parser.parseRParen())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
state.addAttribute(kCompositeSpecConstituentsName,
|
|
|
|
parser.getBuilder().getArrayAttr(constituents));
|
|
|
|
|
|
|
|
Type type;
|
|
|
|
if (parser.parseColonType(type))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
state.addAttribute(kTypeAttrName, TypeAttr::get(type));
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
|
2021-08-28 03:03:15 +00:00
|
|
|
printer << " ";
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printSymbolName(sym_name());
|
2020-10-02 14:56:17 -04:00
|
|
|
printer << " (";
|
2022-02-07 17:54:04 -08:00
|
|
|
auto constituents = this->constituents().getValue();
|
2020-10-02 14:56:17 -04:00
|
|
|
|
|
|
|
if (!constituents.empty())
|
|
|
|
llvm::interleaveComma(constituents, printer);
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
printer << ") : " << type();
|
2020-10-02 14:56:17 -04:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::SpecConstantCompositeOp::verify() {
|
|
|
|
auto cType = type().dyn_cast<spirv::CompositeType>();
|
|
|
|
auto constituents = this->constituents().getValue();
|
2020-10-02 14:56:17 -04:00
|
|
|
|
|
|
|
if (!cType)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("result type must be a composite type, but provided ")
|
|
|
|
<< type();
|
2020-10-02 14:56:17 -04:00
|
|
|
|
|
|
|
if (cType.isa<spirv::CooperativeMatrixNVType>())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("unsupported composite type ") << cType;
|
2021-08-29 14:22:24 -07:00
|
|
|
if (constituents.size() != cType.getNumElements())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("has incorrect number of operands: expected ")
|
2020-10-02 14:56:17 -04:00
|
|
|
<< cType.getNumElements() << ", but provided "
|
|
|
|
<< constituents.size();
|
|
|
|
|
|
|
|
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
2021-08-29 14:22:24 -07:00
|
|
|
auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
|
2020-10-02 14:56:17 -04:00
|
|
|
|
|
|
|
auto constituentSpecConstOp =
|
|
|
|
dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
|
2022-02-02 10:06:30 -08:00
|
|
|
(*this)->getParentOp(), constituent.getAttr()));
|
2020-10-02 14:56:17 -04:00
|
|
|
|
|
|
|
if (constituentSpecConstOp.default_value().getType() !=
|
|
|
|
cType.getElementType(index))
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("has incorrect types of operands: expected ")
|
2020-10-02 14:56:17 -04:00
|
|
|
<< cType.getElementType(index) << ", but provided "
|
|
|
|
<< constituentSpecConstOp.default_value().getType();
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-12-08 09:02:02 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-12-16 08:20:24 -05:00
|
|
|
// spv.SpecConstantOperation
|
2020-12-08 09:02:02 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2020-12-16 08:20:24 -05:00
|
|
|
Region *body = state.addRegion();
|
2020-12-08 09:02:02 -05:00
|
|
|
|
2020-12-16 08:20:24 -05:00
|
|
|
if (parser.parseKeyword("wraps"))
|
|
|
|
return failure();
|
2020-12-08 09:02:02 -05:00
|
|
|
|
2020-12-16 08:20:24 -05:00
|
|
|
body->push_back(new Block);
|
|
|
|
Block &block = body->back();
|
|
|
|
Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
|
2020-12-08 09:02:02 -05:00
|
|
|
|
2020-12-16 08:20:24 -05:00
|
|
|
if (!wrappedOp)
|
|
|
|
return failure();
|
2020-12-08 09:02:02 -05:00
|
|
|
|
2021-09-29 17:47:08 -07:00
|
|
|
OpBuilder builder(parser.getContext());
|
2020-12-16 08:20:24 -05:00
|
|
|
builder.setInsertionPointToEnd(&block);
|
|
|
|
builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
|
|
|
|
state.location = wrappedOp->getLoc();
|
2020-12-08 09:02:02 -05:00
|
|
|
|
2020-12-16 08:20:24 -05:00
|
|
|
state.addTypes(wrappedOp->getResult(0).getType());
|
|
|
|
|
|
|
|
if (parser.parseOptionalAttrDict(state.attributes))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
return success();
|
2020-12-08 09:02:02 -05:00
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
|
2021-08-28 03:03:15 +00:00
|
|
|
printer << " wraps ";
|
2022-02-07 17:54:04 -08:00
|
|
|
printer.printGenericOp(&body().front().front());
|
2020-12-08 09:02:02 -05:00
|
|
|
}
|
|
|
|
|
2022-03-10 22:10:45 +00:00
|
|
|
LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
|
2022-02-02 10:06:30 -08:00
|
|
|
Block &block = getRegion().getBlocks().front();
|
2020-12-08 09:02:02 -05:00
|
|
|
|
|
|
|
if (block.getOperations().size() != 2)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("expected exactly 2 nested ops");
|
2020-12-08 09:02:02 -05:00
|
|
|
|
|
|
|
Operation &enclosedOp = block.getOperations().front();
|
|
|
|
|
2021-01-08 14:48:48 +01:00
|
|
|
if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("invalid enclosed op");
|
2020-12-08 09:02:02 -05:00
|
|
|
|
2020-12-16 08:20:24 -05:00
|
|
|
for (auto operand : enclosedOp.getOperands())
|
2021-01-11 07:37:34 +01:00
|
|
|
if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
|
|
|
|
spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2020-12-16 08:20:24 -05:00
|
|
|
"invalid operand, must be defined by a constant operation");
|
2020-12-08 09:02:02 -05:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-02-17 09:00:28 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
2022-07-21 13:02:45 -04:00
|
|
|
// spv.GL.FrexpStruct
|
2021-02-17 09:00:28 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
2022-02-02 10:06:30 -08:00
|
|
|
|
2022-07-21 13:02:45 -04:00
|
|
|
LogicalResult spirv::GLFrexpStructOp::verify() {
|
2022-02-02 10:06:30 -08:00
|
|
|
spirv::StructType structTy = result().getType().dyn_cast<spirv::StructType>();
|
2021-02-17 09:00:28 -05:00
|
|
|
|
|
|
|
if (structTy.getNumElements() != 2)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("result type must be a struct type with two memebers");
|
2021-02-17 09:00:28 -05:00
|
|
|
|
|
|
|
Type significandTy = structTy.getElementType(0);
|
|
|
|
Type exponentTy = structTy.getElementType(1);
|
|
|
|
VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
|
|
|
|
IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
Type operandTy = operand().getType();
|
2021-02-17 09:00:28 -05:00
|
|
|
VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
|
|
|
|
FloatType operandFTy = operandTy.dyn_cast<FloatType>();
|
|
|
|
|
|
|
|
if (significandTy != operandTy)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("member zero of the resulting struct type must be the "
|
|
|
|
"same type as the operand");
|
2021-02-17 09:00:28 -05:00
|
|
|
|
|
|
|
if (exponentVecTy) {
|
|
|
|
IntegerType componentIntTy =
|
|
|
|
exponentVecTy.getElementType().dyn_cast<IntegerType>();
|
2022-06-12 17:56:43 +00:00
|
|
|
if (!componentIntTy || componentIntTy.getWidth() != 32)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("member one of the resulting struct type must"
|
|
|
|
"be a scalar or vector of 32 bit integer type");
|
2022-06-12 17:56:43 +00:00
|
|
|
} else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("member one of the resulting struct type "
|
|
|
|
"must be a scalar or vector of 32 bit integer type");
|
2021-02-17 09:00:28 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
// Check that the two member types have the same number of components
|
|
|
|
if (operandVecTy && exponentVecTy &&
|
|
|
|
(exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
if (operandFTy && exponentIntTy)
|
|
|
|
return success();
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("member one of the resulting struct type must have the same "
|
|
|
|
"number of components as the operand type");
|
2021-02-17 09:00:28 -05:00
|
|
|
}
|
|
|
|
|
2021-02-24 13:07:05 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
2022-07-21 13:02:45 -04:00
|
|
|
// spv.GL.Ldexp
|
2021-02-24 13:07:05 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-07-21 13:02:45 -04:00
|
|
|
LogicalResult spirv::GLLdexpOp::verify() {
|
2022-02-02 10:06:30 -08:00
|
|
|
Type significandType = x().getType();
|
|
|
|
Type exponentType = exp().getType();
|
2021-02-24 13:07:05 -05:00
|
|
|
|
|
|
|
if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("operands must both be scalars or vectors");
|
2021-02-24 13:07:05 -05:00
|
|
|
|
|
|
|
auto getNumElements = [](Type type) -> unsigned {
|
|
|
|
if (auto vectorType = type.dyn_cast<VectorType>())
|
|
|
|
return vectorType.getNumElements();
|
|
|
|
return 1;
|
|
|
|
};
|
|
|
|
|
|
|
|
if (getNumElements(significandType) != getNumElements(exponentType))
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("operands must have the same number of elements");
|
2021-02-24 13:07:05 -05:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-04-08 19:22:25 -04:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ImageDrefGather
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::ImageDrefGatherOp::verify() {
|
|
|
|
VectorType resultType = result().getType().cast<VectorType>();
|
|
|
|
auto sampledImageType =
|
|
|
|
sampledimage().getType().cast<spirv::SampledImageType>();
|
2021-04-08 19:22:25 -04:00
|
|
|
auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
|
|
|
|
|
|
|
|
if (resultType.getNumElements() != 4)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("result type must be a vector of four components");
|
2021-04-08 19:22:25 -04:00
|
|
|
|
|
|
|
Type elementType = resultType.getElementType();
|
|
|
|
Type sampledElementType = imageType.getElementType();
|
|
|
|
if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2021-04-08 19:22:25 -04:00
|
|
|
"the component type of result must be the same as sampled type of the "
|
|
|
|
"underlying image type");
|
|
|
|
|
|
|
|
spirv::Dim imageDim = imageType.getDim();
|
|
|
|
spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
|
|
|
|
|
|
|
|
if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
|
|
|
|
imageDim != spirv::Dim::Rect)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError(
|
2021-04-08 19:22:25 -04:00
|
|
|
"the Dim operand of the underlying image type must be 2D, Cube, or "
|
|
|
|
"Rect");
|
|
|
|
|
|
|
|
if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitOpError("the MS operand of the underlying image type must be 0");
|
|
|
|
|
|
|
|
spirv::ImageOperandsAttr attr = imageoperandsAttr();
|
|
|
|
auto operandArguments = operand_arguments();
|
|
|
|
|
|
|
|
return verifyImageOperands(*this, attr, operandArguments);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ShiftLeftLogicalOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
2021-04-08 19:22:25 -04:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::ShiftLeftLogicalOp::verify() {
|
|
|
|
return verifyShiftOp(*this);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ShiftRightArithmeticOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::ShiftRightArithmeticOp::verify() {
|
|
|
|
return verifyShiftOp(*this);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ShiftRightLogicalOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
2021-09-02 02:39:05 +08:00
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::ShiftRightLogicalOp::verify() {
|
|
|
|
return verifyShiftOp(*this);
|
2021-04-08 19:22:25 -04:00
|
|
|
}
|
|
|
|
|
2021-05-13 13:06:53 -04:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.ImageQuerySize
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::ImageQuerySizeOp::verify() {
|
|
|
|
spirv::ImageType imageType = image().getType().cast<spirv::ImageType>();
|
|
|
|
Type resultType = result().getType();
|
2021-05-13 13:06:53 -04:00
|
|
|
|
|
|
|
spirv::Dim dim = imageType.getDim();
|
|
|
|
spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
|
|
|
|
spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
|
|
|
|
switch (dim) {
|
|
|
|
case spirv::Dim::Dim1D:
|
|
|
|
case spirv::Dim::Dim2D:
|
|
|
|
case spirv::Dim::Dim3D:
|
|
|
|
case spirv::Dim::Cube:
|
2022-06-12 17:56:43 +00:00
|
|
|
if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
|
|
|
|
samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
|
|
|
|
samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError(
|
2021-05-13 13:06:53 -04:00
|
|
|
"if Dim is 1D, 2D, 3D, or Cube, "
|
|
|
|
"it must also have either an MS of 1 or a Sampled of 0 or 2");
|
|
|
|
break;
|
|
|
|
case spirv::Dim::Buffer:
|
|
|
|
case spirv::Dim::Rect:
|
|
|
|
break;
|
|
|
|
default:
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("the Dim operand of the image type must "
|
|
|
|
"be 1D, 2D, 3D, Buffer, Cube, or Rect");
|
2021-05-13 13:06:53 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
unsigned componentNumber = 0;
|
|
|
|
switch (dim) {
|
|
|
|
case spirv::Dim::Dim1D:
|
|
|
|
case spirv::Dim::Buffer:
|
|
|
|
componentNumber = 1;
|
|
|
|
break;
|
|
|
|
case spirv::Dim::Dim2D:
|
|
|
|
case spirv::Dim::Cube:
|
|
|
|
case spirv::Dim::Rect:
|
|
|
|
componentNumber = 2;
|
|
|
|
break;
|
|
|
|
case spirv::Dim::Dim3D:
|
|
|
|
componentNumber = 3;
|
|
|
|
break;
|
|
|
|
default:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
|
|
|
|
componentNumber += 1;
|
|
|
|
|
|
|
|
unsigned resultComponentNumber = 1;
|
|
|
|
if (auto resultVectorType = resultType.dyn_cast<VectorType>())
|
|
|
|
resultComponentNumber = resultVectorType.getNumElements();
|
|
|
|
|
|
|
|
if (componentNumber != resultComponentNumber)
|
2022-02-02 10:06:30 -08:00
|
|
|
return emitError("expected the result to have ")
|
2021-05-13 13:06:53 -04:00
|
|
|
<< componentNumber << " component(s), but found "
|
|
|
|
<< resultComponentNumber << " component(s)";
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-08-14 11:57:02 +03:00
|
|
|
static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
|
|
|
|
OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2022-03-21 21:42:13 +01:00
|
|
|
OpAsmParser::UnresolvedOperand ptrInfo;
|
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
|
2021-08-14 11:57:02 +03:00
|
|
|
Type type;
|
|
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
SmallVector<Type, 4> indicesTypes;
|
|
|
|
|
|
|
|
if (parser.parseOperand(ptrInfo) ||
|
|
|
|
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
|
|
|
|
parser.parseColonType(type) ||
|
|
|
|
parser.resolveOperand(ptrInfo, type, state.operands))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Check that the provided indices list is not empty before parsing their
|
|
|
|
// type list.
|
|
|
|
if (indicesInfo.empty())
|
|
|
|
return emitError(state.location) << opName << " expected element";
|
|
|
|
|
|
|
|
if (parser.parseComma() || parser.parseTypeList(indicesTypes))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Check that the indices types list is not empty and that it has a one-to-one
|
|
|
|
// mapping to the provided indices.
|
|
|
|
if (indicesTypes.size() != indicesInfo.size())
|
|
|
|
return emitError(state.location)
|
|
|
|
<< opName
|
|
|
|
<< " indices types' count must be equal to indices info count";
|
|
|
|
|
|
|
|
if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto resultType = getElementPtrType(
|
|
|
|
type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
|
|
|
|
if (!resultType)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
state.addTypes(resultType);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename Op>
|
|
|
|
static auto concatElemAndIndices(Op op) {
|
|
|
|
SmallVector<Value> ret(op.indices().size() + 1);
|
|
|
|
ret[0] = op.element();
|
|
|
|
llvm::copy(op.indices(), ret.begin() + 1);
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.InBoundsPtrAccessChainOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
|
|
|
|
OperationState &state,
|
|
|
|
Value basePtr, Value element,
|
|
|
|
ValueRange indices) {
|
|
|
|
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
|
|
|
|
assert(type && "Unable to deduce return type based on basePtr and indices");
|
|
|
|
build(builder, state, type, basePtr, element, indices);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2021-08-14 11:57:02 +03:00
|
|
|
return parsePtrAccessChainOpImpl(
|
|
|
|
spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
|
|
|
|
printAccessChain(*this, concatElemAndIndices(*this), printer);
|
2021-08-14 11:57:02 +03:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::InBoundsPtrAccessChainOp::verify() {
|
|
|
|
return verifyAccessChain(*this, indices());
|
2021-08-14 11:57:02 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.PtrAccessChainOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
|
|
|
|
Value basePtr, Value element,
|
|
|
|
ValueRange indices) {
|
|
|
|
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
|
|
|
|
assert(type && "Unable to deduce return type based on basePtr and indices");
|
|
|
|
build(builder, state, type, basePtr, element, indices);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &state) {
|
2021-08-14 11:57:02 +03:00
|
|
|
return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
|
|
|
|
parser, state);
|
|
|
|
}
|
|
|
|
|
2022-02-07 17:54:04 -08:00
|
|
|
void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) {
|
|
|
|
printAccessChain(*this, concatElemAndIndices(*this), printer);
|
2021-08-14 11:57:02 +03:00
|
|
|
}
|
|
|
|
|
2022-02-02 10:06:30 -08:00
|
|
|
LogicalResult spirv::PtrAccessChainOp::verify() {
|
|
|
|
return verifyAccessChain(*this, indices());
|
2021-08-14 11:57:02 +03:00
|
|
|
}
|
|
|
|
|
2022-03-08 15:58:31 -05:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// spv.VectorTimesScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult spirv::VectorTimesScalarOp::verify() {
|
|
|
|
if (vector().getType() != getType())
|
|
|
|
return emitOpError("vector operand and result type mismatch");
|
|
|
|
auto scalarType = getType().cast<VectorType>().getElementType();
|
|
|
|
if (scalar().getType() != scalarType)
|
|
|
|
return emitOpError("scalar operand and result element type match");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-27 16:24:33 -05:00
|
|
|
// TableGen'erated operation interfaces for querying versions, extensions, and
|
|
|
|
// capabilities.
|
2020-12-17 10:55:45 -05:00
|
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
|
2019-12-27 16:24:33 -05:00
|
|
|
|
|
|
|
// TablenGen'erated operation definitions.
|
2019-05-26 05:43:20 -07:00
|
|
|
#define GET_OP_CLASSES
|
2020-12-17 10:55:45 -05:00
|
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
|
2019-05-26 05:43:20 -07:00
|
|
|
|
2020-09-14 20:01:07 +00:00
|
|
|
namespace mlir {
|
|
|
|
namespace spirv {
|
2019-12-27 16:24:33 -05:00
|
|
|
// TableGen'erated operation availability interface implementations.
|
2020-12-17 10:55:45 -05:00
|
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
|
2019-05-26 05:43:20 -07:00
|
|
|
} // namespace spirv
|
|
|
|
} // namespace mlir
|