[mlir] Fix infinite recursion in alias initializer

The alias initializer keeps a list of child indices around. When an alias is then marked as non-deferrable, all children are also marked non-deferrable.

This is currently done naively which leads to an infinite recursion if using mutable types or attributes containing a cycle.

This patch fixes this by adding an early return if the alias is already marked non-deferrable. Since this function is the only way to mark an alias as non-deferrable, it is guaranteed that if it is marked non-deferrable, all its children are as well, and it is not required to walk all the children.
This incidentally makes the non-deferrable marking also `O(n)` instead of `O(n^2)` (although not performance sensitive obviously).

Differential Revision: https://reviews.llvm.org/D158932
This commit is contained in:
Markus Böck 2023-08-26 16:10:52 +02:00
parent 57390c914b
commit de3f7e2f0f
6 changed files with 94 additions and 3 deletions

View File

@ -1056,6 +1056,12 @@ std::pair<size_t, size_t> AliasInitializer::visitImpl(
void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
auto it = std::next(aliases.begin(), aliasIndex);
// If already marked non-deferrable stop the recursion.
// All children should already be marked non-deferrable as well.
if (!it->second.canBeDeferred)
return;
it->second.canBeDeferred = false;
// Propagate the non-deferrable flag to any child aliases.

View File

@ -1,6 +1,8 @@
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
@ -12,6 +14,16 @@ func.func @roundtrip() {
// into inifinite recursion.
// CHECK: !testrec
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
// CHECK: () -> ![[$NAME]]
// CHECK: () -> ![[$NAME]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: () -> ![[$NAME2]]
// CHECK: () -> ![[$NAME2]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
return
}

View File

@ -218,6 +218,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
os << recAliasType.getName();
return AliasResult::FinalAlias;
}
return AliasResult::NoAlias;
}

View File

@ -373,4 +373,22 @@ def TestI32 : Test_Type<"TestI32"> {
let mnemonic = "i32";
}
def TestRecursiveAlias
: Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> {
let mnemonic = "test_rec_alias";
let storageClass = "TestRecursiveTypeStorage";
let storageNamespace = "test";
let genStorageClass = 0;
let parameters = (ins "llvm::StringRef":$name);
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
Type getBody() const;
void setBody(Type type);
}];
}
#endif // TEST_TYPEDEFS

View File

@ -482,3 +482,54 @@ void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
SetVector<Type> stack;
printTestType(type, printer, stack);
}
Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }
void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }
StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }
Type TestRecursiveAliasType::parse(AsmParser &parser) {
thread_local static SetVector<Type> stack;
StringRef name;
if (parser.parseLess() || parser.parseKeyword(&name))
return Type();
auto rec = TestRecursiveAliasType::get(parser.getContext(), name);
// If this type already has been parsed above in the stack, expect just the
// name.
if (stack.contains(rec)) {
if (failed(parser.parseGreater()))
return Type();
return rec;
}
// Otherwise, parse the body and update the type.
if (failed(parser.parseComma()))
return Type();
stack.insert(rec);
Type subtype;
if (parser.parseType(subtype))
return nullptr;
stack.pop_back();
if (!subtype || failed(parser.parseGreater()))
return Type();
rec.setBody(subtype);
return rec;
}
void TestRecursiveAliasType::print(AsmPrinter &printer) const {
thread_local static SetVector<Type> stack;
printer << "<" << getName();
if (!stack.contains(*this)) {
printer << ", ";
stack.insert(*this);
printer << getBody();
stack.pop_back();
}
printer << ">";
}

View File

@ -91,9 +91,6 @@ struct FieldParser<std::optional<int>> {
#include "TestTypeInterfaces.h.inc"
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.h.inc"
namespace test {
/// Storage for simple named recursive types, where the type is identified by
@ -150,4 +147,7 @@ public:
} // namespace test
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.h.inc"
#endif // MLIR_TESTTYPES_H