[MLIR][TableGen] Use const pointers for various Init objects (#112562)

This reverts commit 0eed3055511381436ee69d1caf64a4af47f8d65c and applies
additional fixes in `verifyArgument` in OmpOpGen.cpp for gcc-7 bot
failures
This commit is contained in:
Rahul Joshi 2024-10-16 11:46:38 -07:00 committed by GitHub
parent 875afa939d
commit e768b076e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 78 additions and 68 deletions

View File

@ -105,7 +105,7 @@ public:
std::optional<StringRef> getDefaultValue() const;
/// Return the underlying def of this parameter.
llvm::Init *getDef() const;
const llvm::Init *getDef() const;
/// The parameter is pointer-comparable.
bool operator==(const AttrOrTypeParameter &other) const {

View File

@ -92,7 +92,7 @@ public:
/// dialect.
bool usePropertiesForAttributes() const;
llvm::DagInit *getDiscardableAttributes() const;
const llvm::DagInit *getDiscardableAttributes() const;
const llvm::Record *getDef() const { return def; }

View File

@ -119,14 +119,15 @@ public:
/// A utility iterator over a list of variable decorators.
struct VariableDecoratorIterator
: public llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)> {
: public llvm::mapped_iterator<const llvm::Init *const *,
VariableDecorator (*)(
const llvm::Init *)> {
/// Initializes the iterator to the specified iterator.
VariableDecoratorIterator(llvm::Init *const *it)
: llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)>(it,
&unwrap) {}
static VariableDecorator unwrap(llvm::Init *init);
VariableDecoratorIterator(const llvm::Init *const *it)
: llvm::mapped_iterator<const llvm::Init *const *,
VariableDecorator (*)(const llvm::Init *)>(
it, &unwrap) {}
static VariableDecorator unwrap(const llvm::Init *init);
};
using var_decorator_iterator = VariableDecoratorIterator;
using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;

View File

@ -40,7 +40,7 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (llvm::Init *init : builderList->getValues()) {
for (const llvm::Init *init : builderList->getValues()) {
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
def->getLoc());
@ -58,8 +58,8 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
if (auto *traitList = def->getValueAsListInit("traits")) {
SmallPtrSet<const llvm::Init *, 32> traitSet;
traits.reserve(traitSet.size());
llvm::unique_function<void(llvm::ListInit *)> processTraitList =
[&](llvm::ListInit *traitList) {
llvm::unique_function<void(const llvm::ListInit *)> processTraitList =
[&](const llvm::ListInit *traitList) {
for (auto *traitInit : *traitList) {
if (!traitSet.insert(traitInit).second)
continue;
@ -335,7 +335,9 @@ std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
return result && !result->empty() ? result : std::nullopt;
}
llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
const llvm::Init *AttrOrTypeParameter::getDef() const {
return def->getArg(index);
}
std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
@ -349,7 +351,7 @@ std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
//===----------------------------------------------------------------------===//
bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
llvm::Init *paramDef = param->getDef();
const llvm::Init *paramDef = param->getDef();
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
return false;

View File

@ -126,7 +126,7 @@ StringRef Attribute::getDerivedCodeBody() const {
Dialect Attribute::getDialect() const {
const llvm::RecordVal *record = def->getValue("dialect");
if (record && record->getValue()) {
if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
if (const DefInit *init = dyn_cast<DefInit>(record->getValue()))
return Dialect(init->getDef());
}
return Dialect(nullptr);

View File

@ -106,7 +106,7 @@ bool Dialect::usePropertiesForAttributes() const {
return def->getValueAsBit("usePropertiesForAttributes");
}
llvm::DagInit *Dialect::getDiscardableAttributes() const {
const llvm::DagInit *Dialect::getDiscardableAttributes() const {
return def->getValueAsDag("discardableAttrs");
}

View File

@ -22,7 +22,7 @@ using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) {
llvm::DagInit *args = def->getValueAsDag("arguments");
const llvm::DagInit *args = def->getValueAsDag("arguments");
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
arguments.push_back(
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
@ -78,7 +78,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {
// Initialize the interface methods.
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
for (llvm::Init *init : listInit->getValues())
for (const llvm::Init *init : listInit->getValues())
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
// Initialize the interface base classes.
@ -98,7 +98,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {
baseInterfaces.push_back(std::make_unique<Interface>(baseInterface));
basesAdded.insert(baseInterface.getName());
};
for (llvm::Init *init : basesInit->getValues())
for (const llvm::Init *init : basesInit->getValues())
addBaseInterfaceFn(Interface(cast<llvm::DefInit>(init)->getDef()));
}

View File

@ -161,7 +161,7 @@ std::string Operator::getQualCppClassName() const {
StringRef Operator::getCppNamespace() const { return cppNamespace; }
int Operator::getNumResults() const {
DagInit *results = def.getValueAsDag("results");
const DagInit *results = def.getValueAsDag("results");
return results->getNumArgs();
}
@ -198,12 +198,12 @@ auto Operator::getResults() const -> const_value_range {
}
TypeConstraint Operator::getResultTypeConstraint(int index) const {
DagInit *results = def.getValueAsDag("results");
const DagInit *results = def.getValueAsDag("results");
return TypeConstraint(cast<DefInit>(results->getArg(index)));
}
StringRef Operator::getResultName(int index) const {
DagInit *results = def.getValueAsDag("results");
const DagInit *results = def.getValueAsDag("results");
return results->getArgNameStr(index);
}
@ -241,7 +241,7 @@ Operator::arg_range Operator::getArgs() const {
}
StringRef Operator::getArgName(int index) const {
DagInit *argumentValues = def.getValueAsDag("arguments");
const DagInit *argumentValues = def.getValueAsDag("arguments");
return argumentValues->getArgNameStr(index);
}
@ -557,7 +557,7 @@ void Operator::populateOpStructure() {
auto *opVarClass = recordKeeper.getClass("OpVariable");
numNativeAttributes = 0;
DagInit *argumentValues = def.getValueAsDag("arguments");
const DagInit *argumentValues = def.getValueAsDag("arguments");
unsigned numArgs = argumentValues->getNumArgs();
// Mapping from name of to argument or result index. Arguments are indexed
@ -721,8 +721,8 @@ void Operator::populateOpStructure() {
" to precede it in traits list");
};
std::function<void(llvm::ListInit *)> insert;
insert = [&](llvm::ListInit *traitList) {
std::function<void(const llvm::ListInit *)> insert;
insert = [&](const llvm::ListInit *traitList) {
for (auto *traitInit : *traitList) {
auto *def = cast<DefInit>(traitInit)->getDef();
if (def->isSubClassOf("TraitList")) {
@ -780,7 +780,7 @@ void Operator::populateOpStructure() {
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (llvm::Init *init : builderList->getValues())
for (const llvm::Init *init : builderList->getValues())
builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
} else if (skipDefaultBuilders()) {
PrintFatalError(
@ -818,7 +818,8 @@ bool Operator::hasAssemblyFormat() const {
}
StringRef Operator::getAssemblyFormat() const {
return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
return TypeSwitch<const llvm::Init *, StringRef>(
def.getValueInit("assemblyFormat"))
.Case<llvm::StringInit>([&](auto *init) { return init->getValue(); });
}
@ -832,7 +833,7 @@ void Operator::print(llvm::raw_ostream &os) const {
}
}
auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
auto Operator::VariableDecoratorIterator::unwrap(const llvm::Init *init)
-> VariableDecorator {
return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
}

View File

@ -700,7 +700,7 @@ int Pattern::getBenefit() const {
// The initial benefit value is a heuristic with number of ops in the source
// pattern.
int initBenefit = getSourcePattern().getNumOps();
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
const llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
PrintFatalError(&def,
"The 'addBenefit' takes and only takes one integer value");

View File

@ -50,7 +50,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
const llvm::RecordVal *builderCall = baseType->getValue("builderCall");
if (!builderCall || !builderCall->getValue())
return std::nullopt;
return TypeSwitch<llvm::Init *, std::optional<StringRef>>(
return TypeSwitch<const llvm::Init *, std::optional<StringRef>>(
builderCall->getValue())
.Case<llvm::StringInit>([&](auto *init) {
StringRef value = init->getValue();

View File

@ -30,8 +30,8 @@ enum DeprecatedAction { None, Warn, Error };
static DeprecatedAction actionOnDeprecatedValue;
// Returns if there is a use of `deprecatedInit` in `field`.
static bool findUse(Init *field, Init *deprecatedInit,
llvm::DenseMap<Init *, bool> &known) {
static bool findUse(const Init *field, const Init *deprecatedInit,
llvm::DenseMap<const Init *, bool> &known) {
if (field == deprecatedInit)
return true;
@ -64,13 +64,13 @@ static bool findUse(Init *field, Init *deprecatedInit,
if (findUse(dagInit->getOperator(), deprecatedInit, known))
return memoize(true);
return memoize(llvm::any_of(dagInit->getArgs(), [&](Init *arg) {
return memoize(llvm::any_of(dagInit->getArgs(), [&](const Init *arg) {
return findUse(arg, deprecatedInit, known);
}));
}
if (ListInit *li = dyn_cast<ListInit>(field)) {
return memoize(llvm::any_of(li->getValues(), [&](Init *jt) {
if (const ListInit *li = dyn_cast<ListInit>(field)) {
return memoize(llvm::any_of(li->getValues(), [&](const Init *jt) {
return findUse(jt, deprecatedInit, known);
}));
}
@ -83,8 +83,8 @@ static bool findUse(Init *field, Init *deprecatedInit,
}
// Returns if there is a use of `deprecatedInit` in `record`.
static bool findUse(Record &record, Init *deprecatedInit,
llvm::DenseMap<Init *, bool> &known) {
static bool findUse(Record &record, const Init *deprecatedInit,
llvm::DenseMap<const Init *, bool> &known) {
return llvm::any_of(record.getValues(), [&](const RecordVal &val) {
return findUse(val.getValue(), deprecatedInit, known);
});
@ -100,7 +100,7 @@ static void warnOfDeprecatedUses(const RecordKeeper &records) {
if (!r || !r->getValue())
continue;
llvm::DenseMap<Init *, bool> hasUse;
llvm::DenseMap<const Init *, bool> hasUse;
if (auto *si = dyn_cast<StringInit>(r->getValue())) {
for (auto &jt : records.getDefs()) {
// Skip anonymous defs.

View File

@ -46,8 +46,9 @@ public:
private:
/// Emits parse calls to construct given kind.
void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
ArrayRef<Init *> args, ArrayRef<std::string> argNames,
StringRef failure, mlir::raw_indented_ostream &ios);
ArrayRef<const Init *> args,
ArrayRef<std::string> argNames, StringRef failure,
mlir::raw_indented_ostream &ios);
/// Emits print instructions.
void emitPrintHelper(const Record *memberRec, StringRef kind,
@ -135,10 +136,12 @@ void Generator::emitParse(StringRef kind, const Record &x) {
R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
mlir::raw_indented_ostream os(output);
std::string returnType = getCType(&x);
os << formatv(head, kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", x.getName());
DagInit *members = x.getValueAsDag("members");
SmallVector<std::string> argNames =
llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
os << formatv(head,
kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type",
x.getName());
const DagInit *members = x.getValueAsDag("members");
SmallVector<std::string> argNames = llvm::to_vector(
map_range(members->getArgNames(), [](const StringInit *init) {
return init->getAsUnquotedString();
}));
StringRef builder = x.getValueAsString("cBuilder").trim();
@ -148,7 +151,7 @@ void Generator::emitParse(StringRef kind, const Record &x) {
}
void printParseConditional(mlir::raw_indented_ostream &ios,
ArrayRef<Init *> args,
ArrayRef<const Init *> args,
ArrayRef<std::string> argNames) {
ios << "if ";
auto parenScope = ios.scope("(", ") {");
@ -159,7 +162,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
};
auto parsedArgs =
llvm::to_vector(make_filter_range(args, [](Init *const attr) {
llvm::to_vector(make_filter_range(args, [](const Init *const attr) {
const Record *def = cast<DefInit>(attr)->getDef();
if (def->isSubClassOf("Array"))
return true;
@ -168,7 +171,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
interleave(
zip(parsedArgs, argNames),
[&](std::tuple<llvm::Init *&, const std::string &> it) {
[&](std::tuple<const Init *&, const std::string &> it) {
const Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
std::string parser;
if (auto optParser = attr->getValueAsOptionalString("cParser")) {
@ -196,7 +199,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
}
void Generator::emitParseHelper(StringRef kind, StringRef returnType,
StringRef builder, ArrayRef<Init *> args,
StringRef builder, ArrayRef<const Init *> args,
ArrayRef<std::string> argNames,
StringRef failure,
mlir::raw_indented_ostream &ios) {
@ -210,7 +213,7 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
// Print decls.
std::string lastCType = "";
for (auto [arg, name] : zip(args, argNames)) {
DefInit *first = dyn_cast<DefInit>(arg);
const DefInit *first = dyn_cast<DefInit>(arg);
if (!first)
PrintFatalError("Unexpected type for " + name);
const Record *def = first->getDef();
@ -251,13 +254,14 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
std::string returnType = getCType(def);
ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
<< returnType << "> ";
SmallVector<Init *> args;
SmallVector<const Init *> args;
SmallVector<std::string> argNames;
if (def->isSubClassOf("CompositeBytecode")) {
DagInit *members = def->getValueAsDag("members");
args = llvm::to_vector(members->getArgs());
const DagInit *members = def->getValueAsDag("members");
args = llvm::to_vector(map_range(
members->getArgs(), [](Init *init) { return (const Init *)init; }));
argNames = llvm::to_vector(
map_range(members->getArgNames(), [](StringInit *init) {
map_range(members->getArgNames(), [](const StringInit *init) {
return init->getAsUnquotedString();
}));
} else {
@ -332,7 +336,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
auto *members = rec->getValueAsDag("members");
for (auto [arg, name] :
llvm::zip(members->getArgs(), members->getArgNames())) {
DefInit *def = dyn_cast<DefInit>(arg);
const DefInit *def = dyn_cast<DefInit>(arg);
assert(def);
const Record *memberRec = def->getDef();
emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
@ -385,7 +389,7 @@ void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
auto *members = memberRec->getValueAsDag("members");
for (auto [arg, argName] :
zip(members->getArgs(), members->getArgNames())) {
DefInit *def = dyn_cast<DefInit>(arg);
const DefInit *def = dyn_cast<DefInit>(arg);
assert(def);
emitPrintHelper(def->getDef(), kind, parent,
argName->getAsUnquotedString(), ios);

View File

@ -46,10 +46,10 @@ using DialectFilterIterator =
} // namespace
static void populateDiscardableAttributes(
Dialect &dialect, llvm::DagInit *discardableAttrDag,
Dialect &dialect, const llvm::DagInit *discardableAttrDag,
SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
llvm::Init *arg = discardableAttrDag->getArg(i);
const llvm::Init *arg = discardableAttrDag->getArg(i);
StringRef givenName = discardableAttrDag->getArgNameStr(i);
if (givenName.empty())
@ -271,7 +271,8 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
if (dialect.hasOperationInterfaceFallback())
os << operationInterfaceFallbackDecl;
llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
const llvm::DagInit *discardableAttrDag =
dialect.getDiscardableAttributes();
SmallVector<std::pair<std::string, std::string>> discardableAttributes;
populateDiscardableAttributes(dialect, discardableAttrDag,
discardableAttributes);
@ -370,7 +371,7 @@ static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
StringRef superClassName =
dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
const llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
SmallVector<std::pair<std::string, std::string>> discardableAttributes;
populateDiscardableAttributes(dialect, discardableAttrDag,
discardableAttributes);

View File

@ -102,11 +102,11 @@ static StringRef extractOmpClauseName(const Record *clause) {
/// Check that the given argument, identified by its name and initialization
/// value, is present in the \c arguments `dag`.
static bool verifyArgument(DagInit *arguments, StringRef argName,
Init *argInit) {
static bool verifyArgument(const DagInit *arguments, StringRef argName,
const Init *argInit) {
auto range = zip_equal(arguments->getArgNames(), arguments->getArgs());
return llvm::any_of(
range, [&](std::tuple<llvm::StringInit *const &, llvm::Init *const &> v) {
range, [&](std::tuple<const llvm::StringInit *, const llvm::Init *> v) {
return std::get<0>(v)->getAsUnquotedString() == argName &&
std::get<1>(v) == argInit;
});
@ -141,8 +141,8 @@ static void verifyClause(const Record *op, const Record *clause) {
StringRef clauseClassName = extractOmpClauseName(clause);
if (!clause->getValueAsBit("ignoreArgs")) {
DagInit *opArguments = op->getValueAsDag("arguments");
DagInit *arguments = clause->getValueAsDag("arguments");
const DagInit *opArguments = op->getValueAsDag("arguments");
const DagInit *arguments = clause->getValueAsDag("arguments");
for (auto [name, arg] :
zip(arguments->getArgNames(), arguments->getArgs())) {
@ -208,8 +208,9 @@ static void verifyClause(const Record *op, const Record *clause) {
///
/// \return the name of the base type to represent elements of the argument
/// type.
static StringRef translateArgumentType(ArrayRef<SMLoc> loc, StringInit *name,
Init *init, int &nest, int &rank) {
static StringRef translateArgumentType(ArrayRef<SMLoc> loc,
const StringInit *name, const Init *init,
int &nest, int &rank) {
const Record *def = cast<DefInit>(init)->getDef();
llvm::StringSet<> superClasses;
@ -282,7 +283,7 @@ static void genClauseOpsStruct(const Record *clause, raw_ostream &os) {
StringRef clauseName = extractOmpClauseName(clause);
os << "struct " << clauseName << "ClauseOps {\n";
DagInit *arguments = clause->getValueAsDag("arguments");
const DagInit *arguments = clause->getValueAsDag("arguments");
for (auto [name, arg] :
zip_equal(arguments->getArgNames(), arguments->getArgs())) {
int nest = 0, rank = 1;