[MLIR][Affine] Make affine fusion MDG API const correct (#125994)

Make affine fusion MDG API const correct. NFC changes otherwise.
This commit is contained in:
Uday Bondhugula 2025-02-11 05:28:15 +05:30 committed by GitHub
parent 998f2422a5
commit 001ba42fe0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 84 additions and 66 deletions

View File

@ -139,9 +139,11 @@ public:
// Map from node id to Node.
DenseMap<unsigned, Node> nodes;
// Map from node id to list of input edges.
// Map from node id to list of input edges. The absence of an entry for a key
// is also equivalent to the absence of any edges.
DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
// Map from node id to list of output edges.
// Map from node id to list of output edges. The absence of an entry for a
// node is also equivalent to the absence of any edges.
DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
// Map from memref to a count on the dependence edges associated with that
// memref.
@ -156,10 +158,21 @@ public:
bool init();
// Returns the graph node for 'id'.
Node *getNode(unsigned id);
const Node *getNode(unsigned id) const;
Node *getNode(unsigned id) {
return const_cast<Node *>(
static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
}
// Returns true if the graph has node with ID `id`.
bool hasNode(unsigned id) const { return nodes.contains(id); }
// Returns the graph node for 'forOp'.
Node *getForOpNode(AffineForOp forOp);
const Node *getForOpNode(AffineForOp forOp) const;
Node *getForOpNode(AffineForOp forOp) {
return const_cast<Node *>(
static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
}
// Adds a node with 'op' to the graph and returns its unique identifier.
unsigned addNode(Operation *op);
@ -169,12 +182,12 @@ public:
// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the block. Returns false otherwise.
bool writesToLiveInOrEscapingMemrefs(unsigned id);
bool writesToLiveInOrEscapingMemrefs(unsigned id) const;
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;
// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
void addEdge(unsigned srcId, unsigned dstId, Value value);
@ -185,23 +198,25 @@ public:
// Returns true if there is a path in the dependence graph from node 'srcId'
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
// operations that the edges connected are expected to be from the same block.
bool hasDependencePath(unsigned srcId, unsigned dstId);
bool hasDependencePath(unsigned srcId, unsigned dstId) const;
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;
// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;
/// Return all nodes which define SSA values used in node 'id'.
void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
void gatherDefiningNodes(unsigned id,
DenseSet<unsigned> &definingNodes) const;
// Computes and returns an insertion point operation, before which the
// the fused <srcId, dstId> loop nest can be inserted while preserving
// dependences. Returns nullptr if no such insertion point is found.
Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
unsigned dstId) const;
// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
// taking into account that:

View File

@ -188,8 +188,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
/// Add `op` to MDG creating a new node and adding its memory accesses (affine
/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
static Node *
addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
auto &nodes = mdg.nodes;
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
@ -359,14 +360,14 @@ bool MemRefDependenceGraph::init() {
}
// Returns the graph node for 'id'.
Node *MemRefDependenceGraph::getNode(unsigned id) {
const Node *MemRefDependenceGraph::getNode(unsigned id) const {
auto it = nodes.find(id);
assert(it != nodes.end());
return &it->second;
}
// Returns the graph node for 'forOp'.
Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
for (auto &idAndNode : nodes)
if (idAndNode.second.op == forOp)
return &idAndNode.second;
@ -390,7 +391,7 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
}
}
// Remove each edge in 'outEdges[id]'.
if (outEdges.count(id) > 0) {
if (outEdges.contains(id)) {
SmallVector<Edge, 2> oldOutEdges = outEdges[id];
for (auto &outEdge : oldOutEdges) {
removeEdge(id, outEdge.id, outEdge.value);
@ -404,8 +405,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the block. Returns false otherwise.
bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
Node *node = getNode(id);
bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
const Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
auto *op = memref.getDefiningOp();
@ -425,14 +426,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
Value value) {
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
Value value) const {
if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) {
return false;
}
bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) {
return edge.id == dstId && (!value || edge.value == value);
});
bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) {
return edge.id == srcId && (!value || edge.value == value);
});
return hasOutEdge && hasInEdge;
@ -477,7 +478,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
// Returns true if there is a path in the dependence graph from node 'srcId'
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
// operations that the edges connected are expected to be from the same block.
bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
unsigned dstId) const {
// Worklist state is: <node-id, next-output-edge-index-to-visit>
SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
worklist.push_back({srcId, 0});
@ -490,13 +492,13 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
return true;
// Pop and continue if node has no out edges, or if all out edges have
// already been visited.
if (outEdges.count(idAndIndex.first) == 0 ||
idAndIndex.second == outEdges[idAndIndex.first].size()) {
if (!outEdges.contains(idAndIndex.first) ||
idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
worklist.pop_back();
continue;
}
// Get graph edge to traverse.
Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
// Increment next output edge index for 'idAndIndex'.
++idAndIndex.second;
// Add node at 'edge.id' to the worklist. We don't need to consider
@ -512,34 +514,34 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
Value memref) {
Value memref) const {
unsigned inEdgeCount = 0;
if (inEdges.count(id) > 0)
for (auto &inEdge : inEdges[id])
if (inEdge.value == memref) {
Node *srcNode = getNode(inEdge.id);
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
if (srcNode->getStoreOpCount(memref) > 0)
++inEdgeCount;
}
for (const Edge &inEdge : inEdges.lookup(id)) {
if (inEdge.value == memref) {
const Node *srcNode = getNode(inEdge.id);
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
if (srcNode->getStoreOpCount(memref) > 0)
++inEdgeCount;
}
}
return inEdgeCount;
}
// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
Value memref) const {
unsigned outEdgeCount = 0;
if (outEdges.count(id) > 0)
for (auto &outEdge : outEdges[id])
if (!memref || outEdge.value == memref)
++outEdgeCount;
for (const auto &outEdge : outEdges.lookup(id))
if (!memref || outEdge.value == memref)
++outEdgeCount;
return outEdgeCount;
}
/// Return all nodes which define SSA values used in node 'id'.
void MemRefDependenceGraph::gatherDefiningNodes(
unsigned id, DenseSet<unsigned> &definingNodes) {
for (MemRefDependenceGraph::Edge edge : inEdges[id])
unsigned id, DenseSet<unsigned> &definingNodes) const {
for (const Edge &edge : inEdges.lookup(id))
// By definition of edge, if the edge value is a non-memref value,
// then the dependence is between a graph node which defines an SSA value
// and another graph node which uses the SSA value.
@ -552,8 +554,8 @@ void MemRefDependenceGraph::gatherDefiningNodes(
// dependences. Returns nullptr if no such insertion point is found.
Operation *
MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
unsigned dstId) {
if (outEdges.count(srcId) == 0)
unsigned dstId) const {
if (!outEdges.contains(srcId))
return getNode(dstId)->op;
// Skip if there is any defining node of 'dstId' that depends on 'srcId'.
@ -569,13 +571,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
SmallPtrSet<Operation *, 2> srcDepInsts;
for (auto &outEdge : outEdges[srcId])
for (auto &outEdge : outEdges.lookup(srcId))
if (outEdge.id != dstId)
srcDepInsts.insert(getNode(outEdge.id)->op);
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
SmallPtrSet<Operation *, 2> dstDepInsts;
for (auto &inEdge : inEdges[dstId])
for (auto &inEdge : inEdges.lookup(dstId))
if (inEdge.id != srcId)
dstDepInsts.insert(getNode(inEdge.id)->op);
@ -635,7 +637,7 @@ void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
for (auto &inEdge : oldInEdges) {
// Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
if (privateMemRefs.count(inEdge.value) == 0)
if (!privateMemRefs.contains(inEdge.value))
addEdge(inEdge.id, dstId, inEdge.value);
}
}

View File

@ -78,13 +78,13 @@ struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
static bool canRemoveSrcNodeAfterFusion(
unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
MemRefDependenceGraph *mdg) {
const MemRefDependenceGraph &mdg) {
Operation *dstNodeOp = mdg->getNode(dstId)->op;
Operation *dstNodeOp = mdg.getNode(dstId)->op;
bool hasOutDepsAfterFusion = false;
for (auto &outEdge : mdg->outEdges[srcId]) {
Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
for (auto &outEdge : mdg.outEdges.lookup(srcId)) {
Operation *depNodeOp = mdg.getNode(outEdge.id)->op;
// Skip dependence with dstOp since it will be removed after fusion.
if (depNodeOp == dstNodeOp)
continue;
@ -134,22 +134,23 @@ static bool canRemoveSrcNodeAfterFusion(
/// held if the 'mdg' is reused from a previous fusion step or if the node
/// creation order changes in the future to support more advance cases.
// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
static void getProducerCandidates(unsigned dstId,
const MemRefDependenceGraph &mdg,
SmallVectorImpl<unsigned> &srcIdCandidates) {
// Skip if no input edges along which to fuse.
if (mdg->inEdges.count(dstId) == 0)
if (mdg.inEdges.count(dstId) == 0)
return;
// Gather memrefs from loads in 'dstId'.
auto *dstNode = mdg->getNode(dstId);
auto *dstNode = mdg.getNode(dstId);
DenseSet<Value> consumedMemrefs;
for (Operation *load : dstNode->loads)
consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
// Traverse 'dstId' incoming edges and gather the nodes that contain a store
// to one of the consumed memrefs.
for (auto &srcEdge : mdg->inEdges[dstId]) {
auto *srcNode = mdg->getNode(srcEdge.id);
for (const auto &srcEdge : mdg.inEdges.lookup(dstId)) {
const auto *srcNode = mdg.getNode(srcEdge.id);
// Skip if 'srcNode' is not a loop nest.
if (!isa<AffineForOp>(srcNode->op))
continue;
@ -169,10 +170,10 @@ static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
/// producer-consumer dependence between 'srcId' and 'dstId'.
static void
gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
MemRefDependenceGraph *mdg,
const MemRefDependenceGraph &mdg,
DenseSet<Value> &producerConsumerMemrefs) {
auto *dstNode = mdg->getNode(dstId);
auto *srcNode = mdg->getNode(srcId);
auto *dstNode = mdg.getNode(dstId);
auto *srcNode = mdg.getNode(srcId);
gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
producerConsumerMemrefs);
}
@ -214,14 +215,14 @@ static bool isEscapingMemref(Value memref, Block *block) {
/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
/// that escape the block or are accessed in a non-affine way.
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg,
DenseSet<Value> &escapingMemRefs) {
auto *node = mdg->getNode(id);
auto *node = mdg.getNode(id);
for (Operation *storeOp : node->stores) {
auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
if (escapingMemRefs.count(memref))
continue;
if (isEscapingMemref(memref, &mdg->block))
if (isEscapingMemref(memref, &mdg.block))
escapingMemRefs.insert(memref);
}
}
@ -826,7 +827,7 @@ public:
// in 'srcIdCandidates'.
dstNodeChanged = false;
SmallVector<unsigned, 16> srcIdCandidates;
getProducerCandidates(dstId, mdg, srcIdCandidates);
getProducerCandidates(dstId, *mdg, srcIdCandidates);
for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
// Get 'srcNode' from which to attempt fusion into 'dstNode'.
@ -841,7 +842,7 @@ public:
continue;
DenseSet<Value> producerConsumerMemrefs;
gatherProducerConsumerMemrefs(srcId, dstId, mdg,
gatherProducerConsumerMemrefs(srcId, dstId, *mdg,
producerConsumerMemrefs);
// Skip if 'srcNode' out edge count on any memref is greater than
@ -856,7 +857,7 @@ public:
// block (e.g., memref block arguments, returned memrefs,
// memrefs passed to function calls, etc.).
DenseSet<Value> srcEscapingMemRefs;
gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
gatherEscapingMemrefs(srcNode->id, *mdg, srcEscapingMemRefs);
// Compute an operation list insertion point for the fused loop
// nest which preserves dependences.
@ -950,7 +951,7 @@ public:
// insertion point.
bool removeSrcNode = canRemoveSrcNodeAfterFusion(
srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
mdg);
*mdg);
DenseSet<Value> privateMemrefs;
for (Value memref : producerConsumerMemrefs) {