diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h new file mode 100644 index 000000000000..9f2c78c511b0 --- /dev/null +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -0,0 +1,191 @@ +//===- Dominance.h - Dominator analysis for CFG Functions -------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_ANALYSIS_DOMINANCE_H +#define MLIR_ANALYSIS_DOMINANCE_H + +#include "mlir/IR/CFGFunction.h" +#include "llvm/Support/GenericDomTree.h" + +namespace llvm { +template <> struct GraphTraits { + using ChildIteratorType = mlir::BasicBlock::succ_iterator; + using Node = mlir::BasicBlock; + using NodeRef = Node *; + + static NodeRef getEntryNode(NodeRef bb) { return bb; } + + static ChildIteratorType child_begin(NodeRef node) { + return node->succ_begin(); + } + static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); } +}; + +template <> struct GraphTraits { + using ChildIteratorType = mlir::BasicBlock::const_succ_iterator; + using Node = const mlir::BasicBlock; + using NodeRef = Node *; + + static NodeRef getEntryNode(NodeRef bb) { return bb; } + + static ChildIteratorType child_begin(NodeRef node) { + return node->succ_begin(); + } + static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); } +}; + +template <> struct GraphTraits> { + using ChildIteratorType = mlir::BasicBlock::pred_iterator; + using Node = mlir::BasicBlock; + using NodeRef = Node *; + static NodeRef getEntryNode(Inverse inverseGraph) { + return inverseGraph.Graph; + } + static inline ChildIteratorType child_begin(NodeRef node) { + return node->pred_begin(); + } + static inline ChildIteratorType child_end(NodeRef node) { + return node->pred_end(); + } +}; + +template <> struct GraphTraits> { + using ChildIteratorType = mlir::BasicBlock::const_pred_iterator; + using Node = const mlir::BasicBlock; + using NodeRef = Node *; + + static NodeRef getEntryNode(Inverse inverseGraph) { + return inverseGraph.Graph; + } + static inline ChildIteratorType child_begin(NodeRef node) { + return node->pred_begin(); + } + static inline ChildIteratorType child_end(NodeRef node) { + return node->pred_end(); + } +}; + +template <> +struct GraphTraits + : public GraphTraits { + using GraphType = mlir::CFGFunction *; + using NodeRef = mlir::BasicBlock *; + + static NodeRef getEntryNode(GraphType fn) { return &fn->front(); } + + using nodes_iterator = pointer_iterator; + static nodes_iterator nodes_begin(GraphType fn) { + return nodes_iterator(fn->begin()); + } + static nodes_iterator nodes_end(GraphType fn) { + return nodes_iterator(fn->end()); + } +}; + +template <> +struct GraphTraits> + : public GraphTraits> { + using GraphType = Inverse; + using NodeRef = NodeRef; + + static NodeRef getEntryNode(GraphType fn) { return &fn.Graph->front(); } + + using nodes_iterator = pointer_iterator; + static nodes_iterator nodes_begin(GraphType fn) { + return nodes_iterator(fn.Graph->begin()); + } + static nodes_iterator nodes_end(GraphType fn) { + return nodes_iterator(fn.Graph->end()); + } +}; +} // namespace llvm + +extern template class llvm::DominatorTreeBase; +extern template class llvm::DominatorTreeBase; +extern template class llvm::DomTreeNodeBase; + +namespace llvm { +namespace DomTreeBuilder { + +using MLIRDomTree = llvm::DomTreeBase; +using MLIRPostDomTree = llvm::PostDomTreeBase; + +// extern template void Calculate(MLIRDomTree &DT); +// extern template void Calculate(MLIRPostDomTree &DT); + +} // namespace DomTreeBuilder +} // namespace llvm + +namespace mlir { +using DominatorTreeBase = llvm::DominatorTreeBase; +using PostDominatorTreeBase = llvm::DominatorTreeBase; +using DominanceInfoNode = llvm::DomTreeNodeBase; + +/// A class for computing basic dominance information. +class DominanceInfo : public DominatorTreeBase { + using super = DominatorTreeBase; + +public: + DominanceInfo(CFGFunction *F); + + /// Return true if instruction A properly dominates instruction B. + bool properlyDominates(const Instruction *a, const Instruction *b); + + /// Return true if instruction A dominates instruction B. + bool dominates(const Instruction *a, const Instruction *b) { + return a == b || properlyDominates(a, b); + } + + /// Return true if value A properly dominates instruction B. + bool properlyDominates(const SSAValue *a, const Instruction *b); + + /// Return true if instruction A dominates instruction B. + bool dominates(const SSAValue *a, const Instruction *b) { + return a->getDefiningInst() == b || properlyDominates(a, b); + } + + // dominates/properlyDominates for basic blocks. + using DominatorTreeBase::dominates; + using DominatorTreeBase::properlyDominates; +}; + +} // end namespace mlir + +namespace llvm { + +/// DominatorTree GraphTraits specialization so the DominatorTree can be +/// iterated by generic graph iterators. +template <> struct GraphTraits { + using ChildIteratorType = mlir::DominanceInfoNode::iterator; + using NodeRef = mlir::DominanceInfoNode *; + + static NodeRef getEntryNode(NodeRef N) { return N; } + static inline ChildIteratorType child_begin(NodeRef N) { return N->begin(); } + static inline ChildIteratorType child_end(NodeRef N) { return N->end(); } +}; + +template <> struct GraphTraits { + using ChildIteratorType = mlir::DominanceInfoNode::const_iterator; + using NodeRef = const mlir::DominanceInfoNode *; + + static NodeRef getEntryNode(NodeRef N) { return N; } + static inline ChildIteratorType child_begin(NodeRef N) { return N->begin(); } + static inline ChildIteratorType child_end(NodeRef N) { return N->end(); } +}; + +} // end namespace llvm +#endif diff --git a/mlir/include/mlir/IR/BasicBlock.h b/mlir/include/mlir/IR/BasicBlock.h index 079849c87224..a38e605ed589 100644 --- a/mlir/include/mlir/IR/BasicBlock.h +++ b/mlir/include/mlir/IR/BasicBlock.h @@ -39,9 +39,12 @@ public: ~BasicBlock(); /// Return the function that a BasicBlock is part of. - CFGFunction *getFunction() const { - return function; - } + CFGFunction *getFunction() { return function; } + const CFGFunction *getFunction() const { return function; } + + /// Return the function that a BasicBlock is part of. + const CFGFunction *getParent() const { return function; } + CFGFunction *getParent() { return function; } //===--------------------------------------------------------------------===// // Block arguments management @@ -194,6 +197,11 @@ public: void print(raw_ostream &os) const; void dump() const; + /// Print out the name of the basic block without printing its body. + /// NOTE: The printType argument is ignored. We keep it for compatibility + /// with LLVM dominator machinery that expects it to exist. + void printAsOperand(raw_ostream &os, bool printType = true); + /// getSublistAccess() - Returns pointer to member of operation list static OperationListType BasicBlock::*getSublistAccess(OperationInst*) { return &BasicBlock::operations; @@ -300,6 +308,9 @@ public: : IndexedAccessorIterator, BlockType, BlockType>(object, index) {} + SuccessorIterator(const SuccessorIterator &other) + : SuccessorIterator(other.object, other.index) {} + /// Support converting to the const variant. This will be a no-op for const /// variant. operator SuccessorIterator() const { diff --git a/mlir/include/mlir/IR/CFGValue.h b/mlir/include/mlir/IR/CFGValue.h index 35d7c028c2a5..939073c23825 100644 --- a/mlir/include/mlir/IR/CFGValue.h +++ b/mlir/include/mlir/IR/CFGValue.h @@ -77,7 +77,10 @@ public: } /// Return the function that this argument is defined in. - CFGFunction *getFunction() const; + CFGFunction *getFunction(); + const CFGFunction *getFunction() const { + return const_cast(this)->getFunction(); + } BasicBlock *getOwner() { return owner; } const BasicBlock *getOwner() const { return owner; } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 40e17cb1581e..21e2aa5a5758 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -287,7 +287,8 @@ private: template class IndexedAccessorIterator : public llvm::iterator_facade_base< - ConcreteType, std::random_access_iterator_tag, ElementType *> { + ConcreteType, std::random_access_iterator_tag, ElementType *, + std::ptrdiff_t, ElementType *, ElementType *> { public: ptrdiff_t operator-(const IndexedAccessorIterator &rhs) const { assert(object == rhs.object && "incompatible iterators"); diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp new file mode 100644 index 000000000000..1eb59e04aa33 --- /dev/null +++ b/mlir/lib/Analysis/Dominance.cpp @@ -0,0 +1,81 @@ +//===- Dominance.cpp - Dominator analysis for CFG Functions ---------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Implementation of dominance related classes and instantiations of extern +// templates. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Dominance.h" +#include "llvm/Support/GenericDomTreeConstruction.h" +using namespace mlir; + +template class llvm::DominatorTreeBase; +template class llvm::DominatorTreeBase; +template class llvm::DomTreeNodeBase; + +/// Compute the immediate-dominators map. +DominanceInfo::DominanceInfo(CFGFunction *function) : DominatorTreeBase() { + // Build the dominator tree for the function. + recalculate(*function); +} + +/// Return true if instruction A properly dominates instruction B. +bool DominanceInfo::properlyDominates(const Instruction *a, + const Instruction *b) { + auto *aBlock = a->getBlock(), *bBlock = b->getBlock(); + + // If the blocks are different, it's as easy as whether A's block + // dominates B's block. + if (aBlock != bBlock) + return properlyDominates(a->getBlock(), b->getBlock()); + + // If a/b are the same, then they don't properly dominate each other. + if (a == b) + return false; + + // If one is a terminator, then the other dominates it. + auto *aOp = dyn_cast(a); + if (!aOp) + return false; + + auto *bOp = dyn_cast(b); + if (!bOp) + return true; + + // Otherwise, do a linear scan to determine whether B comes after A. + auto aIter = BasicBlock::const_iterator(aOp); + auto bIter = BasicBlock::const_iterator(bOp); + auto fIter = aBlock->begin(); + while (bIter != fIter) { + --bIter; + if (aIter == bIter) + return true; + } + + return false; +} + +/// Return true if value A properly dominates instruction B. +bool DominanceInfo::properlyDominates(const SSAValue *a, const Instruction *b) { + if (auto *aInst = a->getDefiningInst()) + return properlyDominates(aInst, b); + + // bbarguments properly dominate all instructions in their own block, so we + // use a dominates check here, not a properlyDominates check. + return dominates(cast(a)->getOwner(), b->getBlock()); +} diff --git a/mlir/lib/Analysis/HyperRectangularSet.cpp b/mlir/lib/Analysis/HyperRectangularSet.cpp index 14e180bb29b9..c385a22b67c9 100644 --- a/mlir/lib/Analysis/HyperRectangularSet.cpp +++ b/mlir/lib/Analysis/HyperRectangularSet.cpp @@ -1,4 +1,4 @@ -//===- HyperRectangularSet.cpp - MLIR HyperRectangularSet Class--*- C++ -*-===// +//===- HyperRectangularSet.cpp - MLIR HyperRectangularSet Class -----------===// // // Copyright 2019 The MLIR Authors. // @@ -20,13 +20,10 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/HyperRectangularSet.h" - -#include - #include "mlir/IR/AffineExpr.h" #include "mlir/IR/IntegerSet.h" #include "llvm/Support/raw_ostream.h" - +#include using namespace mlir; // TODO(bondhugula): clean this code up. diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp similarity index 94% rename from mlir/lib/IR/Verifier.cpp rename to mlir/lib/Analysis/Verifier.cpp index 292d2bfefe16..4ea7fd8da499 100644 --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -33,6 +33,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/Dominance.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/CFGFunction.h" #include "mlir/IR/MLFunction.h" @@ -163,12 +164,15 @@ bool Verifier::verifyOperation(const Operation &op) { namespace { struct CFGFuncVerifier : public Verifier { const CFGFunction &fn; + DominanceInfo domInfo; - CFGFuncVerifier(const CFGFunction &fn) : Verifier(fn), fn(fn) {} + CFGFuncVerifier(const CFGFunction &fn) + : Verifier(fn), fn(fn), domInfo(const_cast(&fn)) {} bool verify(); bool verifyBlock(const BasicBlock &block); bool verifyTerminator(const TerminatorInst &term); + bool verifyInstOperands(const Instruction &inst); bool verifyBBArguments(ArrayRef operands, const BasicBlock *destBB, const TerminatorInst &term); @@ -212,6 +216,24 @@ bool CFGFuncVerifier::verify() { return false; } +bool CFGFuncVerifier::verifyInstOperands(const Instruction &inst) { + // Check that operands properly dominate this use. + for (unsigned operandNo = 0, e = inst.getNumOperands(); operandNo != e; + ++operandNo) { + auto *op = inst.getOperand(operandNo); + if (domInfo.properlyDominates(op, &inst)) + continue; + + inst.emitError("operand #" + Twine(operandNo) + + " does not dominate this use"); + if (auto *useInst = op->getDefiningInst()) + useInst->emitNote("operand defined here"); + return true; + } + + return false; +} + bool CFGFuncVerifier::verifyBlock(const BasicBlock &block) { if (!block.getTerminator()) return failure("basic block with no terminator", block); @@ -225,7 +247,7 @@ bool CFGFuncVerifier::verifyBlock(const BasicBlock &block) { } for (auto &inst : block) { - if (verifyOperation(inst)) + if (verifyOperation(inst) || verifyInstOperands(inst)) return true; } return false; @@ -244,6 +266,9 @@ bool CFGFuncVerifier::verifyTerminator(const TerminatorInst &term) { return failure("reference to operand defined in another function", term); } + // Verify dominance of values. + verifyInstOperands(term); + // Check that successors are in the right function. for (auto *succ : term.getBlock()->getSuccessors()) { if (succ->getFunction() != &fn) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index c269feba15c0..61e0ed15d13d 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1077,6 +1077,8 @@ public: void print(const BranchInst *inst); void print(const CondBranchInst *inst); + void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); } + unsigned getBBID(const BasicBlock *block) { auto it = basicBlockIDs.find(block); assert(it != basicBlockIDs.end() && "Block not in this function?"); @@ -1129,7 +1131,7 @@ void CFGFunctionPrinter::print() { } void CFGFunctionPrinter::print(const BasicBlock *block) { - os << "bb" << getBBID(block); + printBBName(block); if (!block->args_empty()) { os << '('; @@ -1150,7 +1152,8 @@ void CFGFunctionPrinter::print(const BasicBlock *block) { if (block != &block->getFunction()->front()) os << "\t// no predecessors"; } else if (auto *pred = block->getSinglePredecessor()) { - os << "\t// pred: bb" << getBBID(pred); + os << "\t// pred: "; + printBBName(pred); } else { // We want to print the predecessors in increasing numeric order, not in // whatever order the use-list is in, so gather and sort them. @@ -1198,7 +1201,8 @@ void CFGFunctionPrinter::print(const OperationInst *inst) { } void CFGFunctionPrinter::print(const BranchInst *inst) { - os << "br bb" << getBBID(inst->getDest()); + os << "br "; + printBBName(inst->getDest()); if (inst->getNumOperands() != 0) { os << '('; @@ -1215,7 +1219,8 @@ void CFGFunctionPrinter::print(const CondBranchInst *inst) { os << "cond_br "; printValueID(inst->getCondition()); - os << ", bb" << getBBID(inst->getTrueDest()); + os << ", "; + printBBName(inst->getTrueDest()); if (inst->getNumTrueOperands() != 0) { os << '('; interleaveComma(inst->getTrueOperands(), @@ -1227,7 +1232,8 @@ void CFGFunctionPrinter::print(const CondBranchInst *inst) { os << ")"; } - os << ", bb" << getBBID(inst->getFalseDest()); + os << ", "; + printBBName(inst->getFalseDest()); if (inst->getNumFalseOperands() != 0) { os << '('; interleaveComma(inst->getFalseOperands(), @@ -1555,6 +1561,17 @@ void BasicBlock::print(raw_ostream &os) const { void BasicBlock::dump() const { print(llvm::errs()); } +/// Print out the name of the basic block without printing its body. +void BasicBlock::printAsOperand(raw_ostream &os, bool printType) { + if (!getFunction()) { + os << "<>\n"; + return; + } + ModuleState state(getFunction()->getContext()); + ModulePrinter modulePrinter(os, state); + CFGFunctionPrinter(getFunction(), modulePrinter).printBBName(this); +} + void Statement::print(raw_ostream &os) const { MLFunction *function = findFunction(); if (!function) { diff --git a/mlir/lib/IR/SSAValue.cpp b/mlir/lib/IR/SSAValue.cpp index f111d7126b36..469825fbabcd 100644 --- a/mlir/lib/IR/SSAValue.cpp +++ b/mlir/lib/IR/SSAValue.cpp @@ -77,7 +77,7 @@ CFGFunction *CFGValue::getFunction() { //===----------------------------------------------------------------------===// /// Return the function that this argument is defined in. -CFGFunction *BBArgument::getFunction() const { +CFGFunction *BBArgument::getFunction() { if (auto *owner = getOwner()) return owner->getFunction(); return nullptr; diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index ff04776ee59b..df2d72f0b39c 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -40,7 +40,7 @@ struct ConstantFold : public FunctionPass { /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// -/// This returns true if the operation was folded. +/// This returns false if the operation was successfully folded. bool ConstantFold::foldOperation(Operation *op, SmallVectorImpl &existingConstants, ConstantFactoryType constantFactory) { @@ -49,7 +49,7 @@ bool ConstantFold::foldOperation(Operation *op, // later, and don't try to fold it. if (op->is()) { existingConstants.push_back(op->getResult(0)); - return false; + return true; } // Check to see if each of the operands is a trivial constant. If so, get @@ -63,13 +63,13 @@ bool ConstantFold::foldOperation(Operation *op, } } // If one of the operands was non-constant, then we can't fold it. - return false; + return true; } // Attempt to constant fold the operation. SmallVector resultConstants; if (op->constantFold(operandConstants, resultConstants)) - return false; + return true; // Ok, if everything succeeded, then we can create constants corresponding // to the result of the call. @@ -88,7 +88,7 @@ bool ConstantFold::foldOperation(Operation *op, res->replaceAllUsesWith(cst); } - return true; + return false; } // For now, we do a simple top-down pass over a function folding constants. We @@ -108,7 +108,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { ->getResult(); }; - if (foldOperation(&inst, existingConstants, constantFactory)) { + if (!foldOperation(&inst, existingConstants, constantFactory)) { // At this point the operation is dead, remove it. // TODO: This is assuming that all constant foldable operations have no // side effects. When we have side effect modeling, we should verify @@ -160,7 +160,7 @@ void ConstantFold::foldStmtBlock( ->getResult(); }; - if (foldOperation(opStmt, existingConstants, constantFactory)) { + if (!foldOperation(opStmt, existingConstants, constantFactory)) { // At this point the operation is dead, remove it. // TODO: This is assuming that all constant foldable operations have no // side effects. When we have side effect modeling, we should verify that diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index fd98e2ccabe9..2fcc408ce3ba 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -396,6 +396,18 @@ mlfunc @dominance_failure() { return } +// ----- + +cfgfunc @dominance_failure() { +bb0: + "foo"(%x) : (i32) -> () // expected-error {{operand #0 does not dominate this use}} + br bb1 +bb1: + %x = "bar"() : () -> i32 // expected-error {{operand defined here}} + return +} + + // ----- mlfunc @return_type_mismatch() -> i32 {