[mlir][ods] Do not print default-valued properties when the value is equal to the default (#87970)

This diff causes the `tblgen`-erated printProperties() function to skip
printing a `DefaultValuedAttr` property when the value is equal to the
default.

Co-authored-by: Biao Wang <biaow@nvidia.com>
This commit is contained in:
Beal Wang 2024-04-12 16:37:50 +08:00 committed by GitHub
parent ad4e1aba3f
commit d488b2225d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 85 additions and 24 deletions

View File

@ -27,7 +27,8 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
code extraBaseClassDeclaration = [{
void printProperties(::mlir::MLIRContext *ctx,
::mlir::OpAsmPrinter &p, const Properties &prop) {
::mlir::OpAsmPrinter &p, const Properties &prop,
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
Attribute propAttr = getPropertiesAsAttr(ctx, prop);
if (propAttr)
p << "<" << propAttr << ">";

View File

@ -226,8 +226,10 @@ protected:
static ParseResult genericParseProperties(OpAsmParser &parser,
Attribute &result);
/// Print the properties as a Attribute.
static void genericPrintProperties(OpAsmPrinter &p, Attribute properties);
/// Print the properties as a Attribute with names not included within
/// 'elidedProps'
static void genericPrintProperties(OpAsmPrinter &p, Attribute properties,
ArrayRef<StringRef> elidedProps = {});
/// Print an operation name, eliding the dialect prefix if necessary.
static void printOpName(Operation *op, OpAsmPrinter &p,
@ -1805,10 +1807,13 @@ private:
template <typename T>
using detect_has_print = llvm::is_detected<has_print, T>;
/// Trait to check if printProperties(OpAsmPrinter, T) exist
/// Trait to check if printProperties(OpAsmPrinter, T, ArrayRef<StringRef>)
/// exist
template <typename T, typename... Args>
using has_print_properties = decltype(printProperties(
std::declval<OpAsmPrinter &>(), std::declval<T>()));
using has_print_properties =
decltype(printProperties(std::declval<OpAsmPrinter &>(),
std::declval<T>(),
std::declval<ArrayRef<StringRef>>()));
template <typename T>
using detect_has_print_properties =
llvm::is_detected<has_print_properties, T>;
@ -1974,16 +1979,18 @@ public:
static void populateDefaultProperties(OperationName opName,
InferredProperties<T> &properties) {}
/// Print the operation properties. Unless overridden, this method will try to
/// dispatch to a `printProperties` free-function if it exists, and otherwise
/// by converting the properties to an Attribute.
/// Print the operation properties with names not included within
/// 'elidedProps'. Unless overridden, this method will try to dispatch to a
/// `printProperties` free-function if it exists, and otherwise by converting
/// the properties to an Attribute.
template <typename T>
static void printProperties(MLIRContext *ctx, OpAsmPrinter &p,
const T &properties) {
const T &properties,
ArrayRef<StringRef> elidedProps = {}) {
if constexpr (detect_has_print_properties<T>::value)
return printProperties(p, properties);
genericPrintProperties(p,
ConcreteType::getPropertiesAsAttr(ctx, properties));
return printProperties(p, properties, elidedProps);
genericPrintProperties(
p, ConcreteType::getPropertiesAsAttr(ctx, properties), elidedProps);
}
/// Parser the properties. Unless overridden, this method will print by

View File

@ -790,15 +790,33 @@ void OpState::printOpName(Operation *op, OpAsmPrinter &p,
/// Parse properties as a Attribute.
ParseResult OpState::genericParseProperties(OpAsmParser &parser,
Attribute &result) {
if (parser.parseLess() || parser.parseAttribute(result) ||
parser.parseGreater())
return failure();
if (succeeded(parser.parseOptionalLess())) { // The less is optional.
if (parser.parseAttribute(result) || parser.parseGreater())
return failure();
}
return success();
}
/// Print the properties as a Attribute.
void OpState::genericPrintProperties(OpAsmPrinter &p, Attribute properties) {
p << "<" << properties << ">";
/// Print the properties as a Attribute with names not included within
/// 'elidedProps'
void OpState::genericPrintProperties(OpAsmPrinter &p, Attribute properties,
ArrayRef<StringRef> elidedProps) {
auto dictAttr = dyn_cast_or_null<::mlir::DictionaryAttr>(properties);
if (dictAttr && !elidedProps.empty()) {
ArrayRef<NamedAttribute> attrs = dictAttr.getValue();
llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedProps.begin(),
elidedProps.end());
bool atLeastOneAttr = llvm::any_of(attrs, [&](NamedAttribute attr) {
return !elidedAttrsSet.contains(attr.getName().strref());
});
if (atLeastOneAttr) {
p << "<";
p.printOptionalAttrDict(dictAttr.getValue(), elidedProps);
p << ">";
}
} else {
p << "<" << properties << ">";
}
}
/// Emit an error about fatal conditions with this operation, reporting up to

View File

@ -33,3 +33,8 @@ test.using_property_in_custom [1, 4, 20]
// GENERIC-SAME: second = 4
// GENERIC-SAME: }>
test.using_property_ref_in_custom 1 + 4 = 5
// CHECK: test.with_default_valued_properties {{$}}
// GENERIC: "test.with_default_valued_properties"()
// GENERIC-SAME: <{a = 0 : i32}>
test.with_default_valued_properties <{a = 0 : i32}>

View File

@ -2909,7 +2909,8 @@ def TestOpWithNiceProperties : TEST_Op<"with_nice_properties"> {
);
let extraClassDeclaration = [{
void printProperties(::mlir::MLIRContext *ctx, ::mlir::OpAsmPrinter &p,
const Properties &prop);
const Properties &prop,
::mlir::ArrayRef<::llvm::StringRef> elidedProps);
static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
::mlir::OperationState &result);
static ::mlir::LogicalResult readFromMlirBytecode(
@ -2938,7 +2939,8 @@ def TestOpWithNiceProperties : TEST_Op<"with_nice_properties"> {
writer.writeVarInt(prop.value);
}
void TestOpWithNiceProperties::printProperties(::mlir::MLIRContext *ctx,
::mlir::OpAsmPrinter &p, const Properties &prop) {
::mlir::OpAsmPrinter &p, const Properties &prop,
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
customPrintProperties(p, prop.prop);
}
::mlir::ParseResult TestOpWithNiceProperties::parseProperties(
@ -2971,7 +2973,8 @@ def TestOpWithVersionedProperties : TEST_Op<"with_versioned_properties"> {
);
let extraClassDeclaration = [{
void printProperties(::mlir::MLIRContext *ctx, ::mlir::OpAsmPrinter &p,
const Properties &prop);
const Properties &prop,
::mlir::ArrayRef<::llvm::StringRef> elidedProps);
static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
::mlir::OperationState &result);
static ::mlir::LogicalResult readFromMlirBytecode(
@ -2983,7 +2986,8 @@ def TestOpWithVersionedProperties : TEST_Op<"with_versioned_properties"> {
}];
let extraClassDefinition = [{
void TestOpWithVersionedProperties::printProperties(::mlir::MLIRContext *ctx,
::mlir::OpAsmPrinter &p, const Properties &prop) {
::mlir::OpAsmPrinter &p, const Properties &prop,
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
customPrintProperties(p, prop.prop);
}
::mlir::ParseResult TestOpWithVersionedProperties::parseProperties(
@ -2997,6 +3001,11 @@ def TestOpWithVersionedProperties : TEST_Op<"with_versioned_properties"> {
}];
}
def TestOpWithDefaultValuedProperties : TEST_Op<"with_default_valued_properties"> {
let assemblyFormat = "prop-dict attr-dict";
let arguments = (ins DefaultValuedAttr<I32Attr, "0">:$a);
}
//===----------------------------------------------------------------------===//
// Test Dataflow
//===----------------------------------------------------------------------===//

View File

@ -1775,9 +1775,30 @@ const char *enumAttrBeginPrinterCode = R"(
/// Generate the printer for the 'prop-dict' directive.
static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
MethodBody &body) {
body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedProps;\n";
// Add code to check attributes for equality with the default value
// for attributes with the elidePrintingDefaultValue bit set.
for (const NamedAttribute &namedAttr : op.getAttributes()) {
const Attribute &attr = namedAttr.attr;
if (!attr.isDerivedAttr() && attr.hasDefaultValue()) {
const StringRef &name = namedAttr.name;
FmtContext fctx;
fctx.withBuilder("odsBuilder");
std::string defaultValue = std::string(
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
body << " {\n";
body << " ::mlir::Builder odsBuilder(getContext());\n";
body << " ::mlir::Attribute attr = " << op.getGetterName(name)
<< "Attr();\n";
body << " if(attr && (attr == " << defaultValue << "))\n";
body << " elidedProps.push_back(\"" << name << "\");\n";
body << " }\n";
}
}
body << " _odsPrinter << \" \";\n"
<< " printProperties(this->getContext(), _odsPrinter, "
"getProperties());\n";
"getProperties(), elidedProps);\n";
}
/// Generate the printer for the 'attr-dict' directive.