[MLIR][NFC] Remove tblgen:: prefix in TableGen/*.cpp files

- Add "using namespace mlir::tblgen" in several of the TableGen/*.cpp files and
  eliminate the tblgen::prefix to reduce code clutter.

Differential Revision: https://reviews.llvm.org/D85800
This commit is contained in:
Rahul Joshi 2020-08-11 17:47:07 -07:00
parent 7ddfb956e1
commit 12d16de538
10 changed files with 284 additions and 354 deletions

View File

@ -88,13 +88,13 @@ public:
Property property, bool declOnly); Property property, bool declOnly);
virtual ~OpMethod() = default; virtual ~OpMethod() = default;
OpMethodBody &body(); OpMethodBody &body() { return methodBody; }
// Returns true if this is a static method. // Returns true if this is a static method.
bool isStatic() const; bool isStatic() const { return properties & MP_Static; }
// Returns true if this is a private method. // Returns true if this is a private method.
bool isPrivate() const; bool isPrivate() const { return properties & MP_Private; }
// Writes the method as a declaration to the given `os`. // Writes the method as a declaration to the given `os`.
virtual void writeDeclTo(raw_ostream &os) const; virtual void writeDeclTo(raw_ostream &os) const;

View File

@ -10,15 +10,12 @@
#include "llvm/TableGen/Record.h" #include "llvm/TableGen/Record.h"
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen;
bool tblgen::NamedTypeConstraint::hasPredicate() const { bool NamedTypeConstraint::hasPredicate() const {
return !constraint.getPredicate().isNull(); return !constraint.getPredicate().isNull();
} }
bool tblgen::NamedTypeConstraint::isOptional() const { bool NamedTypeConstraint::isOptional() const { return constraint.isOptional(); }
return constraint.isOptional();
}
bool tblgen::NamedTypeConstraint::isVariadic() const { bool NamedTypeConstraint::isVariadic() const { return constraint.isVariadic(); }
return constraint.isVariadic();
}

View File

@ -16,6 +16,7 @@
#include "llvm/TableGen/Record.h" #include "llvm/TableGen/Record.h"
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen;
using llvm::CodeInit; using llvm::CodeInit;
using llvm::DefInit; using llvm::DefInit;
@ -28,41 +29,35 @@ using llvm::StringInit;
static StringRef getValueAsString(const Init *init) { static StringRef getValueAsString(const Init *init) {
if (const auto *code = dyn_cast<CodeInit>(init)) if (const auto *code = dyn_cast<CodeInit>(init))
return code->getValue().trim(); return code->getValue().trim();
else if (const auto *str = dyn_cast<StringInit>(init)) if (const auto *str = dyn_cast<StringInit>(init))
return str->getValue().trim(); return str->getValue().trim();
return {}; return {};
} }
tblgen::AttrConstraint::AttrConstraint(const Record *record) AttrConstraint::AttrConstraint(const Record *record)
: Constraint(Constraint::CK_Attr, record) { : Constraint(Constraint::CK_Attr, record) {
assert(isSubClassOf("AttrConstraint") && assert(isSubClassOf("AttrConstraint") &&
"must be subclass of TableGen 'AttrConstraint' class"); "must be subclass of TableGen 'AttrConstraint' class");
} }
bool tblgen::AttrConstraint::isSubClassOf(StringRef className) const { bool AttrConstraint::isSubClassOf(StringRef className) const {
return def->isSubClassOf(className); return def->isSubClassOf(className);
} }
tblgen::Attribute::Attribute(const Record *record) : AttrConstraint(record) { Attribute::Attribute(const Record *record) : AttrConstraint(record) {
assert(record->isSubClassOf("Attr") && assert(record->isSubClassOf("Attr") &&
"must be subclass of TableGen 'Attr' class"); "must be subclass of TableGen 'Attr' class");
} }
tblgen::Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {} Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
bool tblgen::Attribute::isDerivedAttr() const { bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); }
return isSubClassOf("DerivedAttr");
}
bool tblgen::Attribute::isTypeAttr() const { bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); }
return isSubClassOf("TypeAttrBase");
}
bool tblgen::Attribute::isEnumAttr() const { bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
return isSubClassOf("EnumAttrInfo");
}
StringRef tblgen::Attribute::getStorageType() const { StringRef Attribute::getStorageType() const {
const auto *init = def->getValueInit("storageType"); const auto *init = def->getValueInit("storageType");
auto type = getValueAsString(init); auto type = getValueAsString(init);
if (type.empty()) if (type.empty())
@ -70,35 +65,35 @@ StringRef tblgen::Attribute::getStorageType() const {
return type; return type;
} }
StringRef tblgen::Attribute::getReturnType() const { StringRef Attribute::getReturnType() const {
const auto *init = def->getValueInit("returnType"); const auto *init = def->getValueInit("returnType");
return getValueAsString(init); return getValueAsString(init);
} }
// Return the type constraint corresponding to the type of this attribute, or // Return the type constraint corresponding to the type of this attribute, or
// None if this is not a TypedAttr. // None if this is not a TypedAttr.
llvm::Optional<tblgen::Type> tblgen::Attribute::getValueType() const { llvm::Optional<Type> Attribute::getValueType() const {
if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType"))) if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
return tblgen::Type(defInit->getDef()); return Type(defInit->getDef());
return llvm::None; return llvm::None;
} }
StringRef tblgen::Attribute::getConvertFromStorageCall() const { StringRef Attribute::getConvertFromStorageCall() const {
const auto *init = def->getValueInit("convertFromStorage"); const auto *init = def->getValueInit("convertFromStorage");
return getValueAsString(init); return getValueAsString(init);
} }
bool tblgen::Attribute::isConstBuildable() const { bool Attribute::isConstBuildable() const {
const auto *init = def->getValueInit("constBuilderCall"); const auto *init = def->getValueInit("constBuilderCall");
return !getValueAsString(init).empty(); return !getValueAsString(init).empty();
} }
StringRef tblgen::Attribute::getConstBuilderTemplate() const { StringRef Attribute::getConstBuilderTemplate() const {
const auto *init = def->getValueInit("constBuilderCall"); const auto *init = def->getValueInit("constBuilderCall");
return getValueAsString(init); return getValueAsString(init);
} }
tblgen::Attribute tblgen::Attribute::getBaseAttr() const { Attribute Attribute::getBaseAttr() const {
if (const auto *defInit = if (const auto *defInit =
llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) { llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
return Attribute(defInit).getBaseAttr(); return Attribute(defInit).getBaseAttr();
@ -106,178 +101,166 @@ tblgen::Attribute tblgen::Attribute::getBaseAttr() const {
return *this; return *this;
} }
bool tblgen::Attribute::hasDefaultValue() const { bool Attribute::hasDefaultValue() const {
const auto *init = def->getValueInit("defaultValue"); const auto *init = def->getValueInit("defaultValue");
return !getValueAsString(init).empty(); return !getValueAsString(init).empty();
} }
StringRef tblgen::Attribute::getDefaultValue() const { StringRef Attribute::getDefaultValue() const {
const auto *init = def->getValueInit("defaultValue"); const auto *init = def->getValueInit("defaultValue");
return getValueAsString(init); return getValueAsString(init);
} }
bool tblgen::Attribute::isOptional() const { bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); }
return def->getValueAsBit("isOptional");
}
StringRef tblgen::Attribute::getAttrDefName() const { StringRef Attribute::getAttrDefName() const {
if (def->isAnonymous()) { if (def->isAnonymous()) {
return getBaseAttr().def->getName(); return getBaseAttr().def->getName();
} }
return def->getName(); return def->getName();
} }
StringRef tblgen::Attribute::getDerivedCodeBody() const { StringRef Attribute::getDerivedCodeBody() const {
assert(isDerivedAttr() && "only derived attribute has 'body' field"); assert(isDerivedAttr() && "only derived attribute has 'body' field");
return def->getValueAsString("body"); return def->getValueAsString("body");
} }
tblgen::Dialect tblgen::Attribute::getDialect() const { Dialect Attribute::getDialect() const {
return Dialect(def->getValueAsDef("dialect")); return Dialect(def->getValueAsDef("dialect"));
} }
tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
assert(def->isSubClassOf("ConstantAttr") && assert(def->isSubClassOf("ConstantAttr") &&
"must be subclass of TableGen 'ConstantAttr' class"); "must be subclass of TableGen 'ConstantAttr' class");
} }
tblgen::Attribute tblgen::ConstantAttr::getAttribute() const { Attribute ConstantAttr::getAttribute() const {
return Attribute(def->getValueAsDef("attr")); return Attribute(def->getValueAsDef("attr"));
} }
StringRef tblgen::ConstantAttr::getConstantValue() const { StringRef ConstantAttr::getConstantValue() const {
return def->getValueAsString("value"); return def->getValueAsString("value");
} }
tblgen::EnumAttrCase::EnumAttrCase(const llvm::Record *record) EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
: Attribute(record) {
assert(isSubClassOf("EnumAttrCaseInfo") && assert(isSubClassOf("EnumAttrCaseInfo") &&
"must be subclass of TableGen 'EnumAttrInfo' class"); "must be subclass of TableGen 'EnumAttrInfo' class");
} }
tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
: EnumAttrCase(init->getDef()) {} : EnumAttrCase(init->getDef()) {}
bool tblgen::EnumAttrCase::isStrCase() const { bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); }
return isSubClassOf("StrEnumAttrCase");
}
StringRef tblgen::EnumAttrCase::getSymbol() const { StringRef EnumAttrCase::getSymbol() const {
return def->getValueAsString("symbol"); return def->getValueAsString("symbol");
} }
StringRef tblgen::EnumAttrCase::getStr() const { StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
return def->getValueAsString("str");
}
int64_t tblgen::EnumAttrCase::getValue() const { int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
return def->getValueAsInt("value");
}
const llvm::Record &tblgen::EnumAttrCase::getDef() const { return *def; } const llvm::Record &EnumAttrCase::getDef() const { return *def; }
tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
assert(isSubClassOf("EnumAttrInfo") && assert(isSubClassOf("EnumAttrInfo") &&
"must be subclass of TableGen 'EnumAttr' class"); "must be subclass of TableGen 'EnumAttr' class");
} }
tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init) EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
: EnumAttr(init->getDef()) {}
bool tblgen::EnumAttr::classof(const Attribute *attr) { bool EnumAttr::classof(const Attribute *attr) {
return attr->isSubClassOf("EnumAttrInfo"); return attr->isSubClassOf("EnumAttrInfo");
} }
bool tblgen::EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
StringRef tblgen::EnumAttr::getEnumClassName() const { StringRef EnumAttr::getEnumClassName() const {
return def->getValueAsString("className"); return def->getValueAsString("className");
} }
StringRef tblgen::EnumAttr::getCppNamespace() const { StringRef EnumAttr::getCppNamespace() const {
return def->getValueAsString("cppNamespace"); return def->getValueAsString("cppNamespace");
} }
StringRef tblgen::EnumAttr::getUnderlyingType() const { StringRef EnumAttr::getUnderlyingType() const {
return def->getValueAsString("underlyingType"); return def->getValueAsString("underlyingType");
} }
StringRef tblgen::EnumAttr::getUnderlyingToSymbolFnName() const { StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
return def->getValueAsString("underlyingToSymbolFnName"); return def->getValueAsString("underlyingToSymbolFnName");
} }
StringRef tblgen::EnumAttr::getStringToSymbolFnName() const { StringRef EnumAttr::getStringToSymbolFnName() const {
return def->getValueAsString("stringToSymbolFnName"); return def->getValueAsString("stringToSymbolFnName");
} }
StringRef tblgen::EnumAttr::getSymbolToStringFnName() const { StringRef EnumAttr::getSymbolToStringFnName() const {
return def->getValueAsString("symbolToStringFnName"); return def->getValueAsString("symbolToStringFnName");
} }
StringRef tblgen::EnumAttr::getSymbolToStringFnRetType() const { StringRef EnumAttr::getSymbolToStringFnRetType() const {
return def->getValueAsString("symbolToStringFnRetType"); return def->getValueAsString("symbolToStringFnRetType");
} }
StringRef tblgen::EnumAttr::getMaxEnumValFnName() const { StringRef EnumAttr::getMaxEnumValFnName() const {
return def->getValueAsString("maxEnumValFnName"); return def->getValueAsString("maxEnumValFnName");
} }
std::vector<tblgen::EnumAttrCase> tblgen::EnumAttr::getAllCases() const { std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
const auto *inits = def->getValueAsListInit("enumerants"); const auto *inits = def->getValueAsListInit("enumerants");
std::vector<tblgen::EnumAttrCase> cases; std::vector<EnumAttrCase> cases;
cases.reserve(inits->size()); cases.reserve(inits->size());
for (const llvm::Init *init : *inits) { for (const llvm::Init *init : *inits) {
cases.push_back(tblgen::EnumAttrCase(cast<llvm::DefInit>(init))); cases.push_back(EnumAttrCase(cast<llvm::DefInit>(init)));
} }
return cases; return cases;
} }
tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record *record) StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
: def(record) {
assert(def->isSubClassOf("StructFieldAttr") && assert(def->isSubClassOf("StructFieldAttr") &&
"must be subclass of TableGen 'StructFieldAttr' class"); "must be subclass of TableGen 'StructFieldAttr' class");
} }
tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record &record) StructFieldAttr::StructFieldAttr(const llvm::Record &record)
: StructFieldAttr(&record) {} : StructFieldAttr(&record) {}
tblgen::StructFieldAttr::StructFieldAttr(const llvm::DefInit *init) StructFieldAttr::StructFieldAttr(const llvm::DefInit *init)
: StructFieldAttr(init->getDef()) {} : StructFieldAttr(init->getDef()) {}
StringRef tblgen::StructFieldAttr::getName() const { StringRef StructFieldAttr::getName() const {
return def->getValueAsString("name"); return def->getValueAsString("name");
} }
tblgen::Attribute tblgen::StructFieldAttr::getType() const { Attribute StructFieldAttr::getType() const {
auto init = def->getValueInit("type"); auto init = def->getValueInit("type");
return tblgen::Attribute(cast<llvm::DefInit>(init)); return Attribute(cast<llvm::DefInit>(init));
} }
tblgen::StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) { StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) {
assert(isSubClassOf("StructAttr") && assert(isSubClassOf("StructAttr") &&
"must be subclass of TableGen 'StructAttr' class"); "must be subclass of TableGen 'StructAttr' class");
} }
tblgen::StructAttr::StructAttr(const llvm::DefInit *init) StructAttr::StructAttr(const llvm::DefInit *init)
: StructAttr(init->getDef()) {} : StructAttr(init->getDef()) {}
StringRef tblgen::StructAttr::getStructClassName() const { StringRef StructAttr::getStructClassName() const {
return def->getValueAsString("className"); return def->getValueAsString("className");
} }
StringRef tblgen::StructAttr::getCppNamespace() const { StringRef StructAttr::getCppNamespace() const {
Dialect dialect(def->getValueAsDef("structDialect")); Dialect dialect(def->getValueAsDef("structDialect"));
return dialect.getCppNamespace(); return dialect.getCppNamespace();
} }
std::vector<mlir::tblgen::StructFieldAttr> std::vector<StructFieldAttr> StructAttr::getAllFields() const {
tblgen::StructAttr::getAllFields() const { std::vector<StructFieldAttr> attributes;
std::vector<mlir::tblgen::StructFieldAttr> attributes;
const auto *inits = def->getValueAsListInit("fields"); const auto *inits = def->getValueAsListInit("fields");
attributes.reserve(inits->size()); attributes.reserve(inits->size());
@ -289,4 +272,4 @@ tblgen::StructAttr::getAllFields() const {
return attributes; return attributes;
} }
const char *mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface"; const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";

View File

@ -13,18 +13,16 @@
#include "mlir/TableGen/Dialect.h" #include "mlir/TableGen/Dialect.h"
#include "llvm/TableGen/Record.h" #include "llvm/TableGen/Record.h"
namespace mlir { using namespace mlir;
namespace tblgen { using namespace mlir::tblgen;
StringRef tblgen::Dialect::getName() const { StringRef Dialect::getName() const { return def->getValueAsString("name"); }
return def->getValueAsString("name");
}
StringRef tblgen::Dialect::getCppNamespace() const { StringRef Dialect::getCppNamespace() const {
return def->getValueAsString("cppNamespace"); return def->getValueAsString("cppNamespace");
} }
std::string tblgen::Dialect::getCppClassName() const { std::string Dialect::getCppClassName() const {
// Simply use the name and remove any '_' tokens. // Simply use the name and remove any '_' tokens.
std::string cppName = def->getName().str(); std::string cppName = def->getName().str();
llvm::erase_if(cppName, [](char c) { return c == '_'; }); llvm::erase_if(cppName, [](char c) { return c == '_'; });
@ -40,32 +38,32 @@ static StringRef getAsStringOrEmpty(const llvm::Record &record,
return ""; return "";
} }
StringRef tblgen::Dialect::getSummary() const { StringRef Dialect::getSummary() const {
return getAsStringOrEmpty(*def, "summary"); return getAsStringOrEmpty(*def, "summary");
} }
StringRef tblgen::Dialect::getDescription() const { StringRef Dialect::getDescription() const {
return getAsStringOrEmpty(*def, "description"); return getAsStringOrEmpty(*def, "description");
} }
llvm::Optional<StringRef> tblgen::Dialect::getExtraClassDeclaration() const { llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
auto value = def->getValueAsString("extraClassDeclaration"); auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value; return value.empty() ? llvm::Optional<StringRef>() : value;
} }
bool tblgen::Dialect::hasConstantMaterializer() const { bool Dialect::hasConstantMaterializer() const {
return def->getValueAsBit("hasConstantMaterializer"); return def->getValueAsBit("hasConstantMaterializer");
} }
bool tblgen::Dialect::hasOperationAttrVerify() const { bool Dialect::hasOperationAttrVerify() const {
return def->getValueAsBit("hasOperationAttrVerify"); return def->getValueAsBit("hasOperationAttrVerify");
} }
bool tblgen::Dialect::hasRegionArgAttrVerify() const { bool Dialect::hasRegionArgAttrVerify() const {
return def->getValueAsBit("hasRegionArgAttrVerify"); return def->getValueAsBit("hasRegionArgAttrVerify");
} }
bool tblgen::Dialect::hasRegionResultAttrVerify() const { bool Dialect::hasRegionResultAttrVerify() const {
return def->getValueAsBit("hasRegionResultAttrVerify"); return def->getValueAsBit("hasRegionResultAttrVerify");
} }
@ -76,6 +74,3 @@ bool Dialect::operator==(const Dialect &other) const {
bool Dialect::operator<(const Dialect &other) const { bool Dialect::operator<(const Dialect &other) const {
return getName() < other.getName(); return getName() < other.getName();
} }
} // end namespace tblgen
} // end namespace mlir

View File

@ -21,28 +21,28 @@ using namespace mlir::tblgen;
// Marker to indicate an error happened when replacing a placeholder. // Marker to indicate an error happened when replacing a placeholder.
const char *const kMarkerForNoSubst = "<no-subst-found>"; const char *const kMarkerForNoSubst = "<no-subst-found>";
FmtContext &tblgen::FmtContext::addSubst(StringRef placeholder, Twine subst) { FmtContext &FmtContext::addSubst(StringRef placeholder, Twine subst) {
customSubstMap[placeholder] = subst.str(); customSubstMap[placeholder] = subst.str();
return *this; return *this;
} }
FmtContext &tblgen::FmtContext::withBuilder(Twine subst) { FmtContext &FmtContext::withBuilder(Twine subst) {
builtinSubstMap[PHKind::Builder] = subst.str(); builtinSubstMap[PHKind::Builder] = subst.str();
return *this; return *this;
} }
FmtContext &tblgen::FmtContext::withOp(Twine subst) { FmtContext &FmtContext::withOp(Twine subst) {
builtinSubstMap[PHKind::Op] = subst.str(); builtinSubstMap[PHKind::Op] = subst.str();
return *this; return *this;
} }
FmtContext &tblgen::FmtContext::withSelf(Twine subst) { FmtContext &FmtContext::withSelf(Twine subst) {
builtinSubstMap[PHKind::Self] = subst.str(); builtinSubstMap[PHKind::Self] = subst.str();
return *this; return *this;
} }
Optional<StringRef> Optional<StringRef>
tblgen::FmtContext::getSubstFor(FmtContext::PHKind placeholder) const { FmtContext::getSubstFor(FmtContext::PHKind placeholder) const {
if (placeholder == FmtContext::PHKind::None || if (placeholder == FmtContext::PHKind::None ||
placeholder == FmtContext::PHKind::Custom) placeholder == FmtContext::PHKind::Custom)
return {}; return {};
@ -52,15 +52,14 @@ tblgen::FmtContext::getSubstFor(FmtContext::PHKind placeholder) const {
return StringRef(it->second); return StringRef(it->second);
} }
Optional<StringRef> Optional<StringRef> FmtContext::getSubstFor(StringRef placeholder) const {
tblgen::FmtContext::getSubstFor(StringRef placeholder) const {
auto it = customSubstMap.find(placeholder); auto it = customSubstMap.find(placeholder);
if (it == customSubstMap.end()) if (it == customSubstMap.end())
return {}; return {};
return StringRef(it->second); return StringRef(it->second);
} }
FmtContext::PHKind tblgen::FmtContext::getPlaceHolderKind(StringRef str) { FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) {
return llvm::StringSwitch<FmtContext::PHKind>(str) return llvm::StringSwitch<FmtContext::PHKind>(str)
.Case("_builder", FmtContext::PHKind::Builder) .Case("_builder", FmtContext::PHKind::Builder)
.Case("_op", FmtContext::PHKind::Op) .Case("_op", FmtContext::PHKind::Op)
@ -70,7 +69,7 @@ FmtContext::PHKind tblgen::FmtContext::getPlaceHolderKind(StringRef str) {
} }
std::pair<FmtReplacement, StringRef> std::pair<FmtReplacement, StringRef>
tblgen::FmtObjectBase::splitFmtSegment(StringRef fmt) { FmtObjectBase::splitFmtSegment(StringRef fmt) {
size_t begin = fmt.find_first_of('$'); size_t begin = fmt.find_first_of('$');
if (begin == StringRef::npos) { if (begin == StringRef::npos) {
// No placeholders: the whole format string should be returned as a // No placeholders: the whole format string should be returned as a

View File

@ -13,22 +13,23 @@
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// OpMethodSignature definitions // OpMethodSignature definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
tblgen::OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name, OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name,
StringRef params) StringRef params)
: returnType(retType), methodName(name), parameters(params) {} : returnType(retType), methodName(name), parameters(params) {}
void tblgen::OpMethodSignature::writeDeclTo(raw_ostream &os) const { void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName
<< "(" << parameters << ")"; << "(" << parameters << ")";
} }
void tblgen::OpMethodSignature::writeDefTo(raw_ostream &os, void OpMethodSignature::writeDefTo(raw_ostream &os,
StringRef namePrefix) const { StringRef namePrefix) const {
// We need to remove the default values for parameters in method definition. // We need to remove the default values for parameters in method definition.
// TODO: We are using '=' and ',' as delimiters for parameter // TODO: We are using '=' and ',' as delimiters for parameter
// initializers. This is incorrect for initializer list with more than one // initializers. This is incorrect for initializer list with more than one
@ -50,7 +51,7 @@ void tblgen::OpMethodSignature::writeDefTo(raw_ostream &os,
<< removeParamDefaultValue(parameters) << ")"; << removeParamDefaultValue(parameters) << ")";
} }
bool tblgen::OpMethodSignature::elideSpaceAfterType(StringRef type) { bool OpMethodSignature::elideSpaceAfterType(StringRef type) {
return type.empty() || type.endswith("&") || type.endswith("*"); return type.empty() || type.endswith("&") || type.endswith("*");
} }
@ -58,28 +59,27 @@ bool tblgen::OpMethodSignature::elideSpaceAfterType(StringRef type) {
// OpMethodBody definitions // OpMethodBody definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
tblgen::OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {} OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
tblgen::OpMethodBody &tblgen::OpMethodBody::operator<<(Twine content) { OpMethodBody &OpMethodBody::operator<<(Twine content) {
if (isEffective) if (isEffective)
body.append(content.str()); body.append(content.str());
return *this; return *this;
} }
tblgen::OpMethodBody &tblgen::OpMethodBody::operator<<(int content) { OpMethodBody &OpMethodBody::operator<<(int content) {
if (isEffective) if (isEffective)
body.append(std::to_string(content)); body.append(std::to_string(content));
return *this; return *this;
} }
tblgen::OpMethodBody & OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
tblgen::OpMethodBody::operator<<(const FmtObjectBase &content) {
if (isEffective) if (isEffective)
body.append(content.str()); body.append(content.str());
return *this; return *this;
} }
void tblgen::OpMethodBody::writeTo(raw_ostream &os) const { void OpMethodBody::writeTo(raw_ostream &os) const {
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; }); auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
os << bodyRef; os << bodyRef;
if (bodyRef.empty() || bodyRef.back() != '\n') if (bodyRef.empty() || bodyRef.back() != '\n')
@ -90,18 +90,11 @@ void tblgen::OpMethodBody::writeTo(raw_ostream &os) const {
// OpMethod definitions // OpMethod definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
tblgen::OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params, OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params,
OpMethod::Property property, bool declOnly) OpMethod::Property property, bool declOnly)
: properties(property), isDeclOnly(declOnly), : properties(property), isDeclOnly(declOnly),
methodSignature(retType, name, params), methodBody(declOnly) {} methodSignature(retType, name, params), methodBody(declOnly) {}
void OpMethod::writeDeclTo(raw_ostream &os) const {
tblgen::OpMethodBody &tblgen::OpMethod::body() { return methodBody; }
bool tblgen::OpMethod::isStatic() const { return properties & MP_Static; }
bool tblgen::OpMethod::isPrivate() const { return properties & MP_Private; }
void tblgen::OpMethod::writeDeclTo(raw_ostream &os) const {
os.indent(2); os.indent(2);
if (isStatic()) if (isStatic())
os << "static "; os << "static ";
@ -109,7 +102,7 @@ void tblgen::OpMethod::writeDeclTo(raw_ostream &os) const {
os << ";"; os << ";";
} }
void tblgen::OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const { void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
if (isDeclOnly) if (isDeclOnly)
return; return;
@ -123,14 +116,12 @@ void tblgen::OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
// OpConstructor definitions // OpConstructor definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void mlir::tblgen::OpConstructor::addMemberInitializer(StringRef name, void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
StringRef value) {
memberInitializers.append(std::string(llvm::formatv( memberInitializers.append(std::string(llvm::formatv(
"{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value))); "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
} }
void mlir::tblgen::OpConstructor::writeDefTo(raw_ostream &os, void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
StringRef namePrefix) const {
if (isDeclOnly) if (isDeclOnly)
return; return;
@ -144,25 +135,21 @@ void mlir::tblgen::OpConstructor::writeDefTo(raw_ostream &os,
// Class definitions // Class definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
tblgen::Class::Class(StringRef name) : className(name) {} Class::Class(StringRef name) : className(name) {}
tblgen::OpMethod &tblgen::Class::newMethod(StringRef retType, StringRef name, OpMethod &Class::newMethod(StringRef retType, StringRef name, StringRef params,
StringRef params, OpMethod::Property property, bool declOnly) {
OpMethod::Property property,
bool declOnly) {
methods.emplace_back(retType, name, params, property, declOnly); methods.emplace_back(retType, name, params, property, declOnly);
return methods.back(); return methods.back();
} }
tblgen::OpConstructor &tblgen::Class::newConstructor(StringRef params, OpConstructor &Class::newConstructor(StringRef params, bool declOnly) {
bool declOnly) {
constructors.emplace_back("", getClassName(), params, constructors.emplace_back("", getClassName(), params,
OpMethod::MP_Constructor, declOnly); OpMethod::MP_Constructor, declOnly);
return constructors.back(); return constructors.back();
} }
void tblgen::Class::newField(StringRef type, StringRef name, void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
StringRef defaultValue) {
std::string varName = formatv("{0} {1}", type, name).str(); std::string varName = formatv("{0} {1}", type, name).str();
std::string field = defaultValue.empty() std::string field = defaultValue.empty()
? varName ? varName
@ -170,7 +157,7 @@ void tblgen::Class::newField(StringRef type, StringRef name,
fields.push_back(std::move(field)); fields.push_back(std::move(field));
} }
void tblgen::Class::writeDeclTo(raw_ostream &os) const { void Class::writeDeclTo(raw_ostream &os) const {
bool hasPrivateMethod = false; bool hasPrivateMethod = false;
os << "class " << className << " {\n"; os << "class " << className << " {\n";
os << "public:\n"; os << "public:\n";
@ -200,7 +187,7 @@ void tblgen::Class::writeDeclTo(raw_ostream &os) const {
os << "};\n"; os << "};\n";
} }
void tblgen::Class::writeDefTo(raw_ostream &os) const { void Class::writeDefTo(raw_ostream &os) const {
for (const auto &method : for (const auto &method :
llvm::concat<const OpMethod>(constructors, methods)) { llvm::concat<const OpMethod>(constructors, methods)) {
method.writeDefTo(os, className); method.writeDefTo(os, className);
@ -212,16 +199,16 @@ void tblgen::Class::writeDefTo(raw_ostream &os) const {
// OpClass definitions // OpClass definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
tblgen::OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
: Class(name), extraClassDeclaration(extraClassDeclaration) {} : Class(name), extraClassDeclaration(extraClassDeclaration) {}
void tblgen::OpClass::addTrait(Twine trait) { void OpClass::addTrait(Twine trait) {
auto traitStr = trait.str(); auto traitStr = trait.str();
if (traitsSet.insert(traitStr).second) if (traitsSet.insert(traitStr).second)
traitsVec.push_back(std::move(traitStr)); traitsVec.push_back(std::move(traitStr));
} }
void tblgen::OpClass::writeDeclTo(raw_ostream &os) const { void OpClass::writeDeclTo(raw_ostream &os) const {
os << "class " << className << " : public ::mlir::Op<" << className; os << "class " << className << " : public ::mlir::Op<" << className;
for (const auto &trait : traitsVec) for (const auto &trait : traitsVec)
os << ", " << trait; os << ", " << trait;

View File

@ -44,7 +44,7 @@ llvm::StringRef InternalOpTrait::getTrait() const {
} }
std::string PredOpTrait::getPredTemplate() const { std::string PredOpTrait::getPredTemplate() const {
auto pred = tblgen::Pred(def->getValueInit("predicate")); auto pred = Pred(def->getValueInit("predicate"));
return pred.getCondition(); return pred.getCondition();
} }

View File

@ -27,12 +27,13 @@
#define DEBUG_TYPE "mlir-tblgen-operator" #define DEBUG_TYPE "mlir-tblgen-operator"
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen;
using llvm::DagInit; using llvm::DagInit;
using llvm::DefInit; using llvm::DefInit;
using llvm::Record; using llvm::Record;
tblgen::Operator::Operator(const llvm::Record &def) Operator::Operator(const llvm::Record &def)
: dialect(def.getValueAsDef("opDialect")), def(def) { : dialect(def.getValueAsDef("opDialect")), def(def) {
// The first `_` in the op's TableGen def name is treated as separating the // The first `_` in the op's TableGen def name is treated as separating the
// dialect prefix and the op class name. The dialect prefix will be ignored if // dialect prefix and the op class name. The dialect prefix will be ignored if
@ -51,7 +52,7 @@ tblgen::Operator::Operator(const llvm::Record &def)
populateOpStructure(); populateOpStructure();
} }
std::string tblgen::Operator::getOperationName() const { std::string Operator::getOperationName() const {
auto prefix = dialect.getName(); auto prefix = dialect.getName();
auto opName = def.getValueAsString("opName"); auto opName = def.getValueAsString("opName");
if (prefix.empty()) if (prefix.empty())
@ -59,62 +60,58 @@ std::string tblgen::Operator::getOperationName() const {
return std::string(llvm::formatv("{0}.{1}", prefix, opName)); return std::string(llvm::formatv("{0}.{1}", prefix, opName));
} }
std::string tblgen::Operator::getAdaptorName() const { std::string Operator::getAdaptorName() const {
return std::string(llvm::formatv("{0}Adaptor", getCppClassName())); return std::string(llvm::formatv("{0}Adaptor", getCppClassName()));
} }
StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); } StringRef Operator::getDialectName() const { return dialect.getName(); }
StringRef tblgen::Operator::getCppClassName() const { return cppClassName; } StringRef Operator::getCppClassName() const { return cppClassName; }
std::string tblgen::Operator::getQualCppClassName() const { std::string Operator::getQualCppClassName() const {
auto prefix = dialect.getCppNamespace(); auto prefix = dialect.getCppNamespace();
if (prefix.empty()) if (prefix.empty())
return std::string(cppClassName); return std::string(cppClassName);
return std::string(llvm::formatv("{0}::{1}", prefix, cppClassName)); return std::string(llvm::formatv("{0}::{1}", prefix, cppClassName));
} }
int tblgen::Operator::getNumResults() const { int Operator::getNumResults() const {
DagInit *results = def.getValueAsDag("results"); DagInit *results = def.getValueAsDag("results");
return results->getNumArgs(); return results->getNumArgs();
} }
StringRef tblgen::Operator::getExtraClassDeclaration() const { StringRef Operator::getExtraClassDeclaration() const {
constexpr auto attr = "extraClassDeclaration"; constexpr auto attr = "extraClassDeclaration";
if (def.isValueUnset(attr)) if (def.isValueUnset(attr))
return {}; return {};
return def.getValueAsString(attr); return def.getValueAsString(attr);
} }
const llvm::Record &tblgen::Operator::getDef() const { return def; } const llvm::Record &Operator::getDef() const { return def; }
bool tblgen::Operator::skipDefaultBuilders() const { bool Operator::skipDefaultBuilders() const {
return def.getValueAsBit("skipDefaultBuilders"); return def.getValueAsBit("skipDefaultBuilders");
} }
auto tblgen::Operator::result_begin() -> value_iterator { auto Operator::result_begin() -> value_iterator { return results.begin(); }
return results.begin();
}
auto tblgen::Operator::result_end() -> value_iterator { return results.end(); } auto Operator::result_end() -> value_iterator { return results.end(); }
auto tblgen::Operator::getResults() -> value_range { auto Operator::getResults() -> value_range {
return {result_begin(), result_end()}; return {result_begin(), result_end()};
} }
tblgen::TypeConstraint TypeConstraint Operator::getResultTypeConstraint(int index) const {
tblgen::Operator::getResultTypeConstraint(int index) const {
DagInit *results = def.getValueAsDag("results"); DagInit *results = def.getValueAsDag("results");
return TypeConstraint(cast<DefInit>(results->getArg(index))); return TypeConstraint(cast<DefInit>(results->getArg(index)));
} }
StringRef tblgen::Operator::getResultName(int index) const { StringRef Operator::getResultName(int index) const {
DagInit *results = def.getValueAsDag("results"); DagInit *results = def.getValueAsDag("results");
return results->getArgNameStr(index); return results->getArgNameStr(index);
} }
auto tblgen::Operator::getResultDecorators(int index) const auto Operator::getResultDecorators(int index) const -> var_decorator_range {
-> var_decorator_range {
Record *result = Record *result =
cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef(); cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef();
if (!result->isSubClassOf("OpVariable")) if (!result->isSubClassOf("OpVariable"))
@ -122,42 +119,37 @@ auto tblgen::Operator::getResultDecorators(int index) const
return *result->getValueAsListInit("decorators"); return *result->getValueAsListInit("decorators");
} }
unsigned tblgen::Operator::getNumVariableLengthResults() const { unsigned Operator::getNumVariableLengthResults() const {
return llvm::count_if(results, [](const NamedTypeConstraint &c) { return llvm::count_if(results, [](const NamedTypeConstraint &c) {
return c.constraint.isVariableLength(); return c.constraint.isVariableLength();
}); });
} }
unsigned tblgen::Operator::getNumVariableLengthOperands() const { unsigned Operator::getNumVariableLengthOperands() const {
return llvm::count_if(operands, [](const NamedTypeConstraint &c) { return llvm::count_if(operands, [](const NamedTypeConstraint &c) {
return c.constraint.isVariableLength(); return c.constraint.isVariableLength();
}); });
} }
bool tblgen::Operator::hasSingleVariadicArg() const { bool Operator::hasSingleVariadicArg() const {
return getNumArgs() == 1 && getArg(0).is<tblgen::NamedTypeConstraint *>() && return getNumArgs() == 1 && getArg(0).is<NamedTypeConstraint *>() &&
getOperand(0).isVariadic(); getOperand(0).isVariadic();
} }
tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const { Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); }
return arguments.begin();
}
tblgen::Operator::arg_iterator tblgen::Operator::arg_end() const { Operator::arg_iterator Operator::arg_end() const { return arguments.end(); }
return arguments.end();
}
tblgen::Operator::arg_range tblgen::Operator::getArgs() const { Operator::arg_range Operator::getArgs() const {
return {arg_begin(), arg_end()}; return {arg_begin(), arg_end()};
} }
StringRef tblgen::Operator::getArgName(int index) const { StringRef Operator::getArgName(int index) const {
DagInit *argumentValues = def.getValueAsDag("arguments"); DagInit *argumentValues = def.getValueAsDag("arguments");
return argumentValues->getArgName(index)->getValue(); return argumentValues->getArgName(index)->getValue();
} }
auto tblgen::Operator::getArgDecorators(int index) const auto Operator::getArgDecorators(int index) const -> var_decorator_range {
-> var_decorator_range {
Record *arg = Record *arg =
cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef(); cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef();
if (!arg->isSubClassOf("OpVariable")) if (!arg->isSubClassOf("OpVariable"))
@ -165,15 +157,15 @@ auto tblgen::Operator::getArgDecorators(int index) const
return *arg->getValueAsListInit("decorators"); return *arg->getValueAsListInit("decorators");
} }
const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const { const OpTrait *Operator::getTrait(StringRef trait) const {
for (const auto &t : traits) { for (const auto &t : traits) {
if (const auto *opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) { if (const auto *opTrait = dyn_cast<NativeOpTrait>(&t)) {
if (opTrait->getTrait() == trait) if (opTrait->getTrait() == trait)
return opTrait; return opTrait;
} else if (const auto *opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) { } else if (const auto *opTrait = dyn_cast<InternalOpTrait>(&t)) {
if (opTrait->getTrait() == trait) if (opTrait->getTrait() == trait)
return opTrait; return opTrait;
} else if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&t)) { } else if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&t)) {
if (opTrait->getTrait() == trait) if (opTrait->getTrait() == trait)
return opTrait; return opTrait;
} }
@ -181,100 +173,90 @@ const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const {
return nullptr; return nullptr;
} }
auto tblgen::Operator::region_begin() const -> const_region_iterator { auto Operator::region_begin() const -> const_region_iterator {
return regions.begin(); return regions.begin();
} }
auto tblgen::Operator::region_end() const -> const_region_iterator { auto Operator::region_end() const -> const_region_iterator {
return regions.end(); return regions.end();
} }
auto tblgen::Operator::getRegions() const auto Operator::getRegions() const
-> llvm::iterator_range<const_region_iterator> { -> llvm::iterator_range<const_region_iterator> {
return {region_begin(), region_end()}; return {region_begin(), region_end()};
} }
unsigned tblgen::Operator::getNumRegions() const { return regions.size(); } unsigned Operator::getNumRegions() const { return regions.size(); }
const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const { const NamedRegion &Operator::getRegion(unsigned index) const {
return regions[index]; return regions[index];
} }
unsigned tblgen::Operator::getNumVariadicRegions() const { unsigned Operator::getNumVariadicRegions() const {
return llvm::count_if(regions, return llvm::count_if(regions,
[](const NamedRegion &c) { return c.isVariadic(); }); [](const NamedRegion &c) { return c.isVariadic(); });
} }
auto tblgen::Operator::successor_begin() const -> const_successor_iterator { auto Operator::successor_begin() const -> const_successor_iterator {
return successors.begin(); return successors.begin();
} }
auto tblgen::Operator::successor_end() const -> const_successor_iterator { auto Operator::successor_end() const -> const_successor_iterator {
return successors.end(); return successors.end();
} }
auto tblgen::Operator::getSuccessors() const auto Operator::getSuccessors() const
-> llvm::iterator_range<const_successor_iterator> { -> llvm::iterator_range<const_successor_iterator> {
return {successor_begin(), successor_end()}; return {successor_begin(), successor_end()};
} }
unsigned tblgen::Operator::getNumSuccessors() const { unsigned Operator::getNumSuccessors() const { return successors.size(); }
return successors.size();
}
const tblgen::NamedSuccessor & const NamedSuccessor &Operator::getSuccessor(unsigned index) const {
tblgen::Operator::getSuccessor(unsigned index) const {
return successors[index]; return successors[index];
} }
unsigned tblgen::Operator::getNumVariadicSuccessors() const { unsigned Operator::getNumVariadicSuccessors() const {
return llvm::count_if(successors, return llvm::count_if(successors,
[](const NamedSuccessor &c) { return c.isVariadic(); }); [](const NamedSuccessor &c) { return c.isVariadic(); });
} }
auto tblgen::Operator::trait_begin() const -> const_trait_iterator { auto Operator::trait_begin() const -> const_trait_iterator {
return traits.begin(); return traits.begin();
} }
auto tblgen::Operator::trait_end() const -> const_trait_iterator { auto Operator::trait_end() const -> const_trait_iterator {
return traits.end(); return traits.end();
} }
auto tblgen::Operator::getTraits() const auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> {
-> llvm::iterator_range<const_trait_iterator> {
return {trait_begin(), trait_end()}; return {trait_begin(), trait_end()};
} }
auto tblgen::Operator::attribute_begin() const -> attribute_iterator { auto Operator::attribute_begin() const -> attribute_iterator {
return attributes.begin(); return attributes.begin();
} }
auto tblgen::Operator::attribute_end() const -> attribute_iterator { auto Operator::attribute_end() const -> attribute_iterator {
return attributes.end(); return attributes.end();
} }
auto tblgen::Operator::getAttributes() const auto Operator::getAttributes() const
-> llvm::iterator_range<attribute_iterator> { -> llvm::iterator_range<attribute_iterator> {
return {attribute_begin(), attribute_end()}; return {attribute_begin(), attribute_end()};
} }
auto tblgen::Operator::operand_begin() -> value_iterator { auto Operator::operand_begin() -> value_iterator { return operands.begin(); }
return operands.begin(); auto Operator::operand_end() -> value_iterator { return operands.end(); }
} auto Operator::getOperands() -> value_range {
auto tblgen::Operator::operand_end() -> value_iterator {
return operands.end();
}
auto tblgen::Operator::getOperands() -> value_range {
return {operand_begin(), operand_end()}; return {operand_begin(), operand_end()};
} }
auto tblgen::Operator::getArg(int index) const -> Argument { auto Operator::getArg(int index) const -> Argument { return arguments[index]; }
return arguments[index];
}
// Mapping from result index to combined argument and result index. Arguments // Mapping from result index to combined argument and result index. Arguments
// are indexed to match getArg index, while the result indexes are mapped to // are indexed to match getArg index, while the result indexes are mapped to
// avoid overlap. // avoid overlap.
static int resultIndex(int i) { return -1 - i; } static int resultIndex(int i) { return -1 - i; }
bool tblgen::Operator::isVariadic() const { bool Operator::isVariadic() const {
return any_of(llvm::concat<const NamedTypeConstraint>(operands, results), return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
[](const NamedTypeConstraint &op) { return op.isVariadic(); }); [](const NamedTypeConstraint &op) { return op.isVariadic(); });
} }
void tblgen::Operator::populateTypeInferenceInfo( void Operator::populateTypeInferenceInfo(
const llvm::StringMap<int> &argumentsAndResultsIndex) { const llvm::StringMap<int> &argumentsAndResultsIndex) {
// If the type inference op interface is not registered, then do not attempt // If the type inference op interface is not registered, then do not attempt
// to determine if the result types an be inferred. // to determine if the result types an be inferred.
@ -340,7 +322,7 @@ void tblgen::Operator::populateTypeInferenceInfo(
if (def.isSubClassOf( if (def.isSubClassOf(
llvm::formatv("{0}::Trait", inferTypeOpInterface).str())) llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
return; return;
if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait)) if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&trait))
if (&opTrait->getDef() == inferTrait) if (&opTrait->getDef() == inferTrait)
return; return;
@ -364,7 +346,7 @@ void tblgen::Operator::populateTypeInferenceInfo(
traits.push_back(OpTrait::create(inferTrait->getDefInit())); traits.push_back(OpTrait::create(inferTrait->getDefInit()));
} }
void tblgen::Operator::populateOpStructure() { void Operator::populateOpStructure() {
auto &recordKeeper = def.getRecords(); auto &recordKeeper = def.getRecords();
auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint"); auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint");
auto *attrClass = recordKeeper.getClass("Attr"); auto *attrClass = recordKeeper.getClass("Attr");
@ -541,42 +523,39 @@ void tblgen::Operator::populateOpStructure() {
LLVM_DEBUG(print(llvm::dbgs())); LLVM_DEBUG(print(llvm::dbgs()));
} }
auto tblgen::Operator::getSameTypeAsResult(int index) const auto Operator::getSameTypeAsResult(int index) const -> ArrayRef<ArgOrType> {
-> ArrayRef<ArgOrType> {
assert(allResultTypesKnown()); assert(allResultTypesKnown());
return resultTypeMapping[index]; return resultTypeMapping[index];
} }
ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); } ArrayRef<llvm::SMLoc> Operator::getLoc() const { return def.getLoc(); }
bool tblgen::Operator::hasDescription() const { bool Operator::hasDescription() const {
return def.getValue("description") != nullptr; return def.getValue("description") != nullptr;
} }
StringRef tblgen::Operator::getDescription() const { StringRef Operator::getDescription() const {
return def.getValueAsString("description"); return def.getValueAsString("description");
} }
bool tblgen::Operator::hasSummary() const { bool Operator::hasSummary() const { return def.getValue("summary") != nullptr; }
return def.getValue("summary") != nullptr;
}
StringRef tblgen::Operator::getSummary() const { StringRef Operator::getSummary() const {
return def.getValueAsString("summary"); return def.getValueAsString("summary");
} }
bool tblgen::Operator::hasAssemblyFormat() const { bool Operator::hasAssemblyFormat() const {
auto *valueInit = def.getValueInit("assemblyFormat"); auto *valueInit = def.getValueInit("assemblyFormat");
return isa<llvm::CodeInit, llvm::StringInit>(valueInit); return isa<llvm::CodeInit, llvm::StringInit>(valueInit);
} }
StringRef tblgen::Operator::getAssemblyFormat() const { StringRef Operator::getAssemblyFormat() const {
return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat")) return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
.Case<llvm::StringInit, llvm::CodeInit>( .Case<llvm::StringInit, llvm::CodeInit>(
[&](auto *init) { return init->getValue(); }); [&](auto *init) { return init->getValue(); });
} }
void tblgen::Operator::print(llvm::raw_ostream &os) const { void Operator::print(llvm::raw_ostream &os) const {
os << "op '" << getOperationName() << "'\n"; os << "op '" << getOperationName() << "'\n";
for (Argument arg : arguments) { for (Argument arg : arguments) {
if (auto *attr = arg.dyn_cast<NamedAttribute *>()) if (auto *attr = arg.dyn_cast<NamedAttribute *>())
@ -586,12 +565,12 @@ void tblgen::Operator::print(llvm::raw_ostream &os) const {
} }
} }
auto tblgen::Operator::VariableDecoratorIterator::unwrap(llvm::Init *init) auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
-> VariableDecorator { -> VariableDecorator {
return VariableDecorator(cast<llvm::DefInit>(init)->getDef()); return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
} }
auto tblgen::Operator::getArgToOperandOrAttribute(int index) const auto Operator::getArgToOperandOrAttribute(int index) const
-> OperandOrAttribute { -> OperandOrAttribute {
return attrOrOperandMapping[index]; return attrOrOperandMapping[index];
} }

View File

@ -22,80 +22,78 @@
#define DEBUG_TYPE "mlir-tblgen-pattern" #define DEBUG_TYPE "mlir-tblgen-pattern"
using namespace mlir; using namespace mlir;
using namespace tblgen;
using llvm::formatv; using llvm::formatv;
using mlir::tblgen::Operator;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// DagLeaf // DagLeaf
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool tblgen::DagLeaf::isUnspecified() const { bool DagLeaf::isUnspecified() const {
return dyn_cast_or_null<llvm::UnsetInit>(def); return dyn_cast_or_null<llvm::UnsetInit>(def);
} }
bool tblgen::DagLeaf::isOperandMatcher() const { bool DagLeaf::isOperandMatcher() const {
// Operand matchers specify a type constraint. // Operand matchers specify a type constraint.
return isSubClassOf("TypeConstraint"); return isSubClassOf("TypeConstraint");
} }
bool tblgen::DagLeaf::isAttrMatcher() const { bool DagLeaf::isAttrMatcher() const {
// Attribute matchers specify an attribute constraint. // Attribute matchers specify an attribute constraint.
return isSubClassOf("AttrConstraint"); return isSubClassOf("AttrConstraint");
} }
bool tblgen::DagLeaf::isNativeCodeCall() const { bool DagLeaf::isNativeCodeCall() const {
return isSubClassOf("NativeCodeCall"); return isSubClassOf("NativeCodeCall");
} }
bool tblgen::DagLeaf::isConstantAttr() const { bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
return isSubClassOf("ConstantAttr");
}
bool tblgen::DagLeaf::isEnumAttrCase() const { bool DagLeaf::isEnumAttrCase() const {
return isSubClassOf("EnumAttrCaseInfo"); return isSubClassOf("EnumAttrCaseInfo");
} }
bool tblgen::DagLeaf::isStringAttr() const { bool DagLeaf::isStringAttr() const {
return isa<llvm::StringInit, llvm::CodeInit>(def); return isa<llvm::StringInit, llvm::CodeInit>(def);
} }
tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const { Constraint DagLeaf::getAsConstraint() const {
assert((isOperandMatcher() || isAttrMatcher()) && assert((isOperandMatcher() || isAttrMatcher()) &&
"the DAG leaf must be operand or attribute"); "the DAG leaf must be operand or attribute");
return Constraint(cast<llvm::DefInit>(def)->getDef()); return Constraint(cast<llvm::DefInit>(def)->getDef());
} }
tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const { ConstantAttr DagLeaf::getAsConstantAttr() const {
assert(isConstantAttr() && "the DAG leaf must be constant attribute"); assert(isConstantAttr() && "the DAG leaf must be constant attribute");
return ConstantAttr(cast<llvm::DefInit>(def)); return ConstantAttr(cast<llvm::DefInit>(def));
} }
tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const { EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
return EnumAttrCase(cast<llvm::DefInit>(def)); return EnumAttrCase(cast<llvm::DefInit>(def));
} }
std::string tblgen::DagLeaf::getConditionTemplate() const { std::string DagLeaf::getConditionTemplate() const {
return getAsConstraint().getConditionTemplate(); return getAsConstraint().getConditionTemplate();
} }
llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const { llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
} }
std::string tblgen::DagLeaf::getStringAttr() const { std::string DagLeaf::getStringAttr() const {
assert(isStringAttr() && "the DAG leaf must be string attribute"); assert(isStringAttr() && "the DAG leaf must be string attribute");
return def->getAsUnquotedString(); return def->getAsUnquotedString();
} }
bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { bool DagLeaf::isSubClassOf(StringRef superclass) const {
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
return defInit->getDef()->isSubClassOf(superclass); return defInit->getDef()->isSubClassOf(superclass);
return false; return false;
} }
void tblgen::DagLeaf::print(raw_ostream &os) const { void DagLeaf::print(raw_ostream &os) const {
if (def) if (def)
def->print(os); def->print(os);
} }
@ -104,28 +102,26 @@ void tblgen::DagLeaf::print(raw_ostream &os) const {
// DagNode // DagNode
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool tblgen::DagNode::isNativeCodeCall() const { bool DagNode::isNativeCodeCall() const {
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
return defInit->getDef()->isSubClassOf("NativeCodeCall"); return defInit->getDef()->isSubClassOf("NativeCodeCall");
return false; return false;
} }
bool tblgen::DagNode::isOperation() const { bool DagNode::isOperation() const {
return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
} }
llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { llvm::StringRef DagNode::getNativeCodeTemplate() const {
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
return cast<llvm::DefInit>(node->getOperator()) return cast<llvm::DefInit>(node->getOperator())
->getDef() ->getDef()
->getValueAsString("expression"); ->getValueAsString("expression");
} }
llvm::StringRef tblgen::DagNode::getSymbol() const { llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
return node->getNameStr();
}
Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
auto it = mapper->find(opDef); auto it = mapper->find(opDef);
if (it != mapper->end()) if (it != mapper->end())
@ -134,7 +130,7 @@ Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
.first->second; .first->second;
} }
int tblgen::DagNode::getNumOps() const { int DagNode::getNumOps() const {
int count = isReplaceWithValue() ? 0 : 1; int count = isReplaceWithValue() ? 0 : 1;
for (int i = 0, e = getNumArgs(); i != e; ++i) { for (int i = 0, e = getNumArgs(); i != e; ++i) {
if (auto child = getArgAsNestedDag(i)) if (auto child = getArgAsNestedDag(i))
@ -143,36 +139,36 @@ int tblgen::DagNode::getNumOps() const {
return count; return count;
} }
int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); } int DagNode::getNumArgs() const { return node->getNumArgs(); }
bool tblgen::DagNode::isNestedDagArg(unsigned index) const { bool DagNode::isNestedDagArg(unsigned index) const {
return isa<llvm::DagInit>(node->getArg(index)); return isa<llvm::DagInit>(node->getArg(index));
} }
tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const { DagNode DagNode::getArgAsNestedDag(unsigned index) const {
return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
} }
tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const { DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
assert(!isNestedDagArg(index)); assert(!isNestedDagArg(index));
return DagLeaf(node->getArg(index)); return DagLeaf(node->getArg(index));
} }
StringRef tblgen::DagNode::getArgName(unsigned index) const { StringRef DagNode::getArgName(unsigned index) const {
return node->getArgNameStr(index); return node->getArgNameStr(index);
} }
bool tblgen::DagNode::isReplaceWithValue() const { bool DagNode::isReplaceWithValue() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->getName() == "replaceWithValue"; return dagOpDef->getName() == "replaceWithValue";
} }
bool tblgen::DagNode::isLocationDirective() const { bool DagNode::isLocationDirective() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->getName() == "location"; return dagOpDef->getName() == "location";
} }
void tblgen::DagNode::print(raw_ostream &os) const { void DagNode::print(raw_ostream &os) const {
if (node) if (node)
node->print(os); node->print(os);
} }
@ -181,8 +177,7 @@ void tblgen::DagNode::print(raw_ostream &os) const {
// SymbolInfoMap // SymbolInfoMap
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol, StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
int *index) {
StringRef name, indexStr; StringRef name, indexStr;
int idx = -1; int idx = -1;
std::tie(name, indexStr) = symbol.rsplit("__"); std::tie(name, indexStr) = symbol.rsplit("__");
@ -197,12 +192,11 @@ StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol,
return name; return name;
} }
tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
SymbolInfo::Kind kind, Optional<int> index)
Optional<int> index)
: op(op), kind(kind), argIndex(index) {} : op(op), kind(kind), argIndex(index) {}
int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const { int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
switch (kind) { switch (kind) {
case Kind::Attr: case Kind::Attr:
case Kind::Operand: case Kind::Operand:
@ -214,8 +208,7 @@ int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
llvm_unreachable("unknown kind"); llvm_unreachable("unknown kind");
} }
std::string std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
switch (kind) { switch (kind) {
case Kind::Attr: { case Kind::Attr: {
@ -240,7 +233,7 @@ tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
llvm_unreachable("unknown kind"); llvm_unreachable("unknown kind");
} }
std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse( std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
StringRef name, int index, const char *fmt, const char *separator) const { StringRef name, int index, const char *fmt, const char *separator) const {
LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
switch (kind) { switch (kind) {
@ -311,7 +304,7 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
llvm_unreachable("unknown kind"); llvm_unreachable("unknown kind");
} }
std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse( std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
StringRef name, int index, const char *fmt, const char *separator) const { StringRef name, int index, const char *fmt, const char *separator) const {
LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
switch (kind) { switch (kind) {
@ -353,8 +346,8 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse(
llvm_unreachable("unknown kind"); llvm_unreachable("unknown kind");
} }
bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
int argIndex) { int argIndex) {
StringRef name = getValuePackName(symbol); StringRef name = getValuePackName(symbol);
if (name != symbol) { if (name != symbol) {
auto error = formatv( auto error = formatv(
@ -369,26 +362,25 @@ bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
return symbolInfoMap.insert({symbol, symInfo}).second; return symbolInfoMap.insert({symbol, symInfo}).second;
} }
bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
StringRef name = getValuePackName(symbol); StringRef name = getValuePackName(symbol);
return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second; return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
} }
bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) { bool SymbolInfoMap::bindValue(StringRef symbol) {
return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
} }
bool tblgen::SymbolInfoMap::contains(StringRef symbol) const { bool SymbolInfoMap::contains(StringRef symbol) const {
return find(symbol) != symbolInfoMap.end(); return find(symbol) != symbolInfoMap.end();
} }
tblgen::SymbolInfoMap::const_iterator SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
tblgen::SymbolInfoMap::find(StringRef key) const {
StringRef name = getValuePackName(key); StringRef name = getValuePackName(key);
return symbolInfoMap.find(name); return symbolInfoMap.find(name);
} }
int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const { int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
StringRef name = getValuePackName(symbol); StringRef name = getValuePackName(symbol);
if (name != symbol) { if (name != symbol) {
// If there is a trailing index inside symbol, it references just one // If there is a trailing index inside symbol, it references just one
@ -399,9 +391,9 @@ int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
return find(name)->getValue().getStaticValueCount(); return find(name)->getValue().getStaticValueCount();
} }
std::string std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt, const char *fmt,
const char *separator) const { const char *separator) const {
int index = -1; int index = -1;
StringRef name = getValuePackName(symbol, &index); StringRef name = getValuePackName(symbol, &index);
@ -414,9 +406,8 @@ tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt,
return it->getValue().getValueAndRangeUse(name, index, fmt, separator); return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
} }
std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol, std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
const char *fmt, const char *separator) const {
const char *separator) const {
int index = -1; int index = -1;
StringRef name = getValuePackName(symbol, &index); StringRef name = getValuePackName(symbol, &index);
@ -433,32 +424,30 @@ std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol,
// Pattern // Pattern
//==----------------------------------------------------------------------===// //==----------------------------------------------------------------------===//
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
: def(*def), recordOpMap(mapper) {} : def(*def), recordOpMap(mapper) {}
tblgen::DagNode tblgen::Pattern::getSourcePattern() const { DagNode Pattern::getSourcePattern() const {
return tblgen::DagNode(def.getValueAsDag("sourcePattern")); return DagNode(def.getValueAsDag("sourcePattern"));
} }
int tblgen::Pattern::getNumResultPatterns() const { int Pattern::getNumResultPatterns() const {
auto *results = def.getValueAsListInit("resultPatterns"); auto *results = def.getValueAsListInit("resultPatterns");
return results->size(); return results->size();
} }
tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { DagNode Pattern::getResultPattern(unsigned index) const {
auto *results = def.getValueAsListInit("resultPatterns"); auto *results = def.getValueAsListInit("resultPatterns");
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index))); return DagNode(cast<llvm::DagInit>(results->getElement(index)));
} }
void tblgen::Pattern::collectSourcePatternBoundSymbols( void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
tblgen::SymbolInfoMap &infoMap) {
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
} }
void tblgen::Pattern::collectResultPatternBoundSymbols( void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
tblgen::SymbolInfoMap &infoMap) {
LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
auto pattern = getResultPattern(i); auto pattern = getResultPattern(i);
@ -467,17 +456,17 @@ void tblgen::Pattern::collectResultPatternBoundSymbols(
LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
} }
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { const Operator &Pattern::getSourceRootOp() {
return getSourcePattern().getDialectOp(recordOpMap); return getSourcePattern().getDialectOp(recordOpMap);
} }
tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) { Operator &Pattern::getDialectOp(DagNode node) {
return node.getDialectOp(recordOpMap); return node.getDialectOp(recordOpMap);
} }
std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const { std::vector<AppliedConstraint> Pattern::getConstraints() const {
auto *listInit = def.getValueAsListInit("constraints"); auto *listInit = def.getValueAsListInit("constraints");
std::vector<tblgen::AppliedConstraint> ret; std::vector<AppliedConstraint> ret;
ret.reserve(listInit->size()); ret.reserve(listInit->size());
for (auto it : *listInit) { for (auto it : *listInit) {
@ -503,7 +492,7 @@ std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
return ret; return ret;
} }
int tblgen::Pattern::getBenefit() const { int Pattern::getBenefit() const {
// The initial benefit value is a heuristic with number of ops in the source // The initial benefit value is a heuristic with number of ops in the source
// pattern. // pattern.
int initBenefit = getSourcePattern().getNumOps(); int initBenefit = getSourcePattern().getNumOps();
@ -515,8 +504,7 @@ int tblgen::Pattern::getBenefit() const {
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
} }
std::vector<tblgen::Pattern::IdentifierLine> std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
tblgen::Pattern::getLocation() const {
std::vector<std::pair<StringRef, unsigned>> result; std::vector<std::pair<StringRef, unsigned>> result;
result.reserve(def.getLoc().size()); result.reserve(def.getLoc().size());
for (auto loc : def.getLoc()) { for (auto loc : def.getLoc()) {
@ -529,8 +517,8 @@ tblgen::Pattern::getLocation() const {
return result; return result;
} }
void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
bool isSrcPattern) { bool isSrcPattern) {
auto treeName = tree.getSymbol(); auto treeName = tree.getSymbol();
if (!tree.isOperation()) { if (!tree.isOperation()) {
if (!treeName.empty()) { if (!treeName.empty()) {

View File

@ -19,20 +19,21 @@
#include "llvm/TableGen/Record.h" #include "llvm/TableGen/Record.h"
using namespace mlir; using namespace mlir;
using namespace tblgen;
// Construct a Predicate from a record. // Construct a Predicate from a record.
tblgen::Pred::Pred(const llvm::Record *record) : def(record) { Pred::Pred(const llvm::Record *record) : def(record) {
assert(def->isSubClassOf("Pred") && assert(def->isSubClassOf("Pred") &&
"must be a subclass of TableGen 'Pred' class"); "must be a subclass of TableGen 'Pred' class");
} }
// Construct a Predicate from an initializer. // Construct a Predicate from an initializer.
tblgen::Pred::Pred(const llvm::Init *init) : def(nullptr) { Pred::Pred(const llvm::Init *init) : def(nullptr) {
if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init)) if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
def = defInit->getDef(); def = defInit->getDef();
} }
std::string tblgen::Pred::getCondition() const { std::string Pred::getCondition() const {
// Static dispatch to subclasses. // Static dispatch to subclasses.
if (def->isSubClassOf("CombinedPred")) if (def->isSubClassOf("CombinedPred"))
return static_cast<const CombinedPred *>(this)->getConditionImpl(); return static_cast<const CombinedPred *>(this)->getConditionImpl();
@ -41,44 +42,44 @@ std::string tblgen::Pred::getCondition() const {
llvm_unreachable("Pred::getCondition must be overridden in subclasses"); llvm_unreachable("Pred::getCondition must be overridden in subclasses");
} }
bool tblgen::Pred::isCombined() const { bool Pred::isCombined() const {
return def && def->isSubClassOf("CombinedPred"); return def && def->isSubClassOf("CombinedPred");
} }
ArrayRef<llvm::SMLoc> tblgen::Pred::getLoc() const { return def->getLoc(); } ArrayRef<llvm::SMLoc> Pred::getLoc() const { return def->getLoc(); }
tblgen::CPred::CPred(const llvm::Record *record) : Pred(record) { CPred::CPred(const llvm::Record *record) : Pred(record) {
assert(def->isSubClassOf("CPred") && assert(def->isSubClassOf("CPred") &&
"must be a subclass of Tablegen 'CPred' class"); "must be a subclass of Tablegen 'CPred' class");
} }
tblgen::CPred::CPred(const llvm::Init *init) : Pred(init) { CPred::CPred(const llvm::Init *init) : Pred(init) {
assert((!def || def->isSubClassOf("CPred")) && assert((!def || def->isSubClassOf("CPred")) &&
"must be a subclass of Tablegen 'CPred' class"); "must be a subclass of Tablegen 'CPred' class");
} }
// Get condition of the C Predicate. // Get condition of the C Predicate.
std::string tblgen::CPred::getConditionImpl() const { std::string CPred::getConditionImpl() const {
assert(!isNull() && "null predicate does not have a condition"); assert(!isNull() && "null predicate does not have a condition");
return std::string(def->getValueAsString("predExpr")); return std::string(def->getValueAsString("predExpr"));
} }
tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
assert(def->isSubClassOf("CombinedPred") && assert(def->isSubClassOf("CombinedPred") &&
"must be a subclass of Tablegen 'CombinedPred' class"); "must be a subclass of Tablegen 'CombinedPred' class");
} }
tblgen::CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) { CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
assert((!def || def->isSubClassOf("CombinedPred")) && assert((!def || def->isSubClassOf("CombinedPred")) &&
"must be a subclass of Tablegen 'CombinedPred' class"); "must be a subclass of Tablegen 'CombinedPred' class");
} }
const llvm::Record *tblgen::CombinedPred::getCombinerDef() const { const llvm::Record *CombinedPred::getCombinerDef() const {
assert(def->getValue("kind") && "CombinedPred must have a value 'kind'"); assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
return def->getValueAsDef("kind"); return def->getValueAsDef("kind");
} }
const std::vector<llvm::Record *> tblgen::CombinedPred::getChildren() const { const std::vector<llvm::Record *> CombinedPred::getChildren() const {
assert(def->getValue("children") && assert(def->getValue("children") &&
"CombinedPred must have a value 'children'"); "CombinedPred must have a value 'children'");
return def->getValueAsListOfDefs("children"); return def->getValueAsListOfDefs("children");
@ -101,7 +102,7 @@ enum class PredCombinerKind {
// A node in a logical predicate tree. // A node in a logical predicate tree.
struct PredNode { struct PredNode {
PredCombinerKind kind; PredCombinerKind kind;
const tblgen::Pred *predicate; const Pred *predicate;
SmallVector<PredNode *, 4> children; SmallVector<PredNode *, 4> children;
std::string expr; std::string expr;
@ -113,11 +114,11 @@ struct PredNode {
// Get a predicate tree node kind based on the kind used in the predicate // Get a predicate tree node kind based on the kind used in the predicate
// TableGen record. // TableGen record.
static PredCombinerKind getPredCombinerKind(const tblgen::Pred &pred) { static PredCombinerKind getPredCombinerKind(const Pred &pred) {
if (!pred.isCombined()) if (!pred.isCombined())
return PredCombinerKind::Leaf; return PredCombinerKind::Leaf;
const auto &combinedPred = static_cast<const tblgen::CombinedPred &>(pred); const auto &combinedPred = static_cast<const CombinedPred &>(pred);
return llvm::StringSwitch<PredCombinerKind>( return llvm::StringSwitch<PredCombinerKind>(
combinedPred.getCombinerDef()->getName()) combinedPred.getCombinerDef()->getName())
.Case("PredCombinerAnd", PredCombinerKind::And) .Case("PredCombinerAnd", PredCombinerKind::And)
@ -137,7 +138,7 @@ using Subst = std::pair<StringRef, StringRef>;
// substitution, nodes are still pointing to the original TableGen record. // substitution, nodes are still pointing to the original TableGen record.
// All nodes are created within "allocator". // All nodes are created within "allocator".
static PredNode * static PredNode *
buildPredicateTree(const tblgen::Pred &root, buildPredicateTree(const Pred &root,
llvm::SpecificBumpPtrAllocator<PredNode> &allocator, llvm::SpecificBumpPtrAllocator<PredNode> &allocator,
ArrayRef<Subst> substitutions) { ArrayRef<Subst> substitutions) {
auto *rootNode = allocator.Allocate(); auto *rootNode = allocator.Allocate();
@ -166,22 +167,22 @@ buildPredicateTree(const tblgen::Pred &root,
// list before continuing. // list before continuing.
auto allSubstitutions = llvm::to_vector<4>(substitutions); auto allSubstitutions = llvm::to_vector<4>(substitutions);
if (rootNode->kind == PredCombinerKind::SubstLeaves) { if (rootNode->kind == PredCombinerKind::SubstLeaves) {
const auto &substPred = static_cast<const tblgen::SubstLeavesPred &>(root); const auto &substPred = static_cast<const SubstLeavesPred &>(root);
allSubstitutions.push_back( allSubstitutions.push_back(
{substPred.getPattern(), substPred.getReplacement()}); {substPred.getPattern(), substPred.getReplacement()});
} }
// If the current predicate is a ConcatPred, record the prefix and suffix. // If the current predicate is a ConcatPred, record the prefix and suffix.
else if (rootNode->kind == PredCombinerKind::Concat) { else if (rootNode->kind == PredCombinerKind::Concat) {
const auto &concatPred = static_cast<const tblgen::ConcatPred &>(root); const auto &concatPred = static_cast<const ConcatPred &>(root);
rootNode->prefix = std::string(concatPred.getPrefix()); rootNode->prefix = std::string(concatPred.getPrefix());
rootNode->suffix = std::string(concatPred.getSuffix()); rootNode->suffix = std::string(concatPred.getSuffix());
} }
// Build child subtrees. // Build child subtrees.
auto combined = static_cast<const tblgen::CombinedPred &>(root); auto combined = static_cast<const CombinedPred &>(root);
for (const auto *record : combined.getChildren()) { for (const auto *record : combined.getChildren()) {
auto childTree = auto childTree =
buildPredicateTree(tblgen::Pred(record), allocator, allSubstitutions); buildPredicateTree(Pred(record), allocator, allSubstitutions);
rootNode->children.push_back(childTree); rootNode->children.push_back(childTree);
} }
return rootNode; return rootNode;
@ -192,9 +193,10 @@ buildPredicateTree(const tblgen::Pred &root,
// children is known to be false(true), the result is also false(true). // children is known to be false(true), the result is also false(true).
// Furthermore, for AND(OR) combined predicates, children that are known to be // Furthermore, for AND(OR) combined predicates, children that are known to be
// true(false) don't have to be checked dynamically. // true(false) don't have to be checked dynamically.
static PredNode *propagateGroundTruth( static PredNode *
PredNode *node, const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownTruePreds, propagateGroundTruth(PredNode *node,
const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownFalsePreds) { const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
// If the current predicate is known to be true or false, change the kind of // If the current predicate is known to be true or false, change the kind of
// the node and return immediately. // the node and return immediately.
if (knownTruePreds.count(node->predicate) != 0) { if (knownTruePreds.count(node->predicate) != 0) {
@ -339,29 +341,29 @@ static std::string getCombinedCondition(const PredNode &root) {
llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind"); llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
} }
std::string tblgen::CombinedPred::getConditionImpl() const { std::string CombinedPred::getConditionImpl() const {
llvm::SpecificBumpPtrAllocator<PredNode> allocator; llvm::SpecificBumpPtrAllocator<PredNode> allocator;
auto predicateTree = buildPredicateTree(*this, allocator, {}); auto predicateTree = buildPredicateTree(*this, allocator, {});
predicateTree = propagateGroundTruth( predicateTree =
predicateTree, propagateGroundTruth(predicateTree,
/*knownTruePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>(), /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
/*knownFalsePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>()); /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
return getCombinedCondition(*predicateTree); return getCombinedCondition(*predicateTree);
} }
StringRef tblgen::SubstLeavesPred::getPattern() const { StringRef SubstLeavesPred::getPattern() const {
return def->getValueAsString("pattern"); return def->getValueAsString("pattern");
} }
StringRef tblgen::SubstLeavesPred::getReplacement() const { StringRef SubstLeavesPred::getReplacement() const {
return def->getValueAsString("replacement"); return def->getValueAsString("replacement");
} }
StringRef tblgen::ConcatPred::getPrefix() const { StringRef ConcatPred::getPrefix() const {
return def->getValueAsString("prefix"); return def->getValueAsString("prefix");
} }
StringRef tblgen::ConcatPred::getSuffix() const { StringRef ConcatPred::getSuffix() const {
return def->getValueAsString("suffix"); return def->getValueAsString("suffix");
} }