Marius Brehler 56774bdda5 [mlir] Replace deprecated 'getAttrs'
'getAttrs' has been explicitly marked deprecated. This patch refactors
to use Operation::getAttrs().

Reviewed By: csigg

Differential Revision: https://reviews.llvm.org/D97546
2021-02-26 14:52:40 +01:00

755 lines
30 KiB
C++

//===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utilities used to lower to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include <functional>
#define DEBUG_TYPE "mlir-spirv-conversion"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
/// Checks that `candidates` extension requirements are possible to be satisfied
/// with the given `targetEnv`.
///
/// `candidates` is a vector of vector for extension requirements following
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
/// convention.
template <typename LabelT>
static LogicalResult checkExtensionRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
for (const auto &ors : candidates) {
if (targetEnv.allows(ors))
continue;
SmallVector<StringRef, 4> extStrings;
for (spirv::Extension ext : ors)
extStrings.push_back(spirv::stringifyExtension(ext));
LLVM_DEBUG(llvm::dbgs()
<< label << " illegal: requires at least one extension in ["
<< llvm::join(extStrings, ", ")
<< "] but none allowed in target environment\n");
return failure();
}
return success();
}
/// Checks that `candidates`capability requirements are possible to be satisfied
/// with the given `isAllowedFn`.
///
/// `candidates` is a vector of vector for capability requirements following
/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
/// convention.
template <typename LabelT>
static LogicalResult checkCapabilityRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
for (const auto &ors : candidates) {
if (targetEnv.allows(ors))
continue;
SmallVector<StringRef, 4> capStrings;
for (spirv::Capability cap : ors)
capStrings.push_back(spirv::stringifyCapability(cap));
LLVM_DEBUG(llvm::dbgs()
<< label << " illegal: requires at least one capability in ["
<< llvm::join(capStrings, ", ")
<< "] but none allowed in target environment\n");
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
// Convert to 32-bit integers for now. Might need a way to control this in
// future.
// TODO: It is probably better to make it 64-bit integers. To
// this some support is needed in SPIR-V dialect for Conversion
// instructions. The Vulkan spec requires the builtins like
// GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
// SExtended to 64-bit for index computations.
return IntegerType::get(context, 32);
}
/// Mapping between SPIR-V storage classes to memref memory spaces.
///
/// Note: memref does not have a defined semantics for each memory space; it
/// depends on the context where it is used. There are no particular reasons
/// behind the number assignments; we try to follow NVVM conventions and largely
/// give common storage classes a smaller number. The hope is use symbolic
/// memory space representation eventually after memref supports it.
// TODO: swap Generic and StorageBuffer assignment to be more akin
// to NVVM.
#define STORAGE_SPACE_MAP_LIST(MAP_FN) \
MAP_FN(spirv::StorageClass::Generic, 1) \
MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
MAP_FN(spirv::StorageClass::Workgroup, 3) \
MAP_FN(spirv::StorageClass::Uniform, 4) \
MAP_FN(spirv::StorageClass::Private, 5) \
MAP_FN(spirv::StorageClass::Function, 6) \
MAP_FN(spirv::StorageClass::PushConstant, 7) \
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
MAP_FN(spirv::StorageClass::Input, 9) \
MAP_FN(spirv::StorageClass::Output, 10) \
MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \
MAP_FN(spirv::StorageClass::AtomicCounter, 12) \
MAP_FN(spirv::StorageClass::Image, 13) \
MAP_FN(spirv::StorageClass::CallableDataNV, 14) \
MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15) \
MAP_FN(spirv::StorageClass::RayPayloadNV, 16) \
MAP_FN(spirv::StorageClass::HitAttributeNV, 17) \
MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18) \
MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19) \
MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)
unsigned
SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
#define STORAGE_SPACE_MAP_FN(storage, space) \
case storage: \
return space;
switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
#undef STORAGE_SPACE_MAP_FN
llvm_unreachable("unhandled storage class!");
}
Optional<spirv::StorageClass>
SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
#define STORAGE_SPACE_MAP_FN(storage, space) \
case space: \
return storage;
switch (space) {
STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
default:
return llvm::None;
}
#undef STORAGE_SPACE_MAP_FN
}
#undef STORAGE_SPACE_MAP_LIST
// TODO: This is a utility function that should probably be
// exposed by the SPIR-V dialect. Keeping it local till the use case arises.
static Optional<int64_t> getTypeNumBytes(Type t) {
if (t.isa<spirv::ScalarType>()) {
auto bitWidth = t.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
// "There is no physical size or bit pattern defined for values with boolean
// type. If they are stored (in conjunction with OpVariable), they can only
// be used with logical addressing operations, not physical, and only with
// non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
// Private, Function, Input, and Output."
if (bitWidth == 1) {
return llvm::None;
}
return bitWidth / 8;
}
if (auto vecType = t.dyn_cast<VectorType>()) {
auto elementSize = getTypeNumBytes(vecType.getElementType());
if (!elementSize)
return llvm::None;
return vecType.getNumElements() * *elementSize;
}
if (auto memRefType = t.dyn_cast<MemRefType>()) {
// TODO: Layout should also be controlled by the ABI attributes. For now
// using the layout from MemRef.
int64_t offset;
SmallVector<int64_t, 4> strides;
if (!memRefType.hasStaticShape() ||
failed(getStridesAndOffset(memRefType, strides, offset))) {
return llvm::None;
}
// To get the size of the memref object in memory, the total size is the
// max(stride * dimension-size) computed for all dimensions times the size
// of the element.
auto elementSize = getTypeNumBytes(memRefType.getElementType());
if (!elementSize) {
return llvm::None;
}
if (memRefType.getRank() == 0) {
return elementSize;
}
auto dims = memRefType.getShape();
if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
offset == MemRefType::getDynamicStrideOrOffset() ||
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
return llvm::None;
}
int64_t memrefSize = -1;
for (auto shape : enumerate(dims)) {
memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
}
return (offset + memrefSize) * elementSize.getValue();
} else if (auto tensorType = t.dyn_cast<TensorType>()) {
if (!tensorType.hasStaticShape()) {
return llvm::None;
}
auto elementSize = getTypeNumBytes(tensorType.getElementType());
if (!elementSize) {
return llvm::None;
}
int64_t size = elementSize.getValue();
for (auto shape : tensorType.getShape()) {
size *= shape;
}
return size;
}
// TODO: Add size computation for other types.
return llvm::None;
}
Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
return getTypeNumBytes(t);
}
/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
static Optional<Type>
convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
Optional<spirv::StorageClass> storageClass = {}) {
// Get extension and capability requirements for the given type.
SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
type.getExtensions(extensions, storageClass);
type.getCapabilities(capabilities, storageClass);
// If all requirements are met, then we can accept this type as-is.
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
return type;
// Otherwise we need to adjust the type, which really means adjusting the
// bitwidth given this is a scalar type.
// TODO: We are unconditionally converting the bitwidth here,
// this might be okay for non-interface types (i.e., types used in
// Private/Function storage classes), but not for interface types (i.e.,
// types used in StorageBuffer/Uniform/PushConstant/etc. storage classes).
// This is because the later actually affects the ABI contract with the
// runtime. So we may want to expose a control on SPIRVTypeConverter to fail
// conversion if we cannot change there.
if (auto floatType = type.dyn_cast<FloatType>()) {
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
return Builder(targetEnv.getContext()).getF32Type();
}
auto intType = type.cast<IntegerType>();
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
return IntegerType::get(targetEnv.getContext(), /*width=*/32,
intType.getSignedness());
}
/// Converts a vector `type` to a suitable type under the given `targetEnv`.
static Optional<Type>
convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
if (type.getRank() == 1 && type.getNumElements() == 1)
return type.getElementType();
if (!spirv::CompositeType::isValid(type)) {
// TODO: Vector types with more than four elements can be translated into
// array types.
LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
return llvm::None;
}
// Get extension and capability requirements for the given type.
SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
// If all requirements are met, then we can accept this type as-is.
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
return type;
auto elementType = convertScalarType(
targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass);
if (elementType)
return VectorType::get(type.getShape(), *elementType);
return llvm::None;
}
/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
///
/// Note that this is mainly for lowering constant tensors.In SPIR-V one can
/// create composite constants with OpConstantComposite to embed relative large
/// constant values and use OpCompositeExtract and OpCompositeInsert to
/// manipulate, like what we do for vectors.
static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
TensorType type) {
// TODO: Handle dynamic shapes.
if (!type.hasStaticShape()) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: dynamic shape unimplemented\n");
return llvm::None;
}
auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot convert non-scalar element type\n");
return llvm::None;
}
Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
Optional<int64_t> tensorSize = getTypeNumBytes(type);
if (!scalarSize || !tensorSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element count\n");
return llvm::None;
}
auto arrayElemCount = *tensorSize / *scalarSize;
auto arrayElemType = convertScalarType(targetEnv, scalarType);
if (!arrayElemType)
return llvm::None;
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
if (!arrayElemSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce converted element size\n");
return llvm::None;
}
return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
}
static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
MemRefType type) {
Optional<spirv::StorageClass> storageClass =
SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
if (!storageClass) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot convert memory space\n");
return llvm::None;
}
Optional<Type> arrayElemType;
Type elementType = type.getElementType();
if (auto vecType = elementType.dyn_cast<VectorType>()) {
arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
} else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
} else {
LLVM_DEBUG(
llvm::dbgs()
<< type
<< " unhandled: can only convert scalar or vector element type\n");
return llvm::None;
}
if (!arrayElemType)
return llvm::None;
Optional<int64_t> elementSize = getTypeNumBytes(elementType);
if (!elementSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element size\n");
return llvm::None;
}
if (!type.hasStaticShape()) {
auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
// Wrap in a struct to satisfy Vulkan interface requirements.
auto structType = spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
}
Optional<int64_t> memrefSize = getTypeNumBytes(type);
if (!memrefSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element count\n");
return llvm::None;
}
auto arrayElemCount = *memrefSize / *elementSize;
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
if (!arrayElemSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce converted element size\n");
return llvm::None;
}
auto arrayType =
spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
// workgroup storage class do not need the struct to be laid out explicitly.
auto structType = *storageClass == spirv::StorageClass::Workgroup
? spirv::StructType::get(arrayType)
: spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
}
SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
: targetEnv(targetAttr) {
// Add conversions. The order matters here: later ones will be tried earlier.
// All other cases failed. Then we cannot convert this type.
addConversion([](Type type) { return llvm::None; });
// Allow all SPIR-V dialect specific types. This assumes all builtin types
// adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
// were tried before.
//
// TODO: this assumes that the SPIR-V types are valid to use in
// the given target environment, which should be the case if the whole
// pipeline is driven by the same target environment. Still, we probably still
// want to validate and convert to be safe.
addConversion([](spirv::SPIRVType type) { return type; });
addConversion([](IndexType indexType) {
return SPIRVTypeConverter::getIndexType(indexType.getContext());
});
addConversion([this](IntegerType intType) -> Optional<Type> {
if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
return convertScalarType(targetEnv, scalarType);
return llvm::None;
});
addConversion([this](FloatType floatType) -> Optional<Type> {
if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
return convertScalarType(targetEnv, scalarType);
return llvm::None;
});
addConversion([this](VectorType vectorType) {
return convertVectorType(targetEnv, vectorType);
});
addConversion([this](TensorType tensorType) {
return convertTensorType(targetEnv, tensorType);
});
addConversion([this](MemRefType memRefType) {
return convertMemrefType(targetEnv, memRefType);
});
}
//===----------------------------------------------------------------------===//
// FuncOp Conversion Patterns
//===----------------------------------------------------------------------===//
namespace {
/// A pattern for rewriting function signature to convert arguments of functions
/// to be of valid SPIR-V types.
class FuncOpConversion final : public OpConversionPattern<FuncOp> {
public:
using OpConversionPattern<FuncOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult
FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto fnType = funcOp.getType();
if (fnType.getNumResults() > 1)
return failure();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
for (auto argType : enumerate(fnType.getInputs())) {
auto convertedType = getTypeConverter()->convertType(argType.value());
if (!convertedType)
return failure();
signatureConverter.addInputs(argType.index(), convertedType);
}
Type resultType;
if (fnType.getNumResults() == 1)
resultType = getTypeConverter()->convertType(fnType.getResult(0));
// Create the converted spv.func op.
auto newFuncOp = rewriter.create<spirv::FuncOp>(
funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
resultType ? TypeRange(resultType)
: TypeRange()));
// Copy over all attributes other than the function name and type.
for (const auto &namedAttr : funcOp->getAttrs()) {
if (namedAttr.first != impl::getTypeAttrName() &&
namedAttr.first != SymbolTable::getSymbolAttrName())
newFuncOp->setAttr(namedAttr.first, namedAttr.second);
}
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(
&newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
return failure();
rewriter.eraseOp(funcOp);
return success();
}
void mlir::populateBuiltinFuncToSPIRVPatterns(
MLIRContext *context, SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<FuncOpConversion>(typeConverter, context);
}
//===----------------------------------------------------------------------===//
// Builtin Variables
//===----------------------------------------------------------------------===//
static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
spirv::BuiltIn builtin) {
// Look through all global variables in the given `body` block and check if
// there is a spv.globalVariable that has the same `builtin` attribute.
for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
spirv::SPIRVDialect::getAttributeName(
spirv::Decoration::BuiltIn))) {
auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
if (varBuiltIn && varBuiltIn.getValue() == builtin) {
return varOp;
}
}
}
return nullptr;
}
/// Gets name of global variable for a builtin.
static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
}
/// Gets or inserts a global variable for a builtin within `body` block.
static spirv::GlobalVariableOp
getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
OpBuilder &builder) {
if (auto varOp = getBuiltinVariable(body, builtin))
return varOp;
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(&body);
spirv::GlobalVariableOp newVarOp;
switch (builtin) {
case spirv::BuiltIn::NumWorkgroups:
case spirv::BuiltIn::WorkgroupSize:
case spirv::BuiltIn::WorkgroupId:
case spirv::BuiltIn::LocalInvocationId:
case spirv::BuiltIn::GlobalInvocationId: {
auto ptrType = spirv::PointerType::get(
VectorType::get({3}, builder.getIntegerType(32)),
spirv::StorageClass::Input);
std::string name = getBuiltinVarName(builtin);
newVarOp =
builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
break;
}
case spirv::BuiltIn::SubgroupId:
case spirv::BuiltIn::NumSubgroups:
case spirv::BuiltIn::SubgroupSize: {
auto ptrType = spirv::PointerType::get(builder.getIntegerType(32),
spirv::StorageClass::Input);
std::string name = getBuiltinVarName(builtin);
newVarOp =
builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
break;
}
default:
emitError(loc, "unimplemented builtin variable generation for ")
<< stringifyBuiltIn(builtin);
}
return newVarOp;
}
Value mlir::spirv::getBuiltinVariableValue(Operation *op,
spirv::BuiltIn builtin,
OpBuilder &builder) {
Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
if (!parent) {
op->emitError("expected operation to be within a module-like op");
return nullptr;
}
spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(
*parent->getRegion(0).begin(), op->getLoc(), builtin, builder);
Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
}
//===----------------------------------------------------------------------===//
// Index calculation
//===----------------------------------------------------------------------===//
spirv::AccessChainOp mlir::spirv::getElementPtr(
SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
ValueRange indices, Location loc, OpBuilder &builder) {
// Get base and offset of the MemRefType and verify they are static.
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(baseType, strides, offset)) ||
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
offset == MemRefType::getDynamicStrideOrOffset()) {
return nullptr;
}
auto indexType = typeConverter.getIndexType(builder.getContext());
SmallVector<Value, 2> linearizedIndices;
// Add a '0' at the start to index into the struct.
auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
linearizedIndices.push_back(zero);
if (baseType.getRank() == 0) {
linearizedIndices.push_back(zero);
} else {
// TODO: Instead of this logic, use affine.apply and add patterns for
// lowering affine.apply to standard ops. These will get lowered to SPIR-V
// ops by the DialectConversion framework.
Value ptrLoc = builder.create<spirv::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, offset));
assert(indices.size() == strides.size() &&
"must provide indices for all dimensions");
for (auto index : llvm::enumerate(indices)) {
Value strideVal = builder.create<spirv::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
Value update =
builder.create<spirv::IMulOp>(loc, strideVal, index.value());
ptrLoc = builder.create<spirv::IAddOp>(loc, ptrLoc, update);
}
linearizedIndices.push_back(ptrLoc);
}
return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
}
//===----------------------------------------------------------------------===//
// Set ABI attributes for lowering entry functions.
//===----------------------------------------------------------------------===//
LogicalResult
mlir::spirv::setABIAttrs(spirv::FuncOp funcOp,
spirv::EntryPointABIAttr entryPointInfo,
ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
// Set the attributes for argument and the function.
StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
}
funcOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
return success();
}
//===----------------------------------------------------------------------===//
// SPIR-V ConversionTarget
//===----------------------------------------------------------------------===//
std::unique_ptr<spirv::SPIRVConversionTarget>
spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
std::unique_ptr<SPIRVConversionTarget> target(
// std::make_unique does not work here because the constructor is private.
new SPIRVConversionTarget(targetAttr));
SPIRVConversionTarget *targetPtr = target.get();
target->addDynamicallyLegalDialect<SPIRVDialect>(
// We need to capture the raw pointer here because it is stable:
// target will be destroyed once this function is returned.
[targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
return target;
}
spirv::SPIRVConversionTarget::SPIRVConversionTarget(
spirv::TargetEnvAttr targetAttr)
: ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
// Make sure this op is available at the given version. Ops not implementing
// QueryMinVersionInterface/QueryMaxVersionInterface are available to all
// SPIR-V versions.
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
if (minVersion.getMinVersion() > this->targetEnv.getVersion()) {
LLVM_DEBUG(llvm::dbgs()
<< op->getName() << " illegal: requiring min version "
<< spirv::stringifyVersion(minVersion.getMinVersion())
<< "\n");
return false;
}
if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) {
LLVM_DEBUG(llvm::dbgs()
<< op->getName() << " illegal: requiring max version "
<< spirv::stringifyVersion(maxVersion.getMaxVersion())
<< "\n");
return false;
}
// Make sure this op's required extensions are allowed to use. Ops not
// implementing QueryExtensionInterface do not require extensions to be
// available.
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
extensions.getExtensions())))
return false;
// Make sure this op's required extensions are allowed to use. Ops not
// implementing QueryCapabilityInterface do not require capabilities to be
// available.
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
capabilities.getCapabilities())))
return false;
SmallVector<Type, 4> valueTypes;
valueTypes.append(op->operand_type_begin(), op->operand_type_end());
valueTypes.append(op->result_type_begin(), op->result_type_end());
// Special treatment for global variables, whose type requirements are
// conveyed by type attributes.
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
valueTypes.push_back(globalVar.type());
// Make sure the op's operands/results use types that are allowed by the
// target environment.
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
typeExtensions.clear();
valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
typeExtensions)))
return false;
typeCapabilities.clear();
valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
typeCapabilities)))
return false;
}
return true;
}