llvm-project/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

1817 lines
67 KiB
C++
Raw Normal View History

//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OpDefinitionsGen uses the description of operations to generate C++
// definitions for ops.
//
//===----------------------------------------------------------------------===//
#include "OpFormatGen.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/SideEffects.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
static const char *const tblgenNamePrefix = "tblgen_";
static const char *const generatedArgName = "odsArg";
static const char *const builderOpState = "odsState";
// The logic to calculate the actual value range for a declared operand/result
// of an op with variadic operands/results. Note that this logic is not for
// general use; it assumes all variadic operands/results must have the same
// number of values.
//
// {0}: The list of whether each declared operand/result is variadic.
// {1}: The total number of non-variadic operands/results.
// {2}: The total number of variadic operands/results.
// {3}: The total number of actual values.
// {4}: The begin iterator of the actual values.
// {5}: "operand" or "result".
const char *sameVariadicSizeValueRangeCalcCode = R"(
bool isVariadic[] = {{{0}};
int prevVariadicCount = 0;
for (unsigned i = 0; i < index; ++i)
if (isVariadic[i]) ++prevVariadicCount;
// Calculate how many dynamic values a static variadic {5} corresponds to.
// This assumes all static variadic {5}s have the same dynamic value count.
int variadicSize = ({3} - {1}) / {2};
// `index` passed in as the parameter is the static index which counts each
// {5} (variadic or not) as size 1. So here for each previous static variadic
// {5}, we need to offset by (variadicSize - 1) to get where the dynamic
// value pack for this static {5} starts.
int offset = index + (variadicSize - 1) * prevVariadicCount;
int size = isVariadic[index] ? variadicSize : 1;
return {{std::next({4}, offset), std::next({4}, offset + size)};
)";
// The logic to calculate the actual value range for a declared operand/result
// of an op with variadic operands/results. Note that this logic is assumes
// the op has an attribute specifying the size of each operand/result segment
// (variadic or not).
//
// {0}: The name of the attribute specifying the segment sizes.
// {1}: The begin iterator of the actual values.
const char *attrSizedSegmentValueRangeCalcCode = R"(
auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
unsigned start = 0;
for (unsigned i = 0; i < index; ++i)
start += (*(sizeAttr.begin() + i)).getZExtValue();
unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue();
return {{std::next({1}, start), std::next({1}, end)};
)";
static const char *const opCommentHeader = R"(
//===----------------------------------------------------------------------===//
// {0} {1}
//===----------------------------------------------------------------------===//
)";
//===----------------------------------------------------------------------===//
// Utility structs and functions
//===----------------------------------------------------------------------===//
// Replaces all occurrences of `match` in `str` with `substitute`.
static std::string replaceAllSubstrs(std::string str, const std::string &match,
const std::string &substitute) {
std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
str = str.replace(matchLoc, match.size(), substitute);
scanLoc = matchLoc + substitute.size();
}
return str;
}
// Returns whether the record has a value of the given name that can be returned
// via getValueAsString.
static inline bool hasStringAttribute(const Record &record,
StringRef fieldName) {
auto valueInit = record.getValueInit(fieldName);
return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit);
}
static std::string getArgumentName(const Operator &op, int index) {
const auto &operand = op.getOperand(index);
if (!operand.name.empty())
return std::string(operand.name);
else
return std::string(formatv("{0}_{1}", generatedArgName, index));
}
// Returns true if we can use unwrapped value for the given `attr` in builders.
static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
return attr.getReturnType() != attr.getStorageType() &&
// We need to wrap the raw value into an attribute in the builder impl
// so we need to make sure that the attribute specifies how to do that.
!attr.getConstBuilderTemplate().empty();
}
//===----------------------------------------------------------------------===//
// Op emitter
//===----------------------------------------------------------------------===//
namespace {
// Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope {
public:
IfDefScope(StringRef name, raw_ostream &os) : name(name), os(os) {
os << "#ifdef " << name << "\n"
<< "#undef " << name << "\n\n";
}
~IfDefScope() { os << "\n#endif // " << name << "\n\n"; }
private:
StringRef name;
raw_ostream &os;
};
} // end anonymous namespace
namespace {
// Helper class to emit a record into the given output stream.
class OpEmitter {
public:
static void emitDecl(const Operator &op, raw_ostream &os);
static void emitDef(const Operator &op, raw_ostream &os);
private:
OpEmitter(const Operator &op);
void emitDecl(raw_ostream &os);
void emitDef(raw_ostream &os);
// Generates the OpAsmOpInterface for this operation if possible.
void genOpAsmInterface();
// Generates the `getOperationName` method for this op.
void genOpNameGetter();
// Generates getters for the attributes.
void genAttrGetters();
// Generates setter for the attributes.
void genAttrSetters();
// Generates getters for named operands.
void genNamedOperandGetters();
// Generates getters for named results.
void genNamedResultGetters();
// Generates getters for named regions.
void genNamedRegionGetters();
// Generates getters for named successors.
void genNamedSuccessorGetters();
// Generates builder methods for the operation.
void genBuilder();
// Generates the build() method that takes each operand/attribute
// as a stand-alone parameter.
void genSeparateArgParamBuilder();
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. The generated build() method uses first operand's
// type as all results' types.
void genUseOperandAsResultTypeSeparateParamBuilder();
// Generates the build() method that takes all operands/attributes
// collectively as one parameter. The generated build() method uses first
// operand's type as all results' types.
void genUseOperandAsResultTypeCollectiveParamBuilder();
// Generates the build() method that takes aggregate operands/attributes
// parameters. This build() method uses inferred types as result types.
// Requires: The type needs to be inferable via InferTypeOpInterface.
2020-02-28 10:59:34 -08:00
void genInferredTypeCollectiveParamBuilder();
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. The generated build() method uses first attribute's
// type as all result's types.
void genUseAttrAsResultTypeBuilder();
// Generates the build() method that takes all result types collectively as
// one parameter. Similarly for operands and attributes.
void genCollectiveParamBuilder();
// The kind of parameter to generate for result types in builders.
enum class TypeParamKind {
None, // No result type in parameter list.
Separate, // A separate parameter for each result type.
Collective, // An ArrayRef<Type> for all result types.
};
// The kind of parameter to generate for attributes in builders.
enum class AttrParamKind {
WrappedAttr, // A wrapped MLIR Attribute instance.
UnwrappedValue, // A raw value without MLIR Attribute wrapper.
};
// Builds the parameter list for build() method of this op. This method writes
// to `paramList` the comma-separated parameter list and updates
// `resultTypeNames` with the names for parameters for specifying result
// types. The given `typeParamKind` and `attrParamKind` controls how result
// types and attributes are placed in the parameter list.
void buildParamList(std::string &paramList,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
// Adds op arguments and regions into operation state for build() methods.
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
bool isRawValueAttr = false);
// Generates canonicalizer declaration for the operation.
void genCanonicalizerDecls();
// Generates the folder declaration for the operation.
void genFolderDecls();
// Generates the parser for the operation.
void genParser();
// Generates the printer for the operation.
void genPrinter();
// Generates verify method for the operation.
void genVerifier();
// Generates verify statements for operands and results in the operation.
// The generated code will be attached to `body`.
void genOperandResultVerifier(OpMethodBody &body,
Operator::value_range values,
StringRef valueKind);
// Generates verify statements for regions in the operation.
// The generated code will be attached to `body`.
void genRegionVerifier(OpMethodBody &body);
// Generates verify statements for successors in the operation.
// The generated code will be attached to `body`.
void genSuccessorVerifier(OpMethodBody &body);
// Generates the traits used by the object.
void genTraits();
// Generate the OpInterface methods.
void genOpInterfaceMethods();
// Generate the side effect interface methods.
void genSideEffectInterfaceMethods();
private:
// The TableGen record for this op.
// TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
// it should rather go through the Operator for better abstraction.
const Record &def;
// The wrapper operator class for querying information from this op.
Operator op;
// The C++ code builder for this op
OpClass opClass;
// The format context for verification code generation.
FmtContext verifyCtx;
};
} // end anonymous namespace
OpEmitter::OpEmitter(const Operator &op)
: def(op.getDef()), op(op),
opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
verifyCtx.withOp("(*this->getOperation())");
genTraits();
// Generate C++ code for various op methods. The order here determines the
// methods in the generated file.
genOpAsmInterface();
genOpNameGetter();
genNamedOperandGetters();
genNamedResultGetters();
genNamedRegionGetters();
genNamedSuccessorGetters();
genAttrGetters();
genAttrSetters();
genBuilder();
genParser();
genPrinter();
genVerifier();
genCanonicalizerDecls();
genFolderDecls();
genOpInterfaceMethods();
generateOpFormat(op, opClass);
genSideEffectInterfaceMethods();
}
void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
OpEmitter(op).emitDecl(os);
}
void OpEmitter::emitDef(const Operator &op, raw_ostream &os) {
OpEmitter(op).emitDef(os);
}
void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
void OpEmitter::genAttrGetters() {
FmtContext fctx;
fctx.withBuilder("mlir::Builder(this->getContext())");
// Emit the derived attribute body.
auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
auto &method = opClass.newMethod(attr.getReturnType(), name);
auto &body = method.body();
body << " " << attr.getDerivedCodeBody() << "\n";
};
// Emit with return type specified.
auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
auto &method = opClass.newMethod(attr.getReturnType(), name);
auto &body = method.body();
body << " auto attr = " << name << "Attr();\n";
if (attr.hasDefaultValue()) {
// Returns the default value if not set.
// TODO: this is inefficient, we are recreating the attribute for every
// call. This should be set instead.
std::string defaultValue = std::string(
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
body << " if (!attr)\n return "
<< tgfmt(attr.getConvertFromStorageCall(),
&fctx.withSelf(defaultValue))
<< ";\n";
}
body << " return "
<< tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
<< ";\n";
};
// Generate raw named accessor type. This is a wrapper class that allows
// referring to the attributes via accessors instead of having to use
// the string interface for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
auto &method =
opClass.newMethod(attr.getStorageType(), (name + "Attr").str());
auto &body = method.body();
body << " return this->getAttr(\"" << name << "\").";
if (attr.isOptional() || attr.hasDefaultValue())
body << "dyn_cast_or_null<";
else
body << "cast<";
body << attr.getStorageType() << ">();";
};
for (auto &namedAttr : op.getAttributes()) {
const auto &name = namedAttr.name;
const auto &attr = namedAttr.attr;
if (attr.isDerivedAttr()) {
emitDerivedAttr(name, attr);
} else {
emitAttrWithStorageType(name, attr);
emitAttrWithReturnType(name, attr);
}
}
// Generate helper method to query whether a named attribute is a derived
// attribute. This enables, for example, avoiding adding an attribute that
// overlaps with a derived attribute.
auto derivedAttr = make_filter_range(op.getAttributes(),
[](const NamedAttribute &namedAttr) {
return namedAttr.attr.isDerivedAttr();
});
if (!derivedAttr.empty()) {
opClass.addTrait("DerivedAttributeOpInterface::Trait");
auto &method = opClass.newMethod("bool", "isDerivedAttribute",
"StringRef name", OpMethod::MP_Static);
auto &body = method.body();
for (auto namedAttr : derivedAttr)
body << " if (name == \"" << namedAttr.name << "\") return true;\n";
body << " return false;";
}
}
void OpEmitter::genAttrSetters() {
// Generate raw named setter type. This is a wrapper class that allows setting
// to the attributes via setters instead of having to use the string interface
// for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
auto &method = opClass.newMethod("void", (name + "Attr").str(),
(attr.getStorageType() + " attr").str());
auto &body = method.body();
body << " this->getOperation()->setAttr(\"" << name << "\", attr);";
};
for (auto &namedAttr : op.getAttributes()) {
const auto &name = namedAttr.name;
const auto &attr = namedAttr.attr;
if (!attr.isDerivedAttr())
emitAttrWithStorageType(name, attr);
}
}
// Generates the named operand getter methods for the given Operator `op` and
// puts them in `opClass`. Uses `rangeType` as the return type of getters that
// return a range of operands (individual operands are `Value ` and each
// element in the range must also be `Value `); use `rangeBeginCall` to get
// an iterator to the beginning of the operand range; use `rangeSizeCall` to
// obtain the number of operands. `getOperandCallPattern` contains the code
// necessary to obtain a single operand whose position will be substituted
// instead of
// "{0}" marker in the pattern. Note that the pattern should work for any kind
// of ops, in particular for one-operand ops that may not have the
// `getOperand(unsigned)` method.
static void generateNamedOperandGetters(const Operator &op, Class &opClass,
StringRef rangeType,
StringRef rangeBeginCall,
StringRef rangeSizeCall,
StringRef getOperandCallPattern) {
const int numOperands = op.getNumOperands();
const int numVariadicOperands = op.getNumVariableLengthOperands();
const int numNormalOperands = numOperands - numVariadicOperands;
const auto *sameVariadicSize =
op.getTrait("OpTrait::SameVariadicOperandSize");
const auto *attrSizedOperands =
op.getTrait("OpTrait::AttrSizedOperandSegments");
if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
"specification over their sizes");
}
if (numVariadicOperands < 2 && attrSizedOperands) {
PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
"to use 'AttrSizedOperandSegments' trait");
}
if (attrSizedOperands && sameVariadicSize) {
PrintFatalError(op.getLoc(),
"op cannot have both 'AttrSizedOperandSegments' and "
"'SameVariadicOperandSize' traits");
}
// First emit a "sink" getter method upon which we layer all nicer named
// getter methods.
auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
if (numVariadicOperands == 0) {
// We still need to match the return type, which is a range.
m.body() << " return {std::next(" << rangeBeginCall
<< ", index), std::next(" << rangeBeginCall << ", index + 1)};";
} else if (attrSizedOperands) {
m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
"operand_segment_sizes", rangeBeginCall);
} else {
// Because the op can have arbitrarily interleaved variadic and non-variadic
// operands, we need to embed a list in the "sink" getter method for
// calculation at run-time.
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(numOperands);
for (int i = 0; i < numOperands; ++i)
isVariadic.push_back(op.getOperand(i).isVariableLength() ? "true"
: "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
numNormalOperands, numVariadicOperands, rangeSizeCall,
rangeBeginCall, "operand");
}
// Then we emit nicer named getter methods by redirecting to the "sink" getter
// method.
for (int i = 0; i != numOperands; ++i) {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
if (operand.isOptional()) {
auto &m = opClass.newMethod("Value", operand.name);
m.body() << " auto operands = getODSOperands(" << i << ");\n"
<< " return operands.empty() ? Value() : *operands.begin();";
} else if (operand.isVariadic()) {
auto &m = opClass.newMethod(rangeType, operand.name);
m.body() << " return getODSOperands(" << i << ");";
} else {
auto &m = opClass.newMethod("Value", operand.name);
m.body() << " return *getODSOperands(" << i << ").begin();";
}
}
}
void OpEmitter::genNamedOperandGetters() {
if (op.getTrait("OpTrait::AttrSizedOperandSegments"))
opClass.setHasOperandAdaptorClass(false);
generateNamedOperandGetters(
op, opClass, /*rangeType=*/"Operation::operand_range",
/*rangeBeginCall=*/"getOperation()->operand_begin()",
/*rangeSizeCall=*/"getOperation()->getNumOperands()",
/*getOperandCallPattern=*/"getOperation()->getOperand({0})");
}
void OpEmitter::genNamedResultGetters() {
const int numResults = op.getNumResults();
const int numVariadicResults = op.getNumVariableLengthResults();
const int numNormalResults = numResults - numVariadicResults;
// If we have more than one variadic results, we need more complicated logic
// to calculate the value range for each result.
const auto *sameVariadicSize = op.getTrait("OpTrait::SameVariadicResultSize");
const auto *attrSizedResults =
op.getTrait("OpTrait::AttrSizedResultSegments");
if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
"specification over their sizes");
}
if (numVariadicResults < 2 && attrSizedResults) {
PrintFatalError(op.getLoc(), "op must have at least two variadic results "
"to use 'AttrSizedResultSegments' trait");
}
if (attrSizedResults && sameVariadicSize) {
PrintFatalError(op.getLoc(),
"op cannot have both 'AttrSizedResultSegments' and "
"'SameVariadicResultSize' traits");
}
auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
"unsigned index");
if (numVariadicResults == 0) {
m.body() << " return {std::next(getOperation()->result_begin(), index), "
"std::next(getOperation()->result_begin(), index + 1)};";
} else if (attrSizedResults) {
m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
"result_segment_sizes",
"getOperation()->result_begin()");
} else {
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(numResults);
for (int i = 0; i < numResults; ++i)
isVariadic.push_back(op.getResult(i).isVariableLength() ? "true"
: "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
numNormalResults, numVariadicResults,
"getOperation()->getNumResults()",
"getOperation()->result_begin()", "result");
}
for (int i = 0; i != numResults; ++i) {
const auto &result = op.getResult(i);
if (result.name.empty())
continue;
if (result.isOptional()) {
auto &m = opClass.newMethod("Value", result.name);
m.body() << " auto results = getODSResults(" << i << ");\n"
<< " return results.empty() ? Value() : *results.begin();";
} else if (result.isVariadic()) {
auto &m = opClass.newMethod("Operation::result_range", result.name);
m.body() << " return getODSResults(" << i << ");";
} else {
auto &m = opClass.newMethod("Value", result.name);
m.body() << " return *getODSResults(" << i << ").begin();";
}
}
}
void OpEmitter::genNamedRegionGetters() {
unsigned numRegions = op.getNumRegions();
for (unsigned i = 0; i < numRegions; ++i) {
const auto &region = op.getRegion(i);
if (region.name.empty())
continue;
// Generate the accessors for a varidiadic region.
if (region.isVariadic()) {
auto &m = opClass.newMethod("MutableArrayRef<Region>", region.name);
m.body() << formatv(
" return this->getOperation()->getRegions().drop_front({0});", i);
continue;
}
auto &m = opClass.newMethod("Region &", region.name);
m.body() << formatv(" return this->getOperation()->getRegion({0});", i);
}
}
void OpEmitter::genNamedSuccessorGetters() {
unsigned numSuccessors = op.getNumSuccessors();
for (unsigned i = 0; i < numSuccessors; ++i) {
const NamedSuccessor &successor = op.getSuccessor(i);
if (successor.name.empty())
continue;
// Generate the accessors for a variadic successor list.
if (successor.isVariadic()) {
auto &m = opClass.newMethod("SuccessorRange", successor.name);
m.body() << formatv(
" return {std::next(this->getOperation()->successor_begin(), {0}), "
"this->getOperation()->successor_end()};",
i);
continue;
}
auto &m = opClass.newMethod("Block *", successor.name);
m.body() << formatv(" return this->getOperation()->getSuccessor({0});", i);
}
}
static bool canGenerateUnwrappedBuilder(Operator &op) {
// If this op does not have native attributes at all, return directly to avoid
// redefining builders.
if (op.getNumNativeAttributes() == 0)
return false;
bool canGenerate = false;
// We are generating builders that take raw values for attributes. We need to
// make sure the native attributes have a meaningful "unwrapped" value type
// different from the wrapped mlir::Attribute type to avoid redefining
// builders. This checks for the op has at least one such native attribute.
for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
NamedAttribute &namedAttr = op.getAttribute(i);
if (canUseUnwrappedRawValue(namedAttr.attr)) {
canGenerate = true;
break;
}
}
return canGenerate;
}
void OpEmitter::genSeparateArgParamBuilder() {
SmallVector<AttrParamKind, 2> attrBuilderType;
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
if (canGenerateUnwrappedBuilder(op))
attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
// Emit with separate builders with or without unwrapped attributes and/or
// inferring result type.
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
bool inferType) {
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, paramKind, attrType);
auto &m =
opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
auto &body = m.body();
genCodeForAddingArgAndRegionForBuilder(
body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
// Push all result types to the operation state
if (inferType) {
// Generate builder that infers type too.
// TODO(jpienaar): Subsume this with general checking if type can be
2020-02-28 10:59:34 -08:00
// inferred automatically.
// TODO(jpienaar): Expand to handle regions.
body << formatv(R"(
2020-02-28 10:59:34 -08:00
SmallVector<Type, 2> inferredReturnTypes;
if (succeeded({0}::inferReturnTypes(odsBuilder->getContext(),
{1}.location, {1}.operands, {1}.attributes,
2020-02-28 10:59:34 -08:00
/*regions=*/{{}, inferredReturnTypes)))
{1}.addTypes(inferredReturnTypes);
else
llvm::report_fatal_error("Failed to infer result type(s).");)",
opClass.getClassName(), builderOpState);
return;
}
switch (paramKind) {
case TypeParamKind::None:
return;
case TypeParamKind::Separate:
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
if (op.getResult(i).isOptional())
body << " if (" << resultNames[i] << ")\n ";
body << " " << builderOpState << ".addTypes(" << resultNames[i]
<< ");\n";
}
return;
case TypeParamKind::Collective:
body << " "
<< "assert(resultTypes.size() "
<< (op.getNumVariableLengthResults() == 0 ? "==" : ">=") << " "
<< (op.getNumResults() - op.getNumVariableLengthResults())
<< "u && \"mismatched number of results\");\n";
body << " " << builderOpState << ".addTypes(resultTypes);\n";
return;
}
llvm_unreachable("unhandled TypeParamKind");
};
bool canInferType =
op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0;
for (auto attrType : attrBuilderType) {
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
if (canInferType)
emit(attrType, TypeParamKind::None, /*inferType=*/true);
// Emit separate arg build with collective type, unless there is only one
// variadic result, in which case the above would have already generated
// the same build method.
if (!(op.getNumResults() == 1 && op.getResult(0).isVariableLength()))
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
}
}
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
// If this op has a variadic result, we cannot generate this builder because
// we don't know how many results to create.
if (op.getNumVariableLengthResults() != 0)
return;
int numResults = op.getNumResults();
// Signature
std::string params =
std::string("Builder *odsBuilder, OperationState &") + builderOpState +
", ValueRange operands, ArrayRef<NamedAttribute> attributes";
if (op.getNumVariadicRegions())
params += ", unsigned numRegions";
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
auto &body = m.body();
// Operands
body << " " << builderOpState << ".addOperands(operands);\n";
if (op.hasResizableOperandList())
body << formatv(" {0}.setOperandListToResizable();\n\n", builderOpState);
// Attributes
body << " " << builderOpState << ".addAttributes(attributes);\n";
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
body << llvm::formatv(
" for (unsigned i = 0; i != {0}; ++i)\n",
(op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
body << " (void)" << builderOpState << ".addRegion();\n";
}
// Result types
SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
body << " " << builderOpState << ".addTypes({"
<< llvm::join(resultTypes, ", ") << "});\n\n";
}
2020-02-28 10:59:34 -08:00
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
// TODO(jpienaar): Expand to support regions.
const char *params =
"Builder *odsBuilder, OperationState &{0}, "
"ValueRange operands, ArrayRef<NamedAttribute> attributes";
auto &m =
opClass.newMethod("void", "build", formatv(params, builderOpState).str(),
OpMethod::MP_Static);
auto &body = m.body();
body << formatv(R"(
2020-02-28 10:59:34 -08:00
SmallVector<Type, 2> inferredReturnTypes;
if (succeeded({0}::inferReturnTypes(odsBuilder->getContext(),
{1}.location, operands, attributes,
2020-02-28 10:59:34 -08:00
/*regions=*/{{}, inferredReturnTypes)))
build(odsBuilder, odsState, inferredReturnTypes, operands, attributes);
else
llvm::report_fatal_error("Failed to infer result type(s).");)",
opClass.getClassName(), builderOpState);
}
void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, TypeParamKind::None);
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
genCodeForAddingArgAndRegionForBuilder(m.body());
auto numResults = op.getNumResults();
if (numResults == 0)
return;
// Push all result types to the operation state
const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
std::string resultType =
formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
m.body() << " " << builderOpState << ".addTypes({" << resultType;
for (int i = 1; i != numResults; ++i)
m.body() << ", " << resultType;
m.body() << "});\n\n";
}
void OpEmitter::genUseAttrAsResultTypeBuilder() {
std::string params =
std::string("Builder *odsBuilder, OperationState &") + builderOpState +
", ValueRange operands, ArrayRef<NamedAttribute> attributes";
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
auto &body = m.body();
// Push all result types to the operation state
std::string resultType;
const auto &namedAttr = op.getAttribute(0);
body << " for (auto attr : attributes) {\n";
body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n";
if (namedAttr.attr.isTypeAttr()) {
resultType = "attr.second.cast<TypeAttr>().getValue()";
} else {
resultType = "attr.second.getType()";
}
// Operands
body << " " << builderOpState << ".addOperands(operands);\n";
if (op.hasResizableOperandList())
body << formatv(" {0}.setOperandListToResizable();\n\n", builderOpState);
// Attributes
body << " " << builderOpState << ".addAttributes(attributes);\n";
// Result types
SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
body << " " << builderOpState << ".addTypes({"
<< llvm::join(resultTypes, ", ") << "});\n";
body << " }\n";
}
void OpEmitter::genBuilder() {
// Handle custom builders if provided.
// TODO(antiagainst): Create wrapper class for OpBuilder to hide the native
// TableGen API calls here.
{
auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
if (listInit) {
for (Init *init : listInit->getValues()) {
Record *builderDef = cast<DefInit>(init)->getDef();
StringRef params = builderDef->getValueAsString("params");
StringRef body = builderDef->getValueAsString("body");
bool hasBody = !body.empty();
auto &method =
opClass.newMethod("void", "build", params, OpMethod::MP_Static,
/*declOnly=*/!hasBody);
if (hasBody)
method.body() << body;
}
}
if (op.skipDefaultBuilders()) {
if (!listInit || listInit->empty())
PrintFatalError(
op.getLoc(),
"default builders are skipped and no custom builders provided");
return;
}
}
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
// We generate three classes of builders here:
// 1. one having a stand-alone parameter for each operand / attribute, and
genSeparateArgParamBuilder();
// 2. one having an aggregated parameter for all result types / operands /
// attributes, and
genCollectiveParamBuilder();
// 3. one having a stand-alone parameter for each operand and attribute,
// use the first operand or attribute's type as all result types
// to facilitate different call patterns.
if (op.getNumVariableLengthResults() == 0) {
if (op.getTrait("OpTrait::SameOperandsAndResultType")) {
genUseOperandAsResultTypeSeparateParamBuilder();
genUseOperandAsResultTypeCollectiveParamBuilder();
}
if (op.getTrait("OpTrait::FirstAttrDerivedResultType"))
genUseAttrAsResultTypeBuilder();
}
}
void OpEmitter::genCollectiveParamBuilder() {
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariableLengthResults();
int numNonVariadicResults = numResults - numVariadicResults;
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariableLengthOperands();
int numNonVariadicOperands = numOperands - numVariadicOperands;
// Signature
std::string params = std::string("Builder *, OperationState &") +
builderOpState +
", ArrayRef<Type> resultTypes, ValueRange operands, "
"ArrayRef<NamedAttribute> attributes";
if (op.getNumVariadicRegions())
params += ", unsigned numRegions";
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
auto &body = m.body();
// Operands
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
body << " assert(operands.size()"
<< (numVariadicOperands != 0 ? " >= " : " == ")
<< numNonVariadicOperands
<< "u && \"mismatched number of parameters\");\n";
body << " " << builderOpState << ".addOperands(operands);\n";
if (op.hasResizableOperandList())
body << formatv(" {0}.setOperandListToResizable();\n\n", builderOpState);
// Attributes
body << " " << builderOpState << ".addAttributes(attributes);\n";
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
body << llvm::formatv(
" for (unsigned i = 0; i != {0}; ++i)\n",
(op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
body << " (void)" << builderOpState << ".addRegion();\n";
}
// Result types
if (numVariadicResults == 0 || numNonVariadicResults != 0)
body << " assert(resultTypes.size()"
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
<< "u && \"mismatched number of return types\");\n";
body << " " << builderOpState << ".addTypes(resultTypes);\n";
// Generate builder that infers type too.
2020-02-28 10:59:34 -08:00
// TODO(jpienaar): Subsume this with general checking if type can be inferred
// automatically.
// TODO(jpienaar): Expand to handle regions and successors.
if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0 &&
op.getNumSuccessors() == 0)
2020-02-28 10:59:34 -08:00
genInferredTypeCollectiveParamBuilder();
}
void OpEmitter::buildParamList(std::string &paramList,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
AttrParamKind attrParamKind) {
resultTypeNames.clear();
auto numResults = op.getNumResults();
resultTypeNames.reserve(numResults);
paramList = "Builder *odsBuilder, OperationState &";
paramList.append(builderOpState);
switch (typeParamKind) {
case TypeParamKind::None:
break;
case TypeParamKind::Separate: {
// Add parameters for all return types
for (int i = 0; i < numResults; ++i) {
const auto &result = op.getResult(i);
std::string resultName = std::string(result.name);
if (resultName.empty())
resultName = std::string(formatv("resultType{0}", i));
if (result.isOptional())
paramList.append(", /*optional*/Type ");
else if (result.isVariadic())
paramList.append(", ArrayRef<Type> ");
else
paramList.append(", Type ");
paramList.append(resultName);
resultTypeNames.emplace_back(std::move(resultName));
}
} break;
case TypeParamKind::Collective: {
paramList.append(", ArrayRef<Type> resultTypes");
resultTypeNames.push_back("resultTypes");
} break;
}
// Add parameters for all arguments (operands and attributes).
int numOperands = 0;
int numAttrs = 0;
int defaultValuedAttrStartIndex = op.getNumArgs();
if (attrParamKind == AttrParamKind::UnwrappedValue) {
// Calculate the start index from which we can attach default values in the
// builder declaration.
for (int i = op.getNumArgs() - 1; i >= 0; --i) {
auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
if (!namedAttr || !namedAttr->attr.hasDefaultValue())
break;
if (!canUseUnwrappedRawValue(namedAttr->attr))
break;
// Creating an APInt requires us to provide bitwidth, value, and
// signedness, which is complicated compared to others. Similarly
// for APFloat.
// TODO(b/144412160) Adjust the 'returnType' field of such attributes
// to support them.
StringRef retType = namedAttr->attr.getReturnType();
if (retType == "APInt" || retType == "APFloat")
break;
defaultValuedAttrStartIndex = i;
}
}
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands);
if (operand.isOptional())
paramList.append(", /*optional*/Value ");
else if (operand.isVariadic())
paramList.append(", ValueRange ");
else
paramList.append(", Value ");
paramList.append(getArgumentName(op, numOperands));
++numOperands;
} else {
const auto &namedAttr = op.getAttribute(numAttrs);
const auto &attr = namedAttr.attr;
paramList.append(", ");
if (attr.isOptional())
paramList.append("/*optional*/");
switch (attrParamKind) {
case AttrParamKind::WrappedAttr:
paramList.append(std::string(attr.getStorageType()));
break;
case AttrParamKind::UnwrappedValue:
if (canUseUnwrappedRawValue(attr)) {
paramList.append(std::string(attr.getReturnType()));
} else {
paramList.append(std::string(attr.getStorageType()));
}
break;
}
paramList.append(" ");
paramList.append(std::string(namedAttr.name));
// Attach default value if requested and possible.
if (attrParamKind == AttrParamKind::UnwrappedValue &&
i >= defaultValuedAttrStartIndex) {
bool isString = attr.getReturnType() == "StringRef";
paramList.append(" = ");
if (isString)
paramList.append("\"");
paramList.append(std::string(attr.getDefaultValue()));
if (isString)
paramList.append("\"");
}
++numAttrs;
}
}
/// Insert parameters for each successor.
for (const NamedSuccessor &succ : op.getSuccessors()) {
paramList += (succ.isVariadic() ? ", ArrayRef<Block *> " : ", Block *");
paramList += succ.name;
}
/// Insert parameters for variadic regions.
for (const NamedRegion &region : op.getRegions()) {
if (region.isVariadic())
paramList += llvm::formatv(", unsigned {0}Count", region.name).str();
}
}
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
bool isRawValueAttr) {
// Push all operands to the result.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
std::string argName = getArgumentName(op, i);
if (op.getOperand(i).isOptional())
body << " if (" << argName << ")\n ";
body << " " << builderOpState << ".addOperands(" << argName << ");\n";
}
if (op.hasResizableOperandList())
body << formatv(" {0}.setOperandListToResizable();\n", builderOpState);
// If the operation has the operand segment size attribute, add it here.
if (op.getTrait("OpTrait::AttrSizedOperandSegments")) {
body << " " << builderOpState
<< ".addAttribute(\"operand_segment_sizes\", "
"odsBuilder->getI32VectorAttr({";
llvm::interleaveComma(
llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
if (op.getOperand(i).isOptional())
body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
else if (op.getOperand(i).isVariadic())
body << "static_cast<int32_t>(" << getArgumentName(op, i)
<< ".size())";
else
body << "1";
});
body << "}));\n";
}
// Push all attributes to the result.
for (const auto &namedAttr : op.getAttributes()) {
auto &attr = namedAttr.attr;
if (!attr.isDerivedAttr()) {
bool emitNotNullCheck = attr.isOptional();
if (emitNotNullCheck) {
body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
}
if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
// If this is a raw value, then we need to wrap it in an Attribute
// instance.
FmtContext fctx;
fctx.withBuilder("(*odsBuilder)");
std::string builderTemplate =
std::string(attr.getConstBuilderTemplate());
// For StringAttr, its constant builder call will wrap the input in
// quotes, which is correct for normal string literals, but incorrect
// here given we use function arguments. So we need to strip the
// wrapping quotes.
if (StringRef(builderTemplate).contains("\"$0\""))
builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
std::string value =
std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
namedAttr.name, value);
} else {
body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
namedAttr.name);
}
if (emitNotNullCheck) {
body << " }\n";
}
}
}
// Create the correct number of regions.
for (const NamedRegion &region : op.getRegions()) {
if (region.isVariadic())
body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
region.name);
body << " (void)" << builderOpState << ".addRegion();\n";
}
// Push all successors to the result.
for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
namedSuccessor.name);
}
}
void OpEmitter::genCanonicalizerDecls() {
if (!def.getValueAsBit("hasCanonicalizer"))
return;
const char *const params =
"OwningRewritePatternList &results, MLIRContext *context";
opClass.newMethod("void", "getCanonicalizationPatterns", params,
OpMethod::MP_Static, /*declOnly=*/true);
}
void OpEmitter::genFolderDecls() {
bool hasSingleResult =
op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
const char *const params = "ArrayRef<Attribute> operands";
opClass.newMethod("OpFoldResult", "fold", params, OpMethod::MP_None,
/*declOnly=*/true);
} else {
const char *const params = "ArrayRef<Attribute> operands, "
"SmallVectorImpl<OpFoldResult> &results";
opClass.newMethod("LogicalResult", "fold", params, OpMethod::MP_None,
/*declOnly=*/true);
}
}
}
void OpEmitter::genOpInterfaceMethods() {
for (const auto &trait : op.getTraits()) {
auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait);
if (!opTrait || !opTrait->shouldDeclareMethods())
continue;
auto interface = opTrait->getOpInterface();
for (auto method : interface.getMethods()) {
// Don't declare if the method has a body or a default implementation.
if (method.getBody() || method.getDefaultImplementation())
continue;
std::string args;
llvm::raw_string_ostream os(args);
llvm::interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
os << arg.type << " " << arg.name;
});
opClass.newMethod(method.getReturnType(), method.getName(), os.str(),
method.isStatic() ? OpMethod::MP_Static
: OpMethod::MP_None,
/*declOnly=*/true);
}
}
}
void OpEmitter::genSideEffectInterfaceMethods() {
enum EffectKind { Operand, Result, Static };
struct EffectLocation {
/// The effect applied.
SideEffect effect;
/// The index if the kind is either operand or result.
unsigned index : 30;
/// The kind of the location.
unsigned kind : 2;
};
StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
auto resolveDecorators = [&](Operator::var_decorator_range decorators,
unsigned index, unsigned kind) {
for (auto decorator : decorators)
if (SideEffect *effect = dyn_cast<SideEffect>(&decorator))
interfaceEffects[effect->getBaseEffectName()].push_back(
EffectLocation{*effect, index, kind});
};
// Collect effects that were specified via:
/// Traits.
for (const auto &trait : op.getTraits()) {
const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
if (!opTrait)
continue;
auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
for (auto decorator : opTrait->getEffects())
effects.push_back(EffectLocation{cast<SideEffect>(decorator),
/*index=*/0, EffectKind::Static});
}
/// Operands.
for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
if (op.getArg(i).is<NamedTypeConstraint *>()) {
resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
++operandIt;
}
}
/// Results.
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
for (auto &it : interfaceEffects) {
auto effectsParam =
llvm::formatv(
"SmallVectorImpl<SideEffects::EffectInstance<{0}>> &effects",
it.first())
.str();
// Generate the 'getEffects' method.
auto &getEffects = opClass.newMethod("void", "getEffects", effectsParam);
auto &body = getEffects.body();
// Add effect instances for each of the locations marked on the operation.
for (auto &location : it.second) {
if (location.kind != EffectKind::Static) {
body << " for (Value value : getODS"
<< (location.kind == EffectKind::Operand ? "Operands" : "Results")
<< "(" << location.index << "))\n ";
}
body << " effects.emplace_back(" << location.effect.getName()
<< "::get()";
// If the effect isn't static, it has a specific value attached to it.
if (location.kind != EffectKind::Static)
body << ", value";
body << ", " << location.effect.getResource() << "::get());\n";
}
}
}
void OpEmitter::genParser() {
if (!hasStringAttribute(def, "parser") ||
hasStringAttribute(def, "assemblyFormat"))
return;
auto &method = opClass.newMethod(
"ParseResult", "parse", "OpAsmParser &parser, OperationState &result",
OpMethod::MP_Static);
FmtContext fctx;
fctx.addSubst("cppClass", opClass.getClassName());
auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
method.body() << " " << tgfmt(parser, &fctx);
}
void OpEmitter::genPrinter() {
if (hasStringAttribute(def, "assemblyFormat"))
return;
auto valueInit = def.getValueInit("printer");
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
if (!codeInit)
return;
auto &method = opClass.newMethod("void", "print", "OpAsmPrinter &p");
FmtContext fctx;
fctx.addSubst("cppClass", opClass.getClassName());
auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r");
method.body() << " " << tgfmt(printer, &fctx);
}
void OpEmitter::genVerifier() {
auto valueInit = def.getValueInit("verifier");
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
auto &body = method.body();
const char *checkAttrSizedValueSegmentsCode = R"(
auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
auto numElements = sizeAttr.getType().cast<ShapedType>().getNumElements();
if (numElements != {1}) {{
return emitOpError("'{0}' attribute for specifying {2} segments "
"must have {1} elements");
}
)";
// Verify a few traits first so that we can use
// getODSOperands()/getODSResults() in the rest of the verifier.
for (auto &trait : op.getTraits()) {
if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") {
body << formatv(checkAttrSizedValueSegmentsCode,
"operand_segment_sizes", op.getNumOperands(),
"operand");
} else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") {
body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
op.getNumResults(), "result");
}
}
}
// Populate substitutions for attributes and named operands and results.
for (const auto &namedAttr : op.getAttributes())
verifyCtx.addSubst(namedAttr.name,
formatv("this->getAttr(\"{0}\")", namedAttr.name));
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
auto &value = op.getOperand(i);
if (value.name.empty())
continue;
if (value.isVariadic())
verifyCtx.addSubst(value.name, formatv("this->getODSOperands({0})", i));
else
verifyCtx.addSubst(value.name,
formatv("(*this->getODSOperands({0}).begin())", i));
}
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
auto &value = op.getResult(i);
if (value.name.empty())
continue;
if (value.isVariadic())
verifyCtx.addSubst(value.name, formatv("this->getODSResults({0})", i));
else
verifyCtx.addSubst(value.name,
formatv("(*this->getODSResults({0}).begin())", i));
}
// Verify the attributes have the correct type.
for (const auto &namedAttr : op.getAttributes()) {
const auto &attr = namedAttr.attr;
if (attr.isDerivedAttr())
continue;
auto attrName = namedAttr.name;
// Prefix with `tblgen_` to avoid hiding the attribute accessor.
auto varName = tblgenNamePrefix + attrName;
body << formatv(" auto {0} = this->getAttr(\"{1}\");\n", varName,
attrName);
bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
Spike to define real math ops and lowering of one variant of add to corresponding integer ops. The only reason in starting with a fixedpoint add is that it is the absolute simplest variant and illustrates the level of abstraction I'm aiming for. The overall flow would be: 1. Determine quantization parameters (out of scope of this cl). 2. Source dialect rules to lower supported math ops to the quantization dialect (out of scope of this cl). 3. Quantization passes: [-quant-convert-const, -quant-lower-uniform-real-math, -quant-lower-unsupported-to-float] (the last one not implemented yet) 4. Target specific lowering of the integral arithmetic ops (roughly at the level of gemmlowp) to more fundamental operations (i.e. calls to gemmlowp, simd instructions, DSP instructions, etc). How I'm doing this should facilitate implementation of just about any kind of backend except TFLite, which has a very course, adhoc surface area for its quantized kernels. Options there include (I'm not taking an opinion on this - just trying to provide options): a) Not using any of this: just match q/dbarrier + tf math ops to the supported TFLite quantized op set. b) Implement the more fundamental integer math ops on TFLite and convert to those instead of the current op set. Note that I've hand-waved over the process of choosing appropriate quantization parameters. Getting to that next. As you can see, different implementations will likely have different magic combinations of specific math support, and we will need the target system that has been discussed for some of the esoteric cases (i.e. many DSPs only support POT fixedpoint). Two unrelated changes to the overall goal of this CL and can be broken out of desired: - Adding optional attribute support to TabelGen - Allowing TableGen native rewrite hooks to return nullptr, signalling that no rewrite has been done. PiperOrigin-RevId: 235267229
2019-02-22 15:11:00 -08:00
if (allowMissingAttr) {
// If the attribute has a default value, then only verify the predicate if
// set. This does effectively assume that the default value is valid.
// TODO: verify the debug value is valid (perhaps in debug mode only).
body << " if (" << varName << ") {\n";
} else {
body << " if (!" << varName
<< ") return emitOpError(\"requires attribute '" << attrName
<< "'\");\n {\n";
}
auto attrPred = attr.getPredicate();
if (!attrPred.isNull()) {
body << tgfmt(
" if (!($0)) return emitOpError(\"attribute '$1' "
"failed to satisfy constraint: $2\");\n",
/*ctx=*/nullptr,
tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)),
attrName, attr.getDescription());
}
body << " }\n";
}
genOperandResultVerifier(body, op.getOperands(), "operand");
genOperandResultVerifier(body, op.getResults(), "result");
for (auto &trait : op.getTraits()) {
if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
body << tgfmt(" if (!($0)) {\n "
"return emitOpError(\"failed to verify that $1\");\n }\n",
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
t->getDescription());
}
}
genRegionVerifier(body);
genSuccessorVerifier(body);
if (hasCustomVerify) {
FmtContext fctx;
fctx.addSubst("cppClass", opClass.getClassName());
auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r");
body << " " << tgfmt(printer, &fctx);
} else {
body << " return mlir::success();\n";
}
}
void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
Operator::value_range values,
StringRef valueKind) {
FmtContext fctx;
body << " {\n";
body << " unsigned index = 0; (void)index;\n";
for (auto staticValue : llvm::enumerate(values)) {
bool hasPredicate = staticValue.value().hasPredicate();
bool isOptional = staticValue.value().isOptional();
if (!hasPredicate && !isOptional)
continue;
body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1),
staticValue.index());
// If the constraint is optional check that the value group has at most 1
// value.
if (isOptional) {
body << formatv(" if (valueGroup{0}.size() > 1)\n"
" return emitOpError(\"{1} group starting at #\") "
"<< index << \" requires 0 or 1 element, but found \" << "
"valueGroup{0}.size();\n",
staticValue.index(), valueKind);
}
// Otherwise, if there is no predicate there is nothing left to do.
if (!hasPredicate)
continue;
// Emit a loop to check all the dynamic values in the pack.
body << " for (Value v : valueGroup" << staticValue.index() << ") {\n";
auto constraint = staticValue.value().constraint;
body << " (void)v;\n"
<< " if (!("
<< tgfmt(constraint.getConditionTemplate(),
&fctx.withSelf("v.getType()"))
<< ")) {\n"
<< formatv(" return emitOpError(\"{0} #\") << index "
"<< \" must be {1}, but got \" << v.getType();\n",
valueKind, constraint.getDescription())
<< " }\n" // if
<< " ++index;\n"
<< " }\n"; // for
}
body << " }\n";
}
void OpEmitter::genRegionVerifier(OpMethodBody &body) {
// If we have no regions, there is nothing more to do.
unsigned numRegions = op.getNumRegions();
if (numRegions == 0)
return;
body << "{\n";
body << " unsigned index = 0; (void)index;\n";
for (unsigned i = 0; i < numRegions; ++i) {
const auto &region = op.getRegion(i);
if (region.constraint.getPredicate().isNull())
continue;
body << " for (Region &region : ";
body << formatv(
region.isVariadic()
? "{0}()"
: "MutableArrayRef<Region>(this->getOperation()->getRegion({1}))",
region.name, i);
body << ") {\n";
auto constraint = tgfmt(region.constraint.getConditionTemplate(),
&verifyCtx.withSelf("region"))
.str();
body << formatv(" (void)region;\n"
" if (!({0})) {\n "
"return emitOpError(\"region #\") << index << \" {1}"
"failed to "
"verify constraint: {2}\";\n }\n",
constraint,
region.name.empty() ? "" : "('" + region.name + "') ",
region.constraint.getDescription())
<< " ++index;\n"
<< " }\n";
}
body << " }\n";
}
void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
// If we have no successors, there is nothing more to do.
unsigned numSuccessors = op.getNumSuccessors();
if (numSuccessors == 0)
return;
body << "{\n";
body << " unsigned index = 0; (void)index;\n";
for (unsigned i = 0; i < numSuccessors; ++i) {
const auto &successor = op.getSuccessor(i);
if (successor.constraint.getPredicate().isNull())
continue;
body << " for (Block *successor : ";
body << formatv(successor.isVariadic() ? "{0}()"
: "ArrayRef<Block *>({0}())",
successor.name);
body << ") {\n";
auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
&verifyCtx.withSelf("successor"))
.str();
body << formatv(" (void)successor;\n"
" if (!({0})) {\n "
"return emitOpError(\"successor #\") << index << \"('{1}') "
"failed to "
"verify constraint: {2}\";\n }\n",
constraint, successor.name,
successor.constraint.getDescription())
<< " ++index;\n"
<< " }\n";
}
body << " }\n";
}
/// Add a size count trait to the given operation class.
static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
int numTotal, int numVariadic) {
if (numVariadic != 0) {
if (numTotal == numVariadic)
opClass.addTrait("OpTrait::Variadic" + traitKind + "s");
else
opClass.addTrait("OpTrait::AtLeastN" + traitKind + "s<" +
Twine(numTotal - numVariadic) + ">::Impl");
return;
}
switch (numTotal) {
case 0:
opClass.addTrait("OpTrait::Zero" + traitKind);
break;
case 1:
opClass.addTrait("OpTrait::One" + traitKind);
break;
default:
opClass.addTrait("OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
">::Impl");
break;
}
}
void OpEmitter::genTraits() {
// Add region size trait.
unsigned numRegions = op.getNumRegions();
unsigned numVariadicRegions = op.getNumVariadicRegions();
addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
// Add result size trait.
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariableLengthResults();
addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
// Add successor size trait.
unsigned numSuccessors = op.getNumSuccessors();
unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
// Add variadic size trait and normal op traits.
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariableLengthOperands();
// Add operand size trait.
if (numVariadicOperands != 0) {
if (numOperands == numVariadicOperands)
opClass.addTrait("OpTrait::VariadicOperands");
else
opClass.addTrait("OpTrait::AtLeastNOperands<" +
Twine(numOperands - numVariadicOperands) + ">::Impl");
} else {
switch (numOperands) {
case 0:
opClass.addTrait("OpTrait::ZeroOperands");
break;
case 1:
opClass.addTrait("OpTrait::OneOperand");
break;
default:
opClass.addTrait("OpTrait::NOperands<" + Twine(numOperands) + ">::Impl");
break;
}
}
// Add the native and interface traits.
for (const auto &trait : op.getTraits()) {
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
opClass.addTrait(opTrait->getTrait());
else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
opClass.addTrait(opTrait->getTrait());
}
}
void OpEmitter::genOpNameGetter() {
auto &method = opClass.newMethod("StringRef", "getOperationName",
/*params=*/"", OpMethod::MP_Static);
method.body() << " return \"" << op.getOperationName() << "\";\n";
}
void OpEmitter::genOpAsmInterface() {
// If the user only has one results or specifically added the Asm trait,
// then don't generate it for them. We specifically only handle multi result
// operations, because the name of a single result in the common case is not
// interesting(generally 'result'/'output'/etc.).
// TODO: We could also add a flag to allow operations to opt in to this
// generation, even if they only have a single operation.
int numResults = op.getNumResults();
if (numResults <= 1 || op.getTrait("OpAsmOpInterface::Trait"))
return;
SmallVector<StringRef, 4> resultNames(numResults);
for (int i = 0; i != numResults; ++i)
resultNames[i] = op.getResultName(i);
// Don't add the trait if none of the results have a valid name.
if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
return;
opClass.addTrait("OpAsmOpInterface::Trait");
// Generate the right accessor for the number of results.
auto &method = opClass.newMethod("void", "getAsmResultNames",
"OpAsmSetValueNameFn setNameFn");
auto &body = method.body();
for (int i = 0; i != numResults; ++i) {
body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n"
<< " if (!llvm::empty(resultGroup" << i << "))\n"
<< " setNameFn(*resultGroup" << i << ".begin(), \""
<< resultNames[i] << "\");\n";
}
}
//===----------------------------------------------------------------------===//
// OpOperandAdaptor emitter
//===----------------------------------------------------------------------===//
namespace {
// Helper class to emit Op operand adaptors to an output stream. Operand
// adaptors are wrappers around ArrayRef<Value> that provide named operand
// getters identical to those defined in the Op.
class OpOperandAdaptorEmitter {
public:
static void emitDecl(const Operator &op, raw_ostream &os);
static void emitDef(const Operator &op, raw_ostream &os);
private:
explicit OpOperandAdaptorEmitter(const Operator &op);
Class adapterClass;
};
} // end namespace
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
: adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
adapterClass.newField("ArrayRef<Value>", "tblgen_operands");
auto &constructor = adapterClass.newConstructor("ArrayRef<Value> values");
constructor.body() << " tblgen_operands = values;\n";
generateNamedOperandGetters(op, adapterClass,
/*rangeType=*/"ArrayRef<Value>",
/*rangeBeginCall=*/"tblgen_operands.begin()",
/*rangeSizeCall=*/"tblgen_operands.size()",
/*getOperandCallPattern=*/"tblgen_operands[{0}]");
}
void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os);
}
void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os);
}
// Emits the opcode enum and op classes.
static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
bool emitDecl) {
IfDefScope scope("GET_OP_CLASSES", os);
// First emit forward declaration for each class, this allows them to refer
// to each others in traits for example.
if (emitDecl) {
for (auto *def : defs) {
Operator op(*def);
os << "class " << op.getCppClassName() << ";\n";
}
}
for (auto *def : defs) {
Operator op(*def);
const auto *attrSizedOperands =
op.getTrait("OpTrait::AttrSizedOperandSegments");
if (emitDecl) {
os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
// We cannot generate the operand adaptor class if operand getters depend
// on an attribute.
if (!attrSizedOperands)
OpOperandAdaptorEmitter::emitDecl(op, os);
OpEmitter::emitDecl(op, os);
} else {
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
if (!attrSizedOperands)
OpOperandAdaptorEmitter::emitDef(op, os);
OpEmitter::emitDef(op, os);
}
}
}
// Emits a comma-separated list of the ops.
static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
IfDefScope scope("GET_OP_LIST", os);
llvm::interleave(
// TODO: We are constructing the Operator wrapper instance just for
// getting it's qualified class name here. Reduce the overhead by having a
// lightweight version of Operator class just for that purpose.
defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
[&os]() { os << ",\n"; });
}
static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Declarations", os);
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
emitOpClasses(defs, os, /*emitDecl=*/true);
return false;
}
static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Definitions", os);
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
emitOpList(defs, os);
emitOpClasses(defs, os, /*emitDecl=*/false);
return false;
}
static mlir::GenRegistration
genOpDecls("gen-op-decls", "Generate op declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitOpDecls(records, os);
});
static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
[](const RecordKeeper &records,
raw_ostream &os) {
return emitOpDefs(records, os);
});