mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 06:26:06 +00:00

This PR adds `f8E8M0FNU` type to MLIR. `f8E8M0FNU` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 8-bit floating point number with bit layout S0E8M0. Unlike IEEE-754 types, there are no infinity, denormals, zeros or negative values. ```c f8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - [PR-107127](https://github.com/llvm/llvm-project/pull/107127) [APFloat] Add APFloat support for E8M0 type - [PR-105573](https://github.com/llvm/llvm-project/pull/105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR - [PR-107999](https://github.com/llvm/llvm-project/pull/107999) [MLIR] Add f6E2M3FN type - [PR-108877](https://github.com/llvm/llvm-project/pull/108877) [MLIR] Add f4E2M1FN type
4059 lines
140 KiB
C++
4059 lines
140 KiB
C++
//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the MLIR AsmPrinter class, which is used to implement
|
|
// the various print() methods on the core IR objects.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinDialect.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/DialectResourceBlobManager.h"
|
|
#include "mlir/IR/IntegerSet.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/Verifier.h"
|
|
#include "llvm/ADT/APFloat.h"
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/MapVector.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/ScopeExit.h"
|
|
#include "llvm/ADT/ScopedHashTable.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallString.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/Endian.h"
|
|
#include "llvm/Support/ManagedStatic.h"
|
|
#include "llvm/Support/Regex.h"
|
|
#include "llvm/Support/SaveAndRestore.h"
|
|
#include "llvm/Support/Threading.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <type_traits>
|
|
|
|
#include <optional>
|
|
#include <tuple>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::detail;
|
|
|
|
#define DEBUG_TYPE "mlir-asm-printer"
|
|
|
|
void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
|
|
|
|
void OperationName::dump() const { print(llvm::errs()); }
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// AsmParser
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
AsmParser::~AsmParser() = default;
|
|
DialectAsmParser::~DialectAsmParser() = default;
|
|
OpAsmParser::~OpAsmParser() = default;
|
|
|
|
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
|
|
|
|
/// Parse a type list.
|
|
/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
|
|
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
|
|
return parseCommaSeparatedList(
|
|
[&]() { return parseType(result.emplace_back()); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DialectAsmPrinter
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DialectAsmPrinter::~DialectAsmPrinter() = default;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpAsmPrinter
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpAsmPrinter::~OpAsmPrinter() = default;
|
|
|
|
void OpAsmPrinter::printFunctionalType(Operation *op) {
|
|
auto &os = getStream();
|
|
os << '(';
|
|
llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
|
|
// Print the types of null values as <<NULL TYPE>>.
|
|
*this << (operand ? operand.getType() : Type());
|
|
});
|
|
os << ") -> ";
|
|
|
|
// Print the result list. We don't parenthesize single result types unless
|
|
// it is a function (avoiding a grammar ambiguity).
|
|
bool wrapped = op->getNumResults() != 1;
|
|
if (!wrapped && op->getResult(0).getType() &&
|
|
llvm::isa<FunctionType>(op->getResult(0).getType()))
|
|
wrapped = true;
|
|
|
|
if (wrapped)
|
|
os << '(';
|
|
|
|
llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
|
|
// Print the types of null values as <<NULL TYPE>>.
|
|
*this << (result ? result.getType() : Type());
|
|
});
|
|
|
|
if (wrapped)
|
|
os << ')';
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operation OpAsm interface.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
|
|
#include "mlir/IR/OpAsmInterface.cpp.inc"
|
|
|
|
LogicalResult
|
|
OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
|
|
return entry.emitError() << "unknown 'resource' key '" << entry.getKey()
|
|
<< "' for dialect '" << getDialect()->getNamespace()
|
|
<< "'";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OpPrintingFlags
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This struct contains command line options that can be used to initialize
|
|
/// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
|
|
/// for global command line options.
|
|
struct AsmPrinterOptions {
|
|
llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
|
|
"mlir-print-elementsattrs-with-hex-if-larger",
|
|
llvm::cl::desc(
|
|
"Print DenseElementsAttrs with a hex string that have "
|
|
"more elements than the given upper limit (use -1 to disable)")};
|
|
|
|
llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
|
|
"mlir-elide-elementsattrs-if-larger",
|
|
llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
|
|
"more elements than the given upper limit")};
|
|
|
|
llvm::cl::opt<unsigned> elideResourceStringsIfLarger{
|
|
"mlir-elide-resource-strings-if-larger",
|
|
llvm::cl::desc(
|
|
"Elide printing value of resources if string is too long in chars.")};
|
|
|
|
llvm::cl::opt<bool> printDebugInfoOpt{
|
|
"mlir-print-debuginfo", llvm::cl::init(false),
|
|
llvm::cl::desc("Print debug info in MLIR output")};
|
|
|
|
llvm::cl::opt<bool> printPrettyDebugInfoOpt{
|
|
"mlir-pretty-debuginfo", llvm::cl::init(false),
|
|
llvm::cl::desc("Print pretty debug info in MLIR output")};
|
|
|
|
// Use the generic op output form in the operation printer even if the custom
|
|
// form is defined.
|
|
llvm::cl::opt<bool> printGenericOpFormOpt{
|
|
"mlir-print-op-generic", llvm::cl::init(false),
|
|
llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
|
|
|
|
llvm::cl::opt<bool> assumeVerifiedOpt{
|
|
"mlir-print-assume-verified", llvm::cl::init(false),
|
|
llvm::cl::desc("Skip op verification when using custom printers"),
|
|
llvm::cl::Hidden};
|
|
|
|
llvm::cl::opt<bool> printLocalScopeOpt{
|
|
"mlir-print-local-scope", llvm::cl::init(false),
|
|
llvm::cl::desc("Print with local scope and inline information (eliding "
|
|
"aliases for attributes, types, and locations")};
|
|
|
|
llvm::cl::opt<bool> skipRegionsOpt{
|
|
"mlir-print-skip-regions", llvm::cl::init(false),
|
|
llvm::cl::desc("Skip regions when printing ops.")};
|
|
|
|
llvm::cl::opt<bool> printValueUsers{
|
|
"mlir-print-value-users", llvm::cl::init(false),
|
|
llvm::cl::desc(
|
|
"Print users of operation results and block arguments as a comment")};
|
|
|
|
llvm::cl::opt<bool> printUniqueSSAIDs{
|
|
"mlir-print-unique-ssa-ids", llvm::cl::init(false),
|
|
llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
|
|
"and naming conflicts across all regions")};
|
|
};
|
|
} // namespace
|
|
|
|
static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
|
|
|
|
/// Register a set of useful command-line options that can be used to configure
|
|
/// various flags within the AsmPrinter.
|
|
void mlir::registerAsmPrinterCLOptions() {
|
|
// Make sure that the options struct has been initialized.
|
|
*clOptions;
|
|
}
|
|
|
|
/// Initialize the printing flags with default supplied by the cl::opts above.
|
|
OpPrintingFlags::OpPrintingFlags()
|
|
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
|
|
printGenericOpFormFlag(false), skipRegionsFlag(false),
|
|
assumeVerifiedFlag(false), printLocalScope(false),
|
|
printValueUsersFlag(false), printUniqueSSAIDsFlag(false) {
|
|
// Initialize based upon command line options, if they are available.
|
|
if (!clOptions.isConstructed())
|
|
return;
|
|
if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
|
|
elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
|
|
if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences())
|
|
elementsAttrHexElementLimit =
|
|
clOptions->printElementsAttrWithHexIfLarger.getValue();
|
|
if (clOptions->elideResourceStringsIfLarger.getNumOccurrences())
|
|
resourceStringCharLimit = clOptions->elideResourceStringsIfLarger;
|
|
printDebugInfoFlag = clOptions->printDebugInfoOpt;
|
|
printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
|
|
printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
|
|
assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
|
|
printLocalScope = clOptions->printLocalScopeOpt;
|
|
skipRegionsFlag = clOptions->skipRegionsOpt;
|
|
printValueUsersFlag = clOptions->printValueUsers;
|
|
printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
|
|
}
|
|
|
|
/// Enable the elision of large elements attributes, by printing a '...'
|
|
/// instead of the element data, when the number of elements is greater than
|
|
/// `largeElementLimit`. Note: The IR generated with this option is not
|
|
/// parsable.
|
|
OpPrintingFlags &
|
|
OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
|
|
elementsAttrElementLimit = largeElementLimit;
|
|
return *this;
|
|
}
|
|
|
|
OpPrintingFlags &
|
|
OpPrintingFlags::printLargeElementsAttrWithHex(int64_t largeElementLimit) {
|
|
elementsAttrHexElementLimit = largeElementLimit;
|
|
return *this;
|
|
}
|
|
|
|
OpPrintingFlags &
|
|
OpPrintingFlags::elideLargeResourceString(int64_t largeResourceLimit) {
|
|
resourceStringCharLimit = largeResourceLimit;
|
|
return *this;
|
|
}
|
|
|
|
/// Enable printing of debug information. If 'prettyForm' is set to true,
|
|
/// debug information is printed in a more readable 'pretty' form.
|
|
OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable,
|
|
bool prettyForm) {
|
|
printDebugInfoFlag = enable;
|
|
printDebugInfoPrettyFormFlag = prettyForm;
|
|
return *this;
|
|
}
|
|
|
|
/// Always print operations in the generic form.
|
|
OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) {
|
|
printGenericOpFormFlag = enable;
|
|
return *this;
|
|
}
|
|
|
|
/// Always skip Regions.
|
|
OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) {
|
|
skipRegionsFlag = skip;
|
|
return *this;
|
|
}
|
|
|
|
/// Do not verify the operation when using custom operation printers.
|
|
OpPrintingFlags &OpPrintingFlags::assumeVerified() {
|
|
assumeVerifiedFlag = true;
|
|
return *this;
|
|
}
|
|
|
|
/// Use local scope when printing the operation. This allows for using the
|
|
/// printer in a more localized and thread-safe setting, but may not necessarily
|
|
/// be identical of what the IR will look like when dumping the full module.
|
|
OpPrintingFlags &OpPrintingFlags::useLocalScope() {
|
|
printLocalScope = true;
|
|
return *this;
|
|
}
|
|
|
|
/// Print users of values as comments.
|
|
OpPrintingFlags &OpPrintingFlags::printValueUsers() {
|
|
printValueUsersFlag = true;
|
|
return *this;
|
|
}
|
|
|
|
/// Return if the given ElementsAttr should be elided.
|
|
bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
|
|
return elementsAttrElementLimit &&
|
|
*elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
|
|
!llvm::isa<SplatElementsAttr>(attr);
|
|
}
|
|
|
|
/// Return if the given ElementsAttr should be printed as hex string.
|
|
bool OpPrintingFlags::shouldPrintElementsAttrWithHex(ElementsAttr attr) const {
|
|
// -1 is used to disable hex printing.
|
|
return (elementsAttrHexElementLimit != -1) &&
|
|
(elementsAttrHexElementLimit < int64_t(attr.getNumElements())) &&
|
|
!llvm::isa<SplatElementsAttr>(attr);
|
|
}
|
|
|
|
/// Return the size limit for printing large ElementsAttr.
|
|
std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
|
|
return elementsAttrElementLimit;
|
|
}
|
|
|
|
/// Return the size limit for printing large ElementsAttr as hex string.
|
|
int64_t OpPrintingFlags::getLargeElementsAttrHexLimit() const {
|
|
return elementsAttrHexElementLimit;
|
|
}
|
|
|
|
/// Return the size limit for printing large ElementsAttr.
|
|
std::optional<uint64_t> OpPrintingFlags::getLargeResourceStringLimit() const {
|
|
return resourceStringCharLimit;
|
|
}
|
|
|
|
/// Return if debug information should be printed.
|
|
bool OpPrintingFlags::shouldPrintDebugInfo() const {
|
|
return printDebugInfoFlag;
|
|
}
|
|
|
|
/// Return if debug information should be printed in the pretty form.
|
|
bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
|
|
return printDebugInfoPrettyFormFlag;
|
|
}
|
|
|
|
/// Return if operations should be printed in the generic form.
|
|
bool OpPrintingFlags::shouldPrintGenericOpForm() const {
|
|
return printGenericOpFormFlag;
|
|
}
|
|
|
|
/// Return if Region should be skipped.
|
|
bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; }
|
|
|
|
/// Return if operation verification should be skipped.
|
|
bool OpPrintingFlags::shouldAssumeVerified() const {
|
|
return assumeVerifiedFlag;
|
|
}
|
|
|
|
/// Return if the printer should use local scope when dumping the IR.
|
|
bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
|
|
|
|
/// Return if the printer should print users of values.
|
|
bool OpPrintingFlags::shouldPrintValueUsers() const {
|
|
return printValueUsersFlag;
|
|
}
|
|
|
|
/// Return if the printer should use unique IDs.
|
|
bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
|
|
return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// NewLineCounter
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This class is a simple formatter that emits a new line when inputted into a
|
|
/// stream, that enables counting the number of newlines emitted. This class
|
|
/// should be used whenever emitting newlines in the printer.
|
|
struct NewLineCounter {
|
|
unsigned curLine = 1;
|
|
};
|
|
|
|
static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
|
|
++newLine.curLine;
|
|
return os << '\n';
|
|
}
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AsmPrinter::Impl
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace mlir {
|
|
class AsmPrinter::Impl {
|
|
public:
|
|
Impl(raw_ostream &os, AsmStateImpl &state);
|
|
explicit Impl(Impl &other) : Impl(other.os, other.state) {}
|
|
|
|
/// Returns the output stream of the printer.
|
|
raw_ostream &getStream() { return os; }
|
|
|
|
template <typename Container, typename UnaryFunctor>
|
|
inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
|
|
llvm::interleaveComma(c, os, eachFn);
|
|
}
|
|
|
|
/// This enum describes the different kinds of elision for the type of an
|
|
/// attribute when printing it.
|
|
enum class AttrTypeElision {
|
|
/// The type must not be elided,
|
|
Never,
|
|
/// The type may be elided when it matches the default used in the parser
|
|
/// (for example i64 is the default for integer attributes).
|
|
May,
|
|
/// The type must be elided.
|
|
Must
|
|
};
|
|
|
|
/// Print the given attribute or an alias.
|
|
void printAttribute(Attribute attr,
|
|
AttrTypeElision typeElision = AttrTypeElision::Never);
|
|
/// Print the given attribute without considering an alias.
|
|
void printAttributeImpl(Attribute attr,
|
|
AttrTypeElision typeElision = AttrTypeElision::Never);
|
|
|
|
/// Print the alias for the given attribute, return failure if no alias could
|
|
/// be printed.
|
|
LogicalResult printAlias(Attribute attr);
|
|
|
|
/// Print the given type or an alias.
|
|
void printType(Type type);
|
|
/// Print the given type.
|
|
void printTypeImpl(Type type);
|
|
|
|
/// Print the alias for the given type, return failure if no alias could
|
|
/// be printed.
|
|
LogicalResult printAlias(Type type);
|
|
|
|
/// Print the given location to the stream. If `allowAlias` is true, this
|
|
/// allows for the internal location to use an attribute alias.
|
|
void printLocation(LocationAttr loc, bool allowAlias = false);
|
|
|
|
/// Print a reference to the given resource that is owned by the given
|
|
/// dialect.
|
|
void printResourceHandle(const AsmDialectResourceHandle &resource);
|
|
|
|
void printAffineMap(AffineMap map);
|
|
void
|
|
printAffineExpr(AffineExpr expr,
|
|
function_ref<void(unsigned, bool)> printValueName = nullptr);
|
|
void printAffineConstraint(AffineExpr expr, bool isEq);
|
|
void printIntegerSet(IntegerSet set);
|
|
|
|
LogicalResult pushCyclicPrinting(const void *opaquePointer);
|
|
|
|
void popCyclicPrinting();
|
|
|
|
void printDimensionList(ArrayRef<int64_t> shape);
|
|
|
|
protected:
|
|
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs = {},
|
|
bool withKeyword = false);
|
|
void printNamedAttribute(NamedAttribute attr);
|
|
void printTrailingLocation(Location loc, bool allowAlias = true);
|
|
void printLocationInternal(LocationAttr loc, bool pretty = false,
|
|
bool isTopLevel = false);
|
|
|
|
/// Print a dense elements attribute. If 'allowHex' is true, a hex string is
|
|
/// used instead of individual elements when the elements attr is large.
|
|
void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
|
|
|
|
/// Print a dense string elements attribute.
|
|
void printDenseStringElementsAttr(DenseStringElementsAttr attr);
|
|
|
|
/// Print a dense elements attribute. If 'allowHex' is true, a hex string is
|
|
/// used instead of individual elements when the elements attr is large.
|
|
void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
|
|
bool allowHex);
|
|
|
|
/// Print a dense array attribute.
|
|
void printDenseArrayAttr(DenseArrayAttr attr);
|
|
|
|
void printDialectAttribute(Attribute attr);
|
|
void printDialectType(Type type);
|
|
|
|
/// Print an escaped string, wrapped with "".
|
|
void printEscapedString(StringRef str);
|
|
|
|
/// Print a hex string, wrapped with "".
|
|
void printHexString(StringRef str);
|
|
void printHexString(ArrayRef<char> data);
|
|
|
|
/// This enum is used to represent the binding strength of the enclosing
|
|
/// context that an AffineExprStorage is being printed in, so we can
|
|
/// intelligently produce parens.
|
|
enum class BindingStrength {
|
|
Weak, // + and -
|
|
Strong, // All other binary operators.
|
|
};
|
|
void printAffineExprInternal(
|
|
AffineExpr expr, BindingStrength enclosingTightness,
|
|
function_ref<void(unsigned, bool)> printValueName = nullptr);
|
|
|
|
/// The output stream for the printer.
|
|
raw_ostream &os;
|
|
|
|
/// An underlying assembly printer state.
|
|
AsmStateImpl &state;
|
|
|
|
/// A set of flags to control the printer's behavior.
|
|
OpPrintingFlags printerFlags;
|
|
|
|
/// A tracker for the number of new lines emitted during printing.
|
|
NewLineCounter newLine;
|
|
};
|
|
} // namespace mlir
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AliasInitializer
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This class represents a specific instance of a symbol Alias.
|
|
class SymbolAlias {
|
|
public:
|
|
SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType,
|
|
bool isDeferrable)
|
|
: name(name), suffixIndex(suffixIndex), isType(isType),
|
|
isDeferrable(isDeferrable) {}
|
|
|
|
/// Print this alias to the given stream.
|
|
void print(raw_ostream &os) const {
|
|
os << (isType ? "!" : "#") << name;
|
|
if (suffixIndex)
|
|
os << suffixIndex;
|
|
}
|
|
|
|
/// Returns true if this is a type alias.
|
|
bool isTypeAlias() const { return isType; }
|
|
|
|
/// Returns true if this alias supports deferred resolution when parsing.
|
|
bool canBeDeferred() const { return isDeferrable; }
|
|
|
|
private:
|
|
/// The main name of the alias.
|
|
StringRef name;
|
|
/// The suffix index of the alias.
|
|
uint32_t suffixIndex : 30;
|
|
/// A flag indicating whether this alias is for a type.
|
|
bool isType : 1;
|
|
/// A flag indicating whether this alias may be deferred or not.
|
|
bool isDeferrable : 1;
|
|
|
|
public:
|
|
/// Used to avoid printing incomplete aliases for recursive types.
|
|
bool isPrinted = false;
|
|
};
|
|
|
|
/// This class represents a utility that initializes the set of attribute and
|
|
/// type aliases, without the need to store the extra information within the
|
|
/// main AliasState class or pass it around via function arguments.
|
|
class AliasInitializer {
|
|
public:
|
|
AliasInitializer(
|
|
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
|
|
llvm::BumpPtrAllocator &aliasAllocator)
|
|
: interfaces(interfaces), aliasAllocator(aliasAllocator),
|
|
aliasOS(aliasBuffer) {}
|
|
|
|
void initialize(Operation *op, const OpPrintingFlags &printerFlags,
|
|
llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
|
|
|
|
/// Visit the given attribute to see if it has an alias. `canBeDeferred` is
|
|
/// set to true if the originator of this attribute can resolve the alias
|
|
/// after parsing has completed (e.g. in the case of operation locations).
|
|
/// `elideType` indicates if the type of the attribute should be skipped when
|
|
/// looking for nested aliases. Returns the maximum alias depth of the
|
|
/// attribute, and the alias index of this attribute.
|
|
std::pair<size_t, size_t> visit(Attribute attr, bool canBeDeferred = false,
|
|
bool elideType = false) {
|
|
return visitImpl(attr, aliases, canBeDeferred, elideType);
|
|
}
|
|
|
|
/// Visit the given type to see if it has an alias. `canBeDeferred` is
|
|
/// set to true if the originator of this attribute can resolve the alias
|
|
/// after parsing has completed. Returns the maximum alias depth of the type,
|
|
/// and the alias index of this type.
|
|
std::pair<size_t, size_t> visit(Type type, bool canBeDeferred = false) {
|
|
return visitImpl(type, aliases, canBeDeferred);
|
|
}
|
|
|
|
private:
|
|
struct InProgressAliasInfo {
|
|
InProgressAliasInfo()
|
|
: aliasDepth(0), isType(false), canBeDeferred(false) {}
|
|
InProgressAliasInfo(StringRef alias)
|
|
: alias(alias), aliasDepth(1), isType(false), canBeDeferred(false) {}
|
|
|
|
bool operator<(const InProgressAliasInfo &rhs) const {
|
|
// Order first by depth, then by attr/type kind, and then by name.
|
|
if (aliasDepth != rhs.aliasDepth)
|
|
return aliasDepth < rhs.aliasDepth;
|
|
if (isType != rhs.isType)
|
|
return isType;
|
|
return alias < rhs.alias;
|
|
}
|
|
|
|
/// The alias for the attribute or type, or std::nullopt if the value has no
|
|
/// alias.
|
|
std::optional<StringRef> alias;
|
|
/// The alias depth of this attribute or type, i.e. an indication of the
|
|
/// relative ordering of when to print this alias.
|
|
unsigned aliasDepth : 30;
|
|
/// If this alias represents a type or an attribute.
|
|
bool isType : 1;
|
|
/// If this alias can be deferred or not.
|
|
bool canBeDeferred : 1;
|
|
/// Indices for child aliases.
|
|
SmallVector<size_t> childIndices;
|
|
};
|
|
|
|
/// Visit the given attribute or type to see if it has an alias.
|
|
/// `canBeDeferred` is set to true if the originator of this value can resolve
|
|
/// the alias after parsing has completed (e.g. in the case of operation
|
|
/// locations). Returns the maximum alias depth of the value, and its alias
|
|
/// index.
|
|
template <typename T, typename... PrintArgs>
|
|
std::pair<size_t, size_t>
|
|
visitImpl(T value,
|
|
llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
|
|
bool canBeDeferred, PrintArgs &&...printArgs);
|
|
|
|
/// Mark the given alias as non-deferrable.
|
|
void markAliasNonDeferrable(size_t aliasIndex);
|
|
|
|
/// Try to generate an alias for the provided symbol. If an alias is
|
|
/// generated, the provided alias mapping and reverse mapping are updated.
|
|
template <typename T>
|
|
void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred);
|
|
|
|
/// Given a collection of aliases and symbols, initialize a mapping from a
|
|
/// symbol to a given alias.
|
|
static void initializeAliases(
|
|
llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
|
|
llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
|
|
|
|
/// The set of asm interfaces within the context.
|
|
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
|
|
|
|
/// An allocator used for alias names.
|
|
llvm::BumpPtrAllocator &aliasAllocator;
|
|
|
|
/// The set of built aliases.
|
|
llvm::MapVector<const void *, InProgressAliasInfo> aliases;
|
|
|
|
/// Storage and stream used when generating an alias.
|
|
SmallString<32> aliasBuffer;
|
|
llvm::raw_svector_ostream aliasOS;
|
|
};
|
|
|
|
/// This class implements a dummy OpAsmPrinter that doesn't print any output,
|
|
/// and merely collects the attributes and types that *would* be printed in a
|
|
/// normal print invocation so that we can generate proper aliases. This allows
|
|
/// for us to generate aliases only for the attributes and types that would be
|
|
/// in the output, and trims down unnecessary output.
|
|
class DummyAliasOperationPrinter : private OpAsmPrinter {
|
|
public:
|
|
explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
|
|
AliasInitializer &initializer)
|
|
: printerFlags(printerFlags), initializer(initializer) {}
|
|
|
|
/// Prints the entire operation with the custom assembly form, if available,
|
|
/// or the generic assembly form, otherwise.
|
|
void printCustomOrGenericOp(Operation *op) override {
|
|
// Visit the operation location.
|
|
if (printerFlags.shouldPrintDebugInfo())
|
|
initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
|
|
|
|
// If requested, always print the generic form.
|
|
if (!printerFlags.shouldPrintGenericOpForm()) {
|
|
op->getName().printAssembly(op, *this, /*defaultDialect=*/"");
|
|
return;
|
|
}
|
|
|
|
// Otherwise print with the generic assembly form.
|
|
printGenericOp(op);
|
|
}
|
|
|
|
private:
|
|
/// Print the given operation in the generic form.
|
|
void printGenericOp(Operation *op, bool printOpName = true) override {
|
|
// Consider nested operations for aliases.
|
|
if (!printerFlags.shouldSkipRegions()) {
|
|
for (Region ®ion : op->getRegions())
|
|
printRegion(region, /*printEntryBlockArgs=*/true,
|
|
/*printBlockTerminators=*/true);
|
|
}
|
|
|
|
// Visit all the types used in the operation.
|
|
for (Type type : op->getOperandTypes())
|
|
printType(type);
|
|
for (Type type : op->getResultTypes())
|
|
printType(type);
|
|
|
|
// Consider the attributes of the operation for aliases.
|
|
for (const NamedAttribute &attr : op->getAttrs())
|
|
printAttribute(attr.getValue());
|
|
}
|
|
|
|
/// Print the given block. If 'printBlockArgs' is false, the arguments of the
|
|
/// block are not printed. If 'printBlockTerminator' is false, the terminator
|
|
/// operation of the block is not printed.
|
|
void print(Block *block, bool printBlockArgs = true,
|
|
bool printBlockTerminator = true) {
|
|
// Consider the types of the block arguments for aliases if 'printBlockArgs'
|
|
// is set to true.
|
|
if (printBlockArgs) {
|
|
for (BlockArgument arg : block->getArguments()) {
|
|
printType(arg.getType());
|
|
|
|
// Visit the argument location.
|
|
if (printerFlags.shouldPrintDebugInfo())
|
|
// TODO: Allow deferring argument locations.
|
|
initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
|
|
}
|
|
}
|
|
|
|
// Consider the operations within this block, ignoring the terminator if
|
|
// requested.
|
|
bool hasTerminator =
|
|
!block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
|
|
auto range = llvm::make_range(
|
|
block->begin(),
|
|
std::prev(block->end(),
|
|
(!hasTerminator || printBlockTerminator) ? 0 : 1));
|
|
for (Operation &op : range)
|
|
printCustomOrGenericOp(&op);
|
|
}
|
|
|
|
/// Print the given region.
|
|
void printRegion(Region ®ion, bool printEntryBlockArgs,
|
|
bool printBlockTerminators,
|
|
bool printEmptyBlock = false) override {
|
|
if (region.empty())
|
|
return;
|
|
if (printerFlags.shouldSkipRegions()) {
|
|
os << "{...}";
|
|
return;
|
|
}
|
|
|
|
auto *entryBlock = ®ion.front();
|
|
print(entryBlock, printEntryBlockArgs, printBlockTerminators);
|
|
for (Block &b : llvm::drop_begin(region, 1))
|
|
print(&b);
|
|
}
|
|
|
|
void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
|
|
bool omitType) override {
|
|
printType(arg.getType());
|
|
// Visit the argument location.
|
|
if (printerFlags.shouldPrintDebugInfo())
|
|
// TODO: Allow deferring argument locations.
|
|
initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
|
|
}
|
|
|
|
/// Consider the given type to be printed for an alias.
|
|
void printType(Type type) override { initializer.visit(type); }
|
|
|
|
/// Consider the given attribute to be printed for an alias.
|
|
void printAttribute(Attribute attr) override { initializer.visit(attr); }
|
|
void printAttributeWithoutType(Attribute attr) override {
|
|
printAttribute(attr);
|
|
}
|
|
LogicalResult printAlias(Attribute attr) override {
|
|
initializer.visit(attr);
|
|
return success();
|
|
}
|
|
LogicalResult printAlias(Type type) override {
|
|
initializer.visit(type);
|
|
return success();
|
|
}
|
|
|
|
/// Consider the given location to be printed for an alias.
|
|
void printOptionalLocationSpecifier(Location loc) override {
|
|
printAttribute(loc);
|
|
}
|
|
|
|
/// Print the given set of attributes with names not included within
|
|
/// 'elidedAttrs'.
|
|
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs = {}) override {
|
|
if (attrs.empty())
|
|
return;
|
|
if (elidedAttrs.empty()) {
|
|
for (const NamedAttribute &attr : attrs)
|
|
printAttribute(attr.getValue());
|
|
return;
|
|
}
|
|
llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
|
|
elidedAttrs.end());
|
|
for (const NamedAttribute &attr : attrs)
|
|
if (!elidedAttrsSet.contains(attr.getName().strref()))
|
|
printAttribute(attr.getValue());
|
|
}
|
|
void printOptionalAttrDictWithKeyword(
|
|
ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs = {}) override {
|
|
printOptionalAttrDict(attrs, elidedAttrs);
|
|
}
|
|
|
|
/// Return a null stream as the output stream, this will ignore any data fed
|
|
/// to it.
|
|
raw_ostream &getStream() const override { return os; }
|
|
|
|
/// The following are hooks of `OpAsmPrinter` that are not necessary for
|
|
/// determining potential aliases.
|
|
void printFloat(const APFloat &) override {}
|
|
void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
|
|
void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
|
|
void printNewline() override {}
|
|
void increaseIndent() override {}
|
|
void decreaseIndent() override {}
|
|
void printOperand(Value) override {}
|
|
void printOperand(Value, raw_ostream &os) override {
|
|
// Users expect the output string to have at least the prefixed % to signal
|
|
// a value name. To maintain this invariant, emit a name even if it is
|
|
// guaranteed to go unused.
|
|
os << "%";
|
|
}
|
|
void printKeywordOrString(StringRef) override {}
|
|
void printString(StringRef) override {}
|
|
void printResourceHandle(const AsmDialectResourceHandle &) override {}
|
|
void printSymbolName(StringRef) override {}
|
|
void printSuccessor(Block *) override {}
|
|
void printSuccessorAndUseList(Block *, ValueRange) override {}
|
|
void shadowRegionArgs(Region &, ValueRange) override {}
|
|
|
|
/// The printer flags to use when determining potential aliases.
|
|
const OpPrintingFlags &printerFlags;
|
|
|
|
/// The initializer to use when identifying aliases.
|
|
AliasInitializer &initializer;
|
|
|
|
/// A dummy output stream.
|
|
mutable llvm::raw_null_ostream os;
|
|
};
|
|
|
|
class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
|
|
public:
|
|
explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer,
|
|
bool canBeDeferred,
|
|
SmallVectorImpl<size_t> &childIndices)
|
|
: initializer(initializer), canBeDeferred(canBeDeferred),
|
|
childIndices(childIndices) {}
|
|
|
|
/// Print the given attribute/type, visiting any nested aliases that would be
|
|
/// generated as part of printing. Returns the maximum alias depth found while
|
|
/// printing the given value.
|
|
template <typename T, typename... PrintArgs>
|
|
size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) {
|
|
printAndVisitNestedAliasesImpl(value, printArgs...);
|
|
return maxAliasDepth;
|
|
}
|
|
|
|
private:
|
|
/// Print the given attribute/type, visiting any nested aliases that would be
|
|
/// generated as part of printing.
|
|
void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) {
|
|
if (!isa<BuiltinDialect>(attr.getDialect())) {
|
|
attr.getDialect().printAttribute(attr, *this);
|
|
|
|
// Process the builtin attributes.
|
|
} else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr,
|
|
IntegerSetAttr, UnitAttr>(attr)) {
|
|
return;
|
|
} else if (auto distinctAttr = dyn_cast<DistinctAttr>(attr)) {
|
|
printAttribute(distinctAttr.getReferencedAttr());
|
|
} else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) {
|
|
for (const NamedAttribute &nestedAttr : dictAttr.getValue()) {
|
|
printAttribute(nestedAttr.getName());
|
|
printAttribute(nestedAttr.getValue());
|
|
}
|
|
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
|
for (Attribute nestedAttr : arrayAttr.getValue())
|
|
printAttribute(nestedAttr);
|
|
} else if (auto typeAttr = dyn_cast<TypeAttr>(attr)) {
|
|
printType(typeAttr.getValue());
|
|
} else if (auto locAttr = dyn_cast<OpaqueLoc>(attr)) {
|
|
printAttribute(locAttr.getFallbackLocation());
|
|
} else if (auto locAttr = dyn_cast<NameLoc>(attr)) {
|
|
if (!isa<UnknownLoc>(locAttr.getChildLoc()))
|
|
printAttribute(locAttr.getChildLoc());
|
|
} else if (auto locAttr = dyn_cast<CallSiteLoc>(attr)) {
|
|
printAttribute(locAttr.getCallee());
|
|
printAttribute(locAttr.getCaller());
|
|
} else if (auto locAttr = dyn_cast<FusedLoc>(attr)) {
|
|
if (Attribute metadata = locAttr.getMetadata())
|
|
printAttribute(metadata);
|
|
for (Location nestedLoc : locAttr.getLocations())
|
|
printAttribute(nestedLoc);
|
|
}
|
|
|
|
// Don't print the type if we must elide it, or if it is a None type.
|
|
if (!elideType) {
|
|
if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
|
|
Type attrType = typedAttr.getType();
|
|
if (!llvm::isa<NoneType>(attrType))
|
|
printType(attrType);
|
|
}
|
|
}
|
|
}
|
|
void printAndVisitNestedAliasesImpl(Type type) {
|
|
if (!isa<BuiltinDialect>(type.getDialect()))
|
|
return type.getDialect().printType(type, *this);
|
|
|
|
// Only visit the layout of memref if it isn't the identity.
|
|
if (auto memrefTy = llvm::dyn_cast<MemRefType>(type)) {
|
|
printType(memrefTy.getElementType());
|
|
MemRefLayoutAttrInterface layout = memrefTy.getLayout();
|
|
if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity())
|
|
printAttribute(memrefTy.getLayout());
|
|
if (memrefTy.getMemorySpace())
|
|
printAttribute(memrefTy.getMemorySpace());
|
|
return;
|
|
}
|
|
|
|
// For most builtin types, we can simply walk the sub elements.
|
|
auto visitFn = [&](auto element) {
|
|
if (element)
|
|
(void)printAlias(element);
|
|
};
|
|
type.walkImmediateSubElements(visitFn, visitFn);
|
|
}
|
|
|
|
/// Consider the given type to be printed for an alias.
|
|
void printType(Type type) override {
|
|
recordAliasResult(initializer.visit(type, canBeDeferred));
|
|
}
|
|
|
|
/// Consider the given attribute to be printed for an alias.
|
|
void printAttribute(Attribute attr) override {
|
|
recordAliasResult(initializer.visit(attr, canBeDeferred));
|
|
}
|
|
void printAttributeWithoutType(Attribute attr) override {
|
|
recordAliasResult(
|
|
initializer.visit(attr, canBeDeferred, /*elideType=*/true));
|
|
}
|
|
LogicalResult printAlias(Attribute attr) override {
|
|
printAttribute(attr);
|
|
return success();
|
|
}
|
|
LogicalResult printAlias(Type type) override {
|
|
printType(type);
|
|
return success();
|
|
}
|
|
|
|
/// Record the alias result of a child element.
|
|
void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
|
|
childIndices.push_back(aliasDepthAndIndex.second);
|
|
if (aliasDepthAndIndex.first > maxAliasDepth)
|
|
maxAliasDepth = aliasDepthAndIndex.first;
|
|
}
|
|
|
|
/// Return a null stream as the output stream, this will ignore any data fed
|
|
/// to it.
|
|
raw_ostream &getStream() const override { return os; }
|
|
|
|
/// The following are hooks of `DialectAsmPrinter` that are not necessary for
|
|
/// determining potential aliases.
|
|
void printFloat(const APFloat &) override {}
|
|
void printKeywordOrString(StringRef) override {}
|
|
void printString(StringRef) override {}
|
|
void printSymbolName(StringRef) override {}
|
|
void printResourceHandle(const AsmDialectResourceHandle &) override {}
|
|
|
|
LogicalResult pushCyclicPrinting(const void *opaquePointer) override {
|
|
return success(cyclicPrintingStack.insert(opaquePointer));
|
|
}
|
|
|
|
void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); }
|
|
|
|
/// Stack of potentially cyclic mutable attributes or type currently being
|
|
/// printed.
|
|
SetVector<const void *> cyclicPrintingStack;
|
|
|
|
/// The initializer to use when identifying aliases.
|
|
AliasInitializer &initializer;
|
|
|
|
/// If the aliases visited by this printer can be deferred.
|
|
bool canBeDeferred;
|
|
|
|
/// The indices of child aliases.
|
|
SmallVectorImpl<size_t> &childIndices;
|
|
|
|
/// The maximum alias depth found by the printer.
|
|
size_t maxAliasDepth = 0;
|
|
|
|
/// A dummy output stream.
|
|
mutable llvm::raw_null_ostream os;
|
|
};
|
|
} // namespace
|
|
|
|
/// Sanitize the given name such that it can be used as a valid identifier. If
|
|
/// the string needs to be modified in any way, the provided buffer is used to
|
|
/// store the new copy,
|
|
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
|
|
StringRef allowedPunctChars = "$._-",
|
|
bool allowTrailingDigit = true) {
|
|
assert(!name.empty() && "Shouldn't have an empty name here");
|
|
|
|
auto validChar = [&](char ch) {
|
|
return llvm::isAlnum(ch) || allowedPunctChars.contains(ch);
|
|
};
|
|
|
|
auto copyNameToBuffer = [&] {
|
|
for (char ch : name) {
|
|
if (validChar(ch))
|
|
buffer.push_back(ch);
|
|
else if (ch == ' ')
|
|
buffer.push_back('_');
|
|
else
|
|
buffer.append(llvm::utohexstr((unsigned char)ch));
|
|
}
|
|
};
|
|
|
|
// Check to see if this name is valid. If it starts with a digit, then it
|
|
// could conflict with the autogenerated numeric ID's, so add an underscore
|
|
// prefix to avoid problems.
|
|
if (isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) {
|
|
buffer.push_back('_');
|
|
copyNameToBuffer();
|
|
return buffer;
|
|
}
|
|
|
|
// If the name ends with a trailing digit, add a '_' to avoid potential
|
|
// conflicts with autogenerated ID's.
|
|
if (!allowTrailingDigit && isdigit(name.back())) {
|
|
copyNameToBuffer();
|
|
buffer.push_back('_');
|
|
return buffer;
|
|
}
|
|
|
|
// Check to see that the name consists of only valid identifier characters.
|
|
for (char ch : name) {
|
|
if (!validChar(ch)) {
|
|
copyNameToBuffer();
|
|
return buffer;
|
|
}
|
|
}
|
|
|
|
// If there are no invalid characters, return the original name.
|
|
return name;
|
|
}
|
|
|
|
/// Given a collection of aliases and symbols, initialize a mapping from a
|
|
/// symbol to a given alias.
|
|
void AliasInitializer::initializeAliases(
|
|
llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
|
|
llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
|
|
SmallVector<std::pair<const void *, InProgressAliasInfo>, 0>
|
|
unprocessedAliases = visitedSymbols.takeVector();
|
|
llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) {
|
|
return lhs.second < rhs.second;
|
|
});
|
|
|
|
llvm::StringMap<unsigned> nameCounts;
|
|
for (auto &[symbol, aliasInfo] : unprocessedAliases) {
|
|
if (!aliasInfo.alias)
|
|
continue;
|
|
StringRef alias = *aliasInfo.alias;
|
|
unsigned nameIndex = nameCounts[alias]++;
|
|
symbolToAlias.insert(
|
|
{symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
|
|
aliasInfo.canBeDeferred)});
|
|
}
|
|
}
|
|
|
|
void AliasInitializer::initialize(
|
|
Operation *op, const OpPrintingFlags &printerFlags,
|
|
llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
|
|
// Use a dummy printer when walking the IR so that we can collect the
|
|
// attributes/types that will actually be used during printing when
|
|
// considering aliases.
|
|
DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
|
|
aliasPrinter.printCustomOrGenericOp(op);
|
|
|
|
// Initialize the aliases.
|
|
initializeAliases(aliases, attrTypeToAlias);
|
|
}
|
|
|
|
template <typename T, typename... PrintArgs>
|
|
std::pair<size_t, size_t> AliasInitializer::visitImpl(
|
|
T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
|
|
bool canBeDeferred, PrintArgs &&...printArgs) {
|
|
auto [it, inserted] =
|
|
aliases.insert({value.getAsOpaquePointer(), InProgressAliasInfo()});
|
|
size_t aliasIndex = std::distance(aliases.begin(), it);
|
|
if (!inserted) {
|
|
// Make sure that the alias isn't deferred if we don't permit it.
|
|
if (!canBeDeferred)
|
|
markAliasNonDeferrable(aliasIndex);
|
|
return {static_cast<size_t>(it->second.aliasDepth), aliasIndex};
|
|
}
|
|
|
|
// Try to generate an alias for this value.
|
|
generateAlias(value, it->second, canBeDeferred);
|
|
it->second.isType = std::is_base_of_v<Type, T>;
|
|
it->second.canBeDeferred = canBeDeferred;
|
|
|
|
// Print the value, capturing any nested elements that require aliases.
|
|
SmallVector<size_t> childAliases;
|
|
DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases);
|
|
size_t maxAliasDepth =
|
|
printer.printAndVisitNestedAliases(value, printArgs...);
|
|
|
|
// Make sure to recompute `it` in case the map was reallocated.
|
|
it = std::next(aliases.begin(), aliasIndex);
|
|
|
|
// If we had sub elements, update to account for the depth.
|
|
it->second.childIndices = std::move(childAliases);
|
|
if (maxAliasDepth)
|
|
it->second.aliasDepth = maxAliasDepth + 1;
|
|
|
|
// Propagate the alias depth of the value.
|
|
return {(size_t)it->second.aliasDepth, aliasIndex};
|
|
}
|
|
|
|
void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
|
|
auto *it = std::next(aliases.begin(), aliasIndex);
|
|
|
|
// If already marked non-deferrable stop the recursion.
|
|
// All children should already be marked non-deferrable as well.
|
|
if (!it->second.canBeDeferred)
|
|
return;
|
|
|
|
it->second.canBeDeferred = false;
|
|
|
|
// Propagate the non-deferrable flag to any child aliases.
|
|
for (size_t childIndex : it->second.childIndices)
|
|
markAliasNonDeferrable(childIndex);
|
|
}
|
|
|
|
template <typename T>
|
|
void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
|
|
bool canBeDeferred) {
|
|
SmallString<32> nameBuffer;
|
|
for (const auto &interface : interfaces) {
|
|
OpAsmDialectInterface::AliasResult result =
|
|
interface.getAlias(symbol, aliasOS);
|
|
if (result == OpAsmDialectInterface::AliasResult::NoAlias)
|
|
continue;
|
|
nameBuffer = std::move(aliasBuffer);
|
|
assert(!nameBuffer.empty() && "expected valid alias name");
|
|
if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
|
|
break;
|
|
}
|
|
|
|
if (nameBuffer.empty())
|
|
return;
|
|
|
|
SmallString<16> tempBuffer;
|
|
StringRef name =
|
|
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
|
|
/*allowTrailingDigit=*/false);
|
|
name = name.copy(aliasAllocator);
|
|
alias = InProgressAliasInfo(name);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AliasState
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This class manages the state for type and attribute aliases.
|
|
class AliasState {
|
|
public:
|
|
// Initialize the internal aliases.
|
|
void
|
|
initialize(Operation *op, const OpPrintingFlags &printerFlags,
|
|
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
|
|
|
|
/// Get an alias for the given attribute if it has one and print it in `os`.
|
|
/// Returns success if an alias was printed, failure otherwise.
|
|
LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
|
|
|
|
/// Get an alias for the given type if it has one and print it in `os`.
|
|
/// Returns success if an alias was printed, failure otherwise.
|
|
LogicalResult getAlias(Type ty, raw_ostream &os) const;
|
|
|
|
/// Print all of the referenced aliases that can not be resolved in a deferred
|
|
/// manner.
|
|
void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
|
|
printAliases(p, newLine, /*isDeferred=*/false);
|
|
}
|
|
|
|
/// Print all of the referenced aliases that support deferred resolution.
|
|
void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
|
|
printAliases(p, newLine, /*isDeferred=*/true);
|
|
}
|
|
|
|
private:
|
|
/// Print all of the referenced aliases that support the provided resolution
|
|
/// behavior.
|
|
void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
|
|
bool isDeferred);
|
|
|
|
/// Mapping between attribute/type and alias.
|
|
llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
|
|
|
|
/// An allocator used for alias names.
|
|
llvm::BumpPtrAllocator aliasAllocator;
|
|
};
|
|
} // namespace
|
|
|
|
void AliasState::initialize(
|
|
Operation *op, const OpPrintingFlags &printerFlags,
|
|
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
|
|
AliasInitializer initializer(interfaces, aliasAllocator);
|
|
initializer.initialize(op, printerFlags, attrTypeToAlias);
|
|
}
|
|
|
|
LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
|
|
const auto *it = attrTypeToAlias.find(attr.getAsOpaquePointer());
|
|
if (it == attrTypeToAlias.end())
|
|
return failure();
|
|
it->second.print(os);
|
|
return success();
|
|
}
|
|
|
|
LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
|
|
const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer());
|
|
if (it == attrTypeToAlias.end())
|
|
return failure();
|
|
if (!it->second.isPrinted)
|
|
return failure();
|
|
|
|
it->second.print(os);
|
|
return success();
|
|
}
|
|
|
|
void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
|
|
bool isDeferred) {
|
|
auto filterFn = [=](const auto &aliasIt) {
|
|
return aliasIt.second.canBeDeferred() == isDeferred;
|
|
};
|
|
for (auto &[opaqueSymbol, alias] :
|
|
llvm::make_filter_range(attrTypeToAlias, filterFn)) {
|
|
alias.print(p.getStream());
|
|
p.getStream() << " = ";
|
|
|
|
if (alias.isTypeAlias()) {
|
|
Type type = Type::getFromOpaquePointer(opaqueSymbol);
|
|
p.printTypeImpl(type);
|
|
alias.isPrinted = true;
|
|
} else {
|
|
// TODO: Support nested aliases in mutable attributes.
|
|
Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol);
|
|
if (attr.hasTrait<AttributeTrait::IsMutable>())
|
|
p.getStream() << attr;
|
|
else
|
|
p.printAttributeImpl(attr);
|
|
}
|
|
|
|
p.getStream() << newLine;
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SSANameState
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// Info about block printing: a number which is its position in the visitation
|
|
/// order, and a name that is used to print reference to it, e.g. ^bb42.
|
|
struct BlockInfo {
|
|
int ordering;
|
|
StringRef name;
|
|
};
|
|
|
|
/// This class manages the state of SSA value names.
|
|
class SSANameState {
|
|
public:
|
|
/// A sentinel value used for values with names set.
|
|
enum : unsigned { NameSentinel = ~0U };
|
|
|
|
SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
|
|
SSANameState() = default;
|
|
|
|
/// Print the SSA identifier for the given value to 'stream'. If
|
|
/// 'printResultNo' is true, it also presents the result number ('#' number)
|
|
/// of this value.
|
|
void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
|
|
|
|
/// Print the operation identifier.
|
|
void printOperationID(Operation *op, raw_ostream &stream) const;
|
|
|
|
/// Return the result indices for each of the result groups registered by this
|
|
/// operation, or empty if none exist.
|
|
ArrayRef<int> getOpResultGroups(Operation *op);
|
|
|
|
/// Get the info for the given block.
|
|
BlockInfo getBlockInfo(Block *block);
|
|
|
|
/// Renumber the arguments for the specified region to the same names as the
|
|
/// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
|
|
/// details.
|
|
void shadowRegionArgs(Region ®ion, ValueRange namesToUse);
|
|
|
|
private:
|
|
/// Number the SSA values within the given IR unit.
|
|
void numberValuesInRegion(Region ®ion);
|
|
void numberValuesInBlock(Block &block);
|
|
void numberValuesInOp(Operation &op);
|
|
|
|
/// Given a result of an operation 'result', find the result group head
|
|
/// 'lookupValue' and the result of 'result' within that group in
|
|
/// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
|
|
/// has more than 1 result.
|
|
void getResultIDAndNumber(OpResult result, Value &lookupValue,
|
|
std::optional<int> &lookupResultNo) const;
|
|
|
|
/// Set a special value name for the given value.
|
|
void setValueName(Value value, StringRef name);
|
|
|
|
/// Uniques the given value name within the printer. If the given name
|
|
/// conflicts, it is automatically renamed.
|
|
StringRef uniqueValueName(StringRef name);
|
|
|
|
/// This is the value ID for each SSA value. If this returns NameSentinel,
|
|
/// then the valueID has an entry in valueNames.
|
|
DenseMap<Value, unsigned> valueIDs;
|
|
DenseMap<Value, StringRef> valueNames;
|
|
|
|
/// When printing users of values, an operation without a result might
|
|
/// be the user. This map holds ids for such operations.
|
|
DenseMap<Operation *, unsigned> operationIDs;
|
|
|
|
/// This is a map of operations that contain multiple named result groups,
|
|
/// i.e. there may be multiple names for the results of the operation. The
|
|
/// value of this map are the result numbers that start a result group.
|
|
DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
|
|
|
|
/// This maps blocks to there visitation number in the current region as well
|
|
/// as the string representing their name.
|
|
DenseMap<Block *, BlockInfo> blockNames;
|
|
|
|
/// This keeps track of all of the non-numeric names that are in flight,
|
|
/// allowing us to check for duplicates.
|
|
/// Note: the value of the map is unused.
|
|
llvm::ScopedHashTable<StringRef, char> usedNames;
|
|
llvm::BumpPtrAllocator usedNameAllocator;
|
|
|
|
/// This is the next value ID to assign in numbering.
|
|
unsigned nextValueID = 0;
|
|
/// This is the next ID to assign to a region entry block argument.
|
|
unsigned nextArgumentID = 0;
|
|
/// This is the next ID to assign when a name conflict is detected.
|
|
unsigned nextConflictID = 0;
|
|
|
|
/// These are the printing flags. They control, eg., whether to print in
|
|
/// generic form.
|
|
OpPrintingFlags printerFlags;
|
|
};
|
|
} // namespace
|
|
|
|
SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags)
|
|
: printerFlags(printerFlags) {
|
|
llvm::SaveAndRestore valueIDSaver(nextValueID);
|
|
llvm::SaveAndRestore argumentIDSaver(nextArgumentID);
|
|
llvm::SaveAndRestore conflictIDSaver(nextConflictID);
|
|
|
|
// The naming context includes `nextValueID`, `nextArgumentID`,
|
|
// `nextConflictID` and `usedNames` scoped HashTable. This information is
|
|
// carried from the parent region.
|
|
using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
|
|
using NamingContext =
|
|
std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
|
|
|
|
// Allocator for UsedNamesScopeTy
|
|
llvm::BumpPtrAllocator allocator;
|
|
|
|
// Add a scope for the top level operation.
|
|
auto *topLevelNamesScope =
|
|
new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
|
|
|
|
SmallVector<NamingContext, 8> nameContext;
|
|
for (Region ®ion : op->getRegions())
|
|
nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID,
|
|
nextConflictID, topLevelNamesScope));
|
|
|
|
numberValuesInOp(*op);
|
|
|
|
while (!nameContext.empty()) {
|
|
Region *region;
|
|
UsedNamesScopeTy *parentScope;
|
|
|
|
if (printerFlags.shouldPrintUniqueSSAIDs())
|
|
// To print unique SSA IDs, ignore saved ID counts from parent regions
|
|
std::tie(region, std::ignore, std::ignore, std::ignore, parentScope) =
|
|
nameContext.pop_back_val();
|
|
else
|
|
std::tie(region, nextValueID, nextArgumentID, nextConflictID,
|
|
parentScope) = nameContext.pop_back_val();
|
|
|
|
// When we switch from one subtree to another, pop the scopes(needless)
|
|
// until the parent scope.
|
|
while (usedNames.getCurScope() != parentScope) {
|
|
usedNames.getCurScope()->~UsedNamesScopeTy();
|
|
assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
|
|
"top level parentScope must be a nullptr");
|
|
}
|
|
|
|
// Add a scope for the current region.
|
|
auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
|
|
UsedNamesScopeTy(usedNames);
|
|
|
|
numberValuesInRegion(*region);
|
|
|
|
for (Operation &op : region->getOps())
|
|
for (Region ®ion : op.getRegions())
|
|
nameContext.push_back(std::make_tuple(®ion, nextValueID,
|
|
nextArgumentID, nextConflictID,
|
|
curNamesScope));
|
|
}
|
|
|
|
// Manually remove all the scopes.
|
|
while (usedNames.getCurScope() != nullptr)
|
|
usedNames.getCurScope()->~UsedNamesScopeTy();
|
|
}
|
|
|
|
void SSANameState::printValueID(Value value, bool printResultNo,
|
|
raw_ostream &stream) const {
|
|
if (!value) {
|
|
stream << "<<NULL VALUE>>";
|
|
return;
|
|
}
|
|
|
|
std::optional<int> resultNo;
|
|
auto lookupValue = value;
|
|
|
|
// If this is an operation result, collect the head lookup value of the result
|
|
// group and the result number of 'result' within that group.
|
|
if (OpResult result = dyn_cast<OpResult>(value))
|
|
getResultIDAndNumber(result, lookupValue, resultNo);
|
|
|
|
auto it = valueIDs.find(lookupValue);
|
|
if (it == valueIDs.end()) {
|
|
stream << "<<UNKNOWN SSA VALUE>>";
|
|
return;
|
|
}
|
|
|
|
stream << '%';
|
|
if (it->second != NameSentinel) {
|
|
stream << it->second;
|
|
} else {
|
|
auto nameIt = valueNames.find(lookupValue);
|
|
assert(nameIt != valueNames.end() && "Didn't have a name entry?");
|
|
stream << nameIt->second;
|
|
}
|
|
|
|
if (resultNo && printResultNo)
|
|
stream << '#' << *resultNo;
|
|
}
|
|
|
|
void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
|
|
auto it = operationIDs.find(op);
|
|
if (it == operationIDs.end()) {
|
|
stream << "<<UNKNOWN OPERATION>>";
|
|
} else {
|
|
stream << '%' << it->second;
|
|
}
|
|
}
|
|
|
|
ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
|
|
auto it = opResultGroups.find(op);
|
|
return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
|
|
}
|
|
|
|
BlockInfo SSANameState::getBlockInfo(Block *block) {
|
|
auto it = blockNames.find(block);
|
|
BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
|
|
return it != blockNames.end() ? it->second : invalidBlock;
|
|
}
|
|
|
|
void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
|
|
assert(!region.empty() && "cannot shadow arguments of an empty region");
|
|
assert(region.getNumArguments() == namesToUse.size() &&
|
|
"incorrect number of names passed in");
|
|
assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
|
|
"only KnownIsolatedFromAbove ops can shadow names");
|
|
|
|
SmallVector<char, 16> nameStr;
|
|
for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
|
|
auto nameToUse = namesToUse[i];
|
|
if (nameToUse == nullptr)
|
|
continue;
|
|
auto nameToReplace = region.getArgument(i);
|
|
|
|
nameStr.clear();
|
|
llvm::raw_svector_ostream nameStream(nameStr);
|
|
printValueID(nameToUse, /*printResultNo=*/true, nameStream);
|
|
|
|
// Entry block arguments should already have a pretty "arg" name.
|
|
assert(valueIDs[nameToReplace] == NameSentinel);
|
|
|
|
// Use the name without the leading %.
|
|
auto name = StringRef(nameStream.str()).drop_front();
|
|
|
|
// Overwrite the name.
|
|
valueNames[nameToReplace] = name.copy(usedNameAllocator);
|
|
}
|
|
}
|
|
|
|
void SSANameState::numberValuesInRegion(Region ®ion) {
|
|
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
|
|
assert(!valueIDs.count(arg) && "arg numbered multiple times");
|
|
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion &&
|
|
"arg not defined in current region");
|
|
setValueName(arg, name);
|
|
};
|
|
|
|
if (!printerFlags.shouldPrintGenericOpForm()) {
|
|
if (Operation *op = region.getParentOp()) {
|
|
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
|
|
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
|
|
}
|
|
}
|
|
|
|
// Number the values within this region in a breadth-first order.
|
|
unsigned nextBlockID = 0;
|
|
for (auto &block : region) {
|
|
// Each block gets a unique ID, and all of the operations within it get
|
|
// numbered as well.
|
|
auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
|
|
if (blockInfoIt.second) {
|
|
// This block hasn't been named through `getAsmBlockArgumentNames`, use
|
|
// default `^bbNNN` format.
|
|
std::string name;
|
|
llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
|
|
blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
|
|
}
|
|
blockInfoIt.first->second.ordering = nextBlockID++;
|
|
|
|
numberValuesInBlock(block);
|
|
}
|
|
}
|
|
|
|
void SSANameState::numberValuesInBlock(Block &block) {
|
|
// Number the block arguments. We give entry block arguments a special name
|
|
// 'arg'.
|
|
bool isEntryBlock = block.isEntryBlock();
|
|
SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
|
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
|
for (auto arg : block.getArguments()) {
|
|
if (valueIDs.count(arg))
|
|
continue;
|
|
if (isEntryBlock) {
|
|
specialNameBuffer.resize(strlen("arg"));
|
|
specialName << nextArgumentID++;
|
|
}
|
|
setValueName(arg, specialName.str());
|
|
}
|
|
|
|
// Number the operations in this block.
|
|
for (auto &op : block)
|
|
numberValuesInOp(op);
|
|
}
|
|
|
|
void SSANameState::numberValuesInOp(Operation &op) {
|
|
// Function used to set the special result names for the operation.
|
|
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
|
|
auto setResultNameFn = [&](Value result, StringRef name) {
|
|
assert(!valueIDs.count(result) && "result numbered multiple times");
|
|
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
|
|
setValueName(result, name);
|
|
|
|
// Record the result number for groups not anchored at 0.
|
|
if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
|
|
resultGroups.push_back(resultNo);
|
|
};
|
|
// Operations can customize the printing of block names in OpAsmOpInterface.
|
|
auto setBlockNameFn = [&](Block *block, StringRef name) {
|
|
assert(block->getParentOp() == &op &&
|
|
"getAsmBlockArgumentNames callback invoked on a block not directly "
|
|
"nested under the current operation");
|
|
assert(!blockNames.count(block) && "block numbered multiple times");
|
|
SmallString<16> tmpBuffer{"^"};
|
|
name = sanitizeIdentifier(name, tmpBuffer);
|
|
if (name.data() != tmpBuffer.data()) {
|
|
tmpBuffer.append(name);
|
|
name = tmpBuffer.str();
|
|
}
|
|
name = name.copy(usedNameAllocator);
|
|
blockNames[block] = {-1, name};
|
|
};
|
|
|
|
if (!printerFlags.shouldPrintGenericOpForm()) {
|
|
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
|
|
asmInterface.getAsmBlockNames(setBlockNameFn);
|
|
asmInterface.getAsmResultNames(setResultNameFn);
|
|
}
|
|
}
|
|
|
|
unsigned numResults = op.getNumResults();
|
|
if (numResults == 0) {
|
|
// If value users should be printed, operations with no result need an id.
|
|
if (printerFlags.shouldPrintValueUsers()) {
|
|
if (operationIDs.try_emplace(&op, nextValueID).second)
|
|
++nextValueID;
|
|
}
|
|
return;
|
|
}
|
|
Value resultBegin = op.getResult(0);
|
|
|
|
// If the first result wasn't numbered, give it a default number.
|
|
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
|
|
++nextValueID;
|
|
|
|
// If this operation has multiple result groups, mark it.
|
|
if (resultGroups.size() != 1) {
|
|
llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
|
|
opResultGroups.try_emplace(&op, std::move(resultGroups));
|
|
}
|
|
}
|
|
|
|
void SSANameState::getResultIDAndNumber(
|
|
OpResult result, Value &lookupValue,
|
|
std::optional<int> &lookupResultNo) const {
|
|
Operation *owner = result.getOwner();
|
|
if (owner->getNumResults() == 1)
|
|
return;
|
|
int resultNo = result.getResultNumber();
|
|
|
|
// If this operation has multiple result groups, we will need to find the
|
|
// one corresponding to this result.
|
|
auto resultGroupIt = opResultGroups.find(owner);
|
|
if (resultGroupIt == opResultGroups.end()) {
|
|
// If not, just use the first result.
|
|
lookupResultNo = resultNo;
|
|
lookupValue = owner->getResult(0);
|
|
return;
|
|
}
|
|
|
|
// Find the correct index using a binary search, as the groups are ordered.
|
|
ArrayRef<int> resultGroups = resultGroupIt->second;
|
|
const auto *it = llvm::upper_bound(resultGroups, resultNo);
|
|
int groupResultNo = 0, groupSize = 0;
|
|
|
|
// If there are no smaller elements, the last result group is the lookup.
|
|
if (it == resultGroups.end()) {
|
|
groupResultNo = resultGroups.back();
|
|
groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
|
|
} else {
|
|
// Otherwise, the previous element is the lookup.
|
|
groupResultNo = *std::prev(it);
|
|
groupSize = *it - groupResultNo;
|
|
}
|
|
|
|
// We only record the result number for a group of size greater than 1.
|
|
if (groupSize != 1)
|
|
lookupResultNo = resultNo - groupResultNo;
|
|
lookupValue = owner->getResult(groupResultNo);
|
|
}
|
|
|
|
void SSANameState::setValueName(Value value, StringRef name) {
|
|
// If the name is empty, the value uses the default numbering.
|
|
if (name.empty()) {
|
|
valueIDs[value] = nextValueID++;
|
|
return;
|
|
}
|
|
|
|
valueIDs[value] = NameSentinel;
|
|
valueNames[value] = uniqueValueName(name);
|
|
}
|
|
|
|
StringRef SSANameState::uniqueValueName(StringRef name) {
|
|
SmallString<16> tmpBuffer;
|
|
name = sanitizeIdentifier(name, tmpBuffer);
|
|
|
|
// Check to see if this name is already unique.
|
|
if (!usedNames.count(name)) {
|
|
name = name.copy(usedNameAllocator);
|
|
} else {
|
|
// Otherwise, we had a conflict - probe until we find a unique name. This
|
|
// is guaranteed to terminate (and usually in a single iteration) because it
|
|
// generates new names by incrementing nextConflictID.
|
|
SmallString<64> probeName(name);
|
|
probeName.push_back('_');
|
|
while (true) {
|
|
probeName += llvm::utostr(nextConflictID++);
|
|
if (!usedNames.count(probeName)) {
|
|
name = probeName.str().copy(usedNameAllocator);
|
|
break;
|
|
}
|
|
probeName.resize(name.size() + 1);
|
|
}
|
|
}
|
|
|
|
usedNames.insert(name, char());
|
|
return name;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DistinctState
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This class manages the state for distinct attributes.
|
|
class DistinctState {
|
|
public:
|
|
/// Returns a unique identifier for the given distinct attribute.
|
|
uint64_t getId(DistinctAttr distinctAttr);
|
|
|
|
private:
|
|
uint64_t distinctCounter = 0;
|
|
DenseMap<DistinctAttr, uint64_t> distinctAttrMap;
|
|
};
|
|
} // namespace
|
|
|
|
uint64_t DistinctState::getId(DistinctAttr distinctAttr) {
|
|
auto [it, inserted] =
|
|
distinctAttrMap.try_emplace(distinctAttr, distinctCounter);
|
|
if (inserted)
|
|
distinctCounter++;
|
|
return it->getSecond();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Resources
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
AsmParsedResourceEntry::~AsmParsedResourceEntry() = default;
|
|
AsmResourceBuilder::~AsmResourceBuilder() = default;
|
|
AsmResourceParser::~AsmResourceParser() = default;
|
|
AsmResourcePrinter::~AsmResourcePrinter() = default;
|
|
|
|
StringRef mlir::toString(AsmResourceEntryKind kind) {
|
|
switch (kind) {
|
|
case AsmResourceEntryKind::Blob:
|
|
return "blob";
|
|
case AsmResourceEntryKind::Bool:
|
|
return "bool";
|
|
case AsmResourceEntryKind::String:
|
|
return "string";
|
|
}
|
|
llvm_unreachable("unknown AsmResourceEntryKind");
|
|
}
|
|
|
|
AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) {
|
|
std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()];
|
|
if (!collection)
|
|
collection = std::make_unique<ResourceCollection>(key);
|
|
return *collection;
|
|
}
|
|
|
|
std::vector<std::unique_ptr<AsmResourcePrinter>>
|
|
FallbackAsmResourceMap::getPrinters() {
|
|
std::vector<std::unique_ptr<AsmResourcePrinter>> printers;
|
|
for (auto &it : keyToResources) {
|
|
ResourceCollection *collection = it.second.get();
|
|
auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) {
|
|
return collection->buildResources(op, builder);
|
|
};
|
|
printers.emplace_back(
|
|
AsmResourcePrinter::fromCallable(collection->getName(), buildValues));
|
|
}
|
|
return printers;
|
|
}
|
|
|
|
LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource(
|
|
AsmParsedResourceEntry &entry) {
|
|
switch (entry.getKind()) {
|
|
case AsmResourceEntryKind::Blob: {
|
|
FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
|
|
if (failed(blob))
|
|
return failure();
|
|
resources.emplace_back(entry.getKey(), std::move(*blob));
|
|
return success();
|
|
}
|
|
case AsmResourceEntryKind::Bool: {
|
|
FailureOr<bool> value = entry.parseAsBool();
|
|
if (failed(value))
|
|
return failure();
|
|
resources.emplace_back(entry.getKey(), *value);
|
|
break;
|
|
}
|
|
case AsmResourceEntryKind::String: {
|
|
FailureOr<std::string> str = entry.parseAsString();
|
|
if (failed(str))
|
|
return failure();
|
|
resources.emplace_back(entry.getKey(), std::move(*str));
|
|
break;
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void FallbackAsmResourceMap::ResourceCollection::buildResources(
|
|
Operation *op, AsmResourceBuilder &builder) const {
|
|
for (const auto &entry : resources) {
|
|
if (const auto *value = std::get_if<AsmResourceBlob>(&entry.value))
|
|
builder.buildBlob(entry.key, *value);
|
|
else if (const auto *value = std::get_if<bool>(&entry.value))
|
|
builder.buildBool(entry.key, *value);
|
|
else if (const auto *value = std::get_if<std::string>(&entry.value))
|
|
builder.buildString(entry.key, *value);
|
|
else
|
|
llvm_unreachable("unknown AsmResourceEntryKind");
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AsmState
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace mlir {
|
|
namespace detail {
|
|
class AsmStateImpl {
|
|
public:
|
|
explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
|
|
AsmState::LocationMap *locationMap)
|
|
: interfaces(op->getContext()), nameState(op, printerFlags),
|
|
printerFlags(printerFlags), locationMap(locationMap) {}
|
|
explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
|
|
AsmState::LocationMap *locationMap)
|
|
: interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {}
|
|
|
|
/// Initialize the alias state to enable the printing of aliases.
|
|
void initializeAliases(Operation *op) {
|
|
aliasState.initialize(op, printerFlags, interfaces);
|
|
}
|
|
|
|
/// Get the state used for aliases.
|
|
AliasState &getAliasState() { return aliasState; }
|
|
|
|
/// Get the state used for SSA names.
|
|
SSANameState &getSSANameState() { return nameState; }
|
|
|
|
/// Get the state used for distinct attribute identifiers.
|
|
DistinctState &getDistinctState() { return distinctState; }
|
|
|
|
/// Return the dialects within the context that implement
|
|
/// OpAsmDialectInterface.
|
|
DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() {
|
|
return interfaces;
|
|
}
|
|
|
|
/// Return the non-dialect resource printers.
|
|
auto getResourcePrinters() {
|
|
return llvm::make_pointee_range(externalResourcePrinters);
|
|
}
|
|
|
|
/// Get the printer flags.
|
|
const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
|
|
|
|
/// Register the location, line and column, within the buffer that the given
|
|
/// operation was printed at.
|
|
void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
|
|
if (locationMap)
|
|
(*locationMap)[op] = std::make_pair(line, col);
|
|
}
|
|
|
|
/// Return the referenced dialect resources within the printer.
|
|
DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
|
|
getDialectResources() {
|
|
return dialectResources;
|
|
}
|
|
|
|
LogicalResult pushCyclicPrinting(const void *opaquePointer) {
|
|
return success(cyclicPrintingStack.insert(opaquePointer));
|
|
}
|
|
|
|
void popCyclicPrinting() { cyclicPrintingStack.pop_back(); }
|
|
|
|
private:
|
|
/// Collection of OpAsm interfaces implemented in the context.
|
|
DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
|
|
|
|
/// A collection of non-dialect resource printers.
|
|
SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
|
|
|
|
/// A set of dialect resources that were referenced during printing.
|
|
DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
|
|
|
|
/// The state used for attribute and type aliases.
|
|
AliasState aliasState;
|
|
|
|
/// The state used for SSA value names.
|
|
SSANameState nameState;
|
|
|
|
/// The state used for distinct attribute identifiers.
|
|
DistinctState distinctState;
|
|
|
|
/// Flags that control op output.
|
|
OpPrintingFlags printerFlags;
|
|
|
|
/// An optional location map to be populated.
|
|
AsmState::LocationMap *locationMap;
|
|
|
|
/// Stack of potentially cyclic mutable attributes or type currently being
|
|
/// printed.
|
|
SetVector<const void *> cyclicPrintingStack;
|
|
|
|
// Allow direct access to the impl fields.
|
|
friend AsmState;
|
|
};
|
|
|
|
template <typename Range>
|
|
void printDimensionList(raw_ostream &stream, Range &&shape) {
|
|
llvm::interleave(
|
|
shape, stream,
|
|
[&stream](const auto &dimSize) {
|
|
if (ShapedType::isDynamic(dimSize))
|
|
stream << "?";
|
|
else
|
|
stream << dimSize;
|
|
},
|
|
"x");
|
|
}
|
|
|
|
} // namespace detail
|
|
} // namespace mlir
|
|
|
|
/// Verifies the operation and switches to generic op printing if verification
|
|
/// fails. We need to do this because custom print functions may fail for
|
|
/// invalid ops.
|
|
static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
|
|
OpPrintingFlags printerFlags) {
|
|
if (printerFlags.shouldPrintGenericOpForm() ||
|
|
printerFlags.shouldAssumeVerified())
|
|
return printerFlags;
|
|
|
|
// Ignore errors emitted by the verifier. We check the thread id to avoid
|
|
// consuming other threads' errors.
|
|
auto parentThreadId = llvm::get_threadid();
|
|
ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) {
|
|
if (parentThreadId == llvm::get_threadid()) {
|
|
LLVM_DEBUG({
|
|
diag.print(llvm::dbgs());
|
|
llvm::dbgs() << "\n";
|
|
});
|
|
return success();
|
|
}
|
|
return failure();
|
|
});
|
|
if (failed(verify(op))) {
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< DEBUG_TYPE << ": '" << op->getName()
|
|
<< "' failed to verify and will be printed in generic form\n");
|
|
printerFlags.printGenericOpForm();
|
|
}
|
|
|
|
return printerFlags;
|
|
}
|
|
|
|
AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
|
|
LocationMap *locationMap, FallbackAsmResourceMap *map)
|
|
: impl(std::make_unique<AsmStateImpl>(
|
|
op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {
|
|
if (map)
|
|
attachFallbackResourcePrinter(*map);
|
|
}
|
|
AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
|
|
LocationMap *locationMap, FallbackAsmResourceMap *map)
|
|
: impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {
|
|
if (map)
|
|
attachFallbackResourcePrinter(*map);
|
|
}
|
|
AsmState::~AsmState() = default;
|
|
|
|
const OpPrintingFlags &AsmState::getPrinterFlags() const {
|
|
return impl->getPrinterFlags();
|
|
}
|
|
|
|
void AsmState::attachResourcePrinter(
|
|
std::unique_ptr<AsmResourcePrinter> printer) {
|
|
impl->externalResourcePrinters.emplace_back(std::move(printer));
|
|
}
|
|
|
|
DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
|
|
AsmState::getDialectResources() const {
|
|
return impl->getDialectResources();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AsmPrinter::Impl
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state)
|
|
: os(os), state(state), printerFlags(state.getPrinterFlags()) {}
|
|
|
|
void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
|
|
// Check to see if we are printing debug information.
|
|
if (!printerFlags.shouldPrintDebugInfo())
|
|
return;
|
|
|
|
os << " ";
|
|
printLocation(loc, /*allowAlias=*/allowAlias);
|
|
}
|
|
|
|
void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
|
|
bool isTopLevel) {
|
|
// If this isn't a top-level location, check for an alias.
|
|
if (!isTopLevel && succeeded(state.getAliasState().getAlias(loc, os)))
|
|
return;
|
|
|
|
TypeSwitch<LocationAttr>(loc)
|
|
.Case<OpaqueLoc>([&](OpaqueLoc loc) {
|
|
printLocationInternal(loc.getFallbackLocation(), pretty);
|
|
})
|
|
.Case<UnknownLoc>([&](UnknownLoc loc) {
|
|
if (pretty)
|
|
os << "[unknown]";
|
|
else
|
|
os << "unknown";
|
|
})
|
|
.Case<FileLineColLoc>([&](FileLineColLoc loc) {
|
|
if (pretty)
|
|
os << loc.getFilename().getValue();
|
|
else
|
|
printEscapedString(loc.getFilename());
|
|
os << ':' << loc.getLine() << ':' << loc.getColumn();
|
|
})
|
|
.Case<NameLoc>([&](NameLoc loc) {
|
|
printEscapedString(loc.getName());
|
|
|
|
// Print the child if it isn't unknown.
|
|
auto childLoc = loc.getChildLoc();
|
|
if (!llvm::isa<UnknownLoc>(childLoc)) {
|
|
os << '(';
|
|
printLocationInternal(childLoc, pretty);
|
|
os << ')';
|
|
}
|
|
})
|
|
.Case<CallSiteLoc>([&](CallSiteLoc loc) {
|
|
Location caller = loc.getCaller();
|
|
Location callee = loc.getCallee();
|
|
if (!pretty)
|
|
os << "callsite(";
|
|
printLocationInternal(callee, pretty);
|
|
if (pretty) {
|
|
if (llvm::isa<NameLoc>(callee)) {
|
|
if (llvm::isa<FileLineColLoc>(caller)) {
|
|
os << " at ";
|
|
} else {
|
|
os << newLine << " at ";
|
|
}
|
|
} else {
|
|
os << newLine << " at ";
|
|
}
|
|
} else {
|
|
os << " at ";
|
|
}
|
|
printLocationInternal(caller, pretty);
|
|
if (!pretty)
|
|
os << ")";
|
|
})
|
|
.Case<FusedLoc>([&](FusedLoc loc) {
|
|
if (!pretty)
|
|
os << "fused";
|
|
if (Attribute metadata = loc.getMetadata()) {
|
|
os << '<';
|
|
printAttribute(metadata);
|
|
os << '>';
|
|
}
|
|
os << '[';
|
|
interleave(
|
|
loc.getLocations(),
|
|
[&](Location loc) { printLocationInternal(loc, pretty); },
|
|
[&]() { os << ", "; });
|
|
os << ']';
|
|
})
|
|
.Default([&](LocationAttr loc) {
|
|
// Assumes that this is a dialect-specific attribute and prints it
|
|
// directly.
|
|
printAttribute(loc);
|
|
});
|
|
}
|
|
|
|
/// Print a floating point value in a way that the parser will be able to
|
|
/// round-trip losslessly.
|
|
static void printFloatValue(const APFloat &apValue, raw_ostream &os,
|
|
bool *printedHex = nullptr) {
|
|
// We would like to output the FP constant value in exponential notation,
|
|
// but we cannot do this if doing so will lose precision. Check here to
|
|
// make sure that we only output it in exponential format if we can parse
|
|
// the value back and get the same value.
|
|
bool isInf = apValue.isInfinity();
|
|
bool isNaN = apValue.isNaN();
|
|
if (!isInf && !isNaN) {
|
|
SmallString<128> strValue;
|
|
apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
|
|
/*TruncateZero=*/false);
|
|
|
|
// Check to make sure that the stringized number is not some string like
|
|
// "Inf" or NaN, that atof will accept, but the lexer will not. Check
|
|
// that the string matches the "[-+]?[0-9]" regex.
|
|
assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
|
|
((strValue[0] == '-' || strValue[0] == '+') &&
|
|
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
|
|
"[-+]?[0-9] regex does not match!");
|
|
|
|
// Parse back the stringized version and check that the value is equal
|
|
// (i.e., there is no precision loss).
|
|
if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
|
|
os << strValue;
|
|
return;
|
|
}
|
|
|
|
// If it is not, use the default format of APFloat instead of the
|
|
// exponential notation.
|
|
strValue.clear();
|
|
apValue.toString(strValue);
|
|
|
|
// Make sure that we can parse the default form as a float.
|
|
if (strValue.str().contains('.')) {
|
|
os << strValue;
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Print special values in hexadecimal format. The sign bit should be included
|
|
// in the literal.
|
|
if (printedHex)
|
|
*printedHex = true;
|
|
SmallVector<char, 16> str;
|
|
APInt apInt = apValue.bitcastToAPInt();
|
|
apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
|
|
/*formatAsCLiteral=*/true);
|
|
os << str;
|
|
}
|
|
|
|
void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
|
|
if (printerFlags.shouldPrintDebugInfoPrettyForm())
|
|
return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true);
|
|
|
|
os << "loc(";
|
|
if (!allowAlias || failed(printAlias(loc)))
|
|
printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true);
|
|
os << ')';
|
|
}
|
|
|
|
void AsmPrinter::Impl::printResourceHandle(
|
|
const AsmDialectResourceHandle &resource) {
|
|
auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
|
|
os << interface->getResourceKey(resource);
|
|
state.getDialectResources()[resource.getDialect()].insert(resource);
|
|
}
|
|
|
|
/// Returns true if the given dialect symbol data is simple enough to print in
|
|
/// the pretty form. This is essentially when the symbol takes the form:
|
|
/// identifier (`<` body `>`)?
|
|
static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
|
|
// The name must start with an identifier.
|
|
if (symName.empty() || !isalpha(symName.front()))
|
|
return false;
|
|
|
|
// Ignore all the characters that are valid in an identifier in the symbol
|
|
// name.
|
|
symName = symName.drop_while(
|
|
[](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
|
|
if (symName.empty())
|
|
return true;
|
|
|
|
// If we got to an unexpected character, then it must be a <>. Check that the
|
|
// rest of the symbol is wrapped within <>.
|
|
return symName.front() == '<' && symName.back() == '>';
|
|
}
|
|
|
|
/// Print the given dialect symbol to the stream.
|
|
static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
|
|
StringRef dialectName, StringRef symString) {
|
|
os << symPrefix << dialectName;
|
|
|
|
// If this symbol name is simple enough, print it directly in pretty form,
|
|
// otherwise, we print it as an escaped string.
|
|
if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
|
|
os << '.' << symString;
|
|
return;
|
|
}
|
|
|
|
os << '<' << symString << '>';
|
|
}
|
|
|
|
/// Returns true if the given string can be represented as a bare identifier.
|
|
static bool isBareIdentifier(StringRef name) {
|
|
// By making this unsigned, the value passed in to isalnum will always be
|
|
// in the range 0-255. This is important when building with MSVC because
|
|
// its implementation will assert. This situation can arise when dealing
|
|
// with UTF-8 multibyte characters.
|
|
if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
|
|
return false;
|
|
return llvm::all_of(name.drop_front(), [](unsigned char c) {
|
|
return isalnum(c) || c == '_' || c == '$' || c == '.';
|
|
});
|
|
}
|
|
|
|
/// Print the given string as a keyword, or a quoted and escaped string if it
|
|
/// has any special or non-printable characters in it.
|
|
static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
|
|
// If it can be represented as a bare identifier, write it directly.
|
|
if (isBareIdentifier(keyword)) {
|
|
os << keyword;
|
|
return;
|
|
}
|
|
|
|
// Otherwise, output the keyword wrapped in quotes with proper escaping.
|
|
os << "\"";
|
|
printEscapedString(keyword, os);
|
|
os << '"';
|
|
}
|
|
|
|
/// Print the given string as a symbol reference. A symbol reference is
|
|
/// represented as a string prefixed with '@'. The reference is surrounded with
|
|
/// ""'s and escaped if it has any special or non-printable characters in it.
|
|
static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
|
|
if (symbolRef.empty()) {
|
|
os << "@<<INVALID EMPTY SYMBOL>>";
|
|
return;
|
|
}
|
|
os << '@';
|
|
printKeywordOrString(symbolRef, os);
|
|
}
|
|
|
|
// Print out a valid ElementsAttr that is succinct and can represent any
|
|
// potential shape/type, for use when eliding a large ElementsAttr.
|
|
//
|
|
// We choose to use a dense resource ElementsAttr literal with conspicuous
|
|
// content to hopefully alert readers to the fact that this has been elided.
|
|
static void printElidedElementsAttr(raw_ostream &os) {
|
|
os << R"(dense_resource<__elided__>)";
|
|
}
|
|
|
|
LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
|
|
return state.getAliasState().getAlias(attr, os);
|
|
}
|
|
|
|
LogicalResult AsmPrinter::Impl::printAlias(Type type) {
|
|
return state.getAliasState().getAlias(type, os);
|
|
}
|
|
|
|
void AsmPrinter::Impl::printAttribute(Attribute attr,
|
|
AttrTypeElision typeElision) {
|
|
if (!attr) {
|
|
os << "<<NULL ATTRIBUTE>>";
|
|
return;
|
|
}
|
|
|
|
// Try to print an alias for this attribute.
|
|
if (succeeded(printAlias(attr)))
|
|
return;
|
|
return printAttributeImpl(attr, typeElision);
|
|
}
|
|
|
|
void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
|
|
AttrTypeElision typeElision) {
|
|
if (!isa<BuiltinDialect>(attr.getDialect())) {
|
|
printDialectAttribute(attr);
|
|
} else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) {
|
|
printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
|
|
opaqueAttr.getAttrData());
|
|
} else if (llvm::isa<UnitAttr>(attr)) {
|
|
os << "unit";
|
|
return;
|
|
} else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(attr)) {
|
|
os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<";
|
|
if (!llvm::isa<UnitAttr>(distinctAttr.getReferencedAttr())) {
|
|
printAttribute(distinctAttr.getReferencedAttr());
|
|
}
|
|
os << '>';
|
|
return;
|
|
} else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
|
|
os << '{';
|
|
interleaveComma(dictAttr.getValue(),
|
|
[&](NamedAttribute attr) { printNamedAttribute(attr); });
|
|
os << '}';
|
|
|
|
} else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
|
|
Type intType = intAttr.getType();
|
|
if (intType.isSignlessInteger(1)) {
|
|
os << (intAttr.getValue().getBoolValue() ? "true" : "false");
|
|
|
|
// Boolean integer attributes always elides the type.
|
|
return;
|
|
}
|
|
|
|
// Only print attributes as unsigned if they are explicitly unsigned or are
|
|
// signless 1-bit values. Indexes, signed values, and multi-bit signless
|
|
// values print as signed.
|
|
bool isUnsigned =
|
|
intType.isUnsignedInteger() || intType.isSignlessInteger(1);
|
|
intAttr.getValue().print(os, !isUnsigned);
|
|
|
|
// IntegerAttr elides the type if I64.
|
|
if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64))
|
|
return;
|
|
|
|
} else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) {
|
|
bool printedHex = false;
|
|
printFloatValue(floatAttr.getValue(), os, &printedHex);
|
|
|
|
// FloatAttr elides the type if F64.
|
|
if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64() &&
|
|
!printedHex)
|
|
return;
|
|
|
|
} else if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) {
|
|
printEscapedString(strAttr.getValue());
|
|
|
|
} else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) {
|
|
os << '[';
|
|
interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
|
|
printAttribute(attr, AttrTypeElision::May);
|
|
});
|
|
os << ']';
|
|
|
|
} else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) {
|
|
os << "affine_map<";
|
|
affineMapAttr.getValue().print(os);
|
|
os << '>';
|
|
|
|
// AffineMap always elides the type.
|
|
return;
|
|
|
|
} else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(attr)) {
|
|
os << "affine_set<";
|
|
integerSetAttr.getValue().print(os);
|
|
os << '>';
|
|
|
|
// IntegerSet always elides the type.
|
|
return;
|
|
|
|
} else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(attr)) {
|
|
printType(typeAttr.getValue());
|
|
|
|
} else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) {
|
|
printSymbolReference(refAttr.getRootReference().getValue(), os);
|
|
for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
|
|
os << "::";
|
|
printSymbolReference(nestedRef.getValue(), os);
|
|
}
|
|
|
|
} else if (auto intOrFpEltAttr =
|
|
llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
|
|
if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
|
|
printElidedElementsAttr(os);
|
|
} else {
|
|
os << "dense<";
|
|
printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
|
|
os << '>';
|
|
}
|
|
|
|
} else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) {
|
|
if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
|
|
printElidedElementsAttr(os);
|
|
} else {
|
|
os << "dense<";
|
|
printDenseStringElementsAttr(strEltAttr);
|
|
os << '>';
|
|
}
|
|
|
|
} else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) {
|
|
if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
|
|
printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
|
|
printElidedElementsAttr(os);
|
|
} else {
|
|
os << "sparse<";
|
|
DenseIntElementsAttr indices = sparseEltAttr.getIndices();
|
|
if (indices.getNumElements() != 0) {
|
|
printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
|
|
os << ", ";
|
|
printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
|
|
}
|
|
os << '>';
|
|
}
|
|
} else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(attr)) {
|
|
stridedLayoutAttr.print(os);
|
|
} else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(attr)) {
|
|
os << "array<";
|
|
printType(denseArrayAttr.getElementType());
|
|
if (!denseArrayAttr.empty()) {
|
|
os << ": ";
|
|
printDenseArrayAttr(denseArrayAttr);
|
|
}
|
|
os << ">";
|
|
return;
|
|
} else if (auto resourceAttr =
|
|
llvm::dyn_cast<DenseResourceElementsAttr>(attr)) {
|
|
os << "dense_resource<";
|
|
printResourceHandle(resourceAttr.getRawHandle());
|
|
os << ">";
|
|
} else if (auto locAttr = llvm::dyn_cast<LocationAttr>(attr)) {
|
|
printLocation(locAttr);
|
|
} else {
|
|
llvm::report_fatal_error("Unknown builtin attribute");
|
|
}
|
|
// Don't print the type if we must elide it, or if it is a None type.
|
|
if (typeElision != AttrTypeElision::Must) {
|
|
if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
|
|
Type attrType = typedAttr.getType();
|
|
if (!llvm::isa<NoneType>(attrType)) {
|
|
os << " : ";
|
|
printType(attrType);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Print the integer element of a DenseElementsAttr.
|
|
static void printDenseIntElement(const APInt &value, raw_ostream &os,
|
|
Type type) {
|
|
if (type.isInteger(1))
|
|
os << (value.getBoolValue() ? "true" : "false");
|
|
else
|
|
value.print(os, !type.isUnsignedInteger());
|
|
}
|
|
|
|
static void
|
|
printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
|
|
function_ref<void(unsigned)> printEltFn) {
|
|
// Special case for 0-d and splat tensors.
|
|
if (isSplat)
|
|
return printEltFn(0);
|
|
|
|
// Special case for degenerate tensors.
|
|
auto numElements = type.getNumElements();
|
|
if (numElements == 0)
|
|
return;
|
|
|
|
// We use a mixed-radix counter to iterate through the shape. When we bump a
|
|
// non-least-significant digit, we emit a close bracket. When we next emit an
|
|
// element we re-open all closed brackets.
|
|
|
|
// The mixed-radix counter, with radices in 'shape'.
|
|
int64_t rank = type.getRank();
|
|
SmallVector<unsigned, 4> counter(rank, 0);
|
|
// The number of brackets that have been opened and not closed.
|
|
unsigned openBrackets = 0;
|
|
|
|
auto shape = type.getShape();
|
|
auto bumpCounter = [&] {
|
|
// Bump the least significant digit.
|
|
++counter[rank - 1];
|
|
// Iterate backwards bubbling back the increment.
|
|
for (unsigned i = rank - 1; i > 0; --i)
|
|
if (counter[i] >= shape[i]) {
|
|
// Index 'i' is rolled over. Bump (i-1) and close a bracket.
|
|
counter[i] = 0;
|
|
++counter[i - 1];
|
|
--openBrackets;
|
|
os << ']';
|
|
}
|
|
};
|
|
|
|
for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
|
|
if (idx != 0)
|
|
os << ", ";
|
|
while (openBrackets++ < rank)
|
|
os << '[';
|
|
openBrackets = rank;
|
|
printEltFn(idx);
|
|
bumpCounter();
|
|
}
|
|
while (openBrackets-- > 0)
|
|
os << ']';
|
|
}
|
|
|
|
void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
|
|
bool allowHex) {
|
|
if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr))
|
|
return printDenseStringElementsAttr(stringAttr);
|
|
|
|
printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr),
|
|
allowHex);
|
|
}
|
|
|
|
void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
|
|
DenseIntOrFPElementsAttr attr, bool allowHex) {
|
|
auto type = attr.getType();
|
|
auto elementType = type.getElementType();
|
|
|
|
// Check to see if we should format this attribute as a hex string.
|
|
if (allowHex && printerFlags.shouldPrintElementsAttrWithHex(attr)) {
|
|
ArrayRef<char> rawData = attr.getRawData();
|
|
if (llvm::endianness::native == llvm::endianness::big) {
|
|
// Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
|
|
// machines. It is converted here to print in LE format.
|
|
SmallVector<char, 64> outDataVec(rawData.size());
|
|
MutableArrayRef<char> convRawData(outDataVec);
|
|
DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
|
|
rawData, convRawData, type);
|
|
printHexString(convRawData);
|
|
} else {
|
|
printHexString(rawData);
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) {
|
|
Type complexElementType = complexTy.getElementType();
|
|
// Note: The if and else below had a common lambda function which invoked
|
|
// printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
|
|
// and hence was replaced.
|
|
if (llvm::isa<IntegerType>(complexElementType)) {
|
|
auto valueIt = attr.value_begin<std::complex<APInt>>();
|
|
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
|
|
auto complexValue = *(valueIt + index);
|
|
os << "(";
|
|
printDenseIntElement(complexValue.real(), os, complexElementType);
|
|
os << ",";
|
|
printDenseIntElement(complexValue.imag(), os, complexElementType);
|
|
os << ")";
|
|
});
|
|
} else {
|
|
auto valueIt = attr.value_begin<std::complex<APFloat>>();
|
|
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
|
|
auto complexValue = *(valueIt + index);
|
|
os << "(";
|
|
printFloatValue(complexValue.real(), os);
|
|
os << ",";
|
|
printFloatValue(complexValue.imag(), os);
|
|
os << ")";
|
|
});
|
|
}
|
|
} else if (elementType.isIntOrIndex()) {
|
|
auto valueIt = attr.value_begin<APInt>();
|
|
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
|
|
printDenseIntElement(*(valueIt + index), os, elementType);
|
|
});
|
|
} else {
|
|
assert(llvm::isa<FloatType>(elementType) && "unexpected element type");
|
|
auto valueIt = attr.value_begin<APFloat>();
|
|
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
|
|
printFloatValue(*(valueIt + index), os);
|
|
});
|
|
}
|
|
}
|
|
|
|
void AsmPrinter::Impl::printDenseStringElementsAttr(
|
|
DenseStringElementsAttr attr) {
|
|
ArrayRef<StringRef> data = attr.getRawStringData();
|
|
auto printFn = [&](unsigned index) { printEscapedString(data[index]); };
|
|
printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
|
|
}
|
|
|
|
void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
|
|
Type type = attr.getElementType();
|
|
unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth();
|
|
unsigned byteSize = bitwidth / 8;
|
|
ArrayRef<char> data = attr.getRawData();
|
|
|
|
auto printElementAt = [&](unsigned i) {
|
|
APInt value(bitwidth, 0);
|
|
if (bitwidth) {
|
|
llvm::LoadIntFromMemory(
|
|
value, reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i),
|
|
byteSize);
|
|
}
|
|
// Print the data as-is or as a float.
|
|
if (type.isIntOrIndex()) {
|
|
printDenseIntElement(value, getStream(), type);
|
|
} else {
|
|
APFloat fltVal(llvm::cast<FloatType>(type).getFloatSemantics(), value);
|
|
printFloatValue(fltVal, getStream());
|
|
}
|
|
};
|
|
llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
|
|
printElementAt);
|
|
}
|
|
|
|
void AsmPrinter::Impl::printType(Type type) {
|
|
if (!type) {
|
|
os << "<<NULL TYPE>>";
|
|
return;
|
|
}
|
|
|
|
// Try to print an alias for this type.
|
|
if (succeeded(printAlias(type)))
|
|
return;
|
|
return printTypeImpl(type);
|
|
}
|
|
|
|
void AsmPrinter::Impl::printTypeImpl(Type type) {
|
|
TypeSwitch<Type>(type)
|
|
.Case<OpaqueType>([&](OpaqueType opaqueTy) {
|
|
printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
|
|
opaqueTy.getTypeData());
|
|
})
|
|
.Case<IndexType>([&](Type) { os << "index"; })
|
|
.Case<Float4E2M1FNType>([&](Type) { os << "f4E2M1FN"; })
|
|
.Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN"; })
|
|
.Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
|
|
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
|
|
.Case<Float8E4M3Type>([&](Type) { os << "f8E4M3"; })
|
|
.Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
|
|
.Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
|
|
.Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
|
|
.Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
|
|
.Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
|
|
.Case<Float8E8M0FNUType>([&](Type) { os << "f8E8M0FNU"; })
|
|
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
|
|
.Case<Float16Type>([&](Type) { os << "f16"; })
|
|
.Case<FloatTF32Type>([&](Type) { os << "tf32"; })
|
|
.Case<Float32Type>([&](Type) { os << "f32"; })
|
|
.Case<Float64Type>([&](Type) { os << "f64"; })
|
|
.Case<Float80Type>([&](Type) { os << "f80"; })
|
|
.Case<Float128Type>([&](Type) { os << "f128"; })
|
|
.Case<IntegerType>([&](IntegerType integerTy) {
|
|
if (integerTy.isSigned())
|
|
os << 's';
|
|
else if (integerTy.isUnsigned())
|
|
os << 'u';
|
|
os << 'i' << integerTy.getWidth();
|
|
})
|
|
.Case<FunctionType>([&](FunctionType funcTy) {
|
|
os << '(';
|
|
interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
|
|
os << ") -> ";
|
|
ArrayRef<Type> results = funcTy.getResults();
|
|
if (results.size() == 1 && !llvm::isa<FunctionType>(results[0])) {
|
|
printType(results[0]);
|
|
} else {
|
|
os << '(';
|
|
interleaveComma(results, [&](Type ty) { printType(ty); });
|
|
os << ')';
|
|
}
|
|
})
|
|
.Case<VectorType>([&](VectorType vectorTy) {
|
|
auto scalableDims = vectorTy.getScalableDims();
|
|
os << "vector<";
|
|
auto vShape = vectorTy.getShape();
|
|
unsigned lastDim = vShape.size();
|
|
unsigned dimIdx = 0;
|
|
for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
|
|
if (!scalableDims.empty() && scalableDims[dimIdx])
|
|
os << '[';
|
|
os << vShape[dimIdx];
|
|
if (!scalableDims.empty() && scalableDims[dimIdx])
|
|
os << ']';
|
|
os << 'x';
|
|
}
|
|
printType(vectorTy.getElementType());
|
|
os << '>';
|
|
})
|
|
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
|
|
os << "tensor<";
|
|
printDimensionList(tensorTy.getShape());
|
|
if (!tensorTy.getShape().empty())
|
|
os << 'x';
|
|
printType(tensorTy.getElementType());
|
|
// Only print the encoding attribute value if set.
|
|
if (tensorTy.getEncoding()) {
|
|
os << ", ";
|
|
printAttribute(tensorTy.getEncoding());
|
|
}
|
|
os << '>';
|
|
})
|
|
.Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
|
|
os << "tensor<*x";
|
|
printType(tensorTy.getElementType());
|
|
os << '>';
|
|
})
|
|
.Case<MemRefType>([&](MemRefType memrefTy) {
|
|
os << "memref<";
|
|
printDimensionList(memrefTy.getShape());
|
|
if (!memrefTy.getShape().empty())
|
|
os << 'x';
|
|
printType(memrefTy.getElementType());
|
|
MemRefLayoutAttrInterface layout = memrefTy.getLayout();
|
|
if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
|
|
os << ", ";
|
|
printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
|
|
}
|
|
// Only print the memory space if it is the non-default one.
|
|
if (memrefTy.getMemorySpace()) {
|
|
os << ", ";
|
|
printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
|
|
}
|
|
os << '>';
|
|
})
|
|
.Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
|
|
os << "memref<*x";
|
|
printType(memrefTy.getElementType());
|
|
// Only print the memory space if it is the non-default one.
|
|
if (memrefTy.getMemorySpace()) {
|
|
os << ", ";
|
|
printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
|
|
}
|
|
os << '>';
|
|
})
|
|
.Case<ComplexType>([&](ComplexType complexTy) {
|
|
os << "complex<";
|
|
printType(complexTy.getElementType());
|
|
os << '>';
|
|
})
|
|
.Case<TupleType>([&](TupleType tupleTy) {
|
|
os << "tuple<";
|
|
interleaveComma(tupleTy.getTypes(),
|
|
[&](Type type) { printType(type); });
|
|
os << '>';
|
|
})
|
|
.Case<NoneType>([&](Type) { os << "none"; })
|
|
.Default([&](Type type) { return printDialectType(type); });
|
|
}
|
|
|
|
void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs,
|
|
bool withKeyword) {
|
|
// If there are no attributes, then there is nothing to be done.
|
|
if (attrs.empty())
|
|
return;
|
|
|
|
// Functor used to print a filtered attribute list.
|
|
auto printFilteredAttributesFn = [&](auto filteredAttrs) {
|
|
// Print the 'attributes' keyword if necessary.
|
|
if (withKeyword)
|
|
os << " attributes";
|
|
|
|
// Otherwise, print them all out in braces.
|
|
os << " {";
|
|
interleaveComma(filteredAttrs,
|
|
[&](NamedAttribute attr) { printNamedAttribute(attr); });
|
|
os << '}';
|
|
};
|
|
|
|
// If no attributes are elided, we can directly print with no filtering.
|
|
if (elidedAttrs.empty())
|
|
return printFilteredAttributesFn(attrs);
|
|
|
|
// Otherwise, filter out any attributes that shouldn't be included.
|
|
llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
|
|
elidedAttrs.end());
|
|
auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
|
|
return !elidedAttrsSet.contains(attr.getName().strref());
|
|
});
|
|
if (!filteredAttrs.empty())
|
|
printFilteredAttributesFn(filteredAttrs);
|
|
}
|
|
void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
|
|
// Print the name without quotes if possible.
|
|
::printKeywordOrString(attr.getName().strref(), os);
|
|
|
|
// Pretty printing elides the attribute value for unit attributes.
|
|
if (llvm::isa<UnitAttr>(attr.getValue()))
|
|
return;
|
|
|
|
os << " = ";
|
|
printAttribute(attr.getValue());
|
|
}
|
|
|
|
void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
|
|
auto &dialect = attr.getDialect();
|
|
|
|
// Ask the dialect to serialize the attribute to a string.
|
|
std::string attrName;
|
|
{
|
|
llvm::raw_string_ostream attrNameStr(attrName);
|
|
Impl subPrinter(attrNameStr, state);
|
|
DialectAsmPrinter printer(subPrinter);
|
|
dialect.printAttribute(attr, printer);
|
|
}
|
|
printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
|
|
}
|
|
|
|
void AsmPrinter::Impl::printDialectType(Type type) {
|
|
auto &dialect = type.getDialect();
|
|
|
|
// Ask the dialect to serialize the type to a string.
|
|
std::string typeName;
|
|
{
|
|
llvm::raw_string_ostream typeNameStr(typeName);
|
|
Impl subPrinter(typeNameStr, state);
|
|
DialectAsmPrinter printer(subPrinter);
|
|
dialect.printType(type, printer);
|
|
}
|
|
printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
|
|
}
|
|
|
|
void AsmPrinter::Impl::printEscapedString(StringRef str) {
|
|
os << "\"";
|
|
llvm::printEscapedString(str, os);
|
|
os << "\"";
|
|
}
|
|
|
|
void AsmPrinter::Impl::printHexString(StringRef str) {
|
|
os << "\"0x" << llvm::toHex(str) << "\"";
|
|
}
|
|
void AsmPrinter::Impl::printHexString(ArrayRef<char> data) {
|
|
printHexString(StringRef(data.data(), data.size()));
|
|
}
|
|
|
|
LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) {
|
|
return state.pushCyclicPrinting(opaquePointer);
|
|
}
|
|
|
|
void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); }
|
|
|
|
void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) {
|
|
detail::printDimensionList(os, shape);
|
|
}
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// AsmPrinter
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
AsmPrinter::~AsmPrinter() = default;
|
|
|
|
raw_ostream &AsmPrinter::getStream() const {
|
|
assert(impl && "expected AsmPrinter::getStream to be overriden");
|
|
return impl->getStream();
|
|
}
|
|
|
|
/// Print the given floating point value in a stablized form.
|
|
void AsmPrinter::printFloat(const APFloat &value) {
|
|
assert(impl && "expected AsmPrinter::printFloat to be overriden");
|
|
printFloatValue(value, impl->getStream());
|
|
}
|
|
|
|
void AsmPrinter::printType(Type type) {
|
|
assert(impl && "expected AsmPrinter::printType to be overriden");
|
|
impl->printType(type);
|
|
}
|
|
|
|
void AsmPrinter::printAttribute(Attribute attr) {
|
|
assert(impl && "expected AsmPrinter::printAttribute to be overriden");
|
|
impl->printAttribute(attr);
|
|
}
|
|
|
|
LogicalResult AsmPrinter::printAlias(Attribute attr) {
|
|
assert(impl && "expected AsmPrinter::printAlias to be overriden");
|
|
return impl->printAlias(attr);
|
|
}
|
|
|
|
LogicalResult AsmPrinter::printAlias(Type type) {
|
|
assert(impl && "expected AsmPrinter::printAlias to be overriden");
|
|
return impl->printAlias(type);
|
|
}
|
|
|
|
void AsmPrinter::printAttributeWithoutType(Attribute attr) {
|
|
assert(impl &&
|
|
"expected AsmPrinter::printAttributeWithoutType to be overriden");
|
|
impl->printAttribute(attr, Impl::AttrTypeElision::Must);
|
|
}
|
|
|
|
void AsmPrinter::printKeywordOrString(StringRef keyword) {
|
|
assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
|
|
::printKeywordOrString(keyword, impl->getStream());
|
|
}
|
|
|
|
void AsmPrinter::printString(StringRef keyword) {
|
|
assert(impl && "expected AsmPrinter::printString to be overriden");
|
|
*this << '"';
|
|
printEscapedString(keyword, getStream());
|
|
*this << '"';
|
|
}
|
|
|
|
void AsmPrinter::printSymbolName(StringRef symbolRef) {
|
|
assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
|
|
::printSymbolReference(symbolRef, impl->getStream());
|
|
}
|
|
|
|
void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
|
|
assert(impl && "expected AsmPrinter::printResourceHandle to be overriden");
|
|
impl->printResourceHandle(resource);
|
|
}
|
|
|
|
void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
|
|
detail::printDimensionList(getStream(), shape);
|
|
}
|
|
|
|
LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
|
|
return impl->pushCyclicPrinting(opaquePointer);
|
|
}
|
|
|
|
void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Affine expressions and maps
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AsmPrinter::Impl::printAffineExpr(
|
|
AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
|
|
printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
|
|
}
|
|
|
|
void AsmPrinter::Impl::printAffineExprInternal(
|
|
AffineExpr expr, BindingStrength enclosingTightness,
|
|
function_ref<void(unsigned, bool)> printValueName) {
|
|
const char *binopSpelling = nullptr;
|
|
switch (expr.getKind()) {
|
|
case AffineExprKind::SymbolId: {
|
|
unsigned pos = cast<AffineSymbolExpr>(expr).getPosition();
|
|
if (printValueName)
|
|
printValueName(pos, /*isSymbol=*/true);
|
|
else
|
|
os << 's' << pos;
|
|
return;
|
|
}
|
|
case AffineExprKind::DimId: {
|
|
unsigned pos = cast<AffineDimExpr>(expr).getPosition();
|
|
if (printValueName)
|
|
printValueName(pos, /*isSymbol=*/false);
|
|
else
|
|
os << 'd' << pos;
|
|
return;
|
|
}
|
|
case AffineExprKind::Constant:
|
|
os << cast<AffineConstantExpr>(expr).getValue();
|
|
return;
|
|
case AffineExprKind::Add:
|
|
binopSpelling = " + ";
|
|
break;
|
|
case AffineExprKind::Mul:
|
|
binopSpelling = " * ";
|
|
break;
|
|
case AffineExprKind::FloorDiv:
|
|
binopSpelling = " floordiv ";
|
|
break;
|
|
case AffineExprKind::CeilDiv:
|
|
binopSpelling = " ceildiv ";
|
|
break;
|
|
case AffineExprKind::Mod:
|
|
binopSpelling = " mod ";
|
|
break;
|
|
}
|
|
|
|
auto binOp = cast<AffineBinaryOpExpr>(expr);
|
|
AffineExpr lhsExpr = binOp.getLHS();
|
|
AffineExpr rhsExpr = binOp.getRHS();
|
|
|
|
// Handle tightly binding binary operators.
|
|
if (binOp.getKind() != AffineExprKind::Add) {
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << '(';
|
|
|
|
// Pretty print multiplication with -1.
|
|
auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr);
|
|
if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
|
|
rhsConst.getValue() == -1) {
|
|
os << "-";
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
|
|
|
|
os << binopSpelling;
|
|
printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
|
|
// Print out special "pretty" forms for add.
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << '(';
|
|
|
|
// Pretty print addition to a product that has a negative operand as a
|
|
// subtraction.
|
|
if (auto rhs = dyn_cast<AffineBinaryOpExpr>(rhsExpr)) {
|
|
if (rhs.getKind() == AffineExprKind::Mul) {
|
|
AffineExpr rrhsExpr = rhs.getRHS();
|
|
if (auto rrhs = dyn_cast<AffineConstantExpr>(rrhsExpr)) {
|
|
if (rrhs.getValue() == -1) {
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak,
|
|
printValueName);
|
|
os << " - ";
|
|
if (rhs.getLHS().getKind() == AffineExprKind::Add) {
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
|
|
printValueName);
|
|
} else {
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
|
|
printValueName);
|
|
}
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
|
|
if (rrhs.getValue() < -1) {
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak,
|
|
printValueName);
|
|
os << " - ";
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
|
|
printValueName);
|
|
os << " * " << -rrhs.getValue();
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Pretty print addition to a negative number as a subtraction.
|
|
if (auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr)) {
|
|
if (rhsConst.getValue() < 0) {
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
|
|
os << " - " << -rhsConst.getValue();
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
return;
|
|
}
|
|
}
|
|
|
|
printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
|
|
|
|
os << " + ";
|
|
printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
os << ')';
|
|
}
|
|
|
|
void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
|
|
printAffineExprInternal(expr, BindingStrength::Weak);
|
|
isEq ? os << " == 0" : os << " >= 0";
|
|
}
|
|
|
|
void AsmPrinter::Impl::printAffineMap(AffineMap map) {
|
|
// Dimension identifiers.
|
|
os << '(';
|
|
for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
|
|
os << 'd' << i << ", ";
|
|
if (map.getNumDims() >= 1)
|
|
os << 'd' << map.getNumDims() - 1;
|
|
os << ')';
|
|
|
|
// Symbolic identifiers.
|
|
if (map.getNumSymbols() != 0) {
|
|
os << '[';
|
|
for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
|
|
os << 's' << i << ", ";
|
|
if (map.getNumSymbols() >= 1)
|
|
os << 's' << map.getNumSymbols() - 1;
|
|
os << ']';
|
|
}
|
|
|
|
// Result affine expressions.
|
|
os << " -> (";
|
|
interleaveComma(map.getResults(),
|
|
[&](AffineExpr expr) { printAffineExpr(expr); });
|
|
os << ')';
|
|
}
|
|
|
|
void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
|
|
// Dimension identifiers.
|
|
os << '(';
|
|
for (unsigned i = 1; i < set.getNumDims(); ++i)
|
|
os << 'd' << i - 1 << ", ";
|
|
if (set.getNumDims() >= 1)
|
|
os << 'd' << set.getNumDims() - 1;
|
|
os << ')';
|
|
|
|
// Symbolic identifiers.
|
|
if (set.getNumSymbols() != 0) {
|
|
os << '[';
|
|
for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
|
|
os << 's' << i << ", ";
|
|
if (set.getNumSymbols() >= 1)
|
|
os << 's' << set.getNumSymbols() - 1;
|
|
os << ']';
|
|
}
|
|
|
|
// Print constraints.
|
|
os << " : (";
|
|
int numConstraints = set.getNumConstraints();
|
|
for (int i = 1; i < numConstraints; ++i) {
|
|
printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
|
|
os << ", ";
|
|
}
|
|
if (numConstraints >= 1)
|
|
printAffineConstraint(set.getConstraint(numConstraints - 1),
|
|
set.isEq(numConstraints - 1));
|
|
os << ')';
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OperationPrinter
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This class contains the logic for printing operations, regions, and blocks.
|
|
class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
|
|
public:
|
|
using Impl = AsmPrinter::Impl;
|
|
using Impl::printType;
|
|
|
|
explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
|
|
: Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
|
|
|
|
/// Print the given top-level operation.
|
|
void printTopLevelOperation(Operation *op);
|
|
|
|
/// Print the given operation, including its left-hand side and its right-hand
|
|
/// side, with its indent and location.
|
|
void printFullOpWithIndentAndLoc(Operation *op);
|
|
/// Print the given operation, including its left-hand side and its right-hand
|
|
/// side, but not including indentation and location.
|
|
void printFullOp(Operation *op);
|
|
/// Print the right-hand size of the given operation in the custom or generic
|
|
/// form.
|
|
void printCustomOrGenericOp(Operation *op) override;
|
|
/// Print the right-hand side of the given operation in the generic form.
|
|
void printGenericOp(Operation *op, bool printOpName) override;
|
|
|
|
/// Print the name of the given block.
|
|
void printBlockName(Block *block);
|
|
|
|
/// Print the given block. If 'printBlockArgs' is false, the arguments of the
|
|
/// block are not printed. If 'printBlockTerminator' is false, the terminator
|
|
/// operation of the block is not printed.
|
|
void print(Block *block, bool printBlockArgs = true,
|
|
bool printBlockTerminator = true);
|
|
|
|
/// Print the ID of the given value, optionally with its result number.
|
|
void printValueID(Value value, bool printResultNo = true,
|
|
raw_ostream *streamOverride = nullptr) const;
|
|
|
|
/// Print the ID of the given operation.
|
|
void printOperationID(Operation *op,
|
|
raw_ostream *streamOverride = nullptr) const;
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// OpAsmPrinter methods
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
/// Print a loc(...) specifier if printing debug info is enabled. Locations
|
|
/// may be deferred with an alias.
|
|
void printOptionalLocationSpecifier(Location loc) override {
|
|
printTrailingLocation(loc);
|
|
}
|
|
|
|
/// Print a newline and indent the printer to the start of the current
|
|
/// operation.
|
|
void printNewline() override {
|
|
os << newLine;
|
|
os.indent(currentIndent);
|
|
}
|
|
|
|
/// Increase indentation.
|
|
void increaseIndent() override { currentIndent += indentWidth; }
|
|
|
|
/// Decrease indentation.
|
|
void decreaseIndent() override { currentIndent -= indentWidth; }
|
|
|
|
/// Print a block argument in the usual format of:
|
|
/// %ssaName : type {attr1=42} loc("here")
|
|
/// where location printing is controlled by the standard internal option.
|
|
/// You may pass omitType=true to not print a type, and pass an empty
|
|
/// attribute list if you don't care for attributes.
|
|
void printRegionArgument(BlockArgument arg,
|
|
ArrayRef<NamedAttribute> argAttrs = {},
|
|
bool omitType = false) override;
|
|
|
|
/// Print the ID for the given value.
|
|
void printOperand(Value value) override { printValueID(value); }
|
|
void printOperand(Value value, raw_ostream &os) override {
|
|
printValueID(value, /*printResultNo=*/true, &os);
|
|
}
|
|
|
|
/// Print an optional attribute dictionary with a given set of elided values.
|
|
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs = {}) override {
|
|
Impl::printOptionalAttrDict(attrs, elidedAttrs);
|
|
}
|
|
void printOptionalAttrDictWithKeyword(
|
|
ArrayRef<NamedAttribute> attrs,
|
|
ArrayRef<StringRef> elidedAttrs = {}) override {
|
|
Impl::printOptionalAttrDict(attrs, elidedAttrs,
|
|
/*withKeyword=*/true);
|
|
}
|
|
|
|
/// Print the given successor.
|
|
void printSuccessor(Block *successor) override;
|
|
|
|
/// Print an operation successor with the operands used for the block
|
|
/// arguments.
|
|
void printSuccessorAndUseList(Block *successor,
|
|
ValueRange succOperands) override;
|
|
|
|
/// Print the given region.
|
|
void printRegion(Region ®ion, bool printEntryBlockArgs,
|
|
bool printBlockTerminators, bool printEmptyBlock) override;
|
|
|
|
/// Renumber the arguments for the specified region to the same names as the
|
|
/// SSA values in namesToUse. This may only be used for IsolatedFromAbove
|
|
/// operations. If any entry in namesToUse is null, the corresponding
|
|
/// argument name is left alone.
|
|
void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override {
|
|
state.getSSANameState().shadowRegionArgs(region, namesToUse);
|
|
}
|
|
|
|
/// Print the given affine map with the symbol and dimension operands printed
|
|
/// inline with the map.
|
|
void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
|
|
ValueRange operands) override;
|
|
|
|
/// Print the given affine expression with the symbol and dimension operands
|
|
/// printed inline with the expression.
|
|
void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
|
|
ValueRange symOperands) override;
|
|
|
|
/// Print users of this operation or id of this operation if it has no result.
|
|
void printUsersComment(Operation *op);
|
|
|
|
/// Print users of this block arg.
|
|
void printUsersComment(BlockArgument arg);
|
|
|
|
/// Print the users of a value.
|
|
void printValueUsers(Value value);
|
|
|
|
/// Print either the ids of the result values or the id of the operation if
|
|
/// the operation has no results.
|
|
void printUserIDs(Operation *user, bool prefixComma = false);
|
|
|
|
private:
|
|
/// This class represents a resource builder implementation for the MLIR
|
|
/// textual assembly format.
|
|
class ResourceBuilder : public AsmResourceBuilder {
|
|
public:
|
|
using ValueFn = function_ref<void(raw_ostream &)>;
|
|
using PrintFn = function_ref<void(StringRef, ValueFn)>;
|
|
|
|
ResourceBuilder(PrintFn printFn) : printFn(printFn) {}
|
|
~ResourceBuilder() override = default;
|
|
|
|
void buildBool(StringRef key, bool data) final {
|
|
printFn(key, [&](raw_ostream &os) { os << (data ? "true" : "false"); });
|
|
}
|
|
|
|
void buildString(StringRef key, StringRef data) final {
|
|
printFn(key, [&](raw_ostream &os) {
|
|
os << "\"";
|
|
llvm::printEscapedString(data, os);
|
|
os << "\"";
|
|
});
|
|
}
|
|
|
|
void buildBlob(StringRef key, ArrayRef<char> data,
|
|
uint32_t dataAlignment) final {
|
|
printFn(key, [&](raw_ostream &os) {
|
|
// Store the blob in a hex string containing the alignment and the data.
|
|
llvm::support::ulittle32_t dataAlignmentLE(dataAlignment);
|
|
os << "\"0x"
|
|
<< llvm::toHex(StringRef(reinterpret_cast<char *>(&dataAlignmentLE),
|
|
sizeof(dataAlignment)))
|
|
<< llvm::toHex(StringRef(data.data(), data.size())) << "\"";
|
|
});
|
|
}
|
|
|
|
private:
|
|
PrintFn printFn;
|
|
};
|
|
|
|
/// Print the metadata dictionary for the file, eliding it if it is empty.
|
|
void printFileMetadataDictionary(Operation *op);
|
|
|
|
/// Print the resource sections for the file metadata dictionary.
|
|
/// `checkAddMetadataDict` is used to indicate that metadata is going to be
|
|
/// added, and the file metadata dictionary should be started if it hasn't
|
|
/// yet.
|
|
void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict,
|
|
Operation *op);
|
|
|
|
// Contains the stack of default dialects to use when printing regions.
|
|
// A new dialect is pushed to the stack before parsing regions nested under an
|
|
// operation implementing `OpAsmOpInterface`, and popped when done. At the
|
|
// top-level we start with "builtin" as the default, so that the top-level
|
|
// `module` operation prints as-is.
|
|
SmallVector<StringRef> defaultDialectStack{"builtin"};
|
|
|
|
/// The number of spaces used for indenting nested operations.
|
|
const static unsigned indentWidth = 2;
|
|
|
|
// This is the current indentation level for nested structures.
|
|
unsigned currentIndent = 0;
|
|
};
|
|
} // namespace
|
|
|
|
void OperationPrinter::printTopLevelOperation(Operation *op) {
|
|
// Output the aliases at the top level that can't be deferred.
|
|
state.getAliasState().printNonDeferredAliases(*this, newLine);
|
|
|
|
// Print the module.
|
|
printFullOpWithIndentAndLoc(op);
|
|
os << newLine;
|
|
|
|
// Output the aliases at the top level that can be deferred.
|
|
state.getAliasState().printDeferredAliases(*this, newLine);
|
|
|
|
// Output any file level metadata.
|
|
printFileMetadataDictionary(op);
|
|
}
|
|
|
|
void OperationPrinter::printFileMetadataDictionary(Operation *op) {
|
|
bool sawMetadataEntry = false;
|
|
auto checkAddMetadataDict = [&] {
|
|
if (!std::exchange(sawMetadataEntry, true))
|
|
os << newLine << "{-#" << newLine;
|
|
};
|
|
|
|
// Add the various types of metadata.
|
|
printResourceFileMetadata(checkAddMetadataDict, op);
|
|
|
|
// If the file dictionary exists, close it.
|
|
if (sawMetadataEntry)
|
|
os << newLine << "#-}" << newLine;
|
|
}
|
|
|
|
void OperationPrinter::printResourceFileMetadata(
|
|
function_ref<void()> checkAddMetadataDict, Operation *op) {
|
|
// Functor used to add data entries to the file metadata dictionary.
|
|
bool hadResource = false;
|
|
bool needResourceComma = false;
|
|
bool needEntryComma = false;
|
|
auto processProvider = [&](StringRef dictName, StringRef name, auto &provider,
|
|
auto &&...providerArgs) {
|
|
bool hadEntry = false;
|
|
auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
|
|
checkAddMetadataDict();
|
|
|
|
auto printFormatting = [&]() {
|
|
// Emit the top-level resource entry if we haven't yet.
|
|
if (!std::exchange(hadResource, true)) {
|
|
if (needResourceComma)
|
|
os << "," << newLine;
|
|
os << " " << dictName << "_resources: {" << newLine;
|
|
}
|
|
// Emit the parent resource entry if we haven't yet.
|
|
if (!std::exchange(hadEntry, true)) {
|
|
if (needEntryComma)
|
|
os << "," << newLine;
|
|
os << " " << name << ": {" << newLine;
|
|
} else {
|
|
os << "," << newLine;
|
|
}
|
|
};
|
|
|
|
std::optional<uint64_t> charLimit =
|
|
printerFlags.getLargeResourceStringLimit();
|
|
if (charLimit.has_value()) {
|
|
std::string resourceStr;
|
|
llvm::raw_string_ostream ss(resourceStr);
|
|
valueFn(ss);
|
|
|
|
// Only print entry if it's string is small enough
|
|
if (resourceStr.size() > charLimit.value())
|
|
return;
|
|
|
|
printFormatting();
|
|
os << " " << key << ": " << resourceStr;
|
|
} else {
|
|
printFormatting();
|
|
os << " " << key << ": ";
|
|
valueFn(os);
|
|
}
|
|
};
|
|
ResourceBuilder entryBuilder(printFn);
|
|
provider.buildResources(op, providerArgs..., entryBuilder);
|
|
|
|
needEntryComma |= hadEntry;
|
|
if (hadEntry)
|
|
os << newLine << " }";
|
|
};
|
|
|
|
// Print the `dialect_resources` section if we have any dialects with
|
|
// resources.
|
|
for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) {
|
|
auto &dialectResources = state.getDialectResources();
|
|
StringRef name = interface.getDialect()->getNamespace();
|
|
auto it = dialectResources.find(interface.getDialect());
|
|
if (it != dialectResources.end())
|
|
processProvider("dialect", name, interface, it->second);
|
|
else
|
|
processProvider("dialect", name, interface,
|
|
SetVector<AsmDialectResourceHandle>());
|
|
}
|
|
if (hadResource)
|
|
os << newLine << " }";
|
|
|
|
// Print the `external_resources` section if we have any external clients with
|
|
// resources.
|
|
needEntryComma = false;
|
|
needResourceComma = hadResource;
|
|
hadResource = false;
|
|
for (const auto &printer : state.getResourcePrinters())
|
|
processProvider("external", printer.getName(), printer);
|
|
if (hadResource)
|
|
os << newLine << " }";
|
|
}
|
|
|
|
/// Print a block argument in the usual format of:
|
|
/// %ssaName : type {attr1=42} loc("here")
|
|
/// where location printing is controlled by the standard internal option.
|
|
/// You may pass omitType=true to not print a type, and pass an empty
|
|
/// attribute list if you don't care for attributes.
|
|
void OperationPrinter::printRegionArgument(BlockArgument arg,
|
|
ArrayRef<NamedAttribute> argAttrs,
|
|
bool omitType) {
|
|
printOperand(arg);
|
|
if (!omitType) {
|
|
os << ": ";
|
|
printType(arg.getType());
|
|
}
|
|
printOptionalAttrDict(argAttrs);
|
|
// TODO: We should allow location aliases on block arguments.
|
|
printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
|
|
}
|
|
|
|
void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) {
|
|
// Track the location of this operation.
|
|
state.registerOperationLocation(op, newLine.curLine, currentIndent);
|
|
|
|
os.indent(currentIndent);
|
|
printFullOp(op);
|
|
printTrailingLocation(op->getLoc());
|
|
if (printerFlags.shouldPrintValueUsers())
|
|
printUsersComment(op);
|
|
}
|
|
|
|
void OperationPrinter::printFullOp(Operation *op) {
|
|
if (size_t numResults = op->getNumResults()) {
|
|
auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
|
|
printValueID(op->getResult(resultNo), /*printResultNo=*/false);
|
|
if (resultCount > 1)
|
|
os << ':' << resultCount;
|
|
};
|
|
|
|
// Check to see if this operation has multiple result groups.
|
|
ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op);
|
|
if (!resultGroups.empty()) {
|
|
// Interleave the groups excluding the last one, this one will be handled
|
|
// separately.
|
|
interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
|
|
printResultGroup(resultGroups[i],
|
|
resultGroups[i + 1] - resultGroups[i]);
|
|
});
|
|
os << ", ";
|
|
printResultGroup(resultGroups.back(), numResults - resultGroups.back());
|
|
|
|
} else {
|
|
printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
|
|
}
|
|
|
|
os << " = ";
|
|
}
|
|
|
|
printCustomOrGenericOp(op);
|
|
}
|
|
|
|
void OperationPrinter::printUsersComment(Operation *op) {
|
|
unsigned numResults = op->getNumResults();
|
|
if (!numResults && op->getNumOperands()) {
|
|
os << " // id: ";
|
|
printOperationID(op);
|
|
} else if (numResults && op->use_empty()) {
|
|
os << " // unused";
|
|
} else if (numResults && !op->use_empty()) {
|
|
// Print "user" if the operation has one result used to compute one other
|
|
// result, or is used in one operation with no result.
|
|
unsigned usedInNResults = 0;
|
|
unsigned usedInNOperations = 0;
|
|
SmallPtrSet<Operation *, 1> userSet;
|
|
for (Operation *user : op->getUsers()) {
|
|
if (userSet.insert(user).second) {
|
|
++usedInNOperations;
|
|
usedInNResults += user->getNumResults();
|
|
}
|
|
}
|
|
|
|
// We already know that users is not empty.
|
|
bool exactlyOneUniqueUse =
|
|
usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
|
|
os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
|
|
bool shouldPrintBrackets = numResults > 1;
|
|
auto printOpResult = [&](OpResult opResult) {
|
|
if (shouldPrintBrackets)
|
|
os << "(";
|
|
printValueUsers(opResult);
|
|
if (shouldPrintBrackets)
|
|
os << ")";
|
|
};
|
|
|
|
interleaveComma(op->getResults(), printOpResult);
|
|
}
|
|
}
|
|
|
|
void OperationPrinter::printUsersComment(BlockArgument arg) {
|
|
os << "// ";
|
|
printValueID(arg);
|
|
if (arg.use_empty()) {
|
|
os << " is unused";
|
|
} else {
|
|
os << " is used by ";
|
|
printValueUsers(arg);
|
|
}
|
|
os << newLine;
|
|
}
|
|
|
|
void OperationPrinter::printValueUsers(Value value) {
|
|
if (value.use_empty())
|
|
os << "unused";
|
|
|
|
// One value might be used as the operand of an operation more than once.
|
|
// Only print the operations results once in that case.
|
|
SmallPtrSet<Operation *, 1> userSet;
|
|
for (auto [index, user] : enumerate(value.getUsers())) {
|
|
if (userSet.insert(user).second)
|
|
printUserIDs(user, index);
|
|
}
|
|
}
|
|
|
|
void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
|
|
if (prefixComma)
|
|
os << ", ";
|
|
|
|
if (!user->getNumResults()) {
|
|
printOperationID(user);
|
|
} else {
|
|
interleaveComma(user->getResults(),
|
|
[this](Value result) { printValueID(result); });
|
|
}
|
|
}
|
|
|
|
void OperationPrinter::printCustomOrGenericOp(Operation *op) {
|
|
// If requested, always print the generic form.
|
|
if (!printerFlags.shouldPrintGenericOpForm()) {
|
|
// Check to see if this is a known operation. If so, use the registered
|
|
// custom printer hook.
|
|
if (auto opInfo = op->getRegisteredInfo()) {
|
|
opInfo->printAssembly(op, *this, defaultDialectStack.back());
|
|
return;
|
|
}
|
|
// Otherwise try to dispatch to the dialect, if available.
|
|
if (Dialect *dialect = op->getDialect()) {
|
|
if (auto opPrinter = dialect->getOperationPrinter(op)) {
|
|
// Print the op name first.
|
|
StringRef name = op->getName().getStringRef();
|
|
// Only drop the default dialect prefix when it cannot lead to
|
|
// ambiguities.
|
|
if (name.count('.') == 1)
|
|
name.consume_front((defaultDialectStack.back() + ".").str());
|
|
os << name;
|
|
|
|
// Print the rest of the op now.
|
|
opPrinter(op, *this);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Otherwise print with the generic assembly form.
|
|
printGenericOp(op, /*printOpName=*/true);
|
|
}
|
|
|
|
void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
|
|
if (printOpName)
|
|
printEscapedString(op->getName().getStringRef());
|
|
os << '(';
|
|
interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
|
|
os << ')';
|
|
|
|
// For terminators, print the list of successors and their operands.
|
|
if (op->getNumSuccessors() != 0) {
|
|
os << '[';
|
|
interleaveComma(op->getSuccessors(),
|
|
[&](Block *successor) { printBlockName(successor); });
|
|
os << ']';
|
|
}
|
|
|
|
// Print the properties.
|
|
if (Attribute prop = op->getPropertiesAsAttribute()) {
|
|
os << " <";
|
|
Impl::printAttribute(prop);
|
|
os << '>';
|
|
}
|
|
|
|
// Print regions.
|
|
if (op->getNumRegions() != 0) {
|
|
os << " (";
|
|
interleaveComma(op->getRegions(), [&](Region ®ion) {
|
|
printRegion(region, /*printEntryBlockArgs=*/true,
|
|
/*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
|
|
});
|
|
os << ')';
|
|
}
|
|
|
|
printOptionalAttrDict(op->getPropertiesStorage()
|
|
? llvm::to_vector(op->getDiscardableAttrs())
|
|
: op->getAttrs());
|
|
|
|
// Print the type signature of the operation.
|
|
os << " : ";
|
|
printFunctionalType(op);
|
|
}
|
|
|
|
void OperationPrinter::printBlockName(Block *block) {
|
|
os << state.getSSANameState().getBlockInfo(block).name;
|
|
}
|
|
|
|
void OperationPrinter::print(Block *block, bool printBlockArgs,
|
|
bool printBlockTerminator) {
|
|
// Print the block label and argument list if requested.
|
|
if (printBlockArgs) {
|
|
os.indent(currentIndent);
|
|
printBlockName(block);
|
|
|
|
// Print the argument list if non-empty.
|
|
if (!block->args_empty()) {
|
|
os << '(';
|
|
interleaveComma(block->getArguments(), [&](BlockArgument arg) {
|
|
printValueID(arg);
|
|
os << ": ";
|
|
printType(arg.getType());
|
|
// TODO: We should allow location aliases on block arguments.
|
|
printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
|
|
});
|
|
os << ')';
|
|
}
|
|
os << ':';
|
|
|
|
// Print out some context information about the predecessors of this block.
|
|
if (!block->getParent()) {
|
|
os << " // block is not in a region!";
|
|
} else if (block->hasNoPredecessors()) {
|
|
if (!block->isEntryBlock())
|
|
os << " // no predecessors";
|
|
} else if (auto *pred = block->getSinglePredecessor()) {
|
|
os << " // pred: ";
|
|
printBlockName(pred);
|
|
} else {
|
|
// We want to print the predecessors in a stable order, not in
|
|
// whatever order the use-list is in, so gather and sort them.
|
|
SmallVector<BlockInfo, 4> predIDs;
|
|
for (auto *pred : block->getPredecessors())
|
|
predIDs.push_back(state.getSSANameState().getBlockInfo(pred));
|
|
llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
|
|
return lhs.ordering < rhs.ordering;
|
|
});
|
|
|
|
os << " // " << predIDs.size() << " preds: ";
|
|
|
|
interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
|
|
}
|
|
os << newLine;
|
|
}
|
|
|
|
currentIndent += indentWidth;
|
|
|
|
if (printerFlags.shouldPrintValueUsers()) {
|
|
for (BlockArgument arg : block->getArguments()) {
|
|
os.indent(currentIndent);
|
|
printUsersComment(arg);
|
|
}
|
|
}
|
|
|
|
bool hasTerminator =
|
|
!block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
|
|
auto range = llvm::make_range(
|
|
block->begin(),
|
|
std::prev(block->end(),
|
|
(!hasTerminator || printBlockTerminator) ? 0 : 1));
|
|
for (auto &op : range) {
|
|
printFullOpWithIndentAndLoc(&op);
|
|
os << newLine;
|
|
}
|
|
currentIndent -= indentWidth;
|
|
}
|
|
|
|
void OperationPrinter::printValueID(Value value, bool printResultNo,
|
|
raw_ostream *streamOverride) const {
|
|
state.getSSANameState().printValueID(value, printResultNo,
|
|
streamOverride ? *streamOverride : os);
|
|
}
|
|
|
|
void OperationPrinter::printOperationID(Operation *op,
|
|
raw_ostream *streamOverride) const {
|
|
state.getSSANameState().printOperationID(op, streamOverride ? *streamOverride
|
|
: os);
|
|
}
|
|
|
|
void OperationPrinter::printSuccessor(Block *successor) {
|
|
printBlockName(successor);
|
|
}
|
|
|
|
void OperationPrinter::printSuccessorAndUseList(Block *successor,
|
|
ValueRange succOperands) {
|
|
printBlockName(successor);
|
|
if (succOperands.empty())
|
|
return;
|
|
|
|
os << '(';
|
|
interleaveComma(succOperands,
|
|
[this](Value operand) { printValueID(operand); });
|
|
os << " : ";
|
|
interleaveComma(succOperands,
|
|
[this](Value operand) { printType(operand.getType()); });
|
|
os << ')';
|
|
}
|
|
|
|
void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs,
|
|
bool printBlockTerminators,
|
|
bool printEmptyBlock) {
|
|
if (printerFlags.shouldSkipRegions()) {
|
|
os << "{...}";
|
|
return;
|
|
}
|
|
os << "{" << newLine;
|
|
if (!region.empty()) {
|
|
auto restoreDefaultDialect =
|
|
llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); });
|
|
if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
|
|
defaultDialectStack.push_back(iface.getDefaultDialect());
|
|
else
|
|
defaultDialectStack.push_back("");
|
|
|
|
auto *entryBlock = ®ion.front();
|
|
// Force printing the block header if printEmptyBlock is set and the block
|
|
// is empty or if printEntryBlockArgs is set and there are arguments to
|
|
// print.
|
|
bool shouldAlwaysPrintBlockHeader =
|
|
(printEmptyBlock && entryBlock->empty()) ||
|
|
(printEntryBlockArgs && entryBlock->getNumArguments() != 0);
|
|
print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
|
|
for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
|
|
print(&b);
|
|
}
|
|
os.indent(currentIndent) << "}";
|
|
}
|
|
|
|
void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
|
|
ValueRange operands) {
|
|
if (!mapAttr) {
|
|
os << "<<NULL AFFINE MAP>>";
|
|
return;
|
|
}
|
|
AffineMap map = mapAttr.getValue();
|
|
unsigned numDims = map.getNumDims();
|
|
auto printValueName = [&](unsigned pos, bool isSymbol) {
|
|
unsigned index = isSymbol ? numDims + pos : pos;
|
|
assert(index < operands.size());
|
|
if (isSymbol)
|
|
os << "symbol(";
|
|
printValueID(operands[index]);
|
|
if (isSymbol)
|
|
os << ')';
|
|
};
|
|
|
|
interleaveComma(map.getResults(), [&](AffineExpr expr) {
|
|
printAffineExpr(expr, printValueName);
|
|
});
|
|
}
|
|
|
|
void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
|
|
ValueRange dimOperands,
|
|
ValueRange symOperands) {
|
|
auto printValueName = [&](unsigned pos, bool isSymbol) {
|
|
if (!isSymbol)
|
|
return printValueID(dimOperands[pos]);
|
|
os << "symbol(";
|
|
printValueID(symOperands[pos]);
|
|
os << ')';
|
|
};
|
|
printAffineExpr(expr, printValueName);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// print and dump methods
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void Attribute::print(raw_ostream &os, bool elideType) const {
|
|
if (!*this) {
|
|
os << "<<NULL ATTRIBUTE>>";
|
|
return;
|
|
}
|
|
|
|
AsmState state(getContext());
|
|
print(os, state, elideType);
|
|
}
|
|
void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const {
|
|
using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision;
|
|
AsmPrinter::Impl(os, state.getImpl())
|
|
.printAttribute(*this, elideType ? AttrTypeElision::Must
|
|
: AttrTypeElision::Never);
|
|
}
|
|
|
|
void Attribute::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void Attribute::printStripped(raw_ostream &os, AsmState &state) const {
|
|
if (!*this) {
|
|
os << "<<NULL ATTRIBUTE>>";
|
|
return;
|
|
}
|
|
|
|
AsmPrinter::Impl subPrinter(os, state.getImpl());
|
|
if (succeeded(subPrinter.printAlias(*this)))
|
|
return;
|
|
|
|
auto &dialect = this->getDialect();
|
|
uint64_t posPrior = os.tell();
|
|
DialectAsmPrinter printer(subPrinter);
|
|
dialect.printAttribute(*this, printer);
|
|
if (posPrior != os.tell())
|
|
return;
|
|
|
|
// Fallback to printing with prefix if the above failed to write anything
|
|
// to the output stream.
|
|
print(os, state);
|
|
}
|
|
void Attribute::printStripped(raw_ostream &os) const {
|
|
if (!*this) {
|
|
os << "<<NULL ATTRIBUTE>>";
|
|
return;
|
|
}
|
|
|
|
AsmState state(getContext());
|
|
printStripped(os, state);
|
|
}
|
|
|
|
void Type::print(raw_ostream &os) const {
|
|
if (!*this) {
|
|
os << "<<NULL TYPE>>";
|
|
return;
|
|
}
|
|
|
|
AsmState state(getContext());
|
|
print(os, state);
|
|
}
|
|
void Type::print(raw_ostream &os, AsmState &state) const {
|
|
AsmPrinter::Impl(os, state.getImpl()).printType(*this);
|
|
}
|
|
|
|
void Type::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void AffineMap::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void IntegerSet::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void AffineExpr::print(raw_ostream &os) const {
|
|
if (!expr) {
|
|
os << "<<NULL AFFINE EXPR>>";
|
|
return;
|
|
}
|
|
AsmState state(getContext());
|
|
AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(*this);
|
|
}
|
|
|
|
void AffineExpr::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void AffineMap::print(raw_ostream &os) const {
|
|
if (!map) {
|
|
os << "<<NULL AFFINE MAP>>";
|
|
return;
|
|
}
|
|
AsmState state(getContext());
|
|
AsmPrinter::Impl(os, state.getImpl()).printAffineMap(*this);
|
|
}
|
|
|
|
void IntegerSet::print(raw_ostream &os) const {
|
|
AsmState state(getContext());
|
|
AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(*this);
|
|
}
|
|
|
|
void Value::print(raw_ostream &os) const { print(os, OpPrintingFlags()); }
|
|
void Value::print(raw_ostream &os, const OpPrintingFlags &flags) const {
|
|
if (!impl) {
|
|
os << "<<NULL VALUE>>";
|
|
return;
|
|
}
|
|
|
|
if (auto *op = getDefiningOp())
|
|
return op->print(os, flags);
|
|
// TODO: Improve BlockArgument print'ing.
|
|
BlockArgument arg = llvm::cast<BlockArgument>(*this);
|
|
os << "<block argument> of type '" << arg.getType()
|
|
<< "' at index: " << arg.getArgNumber();
|
|
}
|
|
void Value::print(raw_ostream &os, AsmState &state) const {
|
|
if (!impl) {
|
|
os << "<<NULL VALUE>>";
|
|
return;
|
|
}
|
|
|
|
if (auto *op = getDefiningOp())
|
|
return op->print(os, state);
|
|
|
|
// TODO: Improve BlockArgument print'ing.
|
|
BlockArgument arg = llvm::cast<BlockArgument>(*this);
|
|
os << "<block argument> of type '" << arg.getType()
|
|
<< "' at index: " << arg.getArgNumber();
|
|
}
|
|
|
|
void Value::dump() const {
|
|
print(llvm::errs());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void Value::printAsOperand(raw_ostream &os, AsmState &state) const {
|
|
// TODO: This doesn't necessarily capture all potential cases.
|
|
// Currently, region arguments can be shadowed when printing the main
|
|
// operation. If the IR hasn't been printed, this will produce the old SSA
|
|
// name and not the shadowed name.
|
|
state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
|
|
os);
|
|
}
|
|
|
|
static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
|
|
do {
|
|
// If we are printing local scope, stop at the first operation that is
|
|
// isolated from above.
|
|
if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
|
|
break;
|
|
|
|
// Otherwise, traverse up to the next parent.
|
|
Operation *parentOp = op->getParentOp();
|
|
if (!parentOp)
|
|
break;
|
|
op = parentOp;
|
|
} while (true);
|
|
return op;
|
|
}
|
|
|
|
void Value::printAsOperand(raw_ostream &os,
|
|
const OpPrintingFlags &flags) const {
|
|
Operation *op;
|
|
if (auto result = llvm::dyn_cast<OpResult>(*this)) {
|
|
op = result.getOwner();
|
|
} else {
|
|
op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp();
|
|
if (!op) {
|
|
os << "<<UNKNOWN SSA VALUE>>";
|
|
return;
|
|
}
|
|
}
|
|
op = findParent(op, flags.shouldUseLocalScope());
|
|
AsmState state(op, flags);
|
|
printAsOperand(os, state);
|
|
}
|
|
|
|
void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
|
|
// Find the operation to number from based upon the provided flags.
|
|
Operation *op = findParent(this, printerFlags.shouldUseLocalScope());
|
|
AsmState state(op, printerFlags);
|
|
print(os, state);
|
|
}
|
|
void Operation::print(raw_ostream &os, AsmState &state) {
|
|
OperationPrinter printer(os, state.getImpl());
|
|
if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) {
|
|
state.getImpl().initializeAliases(this);
|
|
printer.printTopLevelOperation(this);
|
|
} else {
|
|
printer.printFullOpWithIndentAndLoc(this);
|
|
}
|
|
}
|
|
|
|
void Operation::dump() {
|
|
print(llvm::errs(), OpPrintingFlags().useLocalScope());
|
|
llvm::errs() << "\n";
|
|
}
|
|
|
|
void Block::print(raw_ostream &os) {
|
|
Operation *parentOp = getParentOp();
|
|
if (!parentOp) {
|
|
os << "<<UNLINKED BLOCK>>\n";
|
|
return;
|
|
}
|
|
// Get the top-level op.
|
|
while (auto *nextOp = parentOp->getParentOp())
|
|
parentOp = nextOp;
|
|
|
|
AsmState state(parentOp);
|
|
print(os, state);
|
|
}
|
|
void Block::print(raw_ostream &os, AsmState &state) {
|
|
OperationPrinter(os, state.getImpl()).print(this);
|
|
}
|
|
|
|
void Block::dump() { print(llvm::errs()); }
|
|
|
|
/// Print out the name of the block without printing its body.
|
|
void Block::printAsOperand(raw_ostream &os, bool printType) {
|
|
Operation *parentOp = getParentOp();
|
|
if (!parentOp) {
|
|
os << "<<UNLINKED BLOCK>>\n";
|
|
return;
|
|
}
|
|
AsmState state(parentOp);
|
|
printAsOperand(os, state);
|
|
}
|
|
void Block::printAsOperand(raw_ostream &os, AsmState &state) {
|
|
OperationPrinter printer(os, state.getImpl());
|
|
printer.printBlockName(this);
|
|
}
|
|
|
|
raw_ostream &mlir::operator<<(raw_ostream &os, Block &block) {
|
|
block.print(os);
|
|
return os;
|
|
}
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Custom printers
|
|
//===--------------------------------------------------------------------===//
|
|
namespace mlir {
|
|
|
|
void printDimensionList(OpAsmPrinter &printer, Operation *op,
|
|
ArrayRef<int64_t> dimensions) {
|
|
if (dimensions.empty())
|
|
printer << "[";
|
|
printer.printDimensionList(dimensions);
|
|
if (dimensions.empty())
|
|
printer << "]";
|
|
}
|
|
|
|
ParseResult parseDimensionList(OpAsmParser &parser,
|
|
DenseI64ArrayAttr &dimensions) {
|
|
// Empty list case denoted by "[]".
|
|
if (succeeded(parser.parseOptionalLSquare())) {
|
|
if (failed(parser.parseRSquare())) {
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "Failed parsing dimension list.";
|
|
}
|
|
dimensions =
|
|
DenseI64ArrayAttr::get(parser.getContext(), ArrayRef<int64_t>());
|
|
return success();
|
|
}
|
|
|
|
// Non-empty list case.
|
|
SmallVector<int64_t> shapeArr;
|
|
if (failed(parser.parseDimensionList(shapeArr, true, false))) {
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "Failed parsing dimension list.";
|
|
}
|
|
if (shapeArr.empty()) {
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "Failed parsing dimension list. Did you mean an empty list? It "
|
|
"must be denoted by \"[]\".";
|
|
}
|
|
dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
|
|
return success();
|
|
}
|
|
|
|
} // namespace mlir
|