//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the MLIR AsmPrinter class, which is used to implement // the various print() methods on the core IR objects. // //===----------------------------------------------------------------------===// #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Verifier.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Endian.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/Threading.h" #include "llvm/Support/raw_ostream.h" #include #include #include using namespace mlir; using namespace mlir::detail; #define DEBUG_TYPE "mlir-asm-printer" void OperationName::print(raw_ostream &os) const { os << getStringRef(); } void OperationName::dump() const { print(llvm::errs()); } //===--------------------------------------------------------------------===// // AsmParser //===--------------------------------------------------------------------===// AsmParser::~AsmParser() = default; DialectAsmParser::~DialectAsmParser() = default; OpAsmParser::~OpAsmParser() = default; MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } /// Parse a type list. /// This is out-of-line to work-around /// https://github.com/llvm/llvm-project/issues/62918 ParseResult AsmParser::parseTypeList(SmallVectorImpl &result) { return parseCommaSeparatedList( [&]() { return parseType(result.emplace_back()); }); } //===----------------------------------------------------------------------===// // DialectAsmPrinter //===----------------------------------------------------------------------===// DialectAsmPrinter::~DialectAsmPrinter() = default; //===----------------------------------------------------------------------===// // OpAsmPrinter //===----------------------------------------------------------------------===// OpAsmPrinter::~OpAsmPrinter() = default; void OpAsmPrinter::printFunctionalType(Operation *op) { auto &os = getStream(); os << '('; llvm::interleaveComma(op->getOperands(), os, [&](Value operand) { // Print the types of null values as <>. *this << (operand ? operand.getType() : Type()); }); os << ") -> "; // Print the result list. We don't parenthesize single result types unless // it is a function (avoiding a grammar ambiguity). bool wrapped = op->getNumResults() != 1; if (!wrapped && op->getResult(0).getType() && llvm::isa(op->getResult(0).getType())) wrapped = true; if (wrapped) os << '('; llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) { // Print the types of null values as <>. *this << (result ? result.getType() : Type()); }); if (wrapped) os << ')'; } //===----------------------------------------------------------------------===// // Operation OpAsm interface. //===----------------------------------------------------------------------===// /// The OpAsmOpInterface, see OpAsmInterface.td for more details. #include "mlir/IR/OpAsmAttrInterface.cpp.inc" #include "mlir/IR/OpAsmOpInterface.cpp.inc" #include "mlir/IR/OpAsmTypeInterface.cpp.inc" LogicalResult OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const { return entry.emitError() << "unknown 'resource' key '" << entry.getKey() << "' for dialect '" << getDialect()->getNamespace() << "'"; } //===----------------------------------------------------------------------===// // OpPrintingFlags //===----------------------------------------------------------------------===// namespace { /// This struct contains command line options that can be used to initialize /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need /// for global command line options. struct AsmPrinterOptions { llvm::cl::opt printElementsAttrWithHexIfLarger{ "mlir-print-elementsattrs-with-hex-if-larger", llvm::cl::desc( "Print DenseElementsAttrs with a hex string that have " "more elements than the given upper limit (use -1 to disable)")}; llvm::cl::opt elideElementsAttrIfLarger{ "mlir-elide-elementsattrs-if-larger", llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " "more elements than the given upper limit")}; llvm::cl::opt elideResourceStringsIfLarger{ "mlir-elide-resource-strings-if-larger", llvm::cl::desc( "Elide printing value of resources if string is too long in chars.")}; llvm::cl::opt printDebugInfoOpt{ "mlir-print-debuginfo", llvm::cl::init(false), llvm::cl::desc("Print debug info in MLIR output")}; llvm::cl::opt printPrettyDebugInfoOpt{ "mlir-pretty-debuginfo", llvm::cl::init(false), llvm::cl::desc("Print pretty debug info in MLIR output")}; // Use the generic op output form in the operation printer even if the custom // form is defined. llvm::cl::opt printGenericOpFormOpt{ "mlir-print-op-generic", llvm::cl::init(false), llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden}; llvm::cl::opt assumeVerifiedOpt{ "mlir-print-assume-verified", llvm::cl::init(false), llvm::cl::desc("Skip op verification when using custom printers"), llvm::cl::Hidden}; llvm::cl::opt printLocalScopeOpt{ "mlir-print-local-scope", llvm::cl::init(false), llvm::cl::desc("Print with local scope and inline information (eliding " "aliases for attributes, types, and locations)")}; llvm::cl::opt skipRegionsOpt{ "mlir-print-skip-regions", llvm::cl::init(false), llvm::cl::desc("Skip regions when printing ops.")}; llvm::cl::opt printValueUsers{ "mlir-print-value-users", llvm::cl::init(false), llvm::cl::desc( "Print users of operation results and block arguments as a comment")}; llvm::cl::opt printUniqueSSAIDs{ "mlir-print-unique-ssa-ids", llvm::cl::init(false), llvm::cl::desc("Print unique SSA ID numbers for values, block arguments " "and naming conflicts across all regions")}; llvm::cl::opt useNameLocAsPrefix{ "mlir-use-nameloc-as-prefix", llvm::cl::init(false), llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")}; }; } // namespace static llvm::ManagedStatic clOptions; /// Register a set of useful command-line options that can be used to configure /// various flags within the AsmPrinter. void mlir::registerAsmPrinterCLOptions() { // Make sure that the options struct has been initialized. *clOptions; } /// Initialize the printing flags with default supplied by the cl::opts above. OpPrintingFlags::OpPrintingFlags() : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), printGenericOpFormFlag(false), skipRegionsFlag(false), assumeVerifiedFlag(false), printLocalScope(false), printValueUsersFlag(false), printUniqueSSAIDsFlag(false), useNameLocAsPrefix(false) { // Initialize based upon command line options, if they are available. if (!clOptions.isConstructed()) return; if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) elementsAttrHexElementLimit = clOptions->printElementsAttrWithHexIfLarger.getValue(); if (clOptions->elideResourceStringsIfLarger.getNumOccurrences()) resourceStringCharLimit = clOptions->elideResourceStringsIfLarger; printDebugInfoFlag = clOptions->printDebugInfoOpt; printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; printGenericOpFormFlag = clOptions->printGenericOpFormOpt; assumeVerifiedFlag = clOptions->assumeVerifiedOpt; printLocalScope = clOptions->printLocalScopeOpt; skipRegionsFlag = clOptions->skipRegionsOpt; printValueUsersFlag = clOptions->printValueUsers; printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs; useNameLocAsPrefix = clOptions->useNameLocAsPrefix; } /// Enable the elision of large elements attributes, by printing a '...' /// instead of the element data, when the number of elements is greater than /// `largeElementLimit`. Note: The IR generated with this option is not /// parsable. OpPrintingFlags & OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { elementsAttrElementLimit = largeElementLimit; return *this; } OpPrintingFlags & OpPrintingFlags::printLargeElementsAttrWithHex(int64_t largeElementLimit) { elementsAttrHexElementLimit = largeElementLimit; return *this; } OpPrintingFlags & OpPrintingFlags::elideLargeResourceString(int64_t largeResourceLimit) { resourceStringCharLimit = largeResourceLimit; return *this; } /// Enable printing of debug information. If 'prettyForm' is set to true, /// debug information is printed in a more readable 'pretty' form. OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable, bool prettyForm) { printDebugInfoFlag = enable; printDebugInfoPrettyFormFlag = prettyForm; return *this; } /// Always print operations in the generic form. OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) { printGenericOpFormFlag = enable; return *this; } /// Always skip Regions. OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) { skipRegionsFlag = skip; return *this; } /// Do not verify the operation when using custom operation printers. OpPrintingFlags &OpPrintingFlags::assumeVerified(bool enable) { assumeVerifiedFlag = enable; return *this; } /// Use local scope when printing the operation. This allows for using the /// printer in a more localized and thread-safe setting, but may not necessarily /// be identical of what the IR will look like when dumping the full module. OpPrintingFlags &OpPrintingFlags::useLocalScope(bool enable) { printLocalScope = enable; return *this; } /// Print users of values as comments. OpPrintingFlags &OpPrintingFlags::printValueUsers(bool enable) { printValueUsersFlag = enable; return *this; } /// Print unique SSA ID numbers for values, block arguments and naming conflicts /// across all regions OpPrintingFlags &OpPrintingFlags::printUniqueSSAIDs(bool enable) { printUniqueSSAIDsFlag = enable; return *this; } /// Return if the given ElementsAttr should be elided. bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { return elementsAttrElementLimit && *elementsAttrElementLimit < int64_t(attr.getNumElements()) && !llvm::isa(attr); } /// Return if the given ElementsAttr should be printed as hex string. bool OpPrintingFlags::shouldPrintElementsAttrWithHex(ElementsAttr attr) const { // -1 is used to disable hex printing. return (elementsAttrHexElementLimit != -1) && (elementsAttrHexElementLimit < int64_t(attr.getNumElements())) && !llvm::isa(attr); } OpPrintingFlags &OpPrintingFlags::printNameLocAsPrefix(bool enable) { useNameLocAsPrefix = enable; return *this; } /// Return the size limit for printing large ElementsAttr. std::optional OpPrintingFlags::getLargeElementsAttrLimit() const { return elementsAttrElementLimit; } /// Return the size limit for printing large ElementsAttr as hex string. int64_t OpPrintingFlags::getLargeElementsAttrHexLimit() const { return elementsAttrHexElementLimit; } /// Return the size limit for printing large ElementsAttr. std::optional OpPrintingFlags::getLargeResourceStringLimit() const { return resourceStringCharLimit; } /// Return if debug information should be printed. bool OpPrintingFlags::shouldPrintDebugInfo() const { return printDebugInfoFlag; } /// Return if debug information should be printed in the pretty form. bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { return printDebugInfoPrettyFormFlag; } /// Return if operations should be printed in the generic form. bool OpPrintingFlags::shouldPrintGenericOpForm() const { return printGenericOpFormFlag; } /// Return if Region should be skipped. bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; } /// Return if operation verification should be skipped. bool OpPrintingFlags::shouldAssumeVerified() const { return assumeVerifiedFlag; } /// Return if the printer should use local scope when dumping the IR. bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } /// Return if the printer should print users of values. bool OpPrintingFlags::shouldPrintValueUsers() const { return printValueUsersFlag; } /// Return if the printer should use unique IDs. bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const { return printUniqueSSAIDsFlag || shouldPrintGenericOpForm(); } /// Return if the printer should use NameLocs as prefixes when printing SSA IDs. bool OpPrintingFlags::shouldUseNameLocAsPrefix() const { return useNameLocAsPrefix; } //===----------------------------------------------------------------------===// // NewLineCounter //===----------------------------------------------------------------------===// namespace { /// This class is a simple formatter that emits a new line when inputted into a /// stream, that enables counting the number of newlines emitted. This class /// should be used whenever emitting newlines in the printer. struct NewLineCounter { unsigned curLine = 1; }; static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { ++newLine.curLine; return os << '\n'; } } // namespace //===----------------------------------------------------------------------===// // AsmPrinter::Impl //===----------------------------------------------------------------------===// namespace mlir { class AsmPrinter::Impl { public: Impl(raw_ostream &os, AsmStateImpl &state); explicit Impl(Impl &other) : Impl(other.os, other.state) {} /// Returns the output stream of the printer. raw_ostream &getStream() { return os; } template inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const { llvm::interleaveComma(c, os, eachFn); } /// This enum describes the different kinds of elision for the type of an /// attribute when printing it. enum class AttrTypeElision { /// The type must not be elided, Never, /// The type may be elided when it matches the default used in the parser /// (for example i64 is the default for integer attributes). May, /// The type must be elided. Must }; /// Print the given attribute or an alias. void printAttribute(Attribute attr, AttrTypeElision typeElision = AttrTypeElision::Never); /// Print the given attribute without considering an alias. void printAttributeImpl(Attribute attr, AttrTypeElision typeElision = AttrTypeElision::Never); /// Print the alias for the given attribute, return failure if no alias could /// be printed. LogicalResult printAlias(Attribute attr); /// Print the given type or an alias. void printType(Type type); /// Print the given type. void printTypeImpl(Type type); /// Print the alias for the given type, return failure if no alias could /// be printed. LogicalResult printAlias(Type type); /// Print the given location to the stream. If `allowAlias` is true, this /// allows for the internal location to use an attribute alias. void printLocation(LocationAttr loc, bool allowAlias = false); /// Print a reference to the given resource that is owned by the given /// dialect. void printResourceHandle(const AsmDialectResourceHandle &resource); void printAffineMap(AffineMap map); void printAffineExpr(AffineExpr expr, function_ref printValueName = nullptr); void printAffineConstraint(AffineExpr expr, bool isEq); void printIntegerSet(IntegerSet set); LogicalResult pushCyclicPrinting(const void *opaquePointer); void popCyclicPrinting(); void printDimensionList(ArrayRef shape); protected: void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}, bool withKeyword = false); void printNamedAttribute(NamedAttribute attr); void printTrailingLocation(Location loc, bool allowAlias = true); void printLocationInternal(LocationAttr loc, bool pretty = false, bool isTopLevel = false); /// Print a dense elements attribute. If 'allowHex' is true, a hex string is /// used instead of individual elements when the elements attr is large. void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); /// Print a dense string elements attribute. void printDenseStringElementsAttr(DenseStringElementsAttr attr); /// Print a dense elements attribute. If 'allowHex' is true, a hex string is /// used instead of individual elements when the elements attr is large. void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, bool allowHex); /// Print a dense array attribute. void printDenseArrayAttr(DenseArrayAttr attr); void printDialectAttribute(Attribute attr); void printDialectType(Type type); /// Print an escaped string, wrapped with "". void printEscapedString(StringRef str); /// Print a hex string, wrapped with "". void printHexString(StringRef str); void printHexString(ArrayRef data); /// This enum is used to represent the binding strength of the enclosing /// context that an AffineExprStorage is being printed in, so we can /// intelligently produce parens. enum class BindingStrength { Weak, // + and - Strong, // All other binary operators. }; void printAffineExprInternal( AffineExpr expr, BindingStrength enclosingTightness, function_ref printValueName = nullptr); /// The output stream for the printer. raw_ostream &os; /// An underlying assembly printer state. AsmStateImpl &state; /// A set of flags to control the printer's behavior. OpPrintingFlags printerFlags; /// A tracker for the number of new lines emitted during printing. NewLineCounter newLine; }; } // namespace mlir //===----------------------------------------------------------------------===// // AliasInitializer //===----------------------------------------------------------------------===// namespace { /// This class represents a specific instance of a symbol Alias. class SymbolAlias { public: SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType, bool isDeferrable) : name(name), suffixIndex(suffixIndex), isType(isType), isDeferrable(isDeferrable) {} /// Print this alias to the given stream. void print(raw_ostream &os) const { os << (isType ? "!" : "#") << name; if (suffixIndex) { if (isdigit(name.back())) os << '_'; os << suffixIndex; } } /// Returns true if this is a type alias. bool isTypeAlias() const { return isType; } /// Returns true if this alias supports deferred resolution when parsing. bool canBeDeferred() const { return isDeferrable; } private: /// The main name of the alias. StringRef name; /// The suffix index of the alias. uint32_t suffixIndex : 30; /// A flag indicating whether this alias is for a type. bool isType : 1; /// A flag indicating whether this alias may be deferred or not. bool isDeferrable : 1; public: /// Used to avoid printing incomplete aliases for recursive types. bool isPrinted = false; }; /// This class represents a utility that initializes the set of attribute and /// type aliases, without the need to store the extra information within the /// main AliasState class or pass it around via function arguments. class AliasInitializer { public: AliasInitializer( DialectInterfaceCollection &interfaces, llvm::BumpPtrAllocator &aliasAllocator) : interfaces(interfaces), aliasAllocator(aliasAllocator), aliasOS(aliasBuffer) {} void initialize(Operation *op, const OpPrintingFlags &printerFlags, llvm::MapVector &attrTypeToAlias); /// Visit the given attribute to see if it has an alias. `canBeDeferred` is /// set to true if the originator of this attribute can resolve the alias /// after parsing has completed (e.g. in the case of operation locations). /// `elideType` indicates if the type of the attribute should be skipped when /// looking for nested aliases. Returns the maximum alias depth of the /// attribute, and the alias index of this attribute. std::pair visit(Attribute attr, bool canBeDeferred = false, bool elideType = false) { return visitImpl(attr, aliases, canBeDeferred, elideType); } /// Visit the given type to see if it has an alias. `canBeDeferred` is /// set to true if the originator of this attribute can resolve the alias /// after parsing has completed. Returns the maximum alias depth of the type, /// and the alias index of this type. std::pair visit(Type type, bool canBeDeferred = false) { return visitImpl(type, aliases, canBeDeferred); } private: struct InProgressAliasInfo { InProgressAliasInfo() : aliasDepth(0), isType(false), canBeDeferred(false) {} InProgressAliasInfo(StringRef alias) : alias(alias), aliasDepth(1), isType(false), canBeDeferred(false) {} bool operator<(const InProgressAliasInfo &rhs) const { // Order first by depth, then by attr/type kind, and then by name. if (aliasDepth != rhs.aliasDepth) return aliasDepth < rhs.aliasDepth; if (isType != rhs.isType) return isType; return alias < rhs.alias; } /// The alias for the attribute or type, or std::nullopt if the value has no /// alias. std::optional alias; /// The alias depth of this attribute or type, i.e. an indication of the /// relative ordering of when to print this alias. unsigned aliasDepth : 30; /// If this alias represents a type or an attribute. bool isType : 1; /// If this alias can be deferred or not. bool canBeDeferred : 1; /// Indices for child aliases. SmallVector childIndices; }; /// Visit the given attribute or type to see if it has an alias. /// `canBeDeferred` is set to true if the originator of this value can resolve /// the alias after parsing has completed (e.g. in the case of operation /// locations). Returns the maximum alias depth of the value, and its alias /// index. template std::pair visitImpl(T value, llvm::MapVector &aliases, bool canBeDeferred, PrintArgs &&...printArgs); /// Mark the given alias as non-deferrable. void markAliasNonDeferrable(size_t aliasIndex); /// Try to generate an alias for the provided symbol. If an alias is /// generated, the provided alias mapping and reverse mapping are updated. template void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred); /// Uniques the given alias name within the printer by generating name index /// used as alias name suffix. static unsigned uniqueAliasNameIndex(StringRef alias, llvm::StringMap &nameCounts, llvm::StringSet &usedAliases); /// Given a collection of aliases and symbols, initialize a mapping from a /// symbol to a given alias. static void initializeAliases( llvm::MapVector &visitedSymbols, llvm::MapVector &symbolToAlias); /// The set of asm interfaces within the context. DialectInterfaceCollection &interfaces; /// An allocator used for alias names. llvm::BumpPtrAllocator &aliasAllocator; /// The set of built aliases. llvm::MapVector aliases; /// Storage and stream used when generating an alias. SmallString<32> aliasBuffer; llvm::raw_svector_ostream aliasOS; }; /// This class implements a dummy OpAsmPrinter that doesn't print any output, /// and merely collects the attributes and types that *would* be printed in a /// normal print invocation so that we can generate proper aliases. This allows /// for us to generate aliases only for the attributes and types that would be /// in the output, and trims down unnecessary output. class DummyAliasOperationPrinter : private OpAsmPrinter { public: explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags, AliasInitializer &initializer) : printerFlags(printerFlags), initializer(initializer) {} /// Prints the entire operation with the custom assembly form, if available, /// or the generic assembly form, otherwise. void printCustomOrGenericOp(Operation *op) override { // Visit the operation location. if (printerFlags.shouldPrintDebugInfo()) initializer.visit(op->getLoc(), /*canBeDeferred=*/true); // If requested, always print the generic form. if (!printerFlags.shouldPrintGenericOpForm()) { op->getName().printAssembly(op, *this, /*defaultDialect=*/""); return; } // Otherwise print with the generic assembly form. printGenericOp(op); } private: /// Print the given operation in the generic form. void printGenericOp(Operation *op, bool printOpName = true) override { // Consider nested operations for aliases. if (!printerFlags.shouldSkipRegions()) { for (Region ®ion : op->getRegions()) printRegion(region, /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); } // Visit all the types used in the operation. for (Type type : op->getOperandTypes()) printType(type); for (Type type : op->getResultTypes()) printType(type); // Consider the attributes of the operation for aliases. for (const NamedAttribute &attr : op->getAttrs()) printAttribute(attr.getValue()); } /// Print the given block. If 'printBlockArgs' is false, the arguments of the /// block are not printed. If 'printBlockTerminator' is false, the terminator /// operation of the block is not printed. void print(Block *block, bool printBlockArgs = true, bool printBlockTerminator = true) { // Consider the types of the block arguments for aliases if 'printBlockArgs' // is set to true. if (printBlockArgs) { for (BlockArgument arg : block->getArguments()) { printType(arg.getType()); // Visit the argument location. if (printerFlags.shouldPrintDebugInfo()) // TODO: Allow deferring argument locations. initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); } } // Consider the operations within this block, ignoring the terminator if // requested. bool hasTerminator = !block->empty() && block->back().hasTrait(); auto range = llvm::make_range( block->begin(), std::prev(block->end(), (!hasTerminator || printBlockTerminator) ? 0 : 1)); for (Operation &op : range) printCustomOrGenericOp(&op); } /// Print the given region. void printRegion(Region ®ion, bool printEntryBlockArgs, bool printBlockTerminators, bool printEmptyBlock = false) override { if (region.empty()) return; if (printerFlags.shouldSkipRegions()) { os << "{...}"; return; } auto *entryBlock = ®ion.front(); print(entryBlock, printEntryBlockArgs, printBlockTerminators); for (Block &b : llvm::drop_begin(region, 1)) print(&b); } void printRegionArgument(BlockArgument arg, ArrayRef argAttrs, bool omitType) override { printType(arg.getType()); // Visit the argument location. if (printerFlags.shouldPrintDebugInfo()) // TODO: Allow deferring argument locations. initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); } /// Consider the given type to be printed for an alias. void printType(Type type) override { initializer.visit(type); } /// Consider the given attribute to be printed for an alias. void printAttribute(Attribute attr) override { initializer.visit(attr); } void printAttributeWithoutType(Attribute attr) override { printAttribute(attr); } LogicalResult printAlias(Attribute attr) override { initializer.visit(attr); return success(); } LogicalResult printAlias(Type type) override { initializer.visit(type); return success(); } /// Consider the given location to be printed for an alias. void printOptionalLocationSpecifier(Location loc) override { printAttribute(loc); } /// Print the given set of attributes with names not included within /// 'elidedAttrs'. void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}) override { if (attrs.empty()) return; if (elidedAttrs.empty()) { for (const NamedAttribute &attr : attrs) printAttribute(attr.getValue()); return; } llvm::SmallDenseSet elidedAttrsSet(elidedAttrs.begin(), elidedAttrs.end()); for (const NamedAttribute &attr : attrs) if (!elidedAttrsSet.contains(attr.getName().strref())) printAttribute(attr.getValue()); } void printOptionalAttrDictWithKeyword( ArrayRef attrs, ArrayRef elidedAttrs = {}) override { printOptionalAttrDict(attrs, elidedAttrs); } /// Return a null stream as the output stream, this will ignore any data fed /// to it. raw_ostream &getStream() const override { return os; } /// The following are hooks of `OpAsmPrinter` that are not necessary for /// determining potential aliases. void printFloat(const APFloat &) override {} void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} void printNewline() override {} void increaseIndent() override {} void decreaseIndent() override {} void printOperand(Value) override {} void printOperand(Value, raw_ostream &os) override { // Users expect the output string to have at least the prefixed % to signal // a value name. To maintain this invariant, emit a name even if it is // guaranteed to go unused. os << "%"; } void printKeywordOrString(StringRef) override {} void printString(StringRef) override {} void printResourceHandle(const AsmDialectResourceHandle &) override {} void printSymbolName(StringRef) override {} void printSuccessor(Block *) override {} void printSuccessorAndUseList(Block *, ValueRange) override {} void shadowRegionArgs(Region &, ValueRange) override {} /// The printer flags to use when determining potential aliases. const OpPrintingFlags &printerFlags; /// The initializer to use when identifying aliases. AliasInitializer &initializer; /// A dummy output stream. mutable llvm::raw_null_ostream os; }; class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { public: explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer, bool canBeDeferred, SmallVectorImpl &childIndices) : initializer(initializer), canBeDeferred(canBeDeferred), childIndices(childIndices) {} /// Print the given attribute/type, visiting any nested aliases that would be /// generated as part of printing. Returns the maximum alias depth found while /// printing the given value. template size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) { printAndVisitNestedAliasesImpl(value, printArgs...); return maxAliasDepth; } private: /// Print the given attribute/type, visiting any nested aliases that would be /// generated as part of printing. void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) { if (!isa(attr.getDialect())) { attr.getDialect().printAttribute(attr, *this); // Process the builtin attributes. } else if (llvm::isa(attr)) { return; } else if (auto distinctAttr = dyn_cast(attr)) { printAttribute(distinctAttr.getReferencedAttr()); } else if (auto dictAttr = dyn_cast(attr)) { for (const NamedAttribute &nestedAttr : dictAttr.getValue()) { printAttribute(nestedAttr.getName()); printAttribute(nestedAttr.getValue()); } } else if (auto arrayAttr = dyn_cast(attr)) { for (Attribute nestedAttr : arrayAttr.getValue()) printAttribute(nestedAttr); } else if (auto typeAttr = dyn_cast(attr)) { printType(typeAttr.getValue()); } else if (auto locAttr = dyn_cast(attr)) { printAttribute(locAttr.getFallbackLocation()); } else if (auto locAttr = dyn_cast(attr)) { if (!isa(locAttr.getChildLoc())) printAttribute(locAttr.getChildLoc()); } else if (auto locAttr = dyn_cast(attr)) { printAttribute(locAttr.getCallee()); printAttribute(locAttr.getCaller()); } else if (auto locAttr = dyn_cast(attr)) { if (Attribute metadata = locAttr.getMetadata()) printAttribute(metadata); for (Location nestedLoc : locAttr.getLocations()) printAttribute(nestedLoc); } // Don't print the type if we must elide it, or if it is a None type. if (!elideType) { if (auto typedAttr = llvm::dyn_cast(attr)) { Type attrType = typedAttr.getType(); if (!llvm::isa(attrType)) printType(attrType); } } } void printAndVisitNestedAliasesImpl(Type type) { if (!isa(type.getDialect())) return type.getDialect().printType(type, *this); // Only visit the layout of memref if it isn't the identity. if (auto memrefTy = llvm::dyn_cast(type)) { printType(memrefTy.getElementType()); MemRefLayoutAttrInterface layout = memrefTy.getLayout(); if (!llvm::isa(layout) || !layout.isIdentity()) printAttribute(memrefTy.getLayout()); if (memrefTy.getMemorySpace()) printAttribute(memrefTy.getMemorySpace()); return; } // For most builtin types, we can simply walk the sub elements. auto visitFn = [&](auto element) { if (element) (void)printAlias(element); }; type.walkImmediateSubElements(visitFn, visitFn); } /// Consider the given type to be printed for an alias. void printType(Type type) override { recordAliasResult(initializer.visit(type, canBeDeferred)); } /// Consider the given attribute to be printed for an alias. void printAttribute(Attribute attr) override { recordAliasResult(initializer.visit(attr, canBeDeferred)); } void printAttributeWithoutType(Attribute attr) override { recordAliasResult( initializer.visit(attr, canBeDeferred, /*elideType=*/true)); } LogicalResult printAlias(Attribute attr) override { printAttribute(attr); return success(); } LogicalResult printAlias(Type type) override { printType(type); return success(); } /// Record the alias result of a child element. void recordAliasResult(std::pair aliasDepthAndIndex) { childIndices.push_back(aliasDepthAndIndex.second); if (aliasDepthAndIndex.first > maxAliasDepth) maxAliasDepth = aliasDepthAndIndex.first; } /// Return a null stream as the output stream, this will ignore any data fed /// to it. raw_ostream &getStream() const override { return os; } /// The following are hooks of `DialectAsmPrinter` that are not necessary for /// determining potential aliases. void printFloat(const APFloat &) override {} void printKeywordOrString(StringRef) override {} void printString(StringRef) override {} void printSymbolName(StringRef) override {} void printResourceHandle(const AsmDialectResourceHandle &) override {} LogicalResult pushCyclicPrinting(const void *opaquePointer) override { return success(cyclicPrintingStack.insert(opaquePointer)); } void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); } /// Stack of potentially cyclic mutable attributes or type currently being /// printed. SetVector cyclicPrintingStack; /// The initializer to use when identifying aliases. AliasInitializer &initializer; /// If the aliases visited by this printer can be deferred. bool canBeDeferred; /// The indices of child aliases. SmallVectorImpl &childIndices; /// The maximum alias depth found by the printer. size_t maxAliasDepth = 0; /// A dummy output stream. mutable llvm::raw_null_ostream os; }; } // namespace /// Sanitize the given name such that it can be used as a valid identifier. If /// the string needs to be modified in any way, the provided buffer is used to /// store the new copy, static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, StringRef allowedPunctChars = "$._-") { assert(!name.empty() && "Shouldn't have an empty name here"); auto validChar = [&](char ch) { return llvm::isAlnum(ch) || allowedPunctChars.contains(ch); }; auto copyNameToBuffer = [&] { for (char ch : name) { if (validChar(ch)) buffer.push_back(ch); else if (ch == ' ') buffer.push_back('_'); else buffer.append(llvm::utohexstr((unsigned char)ch)); } }; // Check to see if this name is valid. If it starts with a digit, then it // could conflict with the autogenerated numeric ID's, so add an underscore // prefix to avoid problems. if (isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) { buffer.push_back('_'); copyNameToBuffer(); return buffer; } // Check to see that the name consists of only valid identifier characters. for (char ch : name) { if (!validChar(ch)) { copyNameToBuffer(); return buffer; } } // If there are no invalid characters, return the original name. return name; } unsigned AliasInitializer::uniqueAliasNameIndex( StringRef alias, llvm::StringMap &nameCounts, llvm::StringSet &usedAliases) { if (!usedAliases.count(alias)) { usedAliases.insert(alias); // 0 is not printed in SymbolAlias. return 0; } // Otherwise, we had a conflict - probe until we find a unique name. SmallString<64> probeAlias(alias); // alias with trailing digit will be printed as _N if (isdigit(alias.back())) probeAlias.push_back('_'); // nameCounts start from 1 because 0 is not printed in SymbolAlias. if (nameCounts[probeAlias] == 0) nameCounts[probeAlias] = 1; // This is guaranteed to terminate (and usually in a single iteration) // because it generates new names by incrementing nameCounts. while (true) { unsigned nameIndex = nameCounts[probeAlias]++; probeAlias += llvm::utostr(nameIndex); if (!usedAliases.count(probeAlias)) { usedAliases.insert(probeAlias); return nameIndex; } // Reset probeAlias to the original alias for the next iteration. probeAlias.resize(alias.size() + isdigit(alias.back()) ? 1 : 0); } } /// Given a collection of aliases and symbols, initialize a mapping from a /// symbol to a given alias. void AliasInitializer::initializeAliases( llvm::MapVector &visitedSymbols, llvm::MapVector &symbolToAlias) { SmallVector, 0> unprocessedAliases = visitedSymbols.takeVector(); llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) { return lhs.second < rhs.second; }); // This keeps track of all of the non-numeric names that are in flight, // allowing us to check for duplicates. llvm::BumpPtrAllocator usedAliasAllocator; llvm::StringSet usedAliases(usedAliasAllocator); llvm::StringMap nameCounts; for (auto &[symbol, aliasInfo] : unprocessedAliases) { if (!aliasInfo.alias) continue; StringRef alias = *aliasInfo.alias; unsigned nameIndex = uniqueAliasNameIndex(alias, nameCounts, usedAliases); symbolToAlias.insert( {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType, aliasInfo.canBeDeferred)}); } } void AliasInitializer::initialize( Operation *op, const OpPrintingFlags &printerFlags, llvm::MapVector &attrTypeToAlias) { // Use a dummy printer when walking the IR so that we can collect the // attributes/types that will actually be used during printing when // considering aliases. DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); aliasPrinter.printCustomOrGenericOp(op); // Initialize the aliases. initializeAliases(aliases, attrTypeToAlias); } template std::pair AliasInitializer::visitImpl( T value, llvm::MapVector &aliases, bool canBeDeferred, PrintArgs &&...printArgs) { auto [it, inserted] = aliases.insert({value.getAsOpaquePointer(), InProgressAliasInfo()}); size_t aliasIndex = std::distance(aliases.begin(), it); if (!inserted) { // Make sure that the alias isn't deferred if we don't permit it. if (!canBeDeferred) markAliasNonDeferrable(aliasIndex); return {static_cast(it->second.aliasDepth), aliasIndex}; } // Try to generate an alias for this value. generateAlias(value, it->second, canBeDeferred); it->second.isType = std::is_base_of_v; it->second.canBeDeferred = canBeDeferred; // Print the value, capturing any nested elements that require aliases. SmallVector childAliases; DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases); size_t maxAliasDepth = printer.printAndVisitNestedAliases(value, printArgs...); // Make sure to recompute `it` in case the map was reallocated. it = std::next(aliases.begin(), aliasIndex); // If we had sub elements, update to account for the depth. it->second.childIndices = std::move(childAliases); if (maxAliasDepth) it->second.aliasDepth = maxAliasDepth + 1; // Propagate the alias depth of the value. return {(size_t)it->second.aliasDepth, aliasIndex}; } void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) { auto *it = std::next(aliases.begin(), aliasIndex); // If already marked non-deferrable stop the recursion. // All children should already be marked non-deferrable as well. if (!it->second.canBeDeferred) return; it->second.canBeDeferred = false; // Propagate the non-deferrable flag to any child aliases. for (size_t childIndex : it->second.childIndices) markAliasNonDeferrable(childIndex); } template void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred) { SmallString<32> nameBuffer; OpAsmDialectInterface::AliasResult symbolInterfaceResult = OpAsmDialectInterface::AliasResult::NoAlias; using InterfaceT = std::conditional_t, OpAsmAttrInterface, OpAsmTypeInterface>; if (auto symbolInterface = dyn_cast(symbol)) { symbolInterfaceResult = symbolInterface.getAlias(aliasOS); if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::NoAlias) { nameBuffer = std::move(aliasBuffer); assert(!nameBuffer.empty() && "expected valid alias name"); } } if (symbolInterfaceResult != OpAsmDialectInterface::AliasResult::FinalAlias) { for (const auto &interface : interfaces) { OpAsmDialectInterface::AliasResult result = interface.getAlias(symbol, aliasOS); if (result == OpAsmDialectInterface::AliasResult::NoAlias) continue; nameBuffer = std::move(aliasBuffer); assert(!nameBuffer.empty() && "expected valid alias name"); if (result == OpAsmDialectInterface::AliasResult::FinalAlias) break; } } if (nameBuffer.empty()) return; SmallString<16> tempBuffer; StringRef name = sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-"); name = name.copy(aliasAllocator); alias = InProgressAliasInfo(name); } //===----------------------------------------------------------------------===// // AliasState //===----------------------------------------------------------------------===// namespace { /// This class manages the state for type and attribute aliases. class AliasState { public: // Initialize the internal aliases. void initialize(Operation *op, const OpPrintingFlags &printerFlags, DialectInterfaceCollection &interfaces); /// Get an alias for the given attribute if it has one and print it in `os`. /// Returns success if an alias was printed, failure otherwise. LogicalResult getAlias(Attribute attr, raw_ostream &os) const; /// Get an alias for the given type if it has one and print it in `os`. /// Returns success if an alias was printed, failure otherwise. LogicalResult getAlias(Type ty, raw_ostream &os) const; /// Print all of the referenced aliases that can not be resolved in a deferred /// manner. void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { printAliases(p, newLine, /*isDeferred=*/false); } /// Print all of the referenced aliases that support deferred resolution. void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { printAliases(p, newLine, /*isDeferred=*/true); } private: /// Print all of the referenced aliases that support the provided resolution /// behavior. void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, bool isDeferred); /// Mapping between attribute/type and alias. llvm::MapVector attrTypeToAlias; /// An allocator used for alias names. llvm::BumpPtrAllocator aliasAllocator; }; } // namespace void AliasState::initialize( Operation *op, const OpPrintingFlags &printerFlags, DialectInterfaceCollection &interfaces) { AliasInitializer initializer(interfaces, aliasAllocator); initializer.initialize(op, printerFlags, attrTypeToAlias); } LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { const auto *it = attrTypeToAlias.find(attr.getAsOpaquePointer()); if (it == attrTypeToAlias.end()) return failure(); it->second.print(os); return success(); } LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer()); if (it == attrTypeToAlias.end()) return failure(); if (!it->second.isPrinted) return failure(); it->second.print(os); return success(); } void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, bool isDeferred) { auto filterFn = [=](const auto &aliasIt) { return aliasIt.second.canBeDeferred() == isDeferred; }; for (auto &[opaqueSymbol, alias] : llvm::make_filter_range(attrTypeToAlias, filterFn)) { alias.print(p.getStream()); p.getStream() << " = "; if (alias.isTypeAlias()) { Type type = Type::getFromOpaquePointer(opaqueSymbol); p.printTypeImpl(type); alias.isPrinted = true; } else { // TODO: Support nested aliases in mutable attributes. Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol); if (attr.hasTrait()) p.getStream() << attr; else p.printAttributeImpl(attr); } p.getStream() << newLine; } } //===----------------------------------------------------------------------===// // SSANameState //===----------------------------------------------------------------------===// namespace { /// Info about block printing: a number which is its position in the visitation /// order, and a name that is used to print reference to it, e.g. ^bb42. struct BlockInfo { int ordering; StringRef name; }; /// This class manages the state of SSA value names. class SSANameState { public: /// A sentinel value used for values with names set. enum : unsigned { NameSentinel = ~0U }; SSANameState(Operation *op, const OpPrintingFlags &printerFlags); SSANameState() = default; /// Print the SSA identifier for the given value to 'stream'. If /// 'printResultNo' is true, it also presents the result number ('#' number) /// of this value. void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; /// Print the operation identifier. void printOperationID(Operation *op, raw_ostream &stream) const; /// Return the result indices for each of the result groups registered by this /// operation, or empty if none exist. ArrayRef getOpResultGroups(Operation *op); /// Get the info for the given block. BlockInfo getBlockInfo(Block *block); /// Renumber the arguments for the specified region to the same names as the /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for /// details. void shadowRegionArgs(Region ®ion, ValueRange namesToUse); private: /// Number the SSA values within the given IR unit. void numberValuesInRegion(Region ®ion); void numberValuesInBlock(Block &block); void numberValuesInOp(Operation &op); /// Given a result of an operation 'result', find the result group head /// 'lookupValue' and the result of 'result' within that group in /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group /// has more than 1 result. void getResultIDAndNumber(OpResult result, Value &lookupValue, std::optional &lookupResultNo) const; /// Set a special value name for the given value. void setValueName(Value value, StringRef name); /// Uniques the given value name within the printer. If the given name /// conflicts, it is automatically renamed. StringRef uniqueValueName(StringRef name); /// This is the value ID for each SSA value. If this returns NameSentinel, /// then the valueID has an entry in valueNames. DenseMap valueIDs; DenseMap valueNames; /// When printing users of values, an operation without a result might /// be the user. This map holds ids for such operations. DenseMap operationIDs; /// This is a map of operations that contain multiple named result groups, /// i.e. there may be multiple names for the results of the operation. The /// value of this map are the result numbers that start a result group. DenseMap> opResultGroups; /// This maps blocks to there visitation number in the current region as well /// as the string representing their name. DenseMap blockNames; /// This keeps track of all of the non-numeric names that are in flight, /// allowing us to check for duplicates. /// Note: the value of the map is unused. llvm::ScopedHashTable usedNames; llvm::BumpPtrAllocator usedNameAllocator; /// This is the next value ID to assign in numbering. unsigned nextValueID = 0; /// This is the next ID to assign to a region entry block argument. unsigned nextArgumentID = 0; /// This is the next ID to assign when a name conflict is detected. unsigned nextConflictID = 0; /// These are the printing flags. They control, eg., whether to print in /// generic form. OpPrintingFlags printerFlags; }; } // namespace SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags) : printerFlags(printerFlags) { llvm::SaveAndRestore valueIDSaver(nextValueID); llvm::SaveAndRestore argumentIDSaver(nextArgumentID); llvm::SaveAndRestore conflictIDSaver(nextConflictID); // The naming context includes `nextValueID`, `nextArgumentID`, // `nextConflictID` and `usedNames` scoped HashTable. This information is // carried from the parent region. using UsedNamesScopeTy = llvm::ScopedHashTable::ScopeTy; using NamingContext = std::tuple; // Allocator for UsedNamesScopeTy llvm::BumpPtrAllocator allocator; // Add a scope for the top level operation. auto *topLevelNamesScope = new (allocator.Allocate()) UsedNamesScopeTy(usedNames); SmallVector nameContext; for (Region ®ion : op->getRegions()) nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID, nextConflictID, topLevelNamesScope)); numberValuesInOp(*op); while (!nameContext.empty()) { Region *region; UsedNamesScopeTy *parentScope; if (printerFlags.shouldPrintUniqueSSAIDs()) // To print unique SSA IDs, ignore saved ID counts from parent regions std::tie(region, std::ignore, std::ignore, std::ignore, parentScope) = nameContext.pop_back_val(); else std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) = nameContext.pop_back_val(); // When we switch from one subtree to another, pop the scopes(needless) // until the parent scope. while (usedNames.getCurScope() != parentScope) { usedNames.getCurScope()->~UsedNamesScopeTy(); assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) && "top level parentScope must be a nullptr"); } // Add a scope for the current region. auto *curNamesScope = new (allocator.Allocate()) UsedNamesScopeTy(usedNames); numberValuesInRegion(*region); for (Operation &op : region->getOps()) for (Region ®ion : op.getRegions()) nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID, nextConflictID, curNamesScope)); } // Manually remove all the scopes. while (usedNames.getCurScope() != nullptr) usedNames.getCurScope()->~UsedNamesScopeTy(); } void SSANameState::printValueID(Value value, bool printResultNo, raw_ostream &stream) const { if (!value) { stream << "<>"; return; } std::optional resultNo; auto lookupValue = value; // If this is an operation result, collect the head lookup value of the result // group and the result number of 'result' within that group. if (OpResult result = dyn_cast(value)) getResultIDAndNumber(result, lookupValue, resultNo); auto it = valueIDs.find(lookupValue); if (it == valueIDs.end()) { stream << "<>"; return; } stream << '%'; if (it->second != NameSentinel) { stream << it->second; } else { auto nameIt = valueNames.find(lookupValue); assert(nameIt != valueNames.end() && "Didn't have a name entry?"); stream << nameIt->second; } if (resultNo && printResultNo) stream << '#' << *resultNo; } void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const { auto it = operationIDs.find(op); if (it == operationIDs.end()) { stream << "<>"; } else { stream << '%' << it->second; } } ArrayRef SSANameState::getOpResultGroups(Operation *op) { auto it = opResultGroups.find(op); return it == opResultGroups.end() ? ArrayRef() : it->second; } BlockInfo SSANameState::getBlockInfo(Block *block) { auto it = blockNames.find(block); BlockInfo invalidBlock{-1, "INVALIDBLOCK"}; return it != blockNames.end() ? it->second : invalidBlock; } void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { assert(!region.empty() && "cannot shadow arguments of an empty region"); assert(region.getNumArguments() == namesToUse.size() && "incorrect number of names passed in"); assert(region.getParentOp()->hasTrait() && "only KnownIsolatedFromAbove ops can shadow names"); SmallVector nameStr; for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { auto nameToUse = namesToUse[i]; if (nameToUse == nullptr) continue; auto nameToReplace = region.getArgument(i); nameStr.clear(); llvm::raw_svector_ostream nameStream(nameStr); printValueID(nameToUse, /*printResultNo=*/true, nameStream); // Entry block arguments should already have a pretty "arg" name. assert(valueIDs[nameToReplace] == NameSentinel); // Use the name without the leading %. auto name = StringRef(nameStream.str()).drop_front(); // Overwrite the name. valueNames[nameToReplace] = name.copy(usedNameAllocator); } } namespace { /// Try to get value name from value's location, fallback to `name`. StringRef maybeGetValueNameFromLoc(Value value, StringRef name) { if (auto maybeNameLoc = value.getLoc()->findInstanceOf()) return maybeNameLoc.getName(); return name; } } // namespace void SSANameState::numberValuesInRegion(Region ®ion) { // Indicates whether OpAsmOpInterface set a name. bool opAsmOpInterfaceUsed = false; auto setBlockArgNameFn = [&](Value arg, StringRef name) { assert(!valueIDs.count(arg) && "arg numbered multiple times"); assert(llvm::cast(arg).getOwner()->getParent() == ®ion && "arg not defined in current region"); opAsmOpInterfaceUsed = true; if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) name = maybeGetValueNameFromLoc(arg, name); setValueName(arg, name); }; if (!printerFlags.shouldPrintGenericOpForm()) { if (Operation *op = region.getParentOp()) { if (auto asmInterface = dyn_cast(op)) asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn); // If the OpAsmOpInterface didn't set a name, get name from the type. if (!opAsmOpInterfaceUsed) { for (BlockArgument arg : region.getArguments()) { if (auto interface = dyn_cast(arg.getType())) { interface.getAsmName( [&](StringRef name) { setBlockArgNameFn(arg, name); }); } } } } } // Number the values within this region in a breadth-first order. unsigned nextBlockID = 0; for (auto &block : region) { // Each block gets a unique ID, and all of the operations within it get // numbered as well. auto blockInfoIt = blockNames.insert({&block, {-1, ""}}); if (blockInfoIt.second) { // This block hasn't been named through `getAsmBlockArgumentNames`, use // default `^bbNNN` format. std::string name; llvm::raw_string_ostream(name) << "^bb" << nextBlockID; blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator); } blockInfoIt.first->second.ordering = nextBlockID++; numberValuesInBlock(block); } } void SSANameState::numberValuesInBlock(Block &block) { // Number the block arguments. We give entry block arguments a special name // 'arg'. bool isEntryBlock = block.isEntryBlock(); SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); llvm::raw_svector_ostream specialName(specialNameBuffer); for (auto arg : block.getArguments()) { if (valueIDs.count(arg)) continue; if (isEntryBlock) { specialNameBuffer.resize(strlen("arg")); specialName << nextArgumentID++; } StringRef specialNameStr = specialName.str(); if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) specialNameStr = maybeGetValueNameFromLoc(arg, specialNameStr); setValueName(arg, specialNameStr); } // Number the operations in this block. for (auto &op : block) numberValuesInOp(op); } void SSANameState::numberValuesInOp(Operation &op) { // Function used to set the special result names for the operation. SmallVector resultGroups(/*Size=*/1, /*Value=*/0); // Indicates whether OpAsmOpInterface set a name. bool opAsmOpInterfaceUsed = false; auto setResultNameFn = [&](Value result, StringRef name) { assert(!valueIDs.count(result) && "result numbered multiple times"); assert(result.getDefiningOp() == &op && "result not defined by 'op'"); opAsmOpInterfaceUsed = true; if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix())) name = maybeGetValueNameFromLoc(result, name); setValueName(result, name); // Record the result number for groups not anchored at 0. if (int resultNo = llvm::cast(result).getResultNumber()) resultGroups.push_back(resultNo); }; // Operations can customize the printing of block names in OpAsmOpInterface. auto setBlockNameFn = [&](Block *block, StringRef name) { assert(block->getParentOp() == &op && "getAsmBlockArgumentNames callback invoked on a block not directly " "nested under the current operation"); assert(!blockNames.count(block) && "block numbered multiple times"); SmallString<16> tmpBuffer{"^"}; name = sanitizeIdentifier(name, tmpBuffer); if (name.data() != tmpBuffer.data()) { tmpBuffer.append(name); name = tmpBuffer.str(); } name = name.copy(usedNameAllocator); blockNames[block] = {-1, name}; }; if (!printerFlags.shouldPrintGenericOpForm()) { if (OpAsmOpInterface asmInterface = dyn_cast(&op)) { asmInterface.getAsmBlockNames(setBlockNameFn); asmInterface.getAsmResultNames(setResultNameFn); } if (!opAsmOpInterfaceUsed) { // If the OpAsmOpInterface didn't set a name, and all results have // OpAsmTypeInterface, get names from types. bool allHaveOpAsmTypeInterface = llvm::all_of(op.getResultTypes(), [&](Type type) { return isa(type); }); if (allHaveOpAsmTypeInterface) { for (OpResult result : op.getResults()) { auto interface = cast(result.getType()); interface.getAsmName( [&](StringRef name) { setResultNameFn(result, name); }); } } } } unsigned numResults = op.getNumResults(); if (numResults == 0) { // If value users should be printed, operations with no result need an id. if (printerFlags.shouldPrintValueUsers()) { if (operationIDs.try_emplace(&op, nextValueID).second) ++nextValueID; } return; } Value resultBegin = op.getResult(0); if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(resultBegin)) { if (auto nameLoc = resultBegin.getLoc()->findInstanceOf()) { setValueName(resultBegin, nameLoc.getName()); } } // If the first result wasn't numbered, give it a default number. if (valueIDs.try_emplace(resultBegin, nextValueID).second) ++nextValueID; // If this operation has multiple result groups, mark it. if (resultGroups.size() != 1) { llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); opResultGroups.try_emplace(&op, std::move(resultGroups)); } } void SSANameState::getResultIDAndNumber( OpResult result, Value &lookupValue, std::optional &lookupResultNo) const { Operation *owner = result.getOwner(); if (owner->getNumResults() == 1) return; int resultNo = result.getResultNumber(); // If this operation has multiple result groups, we will need to find the // one corresponding to this result. auto resultGroupIt = opResultGroups.find(owner); if (resultGroupIt == opResultGroups.end()) { // If not, just use the first result. lookupResultNo = resultNo; lookupValue = owner->getResult(0); return; } // Find the correct index using a binary search, as the groups are ordered. ArrayRef resultGroups = resultGroupIt->second; const auto *it = llvm::upper_bound(resultGroups, resultNo); int groupResultNo = 0, groupSize = 0; // If there are no smaller elements, the last result group is the lookup. if (it == resultGroups.end()) { groupResultNo = resultGroups.back(); groupSize = static_cast(owner->getNumResults()) - resultGroups.back(); } else { // Otherwise, the previous element is the lookup. groupResultNo = *std::prev(it); groupSize = *it - groupResultNo; } // We only record the result number for a group of size greater than 1. if (groupSize != 1) lookupResultNo = resultNo - groupResultNo; lookupValue = owner->getResult(groupResultNo); } void SSANameState::setValueName(Value value, StringRef name) { // If the name is empty, the value uses the default numbering. if (name.empty()) { valueIDs[value] = nextValueID++; return; } valueIDs[value] = NameSentinel; valueNames[value] = uniqueValueName(name); } StringRef SSANameState::uniqueValueName(StringRef name) { SmallString<16> tmpBuffer; name = sanitizeIdentifier(name, tmpBuffer); // Check to see if this name is already unique. if (!usedNames.count(name)) { name = name.copy(usedNameAllocator); } else { // Otherwise, we had a conflict - probe until we find a unique name. This // is guaranteed to terminate (and usually in a single iteration) because it // generates new names by incrementing nextConflictID. SmallString<64> probeName(name); probeName.push_back('_'); while (true) { probeName += llvm::utostr(nextConflictID++); if (!usedNames.count(probeName)) { name = probeName.str().copy(usedNameAllocator); break; } probeName.resize(name.size() + 1); } } usedNames.insert(name, char()); return name; } //===----------------------------------------------------------------------===// // DistinctState //===----------------------------------------------------------------------===// namespace { /// This class manages the state for distinct attributes. class DistinctState { public: /// Returns a unique identifier for the given distinct attribute. uint64_t getId(DistinctAttr distinctAttr); private: uint64_t distinctCounter = 0; DenseMap distinctAttrMap; }; } // namespace uint64_t DistinctState::getId(DistinctAttr distinctAttr) { auto [it, inserted] = distinctAttrMap.try_emplace(distinctAttr, distinctCounter); if (inserted) distinctCounter++; return it->getSecond(); } //===----------------------------------------------------------------------===// // Resources //===----------------------------------------------------------------------===// AsmParsedResourceEntry::~AsmParsedResourceEntry() = default; AsmResourceBuilder::~AsmResourceBuilder() = default; AsmResourceParser::~AsmResourceParser() = default; AsmResourcePrinter::~AsmResourcePrinter() = default; StringRef mlir::toString(AsmResourceEntryKind kind) { switch (kind) { case AsmResourceEntryKind::Blob: return "blob"; case AsmResourceEntryKind::Bool: return "bool"; case AsmResourceEntryKind::String: return "string"; } llvm_unreachable("unknown AsmResourceEntryKind"); } AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) { std::unique_ptr &collection = keyToResources[key.str()]; if (!collection) collection = std::make_unique(key); return *collection; } std::vector> FallbackAsmResourceMap::getPrinters() { std::vector> printers; for (auto &it : keyToResources) { ResourceCollection *collection = it.second.get(); auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) { return collection->buildResources(op, builder); }; printers.emplace_back( AsmResourcePrinter::fromCallable(collection->getName(), buildValues)); } return printers; } LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource( AsmParsedResourceEntry &entry) { switch (entry.getKind()) { case AsmResourceEntryKind::Blob: { FailureOr blob = entry.parseAsBlob(); if (failed(blob)) return failure(); resources.emplace_back(entry.getKey(), std::move(*blob)); return success(); } case AsmResourceEntryKind::Bool: { FailureOr value = entry.parseAsBool(); if (failed(value)) return failure(); resources.emplace_back(entry.getKey(), *value); break; } case AsmResourceEntryKind::String: { FailureOr str = entry.parseAsString(); if (failed(str)) return failure(); resources.emplace_back(entry.getKey(), std::move(*str)); break; } } return success(); } void FallbackAsmResourceMap::ResourceCollection::buildResources( Operation *op, AsmResourceBuilder &builder) const { for (const auto &entry : resources) { if (const auto *value = std::get_if(&entry.value)) builder.buildBlob(entry.key, *value); else if (const auto *value = std::get_if(&entry.value)) builder.buildBool(entry.key, *value); else if (const auto *value = std::get_if(&entry.value)) builder.buildString(entry.key, *value); else llvm_unreachable("unknown AsmResourceEntryKind"); } } //===----------------------------------------------------------------------===// // AsmState //===----------------------------------------------------------------------===// namespace mlir { namespace detail { class AsmStateImpl { public: explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, AsmState::LocationMap *locationMap) : interfaces(op->getContext()), nameState(op, printerFlags), printerFlags(printerFlags), locationMap(locationMap) {} explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, AsmState::LocationMap *locationMap) : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {} /// Initialize the alias state to enable the printing of aliases. void initializeAliases(Operation *op) { aliasState.initialize(op, printerFlags, interfaces); } /// Get the state used for aliases. AliasState &getAliasState() { return aliasState; } /// Get the state used for SSA names. SSANameState &getSSANameState() { return nameState; } /// Get the state used for distinct attribute identifiers. DistinctState &getDistinctState() { return distinctState; } /// Return the dialects within the context that implement /// OpAsmDialectInterface. DialectInterfaceCollection &getDialectInterfaces() { return interfaces; } /// Return the non-dialect resource printers. auto getResourcePrinters() { return llvm::make_pointee_range(externalResourcePrinters); } /// Get the printer flags. const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } /// Register the location, line and column, within the buffer that the given /// operation was printed at. void registerOperationLocation(Operation *op, unsigned line, unsigned col) { if (locationMap) (*locationMap)[op] = std::make_pair(line, col); } /// Return the referenced dialect resources within the printer. DenseMap> & getDialectResources() { return dialectResources; } LogicalResult pushCyclicPrinting(const void *opaquePointer) { return success(cyclicPrintingStack.insert(opaquePointer)); } void popCyclicPrinting() { cyclicPrintingStack.pop_back(); } private: /// Collection of OpAsm interfaces implemented in the context. DialectInterfaceCollection interfaces; /// A collection of non-dialect resource printers. SmallVector> externalResourcePrinters; /// A set of dialect resources that were referenced during printing. DenseMap> dialectResources; /// The state used for attribute and type aliases. AliasState aliasState; /// The state used for SSA value names. SSANameState nameState; /// The state used for distinct attribute identifiers. DistinctState distinctState; /// Flags that control op output. OpPrintingFlags printerFlags; /// An optional location map to be populated. AsmState::LocationMap *locationMap; /// Stack of potentially cyclic mutable attributes or type currently being /// printed. SetVector cyclicPrintingStack; // Allow direct access to the impl fields. friend AsmState; }; template void printDimensionList(raw_ostream &stream, Range &&shape) { llvm::interleave( shape, stream, [&stream](const auto &dimSize) { if (ShapedType::isDynamic(dimSize)) stream << "?"; else stream << dimSize; }, "x"); } } // namespace detail } // namespace mlir /// Verifies the operation and switches to generic op printing if verification /// fails. We need to do this because custom print functions may fail for /// invalid ops. static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, OpPrintingFlags printerFlags) { if (printerFlags.shouldPrintGenericOpForm() || printerFlags.shouldAssumeVerified()) return printerFlags; // Ignore errors emitted by the verifier. We check the thread id to avoid // consuming other threads' errors. auto parentThreadId = llvm::get_threadid(); ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) { if (parentThreadId == llvm::get_threadid()) { LLVM_DEBUG({ diag.print(llvm::dbgs()); llvm::dbgs() << "\n"; }); return success(); } return failure(); }); if (failed(verify(op))) { LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": '" << op->getName() << "' failed to verify and will be printed in generic form\n"); printerFlags.printGenericOpForm(); } return printerFlags; } AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, LocationMap *locationMap, FallbackAsmResourceMap *map) : impl(std::make_unique( op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) { if (map) attachFallbackResourcePrinter(*map); } AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags, LocationMap *locationMap, FallbackAsmResourceMap *map) : impl(std::make_unique(ctx, printerFlags, locationMap)) { if (map) attachFallbackResourcePrinter(*map); } AsmState::~AsmState() = default; const OpPrintingFlags &AsmState::getPrinterFlags() const { return impl->getPrinterFlags(); } void AsmState::attachResourcePrinter( std::unique_ptr printer) { impl->externalResourcePrinters.emplace_back(std::move(printer)); } DenseMap> & AsmState::getDialectResources() const { return impl->getDialectResources(); } //===----------------------------------------------------------------------===// // AsmPrinter::Impl //===----------------------------------------------------------------------===// AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state) : os(os), state(state), printerFlags(state.getPrinterFlags()) {} void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) { // Check to see if we are printing debug information. if (!printerFlags.shouldPrintDebugInfo()) return; os << " "; printLocation(loc, /*allowAlias=*/allowAlias); } void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, bool isTopLevel) { // If this isn't a top-level location, check for an alias. if (!isTopLevel && succeeded(state.getAliasState().getAlias(loc, os))) return; TypeSwitch(loc) .Case([&](OpaqueLoc loc) { printLocationInternal(loc.getFallbackLocation(), pretty); }) .Case([&](UnknownLoc loc) { if (pretty) os << "[unknown]"; else os << "unknown"; }) .Case([&](FileLineColRange loc) { if (pretty) os << loc.getFilename().getValue(); else printEscapedString(loc.getFilename()); if (loc.getEndColumn() == loc.getStartColumn() && loc.getStartLine() == loc.getEndLine()) { os << ':' << loc.getStartLine() << ':' << loc.getStartColumn(); return; } if (loc.getStartLine() == loc.getEndLine()) { os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() << " to :" << loc.getEndColumn(); return; } os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() << " to " << loc.getEndLine() << ':' << loc.getEndColumn(); }) .Case([&](NameLoc loc) { printEscapedString(loc.getName()); // Print the child if it isn't unknown. auto childLoc = loc.getChildLoc(); if (!llvm::isa(childLoc)) { os << '('; printLocationInternal(childLoc, pretty); os << ')'; } }) .Case([&](CallSiteLoc loc) { Location caller = loc.getCaller(); Location callee = loc.getCallee(); if (!pretty) os << "callsite("; printLocationInternal(callee, pretty); if (pretty) { if (llvm::isa(callee)) { if (llvm::isa(caller)) { os << " at "; } else { os << newLine << " at "; } } else { os << newLine << " at "; } } else { os << " at "; } printLocationInternal(caller, pretty); if (!pretty) os << ")"; }) .Case([&](FusedLoc loc) { if (!pretty) os << "fused"; if (Attribute metadata = loc.getMetadata()) { os << '<'; printAttribute(metadata); os << '>'; } os << '['; interleave( loc.getLocations(), [&](Location loc) { printLocationInternal(loc, pretty); }, [&]() { os << ", "; }); os << ']'; }) .Default([&](LocationAttr loc) { // Assumes that this is a dialect-specific attribute and prints it // directly. printAttribute(loc); }); } /// Print a floating point value in a way that the parser will be able to /// round-trip losslessly. static void printFloatValue(const APFloat &apValue, raw_ostream &os, bool *printedHex = nullptr) { // We would like to output the FP constant value in exponential notation, // but we cannot do this if doing so will lose precision. Check here to // make sure that we only output it in exponential format if we can parse // the value back and get the same value. bool isInf = apValue.isInfinity(); bool isNaN = apValue.isNaN(); if (!isInf && !isNaN) { SmallString<128> strValue; apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, /*TruncateZero=*/false); // Check to make sure that the stringized number is not some string like // "Inf" or NaN, that atof will accept, but the lexer will not. Check // that the string matches the "[-+]?[0-9]" regex. assert(((strValue[0] >= '0' && strValue[0] <= '9') || ((strValue[0] == '-' || strValue[0] == '+') && (strValue[1] >= '0' && strValue[1] <= '9'))) && "[-+]?[0-9] regex does not match!"); // Parse back the stringized version and check that the value is equal // (i.e., there is no precision loss). if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { os << strValue; return; } // If it is not, use the default format of APFloat instead of the // exponential notation. strValue.clear(); apValue.toString(strValue); // Make sure that we can parse the default form as a float. if (strValue.str().contains('.')) { os << strValue; return; } } // Print special values in hexadecimal format. The sign bit should be included // in the literal. if (printedHex) *printedHex = true; SmallVector str; APInt apInt = apValue.bitcastToAPInt(); apInt.toString(str, /*Radix=*/16, /*Signed=*/false, /*formatAsCLiteral=*/true); os << str; } void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { if (printerFlags.shouldPrintDebugInfoPrettyForm()) return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true); os << "loc("; if (!allowAlias || failed(printAlias(loc))) printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true); os << ')'; } /// Returns true if the given dialect symbol data is simple enough to print in /// the pretty form. This is essentially when the symbol takes the form: /// identifier (`<` body `>`)? static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { // The name must start with an identifier. if (symName.empty() || !isalpha(symName.front())) return false; // Ignore all the characters that are valid in an identifier in the symbol // name. symName = symName.drop_while( [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); if (symName.empty()) return true; // If we got to an unexpected character, then it must be a <>. Check that the // rest of the symbol is wrapped within <>. return symName.front() == '<' && symName.back() == '>'; } /// Print the given dialect symbol to the stream. static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, StringRef dialectName, StringRef symString) { os << symPrefix << dialectName; // If this symbol name is simple enough, print it directly in pretty form, // otherwise, we print it as an escaped string. if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { os << '.' << symString; return; } os << '<' << symString << '>'; } /// Returns true if the given string can be represented as a bare identifier. static bool isBareIdentifier(StringRef name) { // By making this unsigned, the value passed in to isalnum will always be // in the range 0-255. This is important when building with MSVC because // its implementation will assert. This situation can arise when dealing // with UTF-8 multibyte characters. if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) return false; return llvm::all_of(name.drop_front(), [](unsigned char c) { return isalnum(c) || c == '_' || c == '$' || c == '.'; }); } /// Print the given string as a keyword, or a quoted and escaped string if it /// has any special or non-printable characters in it. static void printKeywordOrString(StringRef keyword, raw_ostream &os) { // If it can be represented as a bare identifier, write it directly. if (isBareIdentifier(keyword)) { os << keyword; return; } // Otherwise, output the keyword wrapped in quotes with proper escaping. os << "\""; printEscapedString(keyword, os); os << '"'; } /// Print the given string as a symbol reference. A symbol reference is /// represented as a string prefixed with '@'. The reference is surrounded with /// ""'s and escaped if it has any special or non-printable characters in it. static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { if (symbolRef.empty()) { os << "@<>"; return; } os << '@'; printKeywordOrString(symbolRef, os); } // Print out a valid ElementsAttr that is succinct and can represent any // potential shape/type, for use when eliding a large ElementsAttr. // // We choose to use a dense resource ElementsAttr literal with conspicuous // content to hopefully alert readers to the fact that this has been elided. static void printElidedElementsAttr(raw_ostream &os) { os << R"(dense_resource<__elided__>)"; } void AsmPrinter::Impl::printResourceHandle( const AsmDialectResourceHandle &resource) { auto *interface = cast(resource.getDialect()); ::printKeywordOrString(interface->getResourceKey(resource), os); state.getDialectResources()[resource.getDialect()].insert(resource); } LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { return state.getAliasState().getAlias(attr, os); } LogicalResult AsmPrinter::Impl::printAlias(Type type) { return state.getAliasState().getAlias(type, os); } void AsmPrinter::Impl::printAttribute(Attribute attr, AttrTypeElision typeElision) { if (!attr) { os << "<>"; return; } // Try to print an alias for this attribute. if (succeeded(printAlias(attr))) return; return printAttributeImpl(attr, typeElision); } void AsmPrinter::Impl::printAttributeImpl(Attribute attr, AttrTypeElision typeElision) { if (!isa(attr.getDialect())) { printDialectAttribute(attr); } else if (auto opaqueAttr = llvm::dyn_cast(attr)) { printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), opaqueAttr.getAttrData()); } else if (llvm::isa(attr)) { os << "unit"; return; } else if (auto distinctAttr = llvm::dyn_cast(attr)) { os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<"; if (!llvm::isa(distinctAttr.getReferencedAttr())) { printAttribute(distinctAttr.getReferencedAttr()); } os << '>'; return; } else if (auto dictAttr = llvm::dyn_cast(attr)) { os << '{'; interleaveComma(dictAttr.getValue(), [&](NamedAttribute attr) { printNamedAttribute(attr); }); os << '}'; } else if (auto intAttr = llvm::dyn_cast(attr)) { Type intType = intAttr.getType(); if (intType.isSignlessInteger(1)) { os << (intAttr.getValue().getBoolValue() ? "true" : "false"); // Boolean integer attributes always elides the type. return; } // Only print attributes as unsigned if they are explicitly unsigned or are // signless 1-bit values. Indexes, signed values, and multi-bit signless // values print as signed. bool isUnsigned = intType.isUnsignedInteger() || intType.isSignlessInteger(1); intAttr.getValue().print(os, !isUnsigned); // IntegerAttr elides the type if I64. if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64)) return; } else if (auto floatAttr = llvm::dyn_cast(attr)) { bool printedHex = false; printFloatValue(floatAttr.getValue(), os, &printedHex); // FloatAttr elides the type if F64. if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64() && !printedHex) return; } else if (auto strAttr = llvm::dyn_cast(attr)) { printEscapedString(strAttr.getValue()); } else if (auto arrayAttr = llvm::dyn_cast(attr)) { os << '['; interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { printAttribute(attr, AttrTypeElision::May); }); os << ']'; } else if (auto affineMapAttr = llvm::dyn_cast(attr)) { os << "affine_map<"; affineMapAttr.getValue().print(os); os << '>'; // AffineMap always elides the type. return; } else if (auto integerSetAttr = llvm::dyn_cast(attr)) { os << "affine_set<"; integerSetAttr.getValue().print(os); os << '>'; // IntegerSet always elides the type. return; } else if (auto typeAttr = llvm::dyn_cast(attr)) { printType(typeAttr.getValue()); } else if (auto refAttr = llvm::dyn_cast(attr)) { printSymbolReference(refAttr.getRootReference().getValue(), os); for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { os << "::"; printSymbolReference(nestedRef.getValue(), os); } } else if (auto intOrFpEltAttr = llvm::dyn_cast(attr)) { if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { printElidedElementsAttr(os); } else { os << "dense<"; printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); os << '>'; } } else if (auto strEltAttr = llvm::dyn_cast(attr)) { if (printerFlags.shouldElideElementsAttr(strEltAttr)) { printElidedElementsAttr(os); } else { os << "dense<"; printDenseStringElementsAttr(strEltAttr); os << '>'; } } else if (auto sparseEltAttr = llvm::dyn_cast(attr)) { if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { printElidedElementsAttr(os); } else { os << "sparse<"; DenseIntElementsAttr indices = sparseEltAttr.getIndices(); if (indices.getNumElements() != 0) { printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); os << ", "; printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true); } os << '>'; } } else if (auto stridedLayoutAttr = llvm::dyn_cast(attr)) { stridedLayoutAttr.print(os); } else if (auto denseArrayAttr = llvm::dyn_cast(attr)) { os << "array<"; printType(denseArrayAttr.getElementType()); if (!denseArrayAttr.empty()) { os << ": "; printDenseArrayAttr(denseArrayAttr); } os << ">"; return; } else if (auto resourceAttr = llvm::dyn_cast(attr)) { os << "dense_resource<"; printResourceHandle(resourceAttr.getRawHandle()); os << ">"; } else if (auto locAttr = llvm::dyn_cast(attr)) { printLocation(locAttr); } else { llvm::report_fatal_error("Unknown builtin attribute"); } // Don't print the type if we must elide it, or if it is a None type. if (typeElision != AttrTypeElision::Must) { if (auto typedAttr = llvm::dyn_cast(attr)) { Type attrType = typedAttr.getType(); if (!llvm::isa(attrType)) { os << " : "; printType(attrType); } } } } /// Print the integer element of a DenseElementsAttr. static void printDenseIntElement(const APInt &value, raw_ostream &os, Type type) { if (type.isInteger(1)) os << (value.getBoolValue() ? "true" : "false"); else value.print(os, !type.isUnsignedInteger()); } static void printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, function_ref printEltFn) { // Special case for 0-d and splat tensors. if (isSplat) return printEltFn(0); // Special case for degenerate tensors. auto numElements = type.getNumElements(); if (numElements == 0) return; // We use a mixed-radix counter to iterate through the shape. When we bump a // non-least-significant digit, we emit a close bracket. When we next emit an // element we re-open all closed brackets. // The mixed-radix counter, with radices in 'shape'. int64_t rank = type.getRank(); SmallVector counter(rank, 0); // The number of brackets that have been opened and not closed. unsigned openBrackets = 0; auto shape = type.getShape(); auto bumpCounter = [&] { // Bump the least significant digit. ++counter[rank - 1]; // Iterate backwards bubbling back the increment. for (unsigned i = rank - 1; i > 0; --i) if (counter[i] >= shape[i]) { // Index 'i' is rolled over. Bump (i-1) and close a bracket. counter[i] = 0; ++counter[i - 1]; --openBrackets; os << ']'; } }; for (unsigned idx = 0, e = numElements; idx != e; ++idx) { if (idx != 0) os << ", "; while (openBrackets++ < rank) os << '['; openBrackets = rank; printEltFn(idx); bumpCounter(); } while (openBrackets-- > 0) os << ']'; } void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, bool allowHex) { if (auto stringAttr = llvm::dyn_cast(attr)) return printDenseStringElementsAttr(stringAttr); printDenseIntOrFPElementsAttr(llvm::cast(attr), allowHex); } void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( DenseIntOrFPElementsAttr attr, bool allowHex) { auto type = attr.getType(); auto elementType = type.getElementType(); // Check to see if we should format this attribute as a hex string. if (allowHex && printerFlags.shouldPrintElementsAttrWithHex(attr)) { ArrayRef rawData = attr.getRawData(); if (llvm::endianness::native == llvm::endianness::big) { // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE // machines. It is converted here to print in LE format. SmallVector outDataVec(rawData.size()); MutableArrayRef convRawData(outDataVec); DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( rawData, convRawData, type); printHexString(convRawData); } else { printHexString(rawData); } return; } if (ComplexType complexTy = llvm::dyn_cast(elementType)) { Type complexElementType = complexTy.getElementType(); // Note: The if and else below had a common lambda function which invoked // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 // and hence was replaced. if (llvm::isa(complexElementType)) { auto valueIt = attr.value_begin>(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { auto complexValue = *(valueIt + index); os << "("; printDenseIntElement(complexValue.real(), os, complexElementType); os << ","; printDenseIntElement(complexValue.imag(), os, complexElementType); os << ")"; }); } else { auto valueIt = attr.value_begin>(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { auto complexValue = *(valueIt + index); os << "("; printFloatValue(complexValue.real(), os); os << ","; printFloatValue(complexValue.imag(), os); os << ")"; }); } } else if (elementType.isIntOrIndex()) { auto valueIt = attr.value_begin(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { printDenseIntElement(*(valueIt + index), os, elementType); }); } else { assert(llvm::isa(elementType) && "unexpected element type"); auto valueIt = attr.value_begin(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { printFloatValue(*(valueIt + index), os); }); } } void AsmPrinter::Impl::printDenseStringElementsAttr( DenseStringElementsAttr attr) { ArrayRef data = attr.getRawStringData(); auto printFn = [&](unsigned index) { printEscapedString(data[index]); }; printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); } void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) { Type type = attr.getElementType(); unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth(); unsigned byteSize = bitwidth / 8; ArrayRef data = attr.getRawData(); auto printElementAt = [&](unsigned i) { APInt value(bitwidth, 0); if (bitwidth) { llvm::LoadIntFromMemory( value, reinterpret_cast(data.begin() + byteSize * i), byteSize); } // Print the data as-is or as a float. if (type.isIntOrIndex()) { printDenseIntElement(value, getStream(), type); } else { APFloat fltVal(llvm::cast(type).getFloatSemantics(), value); printFloatValue(fltVal, getStream()); } }; llvm::interleaveComma(llvm::seq(0, attr.size()), getStream(), printElementAt); } void AsmPrinter::Impl::printType(Type type) { if (!type) { os << "<>"; return; } // Try to print an alias for this type. if (succeeded(printAlias(type))) return; return printTypeImpl(type); } void AsmPrinter::Impl::printTypeImpl(Type type) { TypeSwitch(type) .Case([&](OpaqueType opaqueTy) { printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), opaqueTy.getTypeData()); }) .Case([&](Type) { os << "index"; }) .Case([&](Type) { os << "f4E2M1FN"; }) .Case([&](Type) { os << "f6E2M3FN"; }) .Case([&](Type) { os << "f6E3M2FN"; }) .Case([&](Type) { os << "f8E5M2"; }) .Case([&](Type) { os << "f8E4M3"; }) .Case([&](Type) { os << "f8E4M3FN"; }) .Case([&](Type) { os << "f8E5M2FNUZ"; }) .Case([&](Type) { os << "f8E4M3FNUZ"; }) .Case([&](Type) { os << "f8E4M3B11FNUZ"; }) .Case([&](Type) { os << "f8E3M4"; }) .Case([&](Type) { os << "f8E8M0FNU"; }) .Case([&](Type) { os << "bf16"; }) .Case([&](Type) { os << "f16"; }) .Case([&](Type) { os << "tf32"; }) .Case([&](Type) { os << "f32"; }) .Case([&](Type) { os << "f64"; }) .Case([&](Type) { os << "f80"; }) .Case([&](Type) { os << "f128"; }) .Case([&](IntegerType integerTy) { if (integerTy.isSigned()) os << 's'; else if (integerTy.isUnsigned()) os << 'u'; os << 'i' << integerTy.getWidth(); }) .Case([&](FunctionType funcTy) { os << '('; interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); os << ") -> "; ArrayRef results = funcTy.getResults(); if (results.size() == 1 && !llvm::isa(results[0])) { printType(results[0]); } else { os << '('; interleaveComma(results, [&](Type ty) { printType(ty); }); os << ')'; } }) .Case([&](VectorType vectorTy) { auto scalableDims = vectorTy.getScalableDims(); os << "vector<"; auto vShape = vectorTy.getShape(); unsigned lastDim = vShape.size(); unsigned dimIdx = 0; for (dimIdx = 0; dimIdx < lastDim; dimIdx++) { if (!scalableDims.empty() && scalableDims[dimIdx]) os << '['; os << vShape[dimIdx]; if (!scalableDims.empty() && scalableDims[dimIdx]) os << ']'; os << 'x'; } printType(vectorTy.getElementType()); os << '>'; }) .Case([&](RankedTensorType tensorTy) { os << "tensor<"; printDimensionList(tensorTy.getShape()); if (!tensorTy.getShape().empty()) os << 'x'; printType(tensorTy.getElementType()); // Only print the encoding attribute value if set. if (tensorTy.getEncoding()) { os << ", "; printAttribute(tensorTy.getEncoding()); } os << '>'; }) .Case([&](UnrankedTensorType tensorTy) { os << "tensor<*x"; printType(tensorTy.getElementType()); os << '>'; }) .Case([&](MemRefType memrefTy) { os << "memref<"; printDimensionList(memrefTy.getShape()); if (!memrefTy.getShape().empty()) os << 'x'; printType(memrefTy.getElementType()); MemRefLayoutAttrInterface layout = memrefTy.getLayout(); if (!llvm::isa(layout) || !layout.isIdentity()) { os << ", "; printAttribute(memrefTy.getLayout(), AttrTypeElision::May); } // Only print the memory space if it is the non-default one. if (memrefTy.getMemorySpace()) { os << ", "; printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); } os << '>'; }) .Case([&](UnrankedMemRefType memrefTy) { os << "memref<*x"; printType(memrefTy.getElementType()); // Only print the memory space if it is the non-default one. if (memrefTy.getMemorySpace()) { os << ", "; printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); } os << '>'; }) .Case([&](ComplexType complexTy) { os << "complex<"; printType(complexTy.getElementType()); os << '>'; }) .Case([&](TupleType tupleTy) { os << "tuple<"; interleaveComma(tupleTy.getTypes(), [&](Type type) { printType(type); }); os << '>'; }) .Case([&](Type) { os << "none"; }) .Default([&](Type type) { return printDialectType(type); }); } void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs, bool withKeyword) { // If there are no attributes, then there is nothing to be done. if (attrs.empty()) return; // Functor used to print a filtered attribute list. auto printFilteredAttributesFn = [&](auto filteredAttrs) { // Print the 'attributes' keyword if necessary. if (withKeyword) os << " attributes"; // Otherwise, print them all out in braces. os << " {"; interleaveComma(filteredAttrs, [&](NamedAttribute attr) { printNamedAttribute(attr); }); os << '}'; }; // If no attributes are elided, we can directly print with no filtering. if (elidedAttrs.empty()) return printFilteredAttributesFn(attrs); // Otherwise, filter out any attributes that shouldn't be included. llvm::SmallDenseSet elidedAttrsSet(elidedAttrs.begin(), elidedAttrs.end()); auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) { return !elidedAttrsSet.contains(attr.getName().strref()); }); if (!filteredAttrs.empty()) printFilteredAttributesFn(filteredAttrs); } void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { // Print the name without quotes if possible. ::printKeywordOrString(attr.getName().strref(), os); // Pretty printing elides the attribute value for unit attributes. if (llvm::isa(attr.getValue())) return; os << " = "; printAttribute(attr.getValue()); } void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { auto &dialect = attr.getDialect(); // Ask the dialect to serialize the attribute to a string. std::string attrName; { llvm::raw_string_ostream attrNameStr(attrName); Impl subPrinter(attrNameStr, state); DialectAsmPrinter printer(subPrinter); dialect.printAttribute(attr, printer); } printDialectSymbol(os, "#", dialect.getNamespace(), attrName); } void AsmPrinter::Impl::printDialectType(Type type) { auto &dialect = type.getDialect(); // Ask the dialect to serialize the type to a string. std::string typeName; { llvm::raw_string_ostream typeNameStr(typeName); Impl subPrinter(typeNameStr, state); DialectAsmPrinter printer(subPrinter); dialect.printType(type, printer); } printDialectSymbol(os, "!", dialect.getNamespace(), typeName); } void AsmPrinter::Impl::printEscapedString(StringRef str) { os << "\""; llvm::printEscapedString(str, os); os << "\""; } void AsmPrinter::Impl::printHexString(StringRef str) { os << "\"0x" << llvm::toHex(str) << "\""; } void AsmPrinter::Impl::printHexString(ArrayRef data) { printHexString(StringRef(data.data(), data.size())); } LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) { return state.pushCyclicPrinting(opaquePointer); } void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); } void AsmPrinter::Impl::printDimensionList(ArrayRef shape) { detail::printDimensionList(os, shape); } //===--------------------------------------------------------------------===// // AsmPrinter //===--------------------------------------------------------------------===// AsmPrinter::~AsmPrinter() = default; raw_ostream &AsmPrinter::getStream() const { assert(impl && "expected AsmPrinter::getStream to be overriden"); return impl->getStream(); } /// Print the given floating point value in a stablized form. void AsmPrinter::printFloat(const APFloat &value) { assert(impl && "expected AsmPrinter::printFloat to be overriden"); printFloatValue(value, impl->getStream()); } void AsmPrinter::printType(Type type) { assert(impl && "expected AsmPrinter::printType to be overriden"); impl->printType(type); } void AsmPrinter::printAttribute(Attribute attr) { assert(impl && "expected AsmPrinter::printAttribute to be overriden"); impl->printAttribute(attr); } LogicalResult AsmPrinter::printAlias(Attribute attr) { assert(impl && "expected AsmPrinter::printAlias to be overriden"); return impl->printAlias(attr); } LogicalResult AsmPrinter::printAlias(Type type) { assert(impl && "expected AsmPrinter::printAlias to be overriden"); return impl->printAlias(type); } void AsmPrinter::printAttributeWithoutType(Attribute attr) { assert(impl && "expected AsmPrinter::printAttributeWithoutType to be overriden"); impl->printAttribute(attr, Impl::AttrTypeElision::Must); } void AsmPrinter::printKeywordOrString(StringRef keyword) { assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden"); ::printKeywordOrString(keyword, impl->getStream()); } void AsmPrinter::printString(StringRef keyword) { assert(impl && "expected AsmPrinter::printString to be overriden"); *this << '"'; printEscapedString(keyword, getStream()); *this << '"'; } void AsmPrinter::printSymbolName(StringRef symbolRef) { assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); ::printSymbolReference(symbolRef, impl->getStream()); } void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) { assert(impl && "expected AsmPrinter::printResourceHandle to be overriden"); impl->printResourceHandle(resource); } void AsmPrinter::printDimensionList(ArrayRef shape) { detail::printDimensionList(getStream(), shape); } LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) { return impl->pushCyclicPrinting(opaquePointer); } void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); } //===----------------------------------------------------------------------===// // Affine expressions and maps //===----------------------------------------------------------------------===// void AsmPrinter::Impl::printAffineExpr( AffineExpr expr, function_ref printValueName) { printAffineExprInternal(expr, BindingStrength::Weak, printValueName); } void AsmPrinter::Impl::printAffineExprInternal( AffineExpr expr, BindingStrength enclosingTightness, function_ref printValueName) { const char *binopSpelling = nullptr; switch (expr.getKind()) { case AffineExprKind::SymbolId: { unsigned pos = cast(expr).getPosition(); if (printValueName) printValueName(pos, /*isSymbol=*/true); else os << 's' << pos; return; } case AffineExprKind::DimId: { unsigned pos = cast(expr).getPosition(); if (printValueName) printValueName(pos, /*isSymbol=*/false); else os << 'd' << pos; return; } case AffineExprKind::Constant: os << cast(expr).getValue(); return; case AffineExprKind::Add: binopSpelling = " + "; break; case AffineExprKind::Mul: binopSpelling = " * "; break; case AffineExprKind::FloorDiv: binopSpelling = " floordiv "; break; case AffineExprKind::CeilDiv: binopSpelling = " ceildiv "; break; case AffineExprKind::Mod: binopSpelling = " mod "; break; } auto binOp = cast(expr); AffineExpr lhsExpr = binOp.getLHS(); AffineExpr rhsExpr = binOp.getRHS(); // Handle tightly binding binary operators. if (binOp.getKind() != AffineExprKind::Add) { if (enclosingTightness == BindingStrength::Strong) os << '('; // Pretty print multiplication with -1. auto rhsConst = dyn_cast(rhsExpr); if (rhsConst && binOp.getKind() == AffineExprKind::Mul && rhsConst.getValue() == -1) { os << "-"; printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); if (enclosingTightness == BindingStrength::Strong) os << ')'; return; } printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); os << binopSpelling; printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); if (enclosingTightness == BindingStrength::Strong) os << ')'; return; } // Print out special "pretty" forms for add. if (enclosingTightness == BindingStrength::Strong) os << '('; // Pretty print addition to a product that has a negative operand as a // subtraction. if (auto rhs = dyn_cast(rhsExpr)) { if (rhs.getKind() == AffineExprKind::Mul) { AffineExpr rrhsExpr = rhs.getRHS(); if (auto rrhs = dyn_cast(rrhsExpr)) { if (rrhs.getValue() == -1) { printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); os << " - "; if (rhs.getLHS().getKind() == AffineExprKind::Add) { printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, printValueName); } else { printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, printValueName); } if (enclosingTightness == BindingStrength::Strong) os << ')'; return; } if (rrhs.getValue() < -1) { printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); os << " - "; printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, printValueName); os << " * " << -rrhs.getValue(); if (enclosingTightness == BindingStrength::Strong) os << ')'; return; } } } } // Pretty print addition to a negative number as a subtraction. if (auto rhsConst = dyn_cast(rhsExpr)) { if (rhsConst.getValue() < 0) { printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); os << " - " << -rhsConst.getValue(); if (enclosingTightness == BindingStrength::Strong) os << ')'; return; } } printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); os << " + "; printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); if (enclosingTightness == BindingStrength::Strong) os << ')'; } void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) { printAffineExprInternal(expr, BindingStrength::Weak); isEq ? os << " == 0" : os << " >= 0"; } void AsmPrinter::Impl::printAffineMap(AffineMap map) { // Dimension identifiers. os << '('; for (int i = 0; i < (int)map.getNumDims() - 1; ++i) os << 'd' << i << ", "; if (map.getNumDims() >= 1) os << 'd' << map.getNumDims() - 1; os << ')'; // Symbolic identifiers. if (map.getNumSymbols() != 0) { os << '['; for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) os << 's' << i << ", "; if (map.getNumSymbols() >= 1) os << 's' << map.getNumSymbols() - 1; os << ']'; } // Result affine expressions. os << " -> ("; interleaveComma(map.getResults(), [&](AffineExpr expr) { printAffineExpr(expr); }); os << ')'; } void AsmPrinter::Impl::printIntegerSet(IntegerSet set) { // Dimension identifiers. os << '('; for (unsigned i = 1; i < set.getNumDims(); ++i) os << 'd' << i - 1 << ", "; if (set.getNumDims() >= 1) os << 'd' << set.getNumDims() - 1; os << ')'; // Symbolic identifiers. if (set.getNumSymbols() != 0) { os << '['; for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) os << 's' << i << ", "; if (set.getNumSymbols() >= 1) os << 's' << set.getNumSymbols() - 1; os << ']'; } // Print constraints. os << " : ("; int numConstraints = set.getNumConstraints(); for (int i = 1; i < numConstraints; ++i) { printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); os << ", "; } if (numConstraints >= 1) printAffineConstraint(set.getConstraint(numConstraints - 1), set.isEq(numConstraints - 1)); os << ')'; } //===----------------------------------------------------------------------===// // OperationPrinter //===----------------------------------------------------------------------===// namespace { /// This class contains the logic for printing operations, regions, and blocks. class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { public: using Impl = AsmPrinter::Impl; using Impl::printType; explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) : Impl(os, state), OpAsmPrinter(static_cast(*this)) {} /// Print the given top-level operation. void printTopLevelOperation(Operation *op); /// Print the given operation, including its left-hand side and its right-hand /// side, with its indent and location. void printFullOpWithIndentAndLoc(Operation *op); /// Print the given operation, including its left-hand side and its right-hand /// side, but not including indentation and location. void printFullOp(Operation *op); /// Print the right-hand size of the given operation in the custom or generic /// form. void printCustomOrGenericOp(Operation *op) override; /// Print the right-hand side of the given operation in the generic form. void printGenericOp(Operation *op, bool printOpName) override; /// Print the name of the given block. void printBlockName(Block *block); /// Print the given block. If 'printBlockArgs' is false, the arguments of the /// block are not printed. If 'printBlockTerminator' is false, the terminator /// operation of the block is not printed. void print(Block *block, bool printBlockArgs = true, bool printBlockTerminator = true); /// Print the ID of the given value, optionally with its result number. void printValueID(Value value, bool printResultNo = true, raw_ostream *streamOverride = nullptr) const; /// Print the ID of the given operation. void printOperationID(Operation *op, raw_ostream *streamOverride = nullptr) const; //===--------------------------------------------------------------------===// // OpAsmPrinter methods //===--------------------------------------------------------------------===// /// Print a loc(...) specifier if printing debug info is enabled. Locations /// may be deferred with an alias. void printOptionalLocationSpecifier(Location loc) override { printTrailingLocation(loc); } /// Print a newline and indent the printer to the start of the current /// operation. void printNewline() override { os << newLine; os.indent(currentIndent); } /// Increase indentation. void increaseIndent() override { currentIndent += indentWidth; } /// Decrease indentation. void decreaseIndent() override { currentIndent -= indentWidth; } /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. /// You may pass omitType=true to not print a type, and pass an empty /// attribute list if you don't care for attributes. void printRegionArgument(BlockArgument arg, ArrayRef argAttrs = {}, bool omitType = false) override; /// Print the ID for the given value. void printOperand(Value value) override { printValueID(value); } void printOperand(Value value, raw_ostream &os) override { printValueID(value, /*printResultNo=*/true, &os); } /// Print an optional attribute dictionary with a given set of elided values. void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}) override { Impl::printOptionalAttrDict(attrs, elidedAttrs); } void printOptionalAttrDictWithKeyword( ArrayRef attrs, ArrayRef elidedAttrs = {}) override { Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/true); } /// Print the given successor. void printSuccessor(Block *successor) override; /// Print an operation successor with the operands used for the block /// arguments. void printSuccessorAndUseList(Block *successor, ValueRange succOperands) override; /// Print the given region. void printRegion(Region ®ion, bool printEntryBlockArgs, bool printBlockTerminators, bool printEmptyBlock) override; /// Renumber the arguments for the specified region to the same names as the /// SSA values in namesToUse. This may only be used for IsolatedFromAbove /// operations. If any entry in namesToUse is null, the corresponding /// argument name is left alone. void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { state.getSSANameState().shadowRegionArgs(region, namesToUse); } /// Print the given affine map with the symbol and dimension operands printed /// inline with the map. void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands) override; /// Print the given affine expression with the symbol and dimension operands /// printed inline with the expression. void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands) override; /// Print users of this operation or id of this operation if it has no result. void printUsersComment(Operation *op); /// Print users of this block arg. void printUsersComment(BlockArgument arg); /// Print the users of a value. void printValueUsers(Value value); /// Print either the ids of the result values or the id of the operation if /// the operation has no results. void printUserIDs(Operation *user, bool prefixComma = false); private: /// This class represents a resource builder implementation for the MLIR /// textual assembly format. class ResourceBuilder : public AsmResourceBuilder { public: using ValueFn = function_ref; using PrintFn = function_ref; ResourceBuilder(PrintFn printFn) : printFn(printFn) {} ~ResourceBuilder() override = default; void buildBool(StringRef key, bool data) final { printFn(key, [&](raw_ostream &os) { os << (data ? "true" : "false"); }); } void buildString(StringRef key, StringRef data) final { printFn(key, [&](raw_ostream &os) { os << "\""; llvm::printEscapedString(data, os); os << "\""; }); } void buildBlob(StringRef key, ArrayRef data, uint32_t dataAlignment) final { printFn(key, [&](raw_ostream &os) { // Store the blob in a hex string containing the alignment and the data. llvm::support::ulittle32_t dataAlignmentLE(dataAlignment); os << "\"0x" << llvm::toHex(StringRef(reinterpret_cast(&dataAlignmentLE), sizeof(dataAlignment))) << llvm::toHex(StringRef(data.data(), data.size())) << "\""; }); } private: PrintFn printFn; }; /// Print the metadata dictionary for the file, eliding it if it is empty. void printFileMetadataDictionary(Operation *op); /// Print the resource sections for the file metadata dictionary. /// `checkAddMetadataDict` is used to indicate that metadata is going to be /// added, and the file metadata dictionary should be started if it hasn't /// yet. void printResourceFileMetadata(function_ref checkAddMetadataDict, Operation *op); // Contains the stack of default dialects to use when printing regions. // A new dialect is pushed to the stack before parsing regions nested under an // operation implementing `OpAsmOpInterface`, and popped when done. At the // top-level we start with "builtin" as the default, so that the top-level // `module` operation prints as-is. SmallVector defaultDialectStack{"builtin"}; /// The number of spaces used for indenting nested operations. const static unsigned indentWidth = 2; // This is the current indentation level for nested structures. unsigned currentIndent = 0; }; } // namespace void OperationPrinter::printTopLevelOperation(Operation *op) { // Output the aliases at the top level that can't be deferred. state.getAliasState().printNonDeferredAliases(*this, newLine); // Print the module. printFullOpWithIndentAndLoc(op); os << newLine; // Output the aliases at the top level that can be deferred. state.getAliasState().printDeferredAliases(*this, newLine); // Output any file level metadata. printFileMetadataDictionary(op); } void OperationPrinter::printFileMetadataDictionary(Operation *op) { bool sawMetadataEntry = false; auto checkAddMetadataDict = [&] { if (!std::exchange(sawMetadataEntry, true)) os << newLine << "{-#" << newLine; }; // Add the various types of metadata. printResourceFileMetadata(checkAddMetadataDict, op); // If the file dictionary exists, close it. if (sawMetadataEntry) os << newLine << "#-}" << newLine; } void OperationPrinter::printResourceFileMetadata( function_ref checkAddMetadataDict, Operation *op) { // Functor used to add data entries to the file metadata dictionary. bool hadResource = false; bool needResourceComma = false; bool needEntryComma = false; auto processProvider = [&](StringRef dictName, StringRef name, auto &provider, auto &&...providerArgs) { bool hadEntry = false; auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) { checkAddMetadataDict(); std::string resourceStr; auto printResourceStr = [&](raw_ostream &os) { os << resourceStr; }; std::optional charLimit = printerFlags.getLargeResourceStringLimit(); if (charLimit.has_value()) { llvm::raw_string_ostream ss(resourceStr); valueFn(ss); // Only print entry if its string is small enough. if (resourceStr.size() > charLimit.value()) return; // Don't recompute resourceStr when valueFn is called below. valueFn = printResourceStr; } // Emit the top-level resource entry if we haven't yet. if (!std::exchange(hadResource, true)) { if (needResourceComma) os << "," << newLine; os << " " << dictName << "_resources: {" << newLine; } // Emit the parent resource entry if we haven't yet. if (!std::exchange(hadEntry, true)) { if (needEntryComma) os << "," << newLine; os << " " << name << ": {" << newLine; } else { os << "," << newLine; } os << " "; ::printKeywordOrString(key, os); os << ": "; // Call printResourceStr or original valueFn, depending on charLimit. valueFn(os); }; ResourceBuilder entryBuilder(printFn); provider.buildResources(op, providerArgs..., entryBuilder); needEntryComma |= hadEntry; if (hadEntry) os << newLine << " }"; }; // Print the `dialect_resources` section if we have any dialects with // resources. for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) { auto &dialectResources = state.getDialectResources(); StringRef name = interface.getDialect()->getNamespace(); auto it = dialectResources.find(interface.getDialect()); if (it != dialectResources.end()) processProvider("dialect", name, interface, it->second); else processProvider("dialect", name, interface, SetVector()); } if (hadResource) os << newLine << " }"; // Print the `external_resources` section if we have any external clients with // resources. needEntryComma = false; needResourceComma = hadResource; hadResource = false; for (const auto &printer : state.getResourcePrinters()) processProvider("external", printer.getName(), printer); if (hadResource) os << newLine << " }"; } /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. /// You may pass omitType=true to not print a type, and pass an empty /// attribute list if you don't care for attributes. void OperationPrinter::printRegionArgument(BlockArgument arg, ArrayRef argAttrs, bool omitType) { printOperand(arg); if (!omitType) { os << ": "; printType(arg.getType()); } printOptionalAttrDict(argAttrs); // TODO: We should allow location aliases on block arguments. printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); } void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) { // Track the location of this operation. state.registerOperationLocation(op, newLine.curLine, currentIndent); os.indent(currentIndent); printFullOp(op); printTrailingLocation(op->getLoc()); if (printerFlags.shouldPrintValueUsers()) printUsersComment(op); } void OperationPrinter::printFullOp(Operation *op) { if (size_t numResults = op->getNumResults()) { auto printResultGroup = [&](size_t resultNo, size_t resultCount) { printValueID(op->getResult(resultNo), /*printResultNo=*/false); if (resultCount > 1) os << ':' << resultCount; }; // Check to see if this operation has multiple result groups. ArrayRef resultGroups = state.getSSANameState().getOpResultGroups(op); if (!resultGroups.empty()) { // Interleave the groups excluding the last one, this one will be handled // separately. interleaveComma(llvm::seq(0, resultGroups.size() - 1), [&](int i) { printResultGroup(resultGroups[i], resultGroups[i + 1] - resultGroups[i]); }); os << ", "; printResultGroup(resultGroups.back(), numResults - resultGroups.back()); } else { printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); } os << " = "; } printCustomOrGenericOp(op); } void OperationPrinter::printUsersComment(Operation *op) { unsigned numResults = op->getNumResults(); if (!numResults && op->getNumOperands()) { os << " // id: "; printOperationID(op); } else if (numResults && op->use_empty()) { os << " // unused"; } else if (numResults && !op->use_empty()) { // Print "user" if the operation has one result used to compute one other // result, or is used in one operation with no result. unsigned usedInNResults = 0; unsigned usedInNOperations = 0; SmallPtrSet userSet; for (Operation *user : op->getUsers()) { if (userSet.insert(user).second) { ++usedInNOperations; usedInNResults += user->getNumResults(); } } // We already know that users is not empty. bool exactlyOneUniqueUse = usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1; os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": "; bool shouldPrintBrackets = numResults > 1; auto printOpResult = [&](OpResult opResult) { if (shouldPrintBrackets) os << "("; printValueUsers(opResult); if (shouldPrintBrackets) os << ")"; }; interleaveComma(op->getResults(), printOpResult); } } void OperationPrinter::printUsersComment(BlockArgument arg) { os << "// "; printValueID(arg); if (arg.use_empty()) { os << " is unused"; } else { os << " is used by "; printValueUsers(arg); } os << newLine; } void OperationPrinter::printValueUsers(Value value) { if (value.use_empty()) os << "unused"; // One value might be used as the operand of an operation more than once. // Only print the operations results once in that case. SmallPtrSet userSet; for (auto [index, user] : enumerate(value.getUsers())) { if (userSet.insert(user).second) printUserIDs(user, index); } } void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) { if (prefixComma) os << ", "; if (!user->getNumResults()) { printOperationID(user); } else { interleaveComma(user->getResults(), [this](Value result) { printValueID(result); }); } } void OperationPrinter::printCustomOrGenericOp(Operation *op) { // If requested, always print the generic form. if (!printerFlags.shouldPrintGenericOpForm()) { // Check to see if this is a known operation. If so, use the registered // custom printer hook. if (auto opInfo = op->getRegisteredInfo()) { opInfo->printAssembly(op, *this, defaultDialectStack.back()); return; } // Otherwise try to dispatch to the dialect, if available. if (Dialect *dialect = op->getDialect()) { if (auto opPrinter = dialect->getOperationPrinter(op)) { // Print the op name first. StringRef name = op->getName().getStringRef(); // Only drop the default dialect prefix when it cannot lead to // ambiguities. if (name.count('.') == 1) name.consume_front((defaultDialectStack.back() + ".").str()); os << name; // Print the rest of the op now. opPrinter(op, *this); return; } } } // Otherwise print with the generic assembly form. printGenericOp(op, /*printOpName=*/true); } void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { if (printOpName) printEscapedString(op->getName().getStringRef()); os << '('; interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); os << ')'; // For terminators, print the list of successors and their operands. if (op->getNumSuccessors() != 0) { os << '['; interleaveComma(op->getSuccessors(), [&](Block *successor) { printBlockName(successor); }); os << ']'; } // Print the properties. if (Attribute prop = op->getPropertiesAsAttribute()) { os << " <"; Impl::printAttribute(prop); os << '>'; } // Print regions. if (op->getNumRegions() != 0) { os << " ("; interleaveComma(op->getRegions(), [&](Region ®ion) { printRegion(region, /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); }); os << ')'; } printOptionalAttrDict(op->getPropertiesStorage() ? llvm::to_vector(op->getDiscardableAttrs()) : op->getAttrs()); // Print the type signature of the operation. os << " : "; printFunctionalType(op); } void OperationPrinter::printBlockName(Block *block) { os << state.getSSANameState().getBlockInfo(block).name; } void OperationPrinter::print(Block *block, bool printBlockArgs, bool printBlockTerminator) { // Print the block label and argument list if requested. if (printBlockArgs) { os.indent(currentIndent); printBlockName(block); // Print the argument list if non-empty. if (!block->args_empty()) { os << '('; interleaveComma(block->getArguments(), [&](BlockArgument arg) { printValueID(arg); os << ": "; printType(arg.getType()); // TODO: We should allow location aliases on block arguments. printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); }); os << ')'; } os << ':'; // Print out some context information about the predecessors of this block. if (!block->getParent()) { os << " // block is not in a region!"; } else if (block->hasNoPredecessors()) { if (!block->isEntryBlock()) os << " // no predecessors"; } else if (auto *pred = block->getSinglePredecessor()) { os << " // pred: "; printBlockName(pred); } else { // We want to print the predecessors in a stable order, not in // whatever order the use-list is in, so gather and sort them. SmallVector predIDs; for (auto *pred : block->getPredecessors()) predIDs.push_back(state.getSSANameState().getBlockInfo(pred)); llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) { return lhs.ordering < rhs.ordering; }); os << " // " << predIDs.size() << " preds: "; interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; }); } os << newLine; } currentIndent += indentWidth; if (printerFlags.shouldPrintValueUsers()) { for (BlockArgument arg : block->getArguments()) { os.indent(currentIndent); printUsersComment(arg); } } bool hasTerminator = !block->empty() && block->back().hasTrait(); auto range = llvm::make_range( block->begin(), std::prev(block->end(), (!hasTerminator || printBlockTerminator) ? 0 : 1)); for (auto &op : range) { printFullOpWithIndentAndLoc(&op); os << newLine; } currentIndent -= indentWidth; } void OperationPrinter::printValueID(Value value, bool printResultNo, raw_ostream *streamOverride) const { state.getSSANameState().printValueID(value, printResultNo, streamOverride ? *streamOverride : os); } void OperationPrinter::printOperationID(Operation *op, raw_ostream *streamOverride) const { state.getSSANameState().printOperationID(op, streamOverride ? *streamOverride : os); } void OperationPrinter::printSuccessor(Block *successor) { printBlockName(successor); } void OperationPrinter::printSuccessorAndUseList(Block *successor, ValueRange succOperands) { printBlockName(successor); if (succOperands.empty()) return; os << '('; interleaveComma(succOperands, [this](Value operand) { printValueID(operand); }); os << " : "; interleaveComma(succOperands, [this](Value operand) { printType(operand.getType()); }); os << ')'; } void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, bool printBlockTerminators, bool printEmptyBlock) { if (printerFlags.shouldSkipRegions()) { os << "{...}"; return; } os << "{" << newLine; if (!region.empty()) { auto restoreDefaultDialect = llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); }); if (auto iface = dyn_cast(region.getParentOp())) defaultDialectStack.push_back(iface.getDefaultDialect()); else defaultDialectStack.push_back(""); auto *entryBlock = ®ion.front(); // Force printing the block header if printEmptyBlock is set and the block // is empty or if printEntryBlockArgs is set and there are arguments to // print. bool shouldAlwaysPrintBlockHeader = (printEmptyBlock && entryBlock->empty()) || (printEntryBlockArgs && entryBlock->getNumArguments() != 0); print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators); for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) print(&b); } os.indent(currentIndent) << "}"; } void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands) { if (!mapAttr) { os << "<>"; return; } AffineMap map = mapAttr.getValue(); unsigned numDims = map.getNumDims(); auto printValueName = [&](unsigned pos, bool isSymbol) { unsigned index = isSymbol ? numDims + pos : pos; assert(index < operands.size()); if (isSymbol) os << "symbol("; printValueID(operands[index]); if (isSymbol) os << ')'; }; interleaveComma(map.getResults(), [&](AffineExpr expr) { printAffineExpr(expr, printValueName); }); } void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands) { auto printValueName = [&](unsigned pos, bool isSymbol) { if (!isSymbol) return printValueID(dimOperands[pos]); os << "symbol("; printValueID(symOperands[pos]); os << ')'; }; printAffineExpr(expr, printValueName); } //===----------------------------------------------------------------------===// // print and dump methods //===----------------------------------------------------------------------===// void Attribute::print(raw_ostream &os, bool elideType) const { if (!*this) { os << "<>"; return; } AsmState state(getContext()); print(os, state, elideType); } void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const { using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision; AsmPrinter::Impl(os, state.getImpl()) .printAttribute(*this, elideType ? AttrTypeElision::Must : AttrTypeElision::Never); } void Attribute::dump() const { print(llvm::errs()); llvm::errs() << "\n"; } void Attribute::printStripped(raw_ostream &os, AsmState &state) const { if (!*this) { os << "<>"; return; } AsmPrinter::Impl subPrinter(os, state.getImpl()); if (succeeded(subPrinter.printAlias(*this))) return; auto &dialect = this->getDialect(); uint64_t posPrior = os.tell(); DialectAsmPrinter printer(subPrinter); dialect.printAttribute(*this, printer); if (posPrior != os.tell()) return; // Fallback to printing with prefix if the above failed to write anything // to the output stream. print(os, state); } void Attribute::printStripped(raw_ostream &os) const { if (!*this) { os << "<>"; return; } AsmState state(getContext()); printStripped(os, state); } void Type::print(raw_ostream &os) const { if (!*this) { os << "<>"; return; } AsmState state(getContext()); print(os, state); } void Type::print(raw_ostream &os, AsmState &state) const { AsmPrinter::Impl(os, state.getImpl()).printType(*this); } void Type::dump() const { print(llvm::errs()); llvm::errs() << "\n"; } void AffineMap::dump() const { print(llvm::errs()); llvm::errs() << "\n"; } void IntegerSet::dump() const { print(llvm::errs()); llvm::errs() << "\n"; } void AffineExpr::print(raw_ostream &os) const { if (!expr) { os << "<>"; return; } AsmState state(getContext()); AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(*this); } void AffineExpr::dump() const { print(llvm::errs()); llvm::errs() << "\n"; } void AffineMap::print(raw_ostream &os) const { if (!map) { os << "<>"; return; } AsmState state(getContext()); AsmPrinter::Impl(os, state.getImpl()).printAffineMap(*this); } void IntegerSet::print(raw_ostream &os) const { AsmState state(getContext()); AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(*this); } void Value::print(raw_ostream &os) const { print(os, OpPrintingFlags()); } void Value::print(raw_ostream &os, const OpPrintingFlags &flags) const { if (!impl) { os << "<>"; return; } if (auto *op = getDefiningOp()) return op->print(os, flags); // TODO: Improve BlockArgument print'ing. BlockArgument arg = llvm::cast(*this); os << " of type '" << arg.getType() << "' at index: " << arg.getArgNumber(); } void Value::print(raw_ostream &os, AsmState &state) const { if (!impl) { os << "<>"; return; } if (auto *op = getDefiningOp()) return op->print(os, state); // TODO: Improve BlockArgument print'ing. BlockArgument arg = llvm::cast(*this); os << " of type '" << arg.getType() << "' at index: " << arg.getArgNumber(); } void Value::dump() const { print(llvm::errs()); llvm::errs() << "\n"; } void Value::printAsOperand(raw_ostream &os, AsmState &state) const { // TODO: This doesn't necessarily capture all potential cases. // Currently, region arguments can be shadowed when printing the main // operation. If the IR hasn't been printed, this will produce the old SSA // name and not the shadowed name. state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, os); } static Operation *findParent(Operation *op, bool shouldUseLocalScope) { do { // If we are printing local scope, stop at the first operation that is // isolated from above. if (shouldUseLocalScope && op->hasTrait()) break; // Otherwise, traverse up to the next parent. Operation *parentOp = op->getParentOp(); if (!parentOp) break; op = parentOp; } while (true); return op; } void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) const { Operation *op; if (auto result = llvm::dyn_cast(*this)) { op = result.getOwner(); } else { op = llvm::cast(*this).getOwner()->getParentOp(); if (!op) { os << "<>"; return; } } op = findParent(op, flags.shouldUseLocalScope()); AsmState state(op, flags); printAsOperand(os, state); } void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { // Find the operation to number from based upon the provided flags. Operation *op = findParent(this, printerFlags.shouldUseLocalScope()); AsmState state(op, printerFlags); print(os, state); } void Operation::print(raw_ostream &os, AsmState &state) { OperationPrinter printer(os, state.getImpl()); if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) { state.getImpl().initializeAliases(this); printer.printTopLevelOperation(this); } else { printer.printFullOpWithIndentAndLoc(this); } } void Operation::dump() { print(llvm::errs(), OpPrintingFlags().useLocalScope()); llvm::errs() << "\n"; } void Operation::dumpPretty() { print(llvm::errs(), OpPrintingFlags().useLocalScope().assumeVerified()); llvm::errs() << "\n"; } void Block::print(raw_ostream &os) { Operation *parentOp = getParentOp(); if (!parentOp) { os << "<>\n"; return; } // Get the top-level op. while (auto *nextOp = parentOp->getParentOp()) parentOp = nextOp; AsmState state(parentOp); print(os, state); } void Block::print(raw_ostream &os, AsmState &state) { OperationPrinter(os, state.getImpl()).print(this); } void Block::dump() { print(llvm::errs()); } /// Print out the name of the block without printing its body. void Block::printAsOperand(raw_ostream &os, bool printType) { Operation *parentOp = getParentOp(); if (!parentOp) { os << "<>\n"; return; } AsmState state(parentOp); printAsOperand(os, state); } void Block::printAsOperand(raw_ostream &os, AsmState &state) { OperationPrinter printer(os, state.getImpl()); printer.printBlockName(this); } raw_ostream &mlir::operator<<(raw_ostream &os, Block &block) { block.print(os); return os; } //===--------------------------------------------------------------------===// // Custom printers //===--------------------------------------------------------------------===// namespace mlir { void printDimensionList(OpAsmPrinter &printer, Operation *op, ArrayRef dimensions) { if (dimensions.empty()) printer << "["; printer.printDimensionList(dimensions); if (dimensions.empty()) printer << "]"; } ParseResult parseDimensionList(OpAsmParser &parser, DenseI64ArrayAttr &dimensions) { // Empty list case denoted by "[]". if (succeeded(parser.parseOptionalLSquare())) { if (failed(parser.parseRSquare())) { return parser.emitError(parser.getCurrentLocation()) << "Failed parsing dimension list."; } dimensions = DenseI64ArrayAttr::get(parser.getContext(), ArrayRef()); return success(); } // Non-empty list case. SmallVector shapeArr; if (failed(parser.parseDimensionList(shapeArr, true, false))) { return parser.emitError(parser.getCurrentLocation()) << "Failed parsing dimension list."; } if (shapeArr.empty()) { return parser.emitError(parser.getCurrentLocation()) << "Failed parsing dimension list. Did you mean an empty list? It " "must be denoted by \"[]\"."; } dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr); return success(); } } // namespace mlir