mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 17:26:06 +00:00
[mlir] Refactor ModuleState into AsmState and expose it to users.
Summary: This allows for users to cache printer state, which can be costly to recompute. Each of the IR print methods gain a new overload taking this new state class. Depends On D72293 Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D72294
This commit is contained in:
parent
23058f9dd4
commit
fa9dd8336b
52
mlir/include/mlir/IR/AsmState.h
Normal file
52
mlir/include/mlir/IR/AsmState.h
Normal file
@ -0,0 +1,52 @@
|
||||
//===- AsmState.h - State class for AsmPrinter ------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the AsmState class.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_ASMSTATE_H_
|
||||
#define MLIR_IR_ASMSTATE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class Operation;
|
||||
|
||||
namespace detail {
|
||||
class AsmStateImpl;
|
||||
} // end namespace detail
|
||||
|
||||
/// This class provides management for the lifetime of the state used when
|
||||
/// printing the IR. It allows for alleviating the cost of recomputing the
|
||||
/// internal state of the asm printer.
|
||||
///
|
||||
/// The IR should not be mutated in-between invocations using this state, and
|
||||
/// the IR being printed must not be an parent of the IR originally used to
|
||||
/// initialize this state. This means that if a child operation is provided, a
|
||||
/// parent operation cannot reuse this state.
|
||||
class AsmState {
|
||||
public:
|
||||
/// Initialize the asm state at the level of the given operation.
|
||||
AsmState(Operation *op);
|
||||
~AsmState();
|
||||
|
||||
/// Return an instance of the internal implementation. Returns nullptr if the
|
||||
/// state has not been initialized.
|
||||
detail::AsmStateImpl &getImpl() { return *impl; }
|
||||
|
||||
private:
|
||||
AsmState() = delete;
|
||||
|
||||
/// A pointer to allocated storage for the impl state.
|
||||
std::unique_ptr<detail::AsmStateImpl> impl;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_ASMSTATE_H_
|
@ -312,12 +312,14 @@ public:
|
||||
}
|
||||
|
||||
void print(raw_ostream &os);
|
||||
void print(raw_ostream &os, AsmState &state);
|
||||
void dump();
|
||||
|
||||
/// Print out the name of the block without printing its body.
|
||||
/// NOTE: The printType argument is ignored. We keep it for compatibility
|
||||
/// with LLVM dominator machinery that expects it to exist.
|
||||
void printAsOperand(raw_ostream &os, bool printType = true);
|
||||
void printAsOperand(raw_ostream &os, AsmState &state);
|
||||
|
||||
private:
|
||||
/// Pair of the parent object that owns this block and a bit that signifies if
|
||||
|
@ -57,6 +57,8 @@ public:
|
||||
|
||||
/// Print the this module in the custom top-level form.
|
||||
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None);
|
||||
void print(raw_ostream &os, AsmState &state,
|
||||
OpPrintingFlags flags = llvm::None);
|
||||
void dump();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -122,6 +122,10 @@ public:
|
||||
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
|
||||
state->print(os, flags);
|
||||
}
|
||||
void print(raw_ostream &os, AsmState &asmState,
|
||||
OpPrintingFlags flags = llvm::None) {
|
||||
state->print(os, asmState, flags);
|
||||
}
|
||||
|
||||
/// Dump this operation.
|
||||
void dump() { state->dump(); }
|
||||
|
@ -187,6 +187,8 @@ public:
|
||||
bool isBeforeInBlock(Operation *other);
|
||||
|
||||
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None);
|
||||
void print(raw_ostream &os, AsmState &state,
|
||||
OpPrintingFlags flags = llvm::None);
|
||||
void dump();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
class AsmState;
|
||||
class BlockArgument;
|
||||
class Operation;
|
||||
class OpResult;
|
||||
@ -172,8 +173,12 @@ public:
|
||||
Kind getKind() const { return ownerAndKind.getInt(); }
|
||||
|
||||
void print(raw_ostream &os);
|
||||
void print(raw_ostream &os, AsmState &state);
|
||||
void dump();
|
||||
|
||||
/// Print this value as if it were an operand.
|
||||
void printAsOperand(raw_ostream &os, AsmState &state);
|
||||
|
||||
/// Methods for supporting PointerLikeTypeTraits.
|
||||
void *getAsOpaquePointer() const { return ownerAndKind.getOpaqueValue(); }
|
||||
static Value getFromOpaquePointer(const void *pointer) {
|
||||
|
@ -13,6 +13,7 @@
|
||||
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
@ -37,6 +38,7 @@
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/SaveAndRestore.h"
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
void Identifier::print(raw_ostream &os) const { os << str(); }
|
||||
|
||||
@ -756,13 +758,14 @@ StringRef SSANameState::uniqueValueName(StringRef name) {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModuleState
|
||||
// AsmState
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class ModuleState {
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
class AsmStateImpl {
|
||||
public:
|
||||
explicit ModuleState(Operation *op)
|
||||
explicit AsmStateImpl(Operation *op)
|
||||
: interfaces(op->getContext()), nameState(op, interfaces) {}
|
||||
|
||||
/// Initialize the alias state to enable the printing of aliases.
|
||||
@ -792,7 +795,11 @@ private:
|
||||
/// The state used for SSA value names.
|
||||
SSANameState nameState;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
} // end namespace detail
|
||||
} // end namespace mlir
|
||||
|
||||
AsmState::AsmState(Operation *op) : impl(std::make_unique<AsmStateImpl>(op)) {}
|
||||
AsmState::~AsmState() {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModulePrinter
|
||||
@ -802,7 +809,7 @@ namespace {
|
||||
class ModulePrinter {
|
||||
public:
|
||||
ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
|
||||
ModuleState *state = nullptr)
|
||||
AsmStateImpl *state = nullptr)
|
||||
: os(os), printerFlags(flags), state(state) {}
|
||||
explicit ModulePrinter(ModulePrinter &printer)
|
||||
: os(printer.os), printerFlags(printer.printerFlags),
|
||||
@ -816,8 +823,6 @@ public:
|
||||
mlir::interleaveComma(c, os, each_fn);
|
||||
}
|
||||
|
||||
void print(ModuleOp module);
|
||||
|
||||
/// Print the given attribute. If 'mayElideType' is true, some attributes are
|
||||
/// printed without the type when the type matches the default used in the
|
||||
/// parser (for example i64 is the default for integer attributes).
|
||||
@ -862,7 +867,7 @@ protected:
|
||||
OpPrintingFlags printerFlags;
|
||||
|
||||
/// An optional printer state for the module.
|
||||
ModuleState *state;
|
||||
AsmStateImpl *state;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
@ -1815,10 +1820,12 @@ namespace {
|
||||
/// This class contains the logic for printing operations, regions, and blocks.
|
||||
class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
|
||||
public:
|
||||
explicit OperationPrinter(ModulePrinter &other) : ModulePrinter(other) {
|
||||
assert(state && "expected valid state when printing operation");
|
||||
}
|
||||
explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
|
||||
AsmStateImpl &state)
|
||||
: ModulePrinter(os, flags, &state) {}
|
||||
|
||||
/// Print the given top-level module.
|
||||
void print(ModuleOp op);
|
||||
/// Print the given operation with its indent and location.
|
||||
void print(Operation *op);
|
||||
/// Print the bare location, not including indentation/location/etc.
|
||||
@ -1903,6 +1910,15 @@ private:
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void OperationPrinter::print(ModuleOp op) {
|
||||
// Output the aliases at the top level.
|
||||
state->getAliasState().printAttributeAliases(os);
|
||||
state->getAliasState().printTypeAliases(os);
|
||||
|
||||
// Print the module.
|
||||
print(op.getOperation());
|
||||
}
|
||||
|
||||
void OperationPrinter::print(Operation *op) {
|
||||
os.indent(currentIndent);
|
||||
printOperation(op);
|
||||
@ -2108,18 +2124,6 @@ void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
|
||||
});
|
||||
}
|
||||
|
||||
void ModulePrinter::print(ModuleOp module) {
|
||||
assert(state && "expected valid state when printing an operation");
|
||||
|
||||
// Output the aliases at the top level.
|
||||
state->getAliasState().printAttributeAliases(os);
|
||||
state->getAliasState().printTypeAliases(os);
|
||||
|
||||
// Print the module.
|
||||
OperationPrinter(*this).print(module);
|
||||
os << '\n';
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// print and dump methods
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2179,18 +2183,34 @@ void Value::print(raw_ostream &os) {
|
||||
assert(isa<BlockArgument>());
|
||||
os << "<block argument>\n";
|
||||
}
|
||||
void Value::print(raw_ostream &os, AsmState &state) {
|
||||
if (auto *op = getDefiningOp())
|
||||
return op->print(os, state);
|
||||
|
||||
// TODO: Improve this.
|
||||
assert(isa<BlockArgument>());
|
||||
os << "<block argument>\n";
|
||||
}
|
||||
|
||||
void Value::dump() {
|
||||
print(llvm::errs());
|
||||
llvm::errs() << "\n";
|
||||
}
|
||||
|
||||
void Value::printAsOperand(raw_ostream &os, AsmState &state) {
|
||||
// TODO(riverriddle) This doesn't necessarily capture all potential cases.
|
||||
// Currently, region arguments can be shadowed when printing the main
|
||||
// operation. If the IR hasn't been printed, this will produce the old SSA
|
||||
// name and not the shadowed name.
|
||||
state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
|
||||
os);
|
||||
}
|
||||
|
||||
void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
|
||||
// Handle top-level operations or local printing.
|
||||
if (!getParent() || flags.shouldUseLocalScope()) {
|
||||
ModuleState state(this);
|
||||
ModulePrinter modulePrinter(os, flags, &state);
|
||||
OperationPrinter(modulePrinter).print(this);
|
||||
AsmState state(this);
|
||||
OperationPrinter(os, flags, state.getImpl()).print(this);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -2203,9 +2223,11 @@ void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
|
||||
while (auto *nextOp = parentOp->getParentOp())
|
||||
parentOp = nextOp;
|
||||
|
||||
ModuleState state(parentOp);
|
||||
ModulePrinter modulePrinter(os, flags, &state);
|
||||
OperationPrinter(modulePrinter).print(this);
|
||||
AsmState state(parentOp);
|
||||
print(os, state, flags);
|
||||
}
|
||||
void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
|
||||
OperationPrinter(os, flags, state.getImpl()).print(this);
|
||||
}
|
||||
|
||||
void Operation::dump() {
|
||||
@ -2223,9 +2245,11 @@ void Block::print(raw_ostream &os) {
|
||||
while (auto *nextOp = parentOp->getParentOp())
|
||||
parentOp = nextOp;
|
||||
|
||||
ModuleState state(parentOp);
|
||||
ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state);
|
||||
OperationPrinter(modulePrinter).print(this);
|
||||
AsmState state(parentOp);
|
||||
print(os, state);
|
||||
}
|
||||
void Block::print(raw_ostream &os, AsmState &state) {
|
||||
OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
|
||||
}
|
||||
|
||||
void Block::dump() { print(llvm::errs()); }
|
||||
@ -2241,18 +2265,24 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
|
||||
while (auto *nextOp = parentOp->getParentOp())
|
||||
parentOp = nextOp;
|
||||
|
||||
ModuleState state(parentOp);
|
||||
ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state);
|
||||
OperationPrinter(modulePrinter).printBlockName(this);
|
||||
AsmState state(parentOp);
|
||||
printAsOperand(os, state);
|
||||
}
|
||||
void Block::printAsOperand(raw_ostream &os, AsmState &state) {
|
||||
OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
|
||||
printer.printBlockName(this);
|
||||
}
|
||||
|
||||
void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
|
||||
ModuleState state(*this);
|
||||
AsmState state(*this);
|
||||
|
||||
// Don't populate aliases when printing at local scope.
|
||||
if (!flags.shouldUseLocalScope())
|
||||
state.initializeAliases(*this);
|
||||
ModulePrinter(os, flags, &state).print(*this);
|
||||
state.getImpl().initializeAliases(*this);
|
||||
print(os, state, flags);
|
||||
}
|
||||
void ModuleOp::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
|
||||
OperationPrinter(os, flags, state.getImpl()).print(*this);
|
||||
}
|
||||
|
||||
void ModuleOp::dump() { print(llvm::errs()); }
|
||||
|
Loading…
x
Reference in New Issue
Block a user