[mlir] Refactor ModuleState into AsmState and expose it to users.

Summary:
This allows for users to cache printer state, which can be costly to recompute. Each of the IR print methods gain a new overload taking this new state class.

Depends On D72293

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D72294
This commit is contained in:
River Riddle 2020-01-14 15:23:05 -08:00
parent 23058f9dd4
commit fa9dd8336b
7 changed files with 136 additions and 39 deletions

View File

@ -0,0 +1,52 @@
//===- AsmState.h - State class for AsmPrinter ------------------*- C++ -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the AsmState class.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_ASMSTATE_H_
#define MLIR_IR_ASMSTATE_H_
#include <memory>
namespace mlir {
class Operation;
namespace detail {
class AsmStateImpl;
} // end namespace detail
/// This class provides management for the lifetime of the state used when
/// printing the IR. It allows for alleviating the cost of recomputing the
/// internal state of the asm printer.
///
/// The IR should not be mutated in-between invocations using this state, and
/// the IR being printed must not be an parent of the IR originally used to
/// initialize this state. This means that if a child operation is provided, a
/// parent operation cannot reuse this state.
class AsmState {
public:
/// Initialize the asm state at the level of the given operation.
AsmState(Operation *op);
~AsmState();
/// Return an instance of the internal implementation. Returns nullptr if the
/// state has not been initialized.
detail::AsmStateImpl &getImpl() { return *impl; }
private:
AsmState() = delete;
/// A pointer to allocated storage for the impl state.
std::unique_ptr<detail::AsmStateImpl> impl;
};
} // end namespace mlir
#endif // MLIR_IR_ASMSTATE_H_

View File

@ -312,12 +312,14 @@ public:
}
void print(raw_ostream &os);
void print(raw_ostream &os, AsmState &state);
void dump();
/// Print out the name of the block without printing its body.
/// NOTE: The printType argument is ignored. We keep it for compatibility
/// with LLVM dominator machinery that expects it to exist.
void printAsOperand(raw_ostream &os, bool printType = true);
void printAsOperand(raw_ostream &os, AsmState &state);
private:
/// Pair of the parent object that owns this block and a bit that signifies if

View File

@ -57,6 +57,8 @@ public:
/// Print the this module in the custom top-level form.
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None);
void print(raw_ostream &os, AsmState &state,
OpPrintingFlags flags = llvm::None);
void dump();
//===--------------------------------------------------------------------===//

View File

@ -122,6 +122,10 @@ public:
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
state->print(os, flags);
}
void print(raw_ostream &os, AsmState &asmState,
OpPrintingFlags flags = llvm::None) {
state->print(os, asmState, flags);
}
/// Dump this operation.
void dump() { state->dump(); }

View File

@ -187,6 +187,8 @@ public:
bool isBeforeInBlock(Operation *other);
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None);
void print(raw_ostream &os, AsmState &state,
OpPrintingFlags flags = llvm::None);
void dump();
//===--------------------------------------------------------------------===//

View File

