Merge ext/cfg/ml function printing logic in the AsmPrinter (shrinking it

by about 100 LOC), without changing any existing behavior.

This is step 20/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 227155000
This commit is contained in:
Chris Lattner 2018-12-28 11:41:56 -08:00 committed by jpienaar
parent 69d9e990fa
commit 69f9f6e21c

View File

@ -276,9 +276,6 @@ public:
void printAttribute(Attribute attr);
void printType(Type type);
void print(const Function *fn);
void printExt(const Function *fn);
void printCFG(const Function *fn);
void printML(const Function *fn);
void printAffineMap(AffineMap map);
void printAffineExpr(AffineExpr expr);
@ -289,8 +286,6 @@ protected:
raw_ostream &os;
ModuleState &state;
void printFunctionSignature(const Function *fn);
void printFunctionAttributes(const Function *fn);
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {});
void printFunctionResultType(FunctionType type);
@ -312,18 +307,6 @@ protected:
};
} // end anonymous namespace
// Prints function with initialized module state.
void ModulePrinter::print(const Function *fn) {
switch (fn->getKind()) {
case Function::Kind::ExtFunc:
return printExt(fn);
case Function::Kind::CFGFunc:
return printCFG(fn);
case Function::Kind::MLFunc:
return printML(fn);
}
}
// Prints affine map identifier.
void ModulePrinter::printAffineMapId(int affineMapId) const {
os << "#map" << affineMapId;
@ -872,24 +855,6 @@ void ModulePrinter::printFunctionResultType(FunctionType type) {
}
}
void ModulePrinter::printFunctionAttributes(const Function *fn) {
auto attrs = fn->getAttrs();
if (attrs.empty())
return;
os << "\n attributes ";
printOptionalAttrDict(attrs);
}
void ModulePrinter::printFunctionSignature(const Function *fn) {
auto type = fn->getType();
os << "@" << fn->getName() << '(';
interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
os << ')';
printFunctionResultType(type);
}
void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs) {
// If there are no attributes, then there is nothing to be done.
@ -929,20 +894,27 @@ void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
os << '}';
}
void ModulePrinter::printExt(const Function *fn) {
os << "extfunc ";
printFunctionSignature(fn);
printFunctionAttributes(fn);
os << '\n';
}
namespace {
// FunctionPrinter contains common functionality for printing
// CFG and ML functions.
class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
public:
FunctionPrinter(const ModulePrinter &other) : ModulePrinter(other) {}
FunctionPrinter(const Function *function, const ModulePrinter &other);
// Prints the function as a whole.
void print();
// Print the function signature.
void printMLFunctionSignature();
void printOtherFunctionSignature();
// Methods to print statements.
void print(const Statement *stmt);
void print(const OperationInst *inst);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
void print(const StmtBlock *block);
void printOperation(const OperationInst *op);
void printDefaultOp(const OperationInst *op);
@ -963,9 +935,6 @@ public:
void printFunctionReference(const Function *func) {
return ModulePrinter::printFunctionReference(func);
}
void printFunctionAttributes(const Function *func) {
return ModulePrinter::printFunctionAttributes(func);
}
void printOperand(const Value *value) { printValueID(value); }
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
@ -975,142 +944,373 @@ public:
enum { nameSentinel = ~0U };
void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); }
unsigned getBBID(const BasicBlock *block) {
auto it = basicBlockIDs.find(block);
assert(it != basicBlockIDs.end() && "Block not in this function?");
return it->second;
}
void printSuccessorAndUseList(const OperationInst *term,
unsigned index) override;
// Print if and loop bounds.
void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims);
void printBound(AffineBound bound, const char *prefix);
// Number of spaces used for indenting nested statements.
const static unsigned indentWidth = 2;
protected:
void numberValueID(const Value *value) {
assert(!valueIDs.count(value) && "Value numbered multiple times");
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
// Give constant integers special names.
if (auto *op = value->getDefiningInst()) {
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
// i1 constants get special names.
if (intOp->getType().isInteger(1)) {
specialName << (intOp->getValue() ? "true" : "false");
} else {
specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
}
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
specialName << 'c' << intOp->getValue();
} else if (auto constant = op->dyn_cast<ConstantOp>()) {
if (constant->getValue().isa<FunctionAttr>())
specialName << 'f';
else
specialName << "cst";
}
}
if (specialNameBuffer.empty()) {
switch (value->getKind()) {
case Value::Kind::BlockArgument:
// If this is an argument to the function, give it an 'arg' name.
if (auto *block = cast<BlockArgument>(value)->getOwner())
if (auto *fn = block->getFunction())
if (&fn->getBlockList().front() == block) {
specialName << "arg" << nextArgumentID++;
break;
}
// Otherwise number it normally.
valueIDs[value] = nextValueID++;
return;
case Value::Kind::InstResult:
// This is an uninteresting result, give it a boring number and be
// done with it.
valueIDs[value] = nextValueID++;
return;
case Value::Kind::ForStmt:
specialName << 'i' << nextLoopID++;
break;
}
}
// Ok, this value had an interesting name. Remember it with a sentinel.
valueIDs[value] = nameSentinel;
// Remember that we've used this name, checking to see if we had a conflict.
auto insertRes = usedNames.insert(specialName.str());
if (insertRes.second) {
// If this is the first use of the name, then we're successful!
valueNames[value] = insertRes.first->first();
return;
}
// Otherwise, we had a conflict - probe until we find a unique name. This
// is guaranteed to terminate (and usually in a single iteration) because it
// generates new names by incrementing nextConflictID.
while (1) {
std::string probeName =
specialName.str().str() + "_" + llvm::utostr(nextConflictID++);
insertRes = usedNames.insert(probeName);
if (insertRes.second) {
// If this is the first use of the name, then we're successful!
valueNames[value] = insertRes.first->first();
return;
}
}
}
void printValueID(const Value *value, bool printResultNo = true) const {
int resultNo = -1;
auto lookupValue = value;
// If this is a reference to the result of a multi-result instruction or
// statement, print out the # identifier and make sure to map our lookup
// to the first result of the instruction.
if (auto *result = dyn_cast<InstResult>(value)) {
if (result->getOwner()->getNumResults() != 1) {
resultNo = result->getResultNumber();
lookupValue = result->getOwner()->getResult(0);
}
} else if (auto *result = dyn_cast<InstResult>(value)) {
if (result->getOwner()->getNumResults() != 1) {
resultNo = result->getResultNumber();
lookupValue = result->getOwner()->getResult(0);
}
}
auto it = valueIDs.find(lookupValue);
if (it == valueIDs.end()) {
os << "<<INVALID SSA VALUE>>";
return;
}
os << '%';
if (it->second != nameSentinel) {
os << it->second;
} else {
auto nameIt = valueNames.find(lookupValue);
assert(nameIt != valueNames.end() && "Didn't have a name entry?");
os << nameIt->second;
}
if (resultNo != -1 && printResultNo)
os << '#' << resultNo;
}
void numberValueID(const Value *value);
void numberValuesInBlock(const StmtBlock &block);
void printValueID(const Value *value, bool printResultNo = true) const;
private:
const Function *function;
/// This is the value ID for each SSA value in the current function. If this
/// returns ~0, then the valueID has an entry in valueNames.
DenseMap<const Value *, unsigned> valueIDs;
DenseMap<const Value *, StringRef> valueNames;
/// This is the block ID for each block in the current function.
DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
/// This keeps track of all of the non-numeric names that are in flight,
/// allowing us to check for duplicates.
llvm::StringSet<> usedNames;
// This is the current indentation level for nested structures.
unsigned currentIndent = 0;
/// This is the next value ID to assign in numbering.
unsigned nextValueID = 0;
/// This is the ID to assign to the next induction variable.
unsigned nextLoopID = 0;
/// This is the next ID to assign to a Function argument.
unsigned nextArgumentID = 0;
/// This is the next ID to assign when a name conflict is detected.
unsigned nextConflictID = 0;
/// This is the next block ID to assign in numbering.
unsigned nextBlockID = 0;
};
} // end anonymous namespace
FunctionPrinter::FunctionPrinter(const Function *function,
const ModulePrinter &other)
: ModulePrinter(other), function(function) {
for (auto &block : *function)
numberValuesInBlock(block);
}
/// Number all of the SSA values in the specified block list.
void FunctionPrinter::numberValuesInBlock(const StmtBlock &block) {
// Each block gets a unique ID, and all of the instructions within it get
// numbered as well.
basicBlockIDs[&block] = nextBlockID++;
for (auto *arg : block.getArguments())
numberValueID(arg);
for (auto &inst : block) {
// We number instruction that have results, and we only number the first
// result.
switch (inst.getKind()) {
case Statement::Kind::OperationInst: {
auto *opInst = cast<OperationInst>(&inst);
if (opInst->getNumResults() != 0)
numberValueID(opInst->getResult(0));
break;
}
case Statement::Kind::For: {
auto *forInst = cast<ForStmt>(&inst);
// Number the induction variable.
numberValueID(forInst);
// Recursively number the stuff in the body.
numberValuesInBlock(*forInst->getBody());
break;
}
case Statement::Kind::If: {
auto *ifInst = cast<IfStmt>(&inst);
numberValuesInBlock(*ifInst->getThen());
if (auto *elseBlock = ifInst->getElse())
numberValuesInBlock(*elseBlock);
}
}
}
}
void FunctionPrinter::numberValueID(const Value *value) {
assert(!valueIDs.count(value) && "Value numbered multiple times");
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
// Give constant integers special names.
if (auto *op = value->getDefiningInst()) {
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
// i1 constants get special names.
if (intOp->getType().isInteger(1)) {
specialName << (intOp->getValue() ? "true" : "false");
} else {
specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
}
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
specialName << 'c' << intOp->getValue();
} else if (auto constant = op->dyn_cast<ConstantOp>()) {
if (constant->getValue().isa<FunctionAttr>())
specialName << 'f';
else
specialName << "cst";
}
}
if (specialNameBuffer.empty()) {
switch (value->getKind()) {
case Value::Kind::BlockArgument:
// If this is an argument to the function, give it an 'arg' name.
if (auto *block = cast<BlockArgument>(value)->getOwner())
if (auto *fn = block->getFunction())
if (&fn->getBlockList().front() == block) {
specialName << "arg" << nextArgumentID++;
break;
}
// Otherwise number it normally.
valueIDs[value] = nextValueID++;
return;
case Value::Kind::InstResult:
// This is an uninteresting result, give it a boring number and be
// done with it.
valueIDs[value] = nextValueID++;
return;
case Value::Kind::ForStmt:
specialName << 'i' << nextLoopID++;
break;
}
}
// Ok, this value had an interesting name. Remember it with a sentinel.
valueIDs[value] = nameSentinel;
// Remember that we've used this name, checking to see if we had a conflict.
auto insertRes = usedNames.insert(specialName.str());
if (insertRes.second) {
// If this is the first use of the name, then we're successful!
valueNames[value] = insertRes.first->first();
return;
}
// Otherwise, we had a conflict - probe until we find a unique name. This
// is guaranteed to terminate (and usually in a single iteration) because it
// generates new names by incrementing nextConflictID.
while (1) {
std::string probeName =
specialName.str().str() + "_" + llvm::utostr(nextConflictID++);
insertRes = usedNames.insert(probeName);
if (insertRes.second) {
// If this is the first use of the name, then we're successful!
valueNames[value] = insertRes.first->first();
return;
}
}
}
void FunctionPrinter::print() {
// TODO(clattner): merge the syntax of functions.
if (function->isML())
printMLFunctionSignature();
else
printOtherFunctionSignature();
// Print out function attributes, if present.
auto attrs = function->getAttrs();
if (!attrs.empty()) {
os << "\n attributes ";
printOptionalAttrDict(attrs);
}
if (!function->empty()) {
os << " {\n";
for (const auto &block : *function)
print(&block);
os << "}\n";
}
os << '\n';
}
void FunctionPrinter::printMLFunctionSignature() {
auto type = function->getType();
os << "mlfunc @" << function->getName() << '(';
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
if (i > 0)
os << ", ";
auto *arg = function->getArgument(i);
printOperand(arg);
os << " : ";
printType(arg->getType());
}
os << ')';
printFunctionResultType(type);
}
// This prints the signature for CFG and External functions.
void FunctionPrinter::printOtherFunctionSignature() {
auto type = function->getType();
if (function->isCFG())
os << "cfgfunc ";
else
os << "extfunc ";
os << '@' << function->getName() << '(';
interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
os << ')';
printFunctionResultType(type);
}
void FunctionPrinter::print(const StmtBlock *block) {
// Print the block label and argument list, unless we are in an ML function.
if (!block->getFunction()->isML()) {
os.indent(currentIndent);
printBBName(block);
// Print the argument list if non-empty.
if (!block->args_empty()) {
os << '(';
interleaveComma(block->getArguments(), [&](const BlockArgument *arg) {
printValueID(arg);
os << ": ";
printType(arg->getType());
});
os << ')';
}
os << ':';
// Print out some context information about the predecessors of this block.
if (!block->getFunction()) {
os << "\t// block is not in a function!";
} else if (block->hasNoPredecessors()) {
// Don't print "no predecessors" for the entry block.
if (block != &block->getFunction()->front())
os << "\t// no predecessors";
} else if (auto *pred = block->getSinglePredecessor()) {
os << "\t// pred: ";
printBBName(pred);
} else {
// We want to print the predecessors in increasing numeric order, not in
// whatever order the use-list is in, so gather and sort them.
SmallVector<unsigned, 4> predIDs;
for (auto *pred : block->getPredecessors())
predIDs.push_back(getBBID(pred));
llvm::array_pod_sort(predIDs.begin(), predIDs.end());
os << "\t// " << predIDs.size() << " preds: ";
interleaveComma(predIDs, [&](unsigned predID) { os << "bb" << predID; });
}
os << '\n';
}
currentIndent += indentWidth;
for (auto &stmt : block->getStatements()) {
print(&stmt);
os << '\n';
}
currentIndent -= indentWidth;
}
void FunctionPrinter::print(const Statement *stmt) {
switch (stmt->getKind()) {
case Statement::Kind::OperationInst:
return print(cast<OperationInst>(stmt));
case Statement::Kind::For:
return print(cast<ForStmt>(stmt));
case Statement::Kind::If:
return print(cast<IfStmt>(stmt));
}
}
void FunctionPrinter::print(const OperationInst *inst) {
os.indent(currentIndent);
printOperation(inst);
}
void FunctionPrinter::print(const ForStmt *stmt) {
os.indent(currentIndent) << "for ";
printOperand(stmt);
os << " = ";
printBound(stmt->getLowerBound(), "max");
os << " to ";
printBound(stmt->getUpperBound(), "min");
if (stmt->getStep() != 1)
os << " step " << stmt->getStep();
os << " {\n";
print(stmt->getBody());
os.indent(currentIndent) << "}";
}
void FunctionPrinter::print(const IfStmt *stmt) {
os.indent(currentIndent) << "if ";
IntegerSet set = stmt->getIntegerSet();
printIntegerSetReference(set);
printDimAndSymbolList(stmt->getInstOperands(), set.getNumDims());
os << " {\n";
print(stmt->getThen());
os.indent(currentIndent) << "}";
if (stmt->hasElse()) {
os << " else {\n";
print(stmt->getElse());
os.indent(currentIndent) << "}";
}
}
void FunctionPrinter::printValueID(const Value *value,
bool printResultNo) const {
int resultNo = -1;
auto lookupValue = value;
// If this is a reference to the result of a multi-result instruction or
// statement, print out the # identifier and make sure to map our lookup
// to the first result of the instruction.
if (auto *result = dyn_cast<InstResult>(value)) {
if (result->getOwner()->getNumResults() != 1) {
resultNo = result->getResultNumber();
lookupValue = result->getOwner()->getResult(0);
}
} else if (auto *result = dyn_cast<InstResult>(value)) {
if (result->getOwner()->getNumResults() != 1) {
resultNo = result->getResultNumber();
lookupValue = result->getOwner()->getResult(0);
}
}
auto it = valueIDs.find(lookupValue);
if (it == valueIDs.end()) {
os << "<<INVALID SSA VALUE>>";
return;
}
os << '%';
if (it->second != nameSentinel) {
os << it->second;
} else {
auto nameIt = valueNames.find(lookupValue);
assert(nameIt != valueNames.end() && "Didn't have a name entry?");
os << nameIt->second;
}
if (resultNo != -1 && printResultNo)
os << '#' << resultNo;
}
void FunctionPrinter::printOperation(const OperationInst *op) {
if (op->getNumResults()) {
printValueID(op->getResult(0), /*printResultNo=*/false);
@ -1156,303 +1356,27 @@ void FunctionPrinter::printDefaultOp(const OperationInst *op) {
}
}
//===----------------------------------------------------------------------===//
// CFG Function printing
//===----------------------------------------------------------------------===//
void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term,
unsigned index) {
printBBName(term->getSuccessor(index));
namespace {
class CFGFunctionPrinter : public FunctionPrinter {
public:
CFGFunctionPrinter(const Function *function, const ModulePrinter &other);
auto succOperands = term->getSuccessorOperands(index);
const Function *getFunction() const { return function; }
void print();
void print(const BasicBlock *block);
void print(const Instruction *inst);
void printSuccessorAndUseList(const OperationInst *term, unsigned index);
void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); }
unsigned getBBID(const BasicBlock *block) {
auto it = basicBlockIDs.find(block);
assert(it != basicBlockIDs.end() && "Block not in this function?");
return it->second;
}
private:
const Function *function;
DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
void numberValuesInBlock(const BasicBlock *block);
template <typename Range> void printBranchOperands(const Range &range);
};
} // end anonymous namespace
CFGFunctionPrinter::CFGFunctionPrinter(const Function *function,
const ModulePrinter &other)
: FunctionPrinter(other), function(function) {
// Each basic block gets a unique ID per function.
unsigned blockID = 0;
for (auto &block : *function) {
basicBlockIDs[&block] = blockID++;
numberValuesInBlock(&block);
}
}
/// Number all of the SSA values in the specified basic block.
void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
for (auto *arg : block->getArguments()) {
numberValueID(arg);
}
for (auto &op : *block) {
// We number instruction that have results, and we only number the first
// result.
if (auto *opInst = dyn_cast<OperationInst>(&op))
if (opInst->getNumResults() != 0)
numberValueID(opInst->getResult(0));
}
// Terminators do not define values.
}
void CFGFunctionPrinter::print() {
os << "cfgfunc ";
printFunctionSignature(getFunction());
printFunctionAttributes(getFunction());
os << " {\n";
for (auto &block : *function)
print(&block);
os << "}\n\n";
}
void CFGFunctionPrinter::print(const BasicBlock *block) {
printBBName(block);
if (!block->args_empty()) {
os << '(';
interleaveComma(block->getArguments(), [&](const BlockArgument *arg) {
printValueID(arg);
os << ": ";
printType(arg->getType());
});
os << ')';
}
os << ':';
// Print out some context information about the predecessors of this block.
if (!block->getFunction()) {
os << "\t// block is not in a function!";
} else if (block->hasNoPredecessors()) {
// Don't print "no predecessors" for the entry block.
if (block != &block->getFunction()->front())
os << "\t// no predecessors";
} else if (auto *pred = block->getSinglePredecessor()) {
os << "\t// pred: ";
printBBName(pred);
} else {
// We want to print the predecessors in increasing numeric order, not in
// whatever order the use-list is in, so gather and sort them.
SmallVector<unsigned, 4> predIDs;
for (auto *pred : block->getPredecessors())
predIDs.push_back(getBBID(pred));
llvm::array_pod_sort(predIDs.begin(), predIDs.end());
os << "\t// " << predIDs.size() << " preds: ";
interleaveComma(predIDs, [&](unsigned predID) { os << "bb" << predID; });
}
os << '\n';
for (auto &inst : block->getStatements()) {
os << " ";
print(&inst);
os << '\n';
}
}
void CFGFunctionPrinter::print(const Instruction *inst) {
if (!inst) {
os << "<<null instruction>>\n";
return;
}
auto *opInst = dyn_cast<OperationInst>(inst);
assert(opInst && "IfStmt/ForStmt aren't supported in CFG functions yet");
printOperation(opInst);
}
// Print the operands from "container" to "os", followed by a colon and their
// respective types, everything in parentheses. Do nothing if the container is
// empty.
template <typename Range>
void CFGFunctionPrinter::printBranchOperands(const Range &range) {
if (llvm::empty(range))
if (succOperands.begin() == succOperands.end())
return;
os << '(';
interleaveComma(range,
interleaveComma(succOperands,
[this](const Value *operand) { printValueID(operand); });
os << " : ";
interleaveComma(
range, [this](const Value *operand) { printType(operand->getType()); });
interleaveComma(succOperands, [this](const Value *operand) {
printType(operand->getType());
});
os << ')';
}
void CFGFunctionPrinter::printSuccessorAndUseList(const OperationInst *term,
unsigned index) {
printBBName(term->getSuccessor(index));
printBranchOperands(term->getSuccessorOperands(index));
}
void ModulePrinter::printCFG(const Function *fn) {
CFGFunctionPrinter(fn, *this).print();
}
//===----------------------------------------------------------------------===//
// ML Function printing
//===----------------------------------------------------------------------===//
namespace {
class MLFunctionPrinter : public FunctionPrinter {
public:
MLFunctionPrinter(const Function *function, const ModulePrinter &other);
const Function *getFunction() const { return function; }
// Prints ML function.
void print();
// Prints ML function signature.
void printFunctionSignature();
// Methods to print ML function statements.
void print(const Statement *stmt);
void print(const OperationInst *stmt);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
void print(const StmtBlock *block);
void printSuccessorAndUseList(const OperationInst *term, unsigned index) {
assert(false && "MLFunctions do not have terminators with successors.");
}
// Print loop bounds.
void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims);
void printBound(AffineBound bound, const char *prefix);
// Number of spaces used for indenting nested statements.
const static unsigned indentWidth = 2;
private:
void numberValues();
const Function *function;
int numSpaces;
};
} // end anonymous namespace
MLFunctionPrinter::MLFunctionPrinter(const Function *function,
const ModulePrinter &other)
: FunctionPrinter(other), function(function), numSpaces(0) {
assert(function && "Cannot print nullptr function");
numberValues();
}
/// Number all of the SSA values in this ML function.
void MLFunctionPrinter::numberValues() {
// Numbers ML function arguments.
for (auto *arg : function->getArguments())
numberValueID(arg);
// Walks ML function statements and numbers for statements and
// the first result of the operation statements.
struct NumberValuesPass : public StmtWalker<NumberValuesPass> {
NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {}
void visitOperationInst(OperationInst *stmt) {
if (stmt->getNumResults() != 0)
printer->numberValueID(stmt->getResult(0));
}
void visitForStmt(ForStmt *stmt) { printer->numberValueID(stmt); }
MLFunctionPrinter *printer;
};
NumberValuesPass pass(this);
// TODO: it'd be cleaner to have constant visitor instead of using const_cast.
pass.walk(const_cast<Function *>(function));
}
void MLFunctionPrinter::print() {
os << "mlfunc ";
printFunctionSignature();
printFunctionAttributes(getFunction());
os << " {\n";
print(function->getBody());
os << "}\n\n";
}
void MLFunctionPrinter::printFunctionSignature() {
auto type = function->getType();
os << "@" << function->getName() << '(';
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
if (i > 0)
os << ", ";
auto *arg = function->getArgument(i);
printOperand(arg);
os << " : ";
printType(arg->getType());
}
os << ")";
printFunctionResultType(type);
}
void MLFunctionPrinter::print(const StmtBlock *block) {
numSpaces += indentWidth;
for (auto &stmt : block->getStatements()) {
print(&stmt);
os << "\n";
}
numSpaces -= indentWidth;
}
void MLFunctionPrinter::print(const Statement *stmt) {
switch (stmt->getKind()) {
case Statement::Kind::OperationInst:
return print(cast<OperationInst>(stmt));
case Statement::Kind::For:
return print(cast<ForStmt>(stmt));
case Statement::Kind::If:
return print(cast<IfStmt>(stmt));
}
}
void MLFunctionPrinter::print(const OperationInst *stmt) {
os.indent(numSpaces);
printOperation(stmt);
}
void MLFunctionPrinter::print(const ForStmt *stmt) {
os.indent(numSpaces) << "for ";
printOperand(stmt);
os << " = ";
printBound(stmt->getLowerBound(), "max");
os << " to ";
printBound(stmt->getUpperBound(), "min");
if (stmt->getStep() != 1)
os << " step " << stmt->getStep();
os << " {\n";
print(stmt->getBody());
os.indent(numSpaces) << "}";
}
void MLFunctionPrinter::printDimAndSymbolList(ArrayRef<InstOperand> ops,
unsigned numDims) {
void FunctionPrinter::printDimAndSymbolList(ArrayRef<InstOperand> ops,
unsigned numDims) {
auto printComma = [&]() { os << ", "; };
os << '(';
interleave(
@ -1469,7 +1393,7 @@ void MLFunctionPrinter::printDimAndSymbolList(ArrayRef<InstOperand> ops,
}
}
void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
void FunctionPrinter::printBound(AffineBound bound, const char *prefix) {
AffineMap map = bound.getMap();
// Check if this bound should be printed using short-hand notation.
@ -1507,23 +1431,9 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
printDimAndSymbolList(bound.getInstOperands(), map.getNumDims());
}
void MLFunctionPrinter::print(const IfStmt *stmt) {
os.indent(numSpaces) << "if ";
IntegerSet set = stmt->getIntegerSet();
printIntegerSetReference(set);
printDimAndSymbolList(stmt->getInstOperands(), set.getNumDims());
os << " {\n";
print(stmt->getThen());
os.indent(numSpaces) << "}";
if (stmt->hasElse()) {
os << " else {\n";
print(stmt->getElse());
os.indent(numSpaces) << "}";
}
}
void ModulePrinter::printML(const Function *fn) {
MLFunctionPrinter(fn, *this).print();
// Prints function with initialized module state.
void ModulePrinter::print(const Function *fn) {
FunctionPrinter(fn, *this).print();
}
//===----------------------------------------------------------------------===//
@ -1595,15 +1505,10 @@ void Instruction::print(raw_ostream &os) const {
os << "<<UNLINKED INSTRUCTION>>\n";
return;
}
if (function->isCFG()) {
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
CFGFunctionPrinter(function, modulePrinter).print(this);
} else {
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
MLFunctionPrinter(function, modulePrinter).print(this);
}
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
FunctionPrinter(function, modulePrinter).print(this);
}
void Instruction::dump() const {
@ -1618,15 +1523,9 @@ void BasicBlock::print(raw_ostream &os) const {
return;
}
if (function->isCFG()) {
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
CFGFunctionPrinter(function, modulePrinter).print(this);
} else {
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
MLFunctionPrinter(function, modulePrinter).print(this);
}
ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
FunctionPrinter(function, modulePrinter).print(this);
}
void BasicBlock::dump() const { print(llvm::errs()); }
@ -1639,7 +1538,7 @@ void StmtBlock::printAsOperand(raw_ostream &os, bool printType) {
}
ModuleState state(getFunction()->getContext());
ModulePrinter modulePrinter(os, state);
CFGFunctionPrinter(getFunction(), modulePrinter).printBBName(this);
FunctionPrinter(getFunction(), modulePrinter).printBBName(this);
}
void Function::print(raw_ostream &os) const {