llvm-project/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Lei Zhang 1ac49ce0bd [TableGen] Remove asserts for attributes in aggregate builders
Attributes can have default values or be optional. Checking the validity of
    attributes in aggregate builder should consider that. And to be accurate,
    we should check all required attributes are indeed provided in the list.
    This is actually duplicating the work done by verifier. Checking the validity
    of attributes should be the responsiblity of verifiers. This CL removes
    the assertion for attributes in aggregate builders for the above reason.
    (Assertions for operands/results are still kept since they are trivial.)

    Also added more tests for aggregate builders.

--

PiperOrigin-RevId: 241746059
2019-04-03 09:30:49 -07:00

999 lines
33 KiB
C++

//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// OpDefinitionsGen uses the description of operations to generate C++
// definitions for ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
using mlir::tblgen::Operator;
static const char *const builderOpState = "tblgen_state";
static const char *const generatedArgName = "_arg";
static const char *const opCommentHeader = R"(
//===----------------------------------------------------------------------===//
// {0} {1}
//===----------------------------------------------------------------------===//
)";
//===----------------------------------------------------------------------===//
// Utility structs and functions
//===----------------------------------------------------------------------===//
// Variation of method in FormatVariadic.h which takes a StringRef as input
// instead.
template <typename... Ts>
inline auto formatv(StringRef fmt, Ts &&... vals) -> formatv_object<decltype(
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...))> {
using ParamTuple = decltype(
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...));
return llvm::formatv_object<ParamTuple>(
fmt,
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...));
}
// Returns whether the record has a value of the given name that can be returned
// via getValueAsString.
static inline bool hasStringAttribute(const Record &record,
StringRef fieldName) {
auto valueInit = record.getValueInit(fieldName);
return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit);
}
// Returns the given `op`'s qualified C++ class name.
static std::string getOpQualClassName(const Record &op) {
SmallVector<StringRef, 2> splittedName;
llvm::SplitString(op.getName(), splittedName, "_");
return llvm::join(splittedName, "::");
}
static std::string getArgumentName(const Operator &op, int index) {
const auto &operand = op.getOperand(index);
if (!operand.name.empty())
return operand.name;
else
return formatv("{0}_{1}", generatedArgName, index);
}
namespace {
// Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope {
public:
IfDefScope(StringRef name, raw_ostream &os) : name(name), os(os) {
os << "#ifdef " << name << "\n"
<< "#undef " << name << "\n\n";
}
~IfDefScope() { os << "\n#endif // " << name << "\n\n"; }
private:
StringRef name;
raw_ostream &os;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Classes for C++ code emission
//===----------------------------------------------------------------------===//
// We emit the op declaration and definition into separate files: *Ops.h.inc
// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
//
// In order to do this split, we need to track method signature and
// implementation logic separately. Signature information is used for both
// declaration and definition, while implementation logic is only for
// definition. So we have the following classes for C++ code emission.
namespace {
// Class for holding the signature of an op's method for C++ code emission
class OpMethodSignature {
public:
OpMethodSignature(StringRef retType, StringRef name, StringRef params);
// Writes the signature as a method declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const;
// Writes the signature as the start of a method definition to the given `os`.
// `namePrefix` is the prefix to be prepended to the method name (typically
// namespaces for qualifying the method definition).
void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
private:
// Returns true if the given C++ `type` ends with '&' or '*'.
static bool endsWithRefOrPtr(StringRef type);
std::string returnType;
std::string methodName;
std::string parameters;
};
// Class for holding the body of an op's method for C++ code emission
class OpMethodBody {
public:
explicit OpMethodBody(bool declOnly);
OpMethodBody &operator<<(Twine content);
OpMethodBody &operator<<(int content);
void writeTo(raw_ostream &os) const;
private:
// Whether this class should record method body.
bool isEffective;
std::string body;
};
// Class for holding an op's method for C++ code emission
class OpMethod {
public:
// Properties (qualifiers) of class methods. Bitfield is used here to help
// querying properties.
enum Property {
MP_None = 0x0,
MP_Static = 0x1, // Static method
};
OpMethod(StringRef retType, StringRef name, StringRef params,
Property property, bool declOnly);
OpMethodSignature &signature();
OpMethodBody &body();
// Returns true if this is a static method.
bool isStatic() const;
// Writes the method as a declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const;
// Writes the method as a definition to the given `os`. `namePrefix` is the
// prefix to be prepended to the method name (typically namespaces for
// qualifying the method definition).
void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
private:
Property properties;
// Whether this method only contains a declaration.
bool isDeclOnly;
OpMethodSignature methodSignature;
OpMethodBody methodBody;
};
// Class for holding an op for C++ code emission
class OpClass {
public:
explicit OpClass(StringRef name);
// Adds an op trait.
void addTrait(Twine trait);
// Creates a new method in this op's class.
OpMethod &newMethod(StringRef retType, StringRef name, StringRef params = "",
OpMethod::Property = OpMethod::MP_None,
bool declOnly = false);
// Writes this op's class as a declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const;
// Writes the method definitions in this op's class to the given `os`.
void writeDefTo(raw_ostream &os) const;
private:
std::string className;
SmallVector<std::string, 4> traits;
SmallVector<OpMethod, 8> methods;
};
} // end anonymous namespace
OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name,
StringRef params)
: returnType(retType), methodName(name), parameters(params) {}
void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
os << returnType << (endsWithRefOrPtr(returnType) ? "" : " ") << methodName
<< "(" << parameters << ")";
}
void OpMethodSignature::writeDefTo(raw_ostream &os,
StringRef namePrefix) const {
// We need to remove the default values for parameters in method definition.
// TODO(antiagainst): We are using '=' and ',' as delimiters for parameter
// initializers. This is incorrect for initializer list with more than one
// element. Change to a more robust approach.
auto removeParamDefaultValue = [](StringRef params) {
std::string result;
std::pair<StringRef, StringRef> parts;
while (!params.empty()) {
parts = params.split("=");
result.append(result.empty() ? "" : ", ");
result.append(parts.first);
params = parts.second.split(",").second;
}
return result;
};
os << returnType << (endsWithRefOrPtr(returnType) ? "" : " ") << namePrefix
<< (namePrefix.empty() ? "" : "::") << methodName << "("
<< removeParamDefaultValue(parameters) << ")";
}
bool OpMethodSignature::endsWithRefOrPtr(StringRef type) {
return type.endswith("&") || type.endswith("*");
};
OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
OpMethodBody &OpMethodBody::operator<<(Twine content) {
if (isEffective)
body.append(content.str());
return *this;
}
OpMethodBody &OpMethodBody::operator<<(int content) {
if (isEffective)
body.append(std::to_string(content));
return *this;
}
void OpMethodBody::writeTo(raw_ostream &os) const {
os << body;
if (body.empty() || body.back() != '\n')
os << "\n";
}
OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params,
OpMethod::Property property, bool declOnly)
: properties(property), isDeclOnly(declOnly),
methodSignature(retType, name, params), methodBody(declOnly) {}
OpMethodSignature &OpMethod::signature() { return methodSignature; }
OpMethodBody &OpMethod::body() { return methodBody; }
bool OpMethod::isStatic() const { return properties & MP_Static; }
void OpMethod::writeDeclTo(raw_ostream &os) const {
os.indent(2);
if (isStatic())
os << "static ";
methodSignature.writeDeclTo(os);
os << ";";
}
void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
if (isDeclOnly)
return;
methodSignature.writeDefTo(os, namePrefix);
os << " {\n";
methodBody.writeTo(os);
os << "}";
}
OpClass::OpClass(StringRef name) : className(name) {}
// Adds the given trait to this op. Prefixes "OpTrait::" to `trait` implicitly.
void OpClass::addTrait(Twine trait) {
traits.push_back(("OpTrait::" + trait).str());
}
OpMethod &OpClass::newMethod(StringRef retType, StringRef name,
StringRef params, OpMethod::Property property,
bool declOnly) {
methods.emplace_back(retType, name, params, property, declOnly);
return methods.back();
}
void OpClass::writeDeclTo(raw_ostream &os) const {
os << "class " << className << " : public Op<" << className;
for (const auto &trait : traits)
os << ", " << trait;
os << "> {\npublic:\n";
os << " using Op::Op;\n";
for (const auto &method : methods) {
method.writeDeclTo(os);
os << "\n";
}
os << "};";
}
void OpClass::writeDefTo(raw_ostream &os) const {
for (const auto &method : methods) {
method.writeDefTo(os, className);
os << "\n\n";
}
}
//===----------------------------------------------------------------------===//
// Op emitter
//===----------------------------------------------------------------------===//
namespace {
// Helper class to emit a record into the given output stream.
class OpEmitter {
public:
static void emitDecl(const Record &def, raw_ostream &os);
static void emitDef(const Record &def, raw_ostream &os);
private:
OpEmitter(const Record &def);
void emitDecl(raw_ostream &os);
void emitDef(raw_ostream &os);
// Generates getters for the attributes.
void genAttrGetters();
// Generates getters for named operands.
void genNamedOperandGetters();
// Generates getters for named results.
void genNamedResultGetters();
// Generates builder method for the operation.
void genBuilder();
// Generates canonicalizer declaration for the operation.
void genCanonicalizerDecls();
// Generates the folder declaration for the operation.
void genFolderDecls();
// Generates the parser for the operation.
void genParser();
// Generates the printer for the operation.
void genPrinter();
// Generates verify method for the operation.
void genVerifier();
// Generates the traits used by the object.
void genTraits();
// Generates the build() method that takes each result-type/operand/attribute
// as a stand-alone parameter. Using the first operand's type as all result
// types if `useOperandType` is true. Using the first attribute's type as all
// result types if `useAttrType` true. Don't set `useOperandType` and
// `useAttrType` at the same time.
void genStandaloneParamBuilder(bool useOperandType, bool useAttrType);
void genOpNameGetter();
// The TableGen record for this op.
const Record &def;
// The wrapper operator class for querying information from this op.
Operator op;
// The C++ code builder for this op
OpClass opClass;
};
} // end anonymous namespace
OpEmitter::OpEmitter(const Record &def)
: def(def), op(def), opClass(op.getCppClassName()) {
genTraits();
// Generate C++ code for various op methods. The order here determines the
// methods in the generated file.
genOpNameGetter();
genNamedOperandGetters();
genNamedResultGetters();
genAttrGetters();
genBuilder();
genParser();
genPrinter();
genVerifier();
genCanonicalizerDecls();
genFolderDecls();
}
void OpEmitter::emitDecl(const Record &def, raw_ostream &os) {
OpEmitter(def).emitDecl(os);
}
void OpEmitter::emitDef(const Record &def, raw_ostream &os) {
OpEmitter(def).emitDef(os);
}
void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
void OpEmitter::genAttrGetters() {
for (auto &namedAttr : op.getAttributes()) {
auto name = namedAttr.getName();
const auto &attr = namedAttr.attr;
// Determine the name of the attribute getter. The name matches the
// attribute name excluding dialect prefix.
StringRef getter = name;
auto it = getter.split('.');
if (!it.second.empty())
getter = it.second;
auto &method = opClass.newMethod(attr.getReturnType(), getter,
/*params=*/"");
// Emit the derived attribute body.
if (attr.isDerivedAttr()) {
method.body() << " " << attr.getDerivedCodeBody() << "\n";
continue;
}
// Emit normal emitter.
// Return the queried attribute with the correct return type.
std::string attrVal =
formatv("this->getAttr(\"{1}\").dyn_cast_or_null<{0}>()",
attr.getStorageType(), name);
method.body() << " auto attr = " << attrVal << ";\n";
if (attr.hasDefaultValue()) {
// Returns the default value if not set.
// TODO: this is inefficient, we are recreating the attribute for every
// call. This should be set instead.
method.body() << " if (!attr)\n"
" return "
<< formatv(attr.getConvertFromStorageCall(),
formatv(attr.getDefaultValueTemplate(),
"mlir::Builder(this->getContext())"))
<< ";\n";
}
method.body() << " return "
<< formatv(attr.getConvertFromStorageCall(), "attr") << ";\n";
}
}
void OpEmitter::genNamedOperandGetters() {
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
if (!operand.constraint.isVariadic()) {
auto &m = opClass.newMethod("Value *", operand.name);
m.body() << " return this->getOperation()->getOperand(" << i << ");\n";
} else {
assert(i + 1 == e && "only the last operand can be variadic");
const char *const code = R"(
assert(getOperation()->getNumOperands() >= {0});
return {std::next(operand_begin(), {0}), operand_end()};
)";
auto &m = opClass.newMethod("Operation::operand_range", operand.name);
m.body() << formatv(code, i);
}
}
}
void OpEmitter::genNamedResultGetters() {
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
const auto &result = op.getResult(i);
if (result.constraint.isVariadic() || result.name.empty())
continue;
auto &m = opClass.newMethod("Value *", result.name);
m.body() << " return this->getOperation()->getResult(" << i << ");\n";
}
}
void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
bool useAttrType) {
if (useOperandType && useAttrType) {
PrintFatalError(def.getLoc(),
"Op definition has both 'SameOperandsAndResultType' and "
"'FirstAttrIsResultType' trait specified.");
}
auto numResults = op.getNumResults();
llvm::SmallVector<std::string, 4> resultNames;
resultNames.reserve(numResults);
std::string paramList = "Builder *, OperationState *";
paramList.append(builderOpState);
// Emit parameters for all return types
if (!useOperandType && !useAttrType) {
for (unsigned i = 0; i != numResults; ++i) {
std::string resultName = op.getResultName(i);
if (resultName.empty())
resultName = formatv("resultType{0}", i);
bool isVariadic = op.getResultTypeConstraint(i).isVariadic();
paramList.append(isVariadic ? ", ArrayRef<Type> " : ", Type ");
paramList.append(resultName);
resultNames.emplace_back(std::move(resultName));
}
}
// Emit parameters for all arguments (operands and attributes).
int numOperands = 0;
int numAttrs = 0;
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) {
auto &operand = op.getOperand(numOperands);
paramList.append(operand.constraint.isVariadic() ? ", ArrayRef<Value *> "
: ", Value *");
paramList.append(getArgumentName(op, numOperands));
++numOperands;
} else {
// TODO(antiagainst): Support default initializer for attributes
const auto &namedAttr = op.getAttribute(numAttrs);
const auto &attr = namedAttr.attr;
paramList.append(", ");
if (attr.isOptional())
paramList.append("/*optional*/");
paramList.append(
(attr.getStorageType() + Twine(" ") + namedAttr.name).str());
++numAttrs;
}
}
if (numOperands + numAttrs != op.getNumArgs())
return PrintFatalError(
"op arguments must be either operands or attributes");
auto &method =
opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
// Push all result types to the result
if (numResults > 0) {
if (!useOperandType && !useAttrType) {
bool hasVariadicResult = op.hasVariadicResult();
int numNonVariadicResults =
numResults - static_cast<int>(hasVariadicResult);
if (numNonVariadicResults > 0) {
method.body() << " " << builderOpState << "->addTypes({"
<< resultNames.front();
for (int i = 1; i < numNonVariadicResults; ++i) {
method.body() << ", " << resultNames[i];
}
method.body() << "});\n";
}
if (hasVariadicResult) {
method.body() << " " << builderOpState << "->addTypes("
<< resultNames.back() << ");\n";
}
} else {
std::string resultType;
if (useAttrType) {
const auto &namedAttr = op.getAttribute(0);
if (namedAttr.attr.isTypeAttr()) {
resultType = formatv("{0}.getValue()", namedAttr.name);
} else {
resultType = formatv("{0}.getType()", namedAttr.name);
}
} else {
resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str();
}
method.body() << " " << builderOpState << "->addTypes({" << resultType;
for (unsigned i = 1; i != numResults; ++i)
method.body() << ", " << resultType;
method.body() << "});\n\n";
}
}
// Push all operands to the result
bool hasVariadicOperand = op.hasVariadicOperand();
int numNonVariadicOperands =
numOperands - static_cast<int>(hasVariadicOperand);
if (numNonVariadicOperands > 0) {
method.body() << " " << builderOpState << "->addOperands({"
<< getArgumentName(op, 0);
for (int i = 1; i < numNonVariadicOperands; ++i) {
method.body() << ", " << getArgumentName(op, i);
}
method.body() << "});\n";
}
if (hasVariadicOperand) {
method.body() << " " << builderOpState << "->addOperands("
<< getArgumentName(op, numOperands - 1) << ");\n";
}
// Push all attributes to the result
for (const auto &namedAttr : op.getAttributes()) {
if (!namedAttr.attr.isDerivedAttr()) {
bool emitNotNullCheck = namedAttr.attr.isOptional();
if (emitNotNullCheck) {
method.body() << formatv(" if ({0}) ", namedAttr.name) << "{\n";
}
method.body() << formatv(" {0}->addAttribute(\"{1}\", {2});\n",
builderOpState, namedAttr.getName(),
namedAttr.name);
if (emitNotNullCheck) {
method.body() << " }\n";
}
}
}
}
void OpEmitter::genBuilder() {
// Handle custom builders if provided.
// TODO(antiagainst): Create wrapper class for OpBuilder to hide the native
// TableGen API calls here.
{
auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
if (listInit) {
for (Init *init : listInit->getValues()) {
Record *builderDef = cast<DefInit>(init)->getDef();
StringRef params = builderDef->getValueAsString("params");
StringRef body = builderDef->getValueAsString("body");
bool hasBody = !body.empty();
auto &method =
opClass.newMethod("void", "build", params, OpMethod::MP_Static,
/*declOnly=*/!hasBody);
if (hasBody)
method.body() << body;
}
}
}
auto numResults = op.getNumResults();
bool hasVariadicResult = op.hasVariadicResult();
int numNonVariadicResults = numResults - int(hasVariadicResult);
auto numOperands = op.getNumOperands();
bool hasVariadicOperand = op.hasVariadicOperand();
int numNonVariadicOperands = numOperands - int(hasVariadicOperand);
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
// We generate three builders here:
// 1. one having a stand-alone parameter for each result type / operand /
// attribute, and
// 2. one having an aggregated parameter for all result types / operands /
// attributes, and
// 3. one having a stand-alone prameter for each operand and attribute,
// use the first operand's type as all result types
// to facilitate different call patterns.
// 1. Stand-alone parameters
genStandaloneParamBuilder(/*useOperandType=*/false, /*useAttrType=*/false);
// 2. Aggregated parameters
// Signature
std::string params =
std::string("Builder *, OperationState *") + builderOpState +
", ArrayRef<Type> resultTypes, ArrayRef<Value *> operands, "
"ArrayRef<NamedAttribute> attributes";
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
auto &body = m.body();
// Result types
body << " assert(resultTypes.size()" << (hasVariadicResult ? " >= " : " == ")
<< numNonVariadicResults
<< "u && \"mismatched number of return types\");\n"
<< " " << builderOpState << "->addTypes(resultTypes);\n";
// Operands
body << " assert(operands.size()" << (hasVariadicOperand ? " >= " : " == ")
<< numNonVariadicOperands
<< "u && \"mismatched number of parameters\");\n"
<< " " << builderOpState << "->addOperands(operands);\n\n";
// Attributes
body << " for (const auto& pair : attributes)\n"
<< " " << builderOpState
<< "->addAttribute(pair.first, pair.second);\n";
// 3. Deduced result types
bool useOperandType = op.hasTrait("SameOperandsAndResultType");
bool useAttrType = op.hasTrait("FirstAttrDerivedResultType");
if (!op.hasVariadicResult() && (useOperandType || useAttrType))
genStandaloneParamBuilder(useOperandType, useAttrType);
}
void OpEmitter::genCanonicalizerDecls() {
if (!def.getValueAsBit("hasCanonicalizer"))
return;
const char *const params =
"OwningRewritePatternList &results, MLIRContext *context";
opClass.newMethod("void", "getCanonicalizationPatterns", params,
OpMethod::MP_Static, /*declOnly=*/true);
}
void OpEmitter::genFolderDecls() {
bool hasSingleResult = op.getNumResults() == 1;
if (def.getValueAsBit("hasConstantFolder")) {
if (hasSingleResult) {
const char *const params =
"ArrayRef<Attribute> operands, MLIRContext *context";
opClass.newMethod("Attribute", "constantFold", params, OpMethod::MP_None,
/*declOnly=*/true);
} else {
const char *const params =
"ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results, "
"MLIRContext *context";
opClass.newMethod("LogicalResult", "constantFold", params,
OpMethod::MP_None, /*declOnly=*/true);
}
}
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
opClass.newMethod("Value *", "fold", /*params=*/"", OpMethod::MP_None,
/*declOnly=*/true);
} else {
opClass.newMethod("bool", "fold", "SmallVectorImpl<Value *> &results",
OpMethod::MP_None,
/*declOnly=*/true);
}
}
}
void OpEmitter::genParser() {
if (!hasStringAttribute(def, "parser"))
return;
auto &method = opClass.newMethod(
"bool", "parse", "OpAsmParser *parser, OperationState *result",
OpMethod::MP_Static);
auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
method.body() << " " << parser;
}
void OpEmitter::genPrinter() {
auto valueInit = def.getValueInit("printer");
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
if (!codeInit)
return;
auto &method = opClass.newMethod("void", "print", "OpAsmPrinter *p");
auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r");
method.body() << " " << printer;
}
void OpEmitter::genVerifier() {
auto valueInit = def.getValueInit("verifier");
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
if (!hasCustomVerify && op.getNumArgs() == 0 && op.getNumResults() == 0)
return;
auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
auto &body = method.body();
// Verify the attributes have the correct type.
for (const auto &namedAttr : op.getAttributes()) {
const auto &attr = namedAttr.attr;
if (attr.isDerivedAttr())
continue;
auto name = namedAttr.getName();
if (!attr.hasStorageType() && !attr.hasDefaultValue()) {
// TODO: Some verification can be done even without storage type.
body << " if (!this->getAttr(\"" << name
<< "\")) return emitOpError(\"requires attribute '" << name
<< "'\");\n";
continue;
}
bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
if (allowMissingAttr) {
// If the attribute has a default value, then only verify the predicate if
// set. This does effectively assume that the default value is valid.
// TODO: verify the debug value is valid (perhaps in debug mode only).
body << " if (this->getAttr(\"" << name << "\")) {\n";
}
body << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
<< attr.getStorageType() << ">()) return emitOpError(\"requires "
<< attr.getDescription() << " attribute '" << name << "'\");\n";
auto attrPred = attr.getPredicate();
if (!attrPred.isNull()) {
body << formatv(" if (!({0})) return emitOpError(\"attribute '{1}' "
"failed to satisfy {2} attribute constraints\");\n",
formatv(attrPred.getCondition(),
formatv("this->getAttr(\"{0}\")", name)),
name, attr.getDescription());
}
if (allowMissingAttr)
body << " }\n";
}
// Emits verification code for an operand or result.
auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index,
bool isOperand) -> void {
// TODO: Handle variadic operand/result verification.
if (value.constraint.isVariadic())
return;
// TODO: Commonality between matchers could be extracted to have a more
// concise code.
if (value.hasPredicate()) {
auto description = value.constraint.getDescription();
body << " if (!("
<< formatv(value.constraint.getConditionTemplate(),
"this->getOperation()->get" +
Twine(isOperand ? "Operand" : "Result") + "(" +
Twine(index) + ")->getType()")
<< "))\n";
body << " return emitOpError(\"" << (isOperand ? "operand" : "result")
<< " #" << index
<< (description.empty() ? " type precondition failed"
: " must be " + Twine(description))
<< "\");\n";
}
};
for (unsigned i = 0, e = op.getNumOperands(); i < e; ++i) {
verifyValue(op.getOperand(i), i, /*isOperand=*/true);
}
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
verifyValue(op.getResult(i), i, /*isOperand=*/false);
}
for (auto &trait : op.getTraits()) {
if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
body << " if (!("
<< formatv(t->getPredTemplate().c_str(), "(*this->getOperation())")
<< "))\n";
body << " return emitOpError(\"failed to verify that "
<< t->getDescription() << "\");\n";
}
}
if (hasCustomVerify)
body << codeInit->getValue() << "\n";
else
body << " return mlir::success();\n";
}
void OpEmitter::genTraits() {
auto numResults = op.getNumResults();
bool hasVariadicResult = op.hasVariadicResult();
// Add return size trait.
if (hasVariadicResult) {
if (numResults == 1)
opClass.addTrait("VariadicResults");
else
opClass.addTrait("AtLeastNResults<" + Twine(numResults - 1) + ">::Impl");
} else {
switch (numResults) {
case 0:
opClass.addTrait("ZeroResult");
break;
case 1:
opClass.addTrait("OneResult");
break;
default:
opClass.addTrait("NResults<" + Twine(numResults) + ">::Impl");
break;
}
}
for (const auto &trait : op.getTraits()) {
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
opClass.addTrait(opTrait->getTrait());
}
// Add variadic size trait and normal op traits.
auto numOperands = op.getNumOperands();
bool hasVariadicOperand = op.hasVariadicOperand();
// Add operand size trait.
if (hasVariadicOperand) {
if (numOperands == 1)
opClass.addTrait("VariadicOperands");
else
opClass.addTrait("AtLeastNOperands<" + Twine(numOperands - 1) +
">::Impl");
} else {
switch (numOperands) {
case 0:
opClass.addTrait("ZeroOperands");
break;
case 1:
opClass.addTrait("OneOperand");
break;
default:
opClass.addTrait("NOperands<" + Twine(numOperands) + ">::Impl");
break;
}
}
}
void OpEmitter::genOpNameGetter() {
auto &method = opClass.newMethod("StringRef", "getOperationName",
/*params=*/"", OpMethod::MP_Static);
method.body() << " return \"" << op.getOperationName() << "\";\n";
}
// Emits the opcode enum and op classes.
static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
bool emitDecl) {
IfDefScope scope("GET_OP_CLASSES", os);
for (auto *def : defs) {
if (emitDecl) {
os << formatv(opCommentHeader, getOpQualClassName(*def), "declarations");
OpEmitter::emitDecl(*def, os);
} else {
os << formatv(opCommentHeader, getOpQualClassName(*def), "definitions");
OpEmitter::emitDef(*def, os);
}
}
}
// Emits a comma-separated list of the ops.
static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
IfDefScope scope("GET_OP_LIST", os);
interleave(
defs, [&os](Record *def) { os << getOpQualClassName(*def); },
[&os]() { os << ",\n"; });
}
static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Declarations", os);
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
emitOpClasses(defs, os, /*emitDecl=*/true);
return false;
}
static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Definitions", os);
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
emitOpList(defs, os);
emitOpClasses(defs, os, /*emitDecl=*/false);
return false;
}
static mlir::GenRegistration
genOpDecls("gen-op-decls", "Generate op declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitOpDecls(records, os);
});
static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
[](const RecordKeeper &records,
raw_ostream &os) {
return emitOpDefs(records, os);
});