mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 04:56:07 +00:00
[mlir] Fix infinite recursion in alias initializer
The alias initializer keeps a list of child indices around. When an alias is then marked as non-deferrable, all children are also marked non-deferrable. This is currently done naively which leads to an infinite recursion if using mutable types or attributes containing a cycle. This patch fixes this by adding an early return if the alias is already marked non-deferrable. Since this function is the only way to mark an alias as non-deferrable, it is guaranteed that if it is marked non-deferrable, all its children are as well, and it is not required to walk all the children. This incidentally makes the non-deferrable marking also `O(n)` instead of `O(n^2)` (although not performance sensitive obviously). Differential Revision: https://reviews.llvm.org/D158932
This commit is contained in:
parent
57390c914b
commit
de3f7e2f0f
@ -1056,6 +1056,12 @@ std::pair<size_t, size_t> AliasInitializer::visitImpl(
|
||||
|
||||
void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
|
||||
auto it = std::next(aliases.begin(), aliasIndex);
|
||||
|
||||
// If already marked non-deferrable stop the recursion.
|
||||
// All children should already be marked non-deferrable as well.
|
||||
if (!it->second.canBeDeferred)
|
||||
return;
|
||||
|
||||
it->second.canBeDeferred = false;
|
||||
|
||||
// Propagate the non-deferrable flag to any child aliases.
|
||||
|
@ -1,6 +1,8 @@
|
||||
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
|
||||
|
||||
// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
|
||||
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
|
||||
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
|
||||
|
||||
// CHECK-LABEL: @roundtrip
|
||||
func.func @roundtrip() {
|
||||
@ -12,6 +14,16 @@ func.func @roundtrip() {
|
||||
// into inifinite recursion.
|
||||
// CHECK: !testrec
|
||||
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
|
||||
|
||||
// CHECK: () -> ![[$NAME]]
|
||||
// CHECK: () -> ![[$NAME]]
|
||||
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
|
||||
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
|
||||
|
||||
// CHECK: () -> ![[$NAME2]]
|
||||
// CHECK: () -> ![[$NAME2]]
|
||||
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
|
||||
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -218,6 +218,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
}
|
||||
if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
|
||||
os << recAliasType.getName();
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
return AliasResult::NoAlias;
|
||||
}
|
||||
|
||||
|
@ -373,4 +373,22 @@ def TestI32 : Test_Type<"TestI32"> {
|
||||
let mnemonic = "i32";
|
||||
}
|
||||
|
||||
def TestRecursiveAlias
|
||||
: Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> {
|
||||
let mnemonic = "test_rec_alias";
|
||||
let storageClass = "TestRecursiveTypeStorage";
|
||||
let storageNamespace = "test";
|
||||
let genStorageClass = 0;
|
||||
|
||||
let parameters = (ins "llvm::StringRef":$name);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Type getBody() const;
|
||||
|
||||
void setBody(Type type);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TEST_TYPEDEFS
|
||||
|
@ -482,3 +482,54 @@ void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
|
||||
SetVector<Type> stack;
|
||||
printTestType(type, printer, stack);
|
||||
}
|
||||
|
||||
Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }
|
||||
|
||||
void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }
|
||||
|
||||
StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }
|
||||
|
||||
Type TestRecursiveAliasType::parse(AsmParser &parser) {
|
||||
thread_local static SetVector<Type> stack;
|
||||
|
||||
StringRef name;
|
||||
if (parser.parseLess() || parser.parseKeyword(&name))
|
||||
return Type();
|
||||
auto rec = TestRecursiveAliasType::get(parser.getContext(), name);
|
||||
|
||||
// If this type already has been parsed above in the stack, expect just the
|
||||
// name.
|
||||
if (stack.contains(rec)) {
|
||||
if (failed(parser.parseGreater()))
|
||||
return Type();
|
||||
return rec;
|
||||
}
|
||||
|
||||
// Otherwise, parse the body and update the type.
|
||||
if (failed(parser.parseComma()))
|
||||
return Type();
|
||||
stack.insert(rec);
|
||||
Type subtype;
|
||||
if (parser.parseType(subtype))
|
||||
return nullptr;
|
||||
stack.pop_back();
|
||||
if (!subtype || failed(parser.parseGreater()))
|
||||
return Type();
|
||||
|
||||
rec.setBody(subtype);
|
||||
|
||||
return rec;
|
||||
}
|
||||
|
||||
void TestRecursiveAliasType::print(AsmPrinter &printer) const {
|
||||
thread_local static SetVector<Type> stack;
|
||||
|
||||
printer << "<" << getName();
|
||||
if (!stack.contains(*this)) {
|
||||
printer << ", ";
|
||||
stack.insert(*this);
|
||||
printer << getBody();
|
||||
stack.pop_back();
|
||||
}
|
||||
printer << ">";
|
||||
}
|
||||
|
@ -91,9 +91,6 @@ struct FieldParser<std::optional<int>> {
|
||||
|
||||
#include "TestTypeInterfaces.h.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "TestTypeDefs.h.inc"
|
||||
|
||||
namespace test {
|
||||
|
||||
/// Storage for simple named recursive types, where the type is identified by
|
||||
@ -150,4 +147,7 @@ public:
|
||||
|
||||
} // namespace test
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "TestTypeDefs.h.inc"
|
||||
|
||||
#endif // MLIR_TESTTYPES_H
|
||||
|
Loading…
x
Reference in New Issue
Block a user