mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-19 07:46:49 +00:00
[mlir] load dialect in parser for optional parameters (#96667)
https://github.com/llvm/llvm-project/pull/96242 fixed an issue where the auto-generated parsers were not loading dialects whose namespaces are not present in the textual IR. This required the attribute parameter to be a tablegen def with its dialect information attached. This fails when using parameter wrapper classes like `OptionalParameter`. This came up because `RingAttr` uses `OptionalParameter` for its second and third attributes. `OptionalParameter` takes as input the C++ type as a string instead of the tablegen def, and so it doesn't have a dialect member value to trigger the fix from https://github.com/llvm/llvm-project/pull/96242. The docs on this topic say the appropriate solution as overloading `FieldParser` for a particular type. This PR updates `FieldParser` for generic attributes to load the dialect on demand. This requires `mlir-tblgen` to emit a `dialectName` static field on the generated attribute class, and check for it with template metaprogramming, since not all attribute types go through `mlir-tblgen`. --------- Co-authored-by: Jeremy Kun <j2kun@users.noreply.github.com> Co-authored-by: Oleksandr "Alex" Zinenko <ftynse@gmail.com>
This commit is contained in:
parent
c65f8d8816
commit
07c157a435
@ -15,6 +15,22 @@
|
||||
#define MLIR_IR_DIALECTIMPLEMENTATION_H
|
||||
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include <type_traits>
|
||||
|
||||
namespace {
|
||||
|
||||
// reference https://stackoverflow.com/a/16000226
|
||||
template <typename T, typename = void>
|
||||
struct HasStaticDialectName : std::false_type {};
|
||||
|
||||
template <typename T>
|
||||
struct HasStaticDialectName<
|
||||
T, typename std::enable_if<
|
||||
std::is_same<::llvm::StringLiteral,
|
||||
std::decay_t<decltype(T::dialectName)>>::value,
|
||||
void>::type> : std::true_type {};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@ -63,6 +79,9 @@ struct FieldParser<
|
||||
AttributeT, std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
|
||||
AttributeT>> {
|
||||
static FailureOr<AttributeT> parse(AsmParser &parser) {
|
||||
if constexpr (HasStaticDialectName<AttributeT>::value) {
|
||||
parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
|
||||
}
|
||||
AttributeT value;
|
||||
if (parser.parseCustomAttributeWithFallback(value))
|
||||
return failure();
|
||||
@ -112,6 +131,9 @@ struct FieldParser<
|
||||
std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
|
||||
std::optional<AttributeT>>> {
|
||||
static FailureOr<std::optional<AttributeT>> parse(AsmParser &parser) {
|
||||
if constexpr (HasStaticDialectName<AttributeT>::value) {
|
||||
parser.getContext()->getOrLoadDialect(AttributeT::dialectName);
|
||||
}
|
||||
AttributeT attr;
|
||||
OptionalParseResult result = parser.parseOptionalAttribute(attr);
|
||||
if (result.has_value()) {
|
||||
|
@ -1464,15 +1464,3 @@ test.dialect_custom_format_fallback custom_format_fallback
|
||||
// Check that an op with an optional result parses f80 as type.
|
||||
// CHECK: test.format_optional_result_d_op : f80
|
||||
test.format_optional_result_d_op : f80
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// This is a testing that a non-qualified attribute in a custom format
|
||||
// correctly preload the dialect before creating the attribute.
|
||||
#attr = #test.nested_polynomial<<1 + x**2>>
|
||||
// CHECK-lABLE: @parse_correctly
|
||||
llvm.func @parse_correctly() {
|
||||
test.containing_int_polynomial_attr #attr
|
||||
llvm.return
|
||||
}
|
||||
|
19
mlir/test/IR/parser_dialect_loading.mlir
Normal file
19
mlir/test/IR/parser_dialect_loading.mlir
Normal file
@ -0,0 +1,19 @@
|
||||
// RUN: mlir-opt -allow-unregistered-dialect --split-input-file %s | FileCheck %s
|
||||
|
||||
// This is a testing that a non-qualified attribute in a custom format
|
||||
// correctly preload the dialect before creating the attribute.
|
||||
#attr = #test.nested_polynomial<poly=<1 + x**2>>
|
||||
// CHECK-LABEL: @parse_correctly
|
||||
llvm.func @parse_correctly() {
|
||||
test.containing_int_polynomial_attr #attr
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#attr2 = #test.nested_polynomial2<poly=<1 + x**2>>
|
||||
// CHECK-LABEL: @parse_correctly_2
|
||||
llvm.func @parse_correctly_2() {
|
||||
test.containing_int_polynomial_attr2 #attr2
|
||||
llvm.return
|
||||
}
|
@ -356,8 +356,17 @@ def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
|
||||
let mnemonic = "nested_polynomial";
|
||||
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
|
||||
let assemblyFormat = [{
|
||||
`<` $poly `>`
|
||||
`<` struct(params) `>`
|
||||
}];
|
||||
}
|
||||
|
||||
def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
|
||||
let mnemonic = "nested_polynomial2";
|
||||
let parameters = (ins OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">:$poly);
|
||||
let assemblyFormat = [{
|
||||
`<` struct(params) `>`
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif // TEST_ATTRDEFS
|
||||
|
@ -237,6 +237,11 @@ def ContainingIntPolynomialAttrOp : TEST_Op<"containing_int_polynomial_attr"> {
|
||||
let assemblyFormat = "$attr attr-dict";
|
||||
}
|
||||
|
||||
def ContainingIntPolynomialAttr2Op : TEST_Op<"containing_int_polynomial_attr2"> {
|
||||
let arguments = (ins NestedPolynomialAttr2:$attr);
|
||||
let assemblyFormat = "$attr attr-dict";
|
||||
}
|
||||
|
||||
// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
|
||||
// This tests both matching and generating float elements attributes.
|
||||
def UpdateFloatElementsAttr : Pat<
|
||||
|
@ -89,6 +89,8 @@ private:
|
||||
void emitTopLevelDeclarations();
|
||||
/// Emit the function that returns the type or attribute name.
|
||||
void emitName();
|
||||
/// Emit the dialect name as a static member variable.
|
||||
void emitDialectName();
|
||||
/// Emit attribute or type builders.
|
||||
void emitBuilders();
|
||||
/// Emit a verifier for the def.
|
||||
@ -184,6 +186,8 @@ DefGen::DefGen(const AttrOrTypeDef &def)
|
||||
emitBuilders();
|
||||
// Emit the type name.
|
||||
emitName();
|
||||
// Emit the dialect name.
|
||||
emitDialectName();
|
||||
// Emit the verifier.
|
||||
if (storageCls && def.genVerifyDecl())
|
||||
emitVerifier();
|
||||
@ -281,6 +285,13 @@ void DefGen::emitName() {
|
||||
defCls.declare<ExtraClassDeclaration>(std::move(nameDecl));
|
||||
}
|
||||
|
||||
void DefGen::emitDialectName() {
|
||||
std::string decl =
|
||||
strfmt("static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
|
||||
def.getDialect().getName());
|
||||
defCls.declare<ExtraClassDeclaration>(std::move(decl));
|
||||
}
|
||||
|
||||
void DefGen::emitBuilders() {
|
||||
if (!def.skipDefaultBuilders()) {
|
||||
emitDefaultBuilder();
|
||||
|
@ -423,9 +423,11 @@ void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
|
||||
Dialect dialect(dialectInit->getDef());
|
||||
auto cppNamespace = dialect.getCppNamespace();
|
||||
std::string name = dialect.getCppClassName();
|
||||
dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
|
||||
cppNamespace + "::" + name + ">();")
|
||||
.str();
|
||||
if (name != "BuiltinDialect" || cppNamespace != "::mlir") {
|
||||
dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
|
||||
cppNamespace + "::" + name + ">();")
|
||||
.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user