@ -18,6 +18,7 @@
#include "mlir/Support/LLVM.h"
namespace mlir {
class AsmState;
class BlockArgument;
class Operation;
class OpResult;
@ -172,8 +173,12 @@ public:
Kind getKind() const { return ownerAndKind.getInt(); }
void print(raw_ostream &os);
void print(raw_ostream &os, AsmState &state);
void dump();
/// Print this value as if it were an operand.
void printAsOperand(raw_ostream &os, AsmState &state);
/// Methods for supporting PointerLikeTypeTraits.
void *getAsOpaquePointer() const { return ownerAndKind.getOpaqueValue(); }
static Value getFromOpaquePointer(const void *pointer) {

View File

@ -13,6 +13,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
@ -37,6 +38,7 @@
#include "llvm/Support/Regex.h"
#include "llvm/Support/SaveAndRestore.h"
using namespace mlir;
using namespace mlir::detail;
void Identifier::print(raw_ostream &os) const { os << str(); }
@ -756,13 +758,14 @@ StringRef SSANameState::uniqueValueName(StringRef name) {
}
//===----------------------------------------------------------------------===//
// ModuleState
// AsmState
//===----------------------------------------------------------------------===//
namespace {
class ModuleState {
namespace mlir {
namespace detail {
class AsmStateImpl {
public:
explicit ModuleState(Operation *op)
explicit AsmStateImpl(Operation *op)
: interfaces(op->getContext()), nameState(op, interfaces) {}
/// Initialize the alias state to enable the printing of aliases.
@ -792,7 +795,11 @@ private:
/// The state used for SSA value names.
SSANameState nameState;
};
} // end anonymous namespace
} // end namespace detail
} // end namespace mlir
AsmState::AsmState(Operation *op) : impl(std::make_unique<AsmStateImpl>(op)) {}
AsmState::~AsmState() {}
//===----------------------------------------------------------------------===//
// ModulePrinter
@ -802,7 +809,7 @@ namespace {
class ModulePrinter {
public:
ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
ModuleState *state = nullptr)
AsmStateImpl *state = nullptr)
: os(os), printerFlags(flags), state(state) {}
explicit ModulePrinter(ModulePrinter &printer)
: os(printer.os), printerFlags(printer.printerFlags),
@ -816,8 +823,6 @@ public:
mlir::interleaveComma(c, os, each_fn);
}
void print(ModuleOp module);
/// Print the given attribute. If 'mayElideType' is true, some attributes are
/// printed without the type when the type matches the default used in the
/// parser (for example i64 is the default for integer attributes).
@ -862,7 +867,7 @@ protected:
OpPrintingFlags printerFlags;
/// An optional printer state for the module.
ModuleState *state;
AsmStateImpl *state;
};
} // end anonymous namespace
@ -1815,10 +1820,12 @@ namespace {
/// This class contains the logic for printing operations, regions, and blocks.
class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
public:
explicit OperationPrinter(ModulePrinter &other) : ModulePrinter(other) {
assert(state && "expected valid state when printing operation");
}
explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
AsmStateImpl &state)
: ModulePrinter(os, flags, &state) {}
/// Print the given top-level module.
void print(ModuleOp op);
/// Print the given operation with its indent and location.
void print(Operation *op);
/// Print the bare location, not including indentation/location/etc.
@ -1903,6 +1910,15 @@ private:
};
} // end anonymous namespace
void OperationPrinter::print(ModuleOp op) {
// Output the aliases at the top level.
state->getAliasState().printAttributeAliases(os);
state->getAliasState().printTypeAliases(os);
// Print the module.
print(op.getOperation());
}
void OperationPrinter::print(Operation *op) {
os.indent(currentIndent);
printOperation(op);
@ -2108,18 +2124,6 @@ void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
});
}
void ModulePrinter::print(ModuleOp module) {
assert(state && "expected valid state when printing an operation");
// Output the aliases at the top level.
state->getAliasState().printAttributeAliases(os);
state->getAliasState().printTypeAliases(os);
// Print the module.
OperationPrinter(*this).print(module);
os << '\n';
}
//===----------------------------------------------------------------------===//
// print and dump methods
//===----------------------------------------------------------------------===//
@ -2179,18 +2183,34 @@ void Value::print(raw_ostream &os) {
assert(isa<BlockArgument>());
os << "<block argument>\n";
}
void Value::print(raw_ostream &os, AsmState &state) {
if (auto *op = getDefiningOp())
return op->print(os, state);
// TODO: Improve this.
assert(isa<BlockArgument>());
os << "<block argument>\n";
}
void Value::dump() {
print(llvm::errs());
llvm::errs() << "\n";
}
void Value::printAsOperand(raw_ostream &os, AsmState &state) {
// TODO(riverriddle) This doesn't necessarily capture all potential cases.
// Currently, region arguments can be shadowed when printing the main
// operation. If the IR hasn't been printed, this will produce the old SSA
// name and not the shadowed name.
state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
os);
}
void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
// Handle top-level operations or local printing.
if (!getParent() || flags.shouldUseLocalScope()) {
ModuleState state(this);
ModulePrinter modulePrinter(os, flags, &state);
OperationPrinter(modulePrinter).print(this);
AsmState state(this);
OperationPrinter(os, flags, state.getImpl()).print(this);
return;
}
@ -2203,9 +2223,11 @@ void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
while (auto *nextOp = parentOp->getParentOp())
parentOp = nextOp;
ModuleState state(parentOp);
ModulePrinter modulePrinter(os, flags, &state);
OperationPrinter(modulePrinter).print(this);
AsmState state(parentOp);
print(os, state, flags);
}
void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
OperationPrinter(os, flags, state.getImpl()).print(this);
}
void Operation::dump() {
@ -2223,9 +2245,11 @@ void Block::print(raw_ostream &os) {
while (auto *nextOp = parentOp->getParentOp())
parentOp = nextOp;
ModuleState state(parentOp);
ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state);
OperationPrinter(modulePrinter).print(this);
AsmState state(parentOp);
print(os, state);
}
void Block::print(raw_ostream &os, AsmState &state) {
OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
}
void Block::dump() { print(llvm::errs()); }
@ -2241,18 +2265,24 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
while (auto *nextOp = parentOp->getParentOp())
parentOp = nextOp;
ModuleState state(parentOp);
ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state);
OperationPrinter(modulePrinter).printBlockName(this);
AsmState state(parentOp);
printAsOperand(os, state);
}
void Block::printAsOperand(raw_ostream &os, AsmState &state) {
OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
printer.printBlockName(this);
}
void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
ModuleState state(*this);
AsmState state(*this);
// Don't populate aliases when printing at local scope.
if (!flags.shouldUseLocalScope())
state.initializeAliases(*this);
ModulePrinter(os, flags, &state).print(*this);
state.getImpl().initializeAliases(*this);
print(os, state, flags);
}
void ModuleOp::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
OperationPrinter(os, flags, state.getImpl()).print(*this);
}
void ModuleOp::dump() { print(llvm::errs()); }