mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 06:06:07 +00:00
[MLIR][Affine] Make affine fusion MDG API const correct (#125994)
Make affine fusion MDG API const correct. NFC changes otherwise.
This commit is contained in:
parent
998f2422a5
commit
001ba42fe0
@ -139,9 +139,11 @@ public:
|
||||
|
||||
// Map from node id to Node.
|
||||
DenseMap<unsigned, Node> nodes;
|
||||
// Map from node id to list of input edges.
|
||||
// Map from node id to list of input edges. The absence of an entry for a key
|
||||
// is also equivalent to the absence of any edges.
|
||||
DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
|
||||
// Map from node id to list of output edges.
|
||||
// Map from node id to list of output edges. The absence of an entry for a
|
||||
// node is also equivalent to the absence of any edges.
|
||||
DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
|
||||
// Map from memref to a count on the dependence edges associated with that
|
||||
// memref.
|
||||
@ -156,10 +158,21 @@ public:
|
||||
bool init();
|
||||
|
||||
// Returns the graph node for 'id'.
|
||||
Node *getNode(unsigned id);
|
||||
const Node *getNode(unsigned id) const;
|
||||
Node *getNode(unsigned id) {
|
||||
return const_cast<Node *>(
|
||||
static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
|
||||
}
|
||||
|
||||
// Returns true if the graph has node with ID `id`.
|
||||
bool hasNode(unsigned id) const { return nodes.contains(id); }
|
||||
|
||||
// Returns the graph node for 'forOp'.
|
||||
Node *getForOpNode(AffineForOp forOp);
|
||||
const Node *getForOpNode(AffineForOp forOp) const;
|
||||
Node *getForOpNode(AffineForOp forOp) {
|
||||
return const_cast<Node *>(
|
||||
static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
|
||||
}
|
||||
|
||||
// Adds a node with 'op' to the graph and returns its unique identifier.
|
||||
unsigned addNode(Operation *op);
|
||||
@ -169,12 +182,12 @@ public:
|
||||
|
||||
// Returns true if node 'id' writes to any memref which escapes (or is an
|
||||
// argument to) the block. Returns false otherwise.
|
||||
bool writesToLiveInOrEscapingMemrefs(unsigned id);
|
||||
bool writesToLiveInOrEscapingMemrefs(unsigned id) const;
|
||||
|
||||
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
|
||||
// is for 'value' if non-null, or for any value otherwise. Returns false
|
||||
// otherwise.
|
||||
bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
|
||||
bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;
|
||||
|
||||
// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
|
||||
void addEdge(unsigned srcId, unsigned dstId, Value value);
|
||||
@ -185,23 +198,25 @@ public:
|
||||
// Returns true if there is a path in the dependence graph from node 'srcId'
|
||||
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
|
||||
// operations that the edges connected are expected to be from the same block.
|
||||
bool hasDependencePath(unsigned srcId, unsigned dstId);
|
||||
bool hasDependencePath(unsigned srcId, unsigned dstId) const;
|
||||
|
||||
// Returns the input edge count for node 'id' and 'memref' from src nodes
|
||||
// which access 'memref' with a store operation.
|
||||
unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
|
||||
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;
|
||||
|
||||
// Returns the output edge count for node 'id' and 'memref' (if non-null),
|
||||
// otherwise returns the total output edge count from node 'id'.
|
||||
unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
|
||||
unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;
|
||||
|
||||
/// Return all nodes which define SSA values used in node 'id'.
|
||||
void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
|
||||
void gatherDefiningNodes(unsigned id,
|
||||
DenseSet<unsigned> &definingNodes) const;
|
||||
|
||||
// Computes and returns an insertion point operation, before which the
|
||||
// the fused <srcId, dstId> loop nest can be inserted while preserving
|
||||
// dependences. Returns nullptr if no such insertion point is found.
|
||||
Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
|
||||
Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
|
||||
unsigned dstId) const;
|
||||
|
||||
// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
|
||||
// taking into account that:
|
||||
|
@ -188,8 +188,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
|
||||
|
||||
/// Add `op` to MDG creating a new node and adding its memory accesses (affine
|
||||
/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
|
||||
Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
|
||||
DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
|
||||
static Node *
|
||||
addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
|
||||
DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
|
||||
auto &nodes = mdg.nodes;
|
||||
// Create graph node 'id' to represent top-level 'forOp' and record
|
||||
// all loads and store accesses it contains.
|
||||
@ -359,14 +360,14 @@ bool MemRefDependenceGraph::init() {
|
||||
}
|
||||
|
||||
// Returns the graph node for 'id'.
|
||||
Node *MemRefDependenceGraph::getNode(unsigned id) {
|
||||
const Node *MemRefDependenceGraph::getNode(unsigned id) const {
|
||||
auto it = nodes.find(id);
|
||||
assert(it != nodes.end());
|
||||
return &it->second;
|
||||
}
|
||||
|
||||
// Returns the graph node for 'forOp'.
|
||||
Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
|
||||
const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
|
||||
for (auto &idAndNode : nodes)
|
||||
if (idAndNode.second.op == forOp)
|
||||
return &idAndNode.second;
|
||||
@ -390,7 +391,7 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
|
||||
}
|
||||
}
|
||||
// Remove each edge in 'outEdges[id]'.
|
||||
if (outEdges.count(id) > 0) {
|
||||
if (outEdges.contains(id)) {
|
||||
SmallVector<Edge, 2> oldOutEdges = outEdges[id];
|
||||
for (auto &outEdge : oldOutEdges) {
|
||||
removeEdge(id, outEdge.id, outEdge.value);
|
||||
@ -404,8 +405,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
|
||||
|
||||
// Returns true if node 'id' writes to any memref which escapes (or is an
|
||||
// argument to) the block. Returns false otherwise.
|
||||
bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
|
||||
Node *node = getNode(id);
|
||||
bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
|
||||
const Node *node = getNode(id);
|
||||
for (auto *storeOpInst : node->stores) {
|
||||
auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
|
||||
auto *op = memref.getDefiningOp();
|
||||
@ -425,14 +426,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
|
||||
// is for 'value' if non-null, or for any value otherwise. Returns false
|
||||
// otherwise.
|
||||
bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
|
||||
Value value) {
|
||||
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
|
||||
Value value) const {
|
||||
if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) {
|
||||
return false;
|
||||
}
|
||||
bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
|
||||
bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) {
|
||||
return edge.id == dstId && (!value || edge.value == value);
|
||||
});
|
||||
bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
|
||||
bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) {
|
||||
return edge.id == srcId && (!value || edge.value == value);
|
||||
});
|
||||
return hasOutEdge && hasInEdge;
|
||||
@ -477,7 +478,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
|
||||
// Returns true if there is a path in the dependence graph from node 'srcId'
|
||||
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
|
||||
// operations that the edges connected are expected to be from the same block.
|
||||
bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
|
||||
bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
|
||||
unsigned dstId) const {
|
||||
// Worklist state is: <node-id, next-output-edge-index-to-visit>
|
||||
SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
|
||||
worklist.push_back({srcId, 0});
|
||||
@ -490,13 +492,13 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
|
||||
return true;
|
||||
// Pop and continue if node has no out edges, or if all out edges have
|
||||
// already been visited.
|
||||
if (outEdges.count(idAndIndex.first) == 0 ||
|
||||
idAndIndex.second == outEdges[idAndIndex.first].size()) {
|
||||
if (!outEdges.contains(idAndIndex.first) ||
|
||||
idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
|
||||
worklist.pop_back();
|
||||
continue;
|
||||
}
|
||||
// Get graph edge to traverse.
|
||||
Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
|
||||
const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
|
||||
// Increment next output edge index for 'idAndIndex'.
|
||||
++idAndIndex.second;
|
||||
// Add node at 'edge.id' to the worklist. We don't need to consider
|
||||
@ -512,34 +514,34 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
|
||||
// Returns the input edge count for node 'id' and 'memref' from src nodes
|
||||
// which access 'memref' with a store operation.
|
||||
unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
|
||||
Value memref) {
|
||||
Value memref) const {
|
||||
unsigned inEdgeCount = 0;
|
||||
if (inEdges.count(id) > 0)
|
||||
for (auto &inEdge : inEdges[id])
|
||||
if (inEdge.value == memref) {
|
||||
Node *srcNode = getNode(inEdge.id);
|
||||
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
|
||||
if (srcNode->getStoreOpCount(memref) > 0)
|
||||
++inEdgeCount;
|
||||
}
|
||||
for (const Edge &inEdge : inEdges.lookup(id)) {
|
||||
if (inEdge.value == memref) {
|
||||
const Node *srcNode = getNode(inEdge.id);
|
||||
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
|
||||
if (srcNode->getStoreOpCount(memref) > 0)
|
||||
++inEdgeCount;
|
||||
}
|
||||
}
|
||||
return inEdgeCount;
|
||||
}
|
||||
|
||||
// Returns the output edge count for node 'id' and 'memref' (if non-null),
|
||||
// otherwise returns the total output edge count from node 'id'.
|
||||
unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
|
||||
unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
|
||||
Value memref) const {
|
||||
unsigned outEdgeCount = 0;
|
||||
if (outEdges.count(id) > 0)
|
||||
for (auto &outEdge : outEdges[id])
|
||||
if (!memref || outEdge.value == memref)
|
||||
++outEdgeCount;
|
||||
for (const auto &outEdge : outEdges.lookup(id))
|
||||
if (!memref || outEdge.value == memref)
|
||||
++outEdgeCount;
|
||||
return outEdgeCount;
|
||||
}
|
||||
|
||||
/// Return all nodes which define SSA values used in node 'id'.
|
||||
void MemRefDependenceGraph::gatherDefiningNodes(
|
||||
unsigned id, DenseSet<unsigned> &definingNodes) {
|
||||
for (MemRefDependenceGraph::Edge edge : inEdges[id])
|
||||
unsigned id, DenseSet<unsigned> &definingNodes) const {
|
||||
for (const Edge &edge : inEdges.lookup(id))
|
||||
// By definition of edge, if the edge value is a non-memref value,
|
||||
// then the dependence is between a graph node which defines an SSA value
|
||||
// and another graph node which uses the SSA value.
|
||||
@ -552,8 +554,8 @@ void MemRefDependenceGraph::gatherDefiningNodes(
|
||||
// dependences. Returns nullptr if no such insertion point is found.
|
||||
Operation *
|
||||
MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
|
||||
unsigned dstId) {
|
||||
if (outEdges.count(srcId) == 0)
|
||||
unsigned dstId) const {
|
||||
if (!outEdges.contains(srcId))
|
||||
return getNode(dstId)->op;
|
||||
|
||||
// Skip if there is any defining node of 'dstId' that depends on 'srcId'.
|
||||
@ -569,13 +571,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
|
||||
|
||||
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
|
||||
SmallPtrSet<Operation *, 2> srcDepInsts;
|
||||
for (auto &outEdge : outEdges[srcId])
|
||||
for (auto &outEdge : outEdges.lookup(srcId))
|
||||
if (outEdge.id != dstId)
|
||||
srcDepInsts.insert(getNode(outEdge.id)->op);
|
||||
|
||||
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
|
||||
SmallPtrSet<Operation *, 2> dstDepInsts;
|
||||
for (auto &inEdge : inEdges[dstId])
|
||||
for (auto &inEdge : inEdges.lookup(dstId))
|
||||
if (inEdge.id != srcId)
|
||||
dstDepInsts.insert(getNode(inEdge.id)->op);
|
||||
|
||||
@ -635,7 +637,7 @@ void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
|
||||
SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
|
||||
for (auto &inEdge : oldInEdges) {
|
||||
// Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
|
||||
if (privateMemRefs.count(inEdge.value) == 0)
|
||||
if (!privateMemRefs.contains(inEdge.value))
|
||||
addEdge(inEdge.id, dstId, inEdge.value);
|
||||
}
|
||||
}
|
||||
|
@ -78,13 +78,13 @@ struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
|
||||
static bool canRemoveSrcNodeAfterFusion(
|
||||
unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
|
||||
Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
|
||||
MemRefDependenceGraph *mdg) {
|
||||
const MemRefDependenceGraph &mdg) {
|
||||
|
||||
Operation *dstNodeOp = mdg->getNode(dstId)->op;
|
||||
Operation *dstNodeOp = mdg.getNode(dstId)->op;
|
||||
bool hasOutDepsAfterFusion = false;
|
||||
|
||||
for (auto &outEdge : mdg->outEdges[srcId]) {
|
||||
Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
|
||||
for (auto &outEdge : mdg.outEdges.lookup(srcId)) {
|
||||
Operation *depNodeOp = mdg.getNode(outEdge.id)->op;
|
||||
// Skip dependence with dstOp since it will be removed after fusion.
|
||||
if (depNodeOp == dstNodeOp)
|
||||
continue;
|
||||
@ -134,22 +134,23 @@ static bool canRemoveSrcNodeAfterFusion(
|
||||
/// held if the 'mdg' is reused from a previous fusion step or if the node
|
||||
/// creation order changes in the future to support more advance cases.
|
||||
// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
|
||||
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
|
||||
static void getProducerCandidates(unsigned dstId,
|
||||
const MemRefDependenceGraph &mdg,
|
||||
SmallVectorImpl<unsigned> &srcIdCandidates) {
|
||||
// Skip if no input edges along which to fuse.
|
||||
if (mdg->inEdges.count(dstId) == 0)
|
||||
if (mdg.inEdges.count(dstId) == 0)
|
||||
return;
|
||||
|
||||
// Gather memrefs from loads in 'dstId'.
|
||||
auto *dstNode = mdg->getNode(dstId);
|
||||
auto *dstNode = mdg.getNode(dstId);
|
||||
DenseSet<Value> consumedMemrefs;
|
||||
for (Operation *load : dstNode->loads)
|
||||
consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
|
||||
|
||||
// Traverse 'dstId' incoming edges and gather the nodes that contain a store
|
||||
// to one of the consumed memrefs.
|
||||
for (auto &srcEdge : mdg->inEdges[dstId]) {
|
||||
auto *srcNode = mdg->getNode(srcEdge.id);
|
||||
for (const auto &srcEdge : mdg.inEdges.lookup(dstId)) {
|
||||
const auto *srcNode = mdg.getNode(srcEdge.id);
|
||||
// Skip if 'srcNode' is not a loop nest.
|
||||
if (!isa<AffineForOp>(srcNode->op))
|
||||
continue;
|
||||
@ -169,10 +170,10 @@ static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
|
||||
/// producer-consumer dependence between 'srcId' and 'dstId'.
|
||||
static void
|
||||
gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
|
||||
MemRefDependenceGraph *mdg,
|
||||
const MemRefDependenceGraph &mdg,
|
||||
DenseSet<Value> &producerConsumerMemrefs) {
|
||||
auto *dstNode = mdg->getNode(dstId);
|
||||
auto *srcNode = mdg->getNode(srcId);
|
||||
auto *dstNode = mdg.getNode(dstId);
|
||||
auto *srcNode = mdg.getNode(srcId);
|
||||
gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
|
||||
producerConsumerMemrefs);
|
||||
}
|
||||
@ -214,14 +215,14 @@ static bool isEscapingMemref(Value memref, Block *block) {
|
||||
|
||||
/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
|
||||
/// that escape the block or are accessed in a non-affine way.
|
||||
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
|
||||
static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg,
|
||||
DenseSet<Value> &escapingMemRefs) {
|
||||
auto *node = mdg->getNode(id);
|
||||
auto *node = mdg.getNode(id);
|
||||
for (Operation *storeOp : node->stores) {
|
||||
auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
|
||||
if (escapingMemRefs.count(memref))
|
||||
continue;
|
||||
if (isEscapingMemref(memref, &mdg->block))
|
||||
if (isEscapingMemref(memref, &mdg.block))
|
||||
escapingMemRefs.insert(memref);
|
||||
}
|
||||
}
|
||||
@ -826,7 +827,7 @@ public:
|
||||
// in 'srcIdCandidates'.
|
||||
dstNodeChanged = false;
|
||||
SmallVector<unsigned, 16> srcIdCandidates;
|
||||
getProducerCandidates(dstId, mdg, srcIdCandidates);
|
||||
getProducerCandidates(dstId, *mdg, srcIdCandidates);
|
||||
|
||||
for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
|
||||
// Get 'srcNode' from which to attempt fusion into 'dstNode'.
|
||||
@ -841,7 +842,7 @@ public:
|
||||
continue;
|
||||
|
||||
DenseSet<Value> producerConsumerMemrefs;
|
||||
gatherProducerConsumerMemrefs(srcId, dstId, mdg,
|
||||
gatherProducerConsumerMemrefs(srcId, dstId, *mdg,
|
||||
producerConsumerMemrefs);
|
||||
|
||||
// Skip if 'srcNode' out edge count on any memref is greater than
|
||||
@ -856,7 +857,7 @@ public:
|
||||
// block (e.g., memref block arguments, returned memrefs,
|
||||
// memrefs passed to function calls, etc.).
|
||||
DenseSet<Value> srcEscapingMemRefs;
|
||||
gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
|
||||
gatherEscapingMemrefs(srcNode->id, *mdg, srcEscapingMemRefs);
|
||||
|
||||
// Compute an operation list insertion point for the fused loop
|
||||
// nest which preserves dependences.
|
||||
@ -950,7 +951,7 @@ public:
|
||||
// insertion point.
|
||||
bool removeSrcNode = canRemoveSrcNodeAfterFusion(
|
||||
srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
|
||||
mdg);
|
||||
*mdg);
|
||||
|
||||
DenseSet<Value> privateMemrefs;
|
||||
for (Value memref : producerConsumerMemrefs) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user