[mlir][IR] Change block/region walkers to enumerate this block/region (#75020)

This change makes block/region walkers consistent with operation
walkers. An operation walk enumerates the current operation. Similarly,
block/region walks should enumerate the current block/region.

Example:
```
// Current behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 is NOT enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 are NOT enumerated */ });
region1->walk([](Region *region2) { /* region1 is NOT enumerated });

// New behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 IS enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 ARE enumerated */ });
region1->walk([](Region *region2) { /* region1 IS enumerated });
```
This commit is contained in:
Matthias Springer 2023-12-20 14:51:45 +09:00 committed by GitHub
parent 207cbbd710
commit c4457e10fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 193 additions and 93 deletions

View File

@ -260,68 +260,91 @@ public:
SuccessorRange getSuccessors() { return SuccessorRange(this); }
//===--------------------------------------------------------------------===//
// Operation Walkers
// Walkers
//===--------------------------------------------------------------------===//
/// Walk the operations in this block. The callback method is called for each
/// nested region, block or operation, depending on the callback provided.
/// The order in which regions, blocks and operations at the same nesting
/// Walk all nested operations, blocks (including this block) or regions,
/// depending on the type of callback.
///
/// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
/// is determined by 'Iterator'. The walk order for enclosing regions, blocks
/// and operations with respect to their nested ones is specified by 'Order'
/// (post-order by default). A callback on a block or operation is allowed to
/// erase that block or operation if either:
/// is determined by `Iterator`. The walk order for enclosing operations,
/// blocks or regions with respect to their nested ones is specified by
/// `Order` (post-order by default).
///
/// A callback on a operation or block is allowed to erase that operation or
/// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename ArgT = detail::first_argument<FnT>,
typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
return walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
if constexpr (std::is_same<ArgT, Block *>::value &&
Order == WalkOrder::PreOrder) {
// Pre-order walk on blocks: invoke the callback on this block.
if constexpr (std::is_same<RetT, void>::value) {
callback(this);
} else {
RetT result = callback(this);
if (result.wasSkipped())
return WalkResult::advance();
if (result.wasInterrupted())
return WalkResult::interrupt();
}
}
// Walk nested operations, blocks or regions.
if constexpr (std::is_same<RetT, void>::value) {
walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
} else {
if (walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback))
.wasInterrupted())
return WalkResult::interrupt();
}
if constexpr (std::is_same<ArgT, Block *>::value &&
Order == WalkOrder::PostOrder) {
// Post-order walk on blocks: invoke the callback on this block.
return callback(this);
}
if constexpr (!std::is_same<RetT, void>::value)
return WalkResult::advance();
}
/// Walk the operations in the specified [begin, end) range of this block. The
/// callback method is called for each nested region, block or operation,
/// depending on the callback provided. The order in which regions, blocks and
/// operations at the same nesting level are visited (e.g., lexicographical or
/// reverse lexicographical order) is determined by 'Iterator'. The walk order
/// for enclosing regions, blocks and operations with respect to their nested
/// ones is specified by 'Order' (post-order by default). This method is
/// invoked for void-returning callbacks. A callback on a block or operation
/// is allowed to erase that block or operation only if the walk is in
/// post-order. See non-void method for pre-order erasure.
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename RetT = detail::walkResultType<FnT>>
std::enable_if_t<std::is_same<RetT, void>::value, RetT>
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
detail::walk<Order, Iterator>(&op, callback);
}
/// Walk the operations in the specified [begin, end) range of this block. The
/// callback method is called for each nested region, block or operation,
/// depending on the callback provided. The order in which regions, blocks and
/// operations at the same nesting level are visited (e.g., lexicographical or
/// reverse lexicographical order) is determined by 'Iterator'. The walk order
/// for enclosing regions, blocks and operations with respect to their nested
/// ones is specified by 'Order' (post-order by default). This method is
/// invoked for skippable or interruptible callbacks. A callback on a block or
/// operation is allowed to erase that block or operation if either:
/// Walk all nested operations, blocks (excluding this block) or regions,
/// depending on the type of callback, in the specified [begin, end) range of
/// this block.
///
/// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
/// is determined by `Iterator`. The walk order for enclosing operations,
/// blocks or regions with respect to their nested ones is specified by
/// `Order` (post-order by default).
///
/// A callback on a operation or block is allowed to erase that operation or
/// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename RetT = detail::walkResultType<FnT>>
std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
return WalkResult::interrupt();
return WalkResult::advance();
RetT walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) {
if constexpr (std::is_same<RetT, WalkResult>::value) {
if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
return WalkResult::interrupt();
} else {
detail::walk<Order, Iterator>(&op, callback);
}
}
if constexpr (std::is_same<RetT, WalkResult>::value)
return WalkResult::advance();
}
//===--------------------------------------------------------------------===//

View File

@ -260,48 +260,60 @@ public:
void dropAllReferences();
//===--------------------------------------------------------------------===//
// Operation Walkers
// Walkers
//===--------------------------------------------------------------------===//
/// Walk the operations in this region. The callback method is called for each
/// nested region, block or operation, depending on the callback provided.
/// The order in which regions, blocks and operations at the same nesting
/// Walk all nested operations, blocks or regions (including this region),
/// depending on the type of callback.
///
/// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
/// is determined by 'Iterator'. The walk order for enclosing regions, blocks
/// and operations with respect to their nested ones is specified by 'Order'
/// (post-order by default). This method is invoked for void-returning
/// callbacks. A callback on a block or operation is allowed to erase that
/// block or operation only if the walk is in post-order. See non-void method
/// for pre-order erasure. See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename RetT = detail::walkResultType<FnT>>
std::enable_if_t<std::is_same<RetT, void>::value, RetT> walk(FnT &&callback) {
for (auto &block : *this)
block.walk<Order, Iterator>(callback);
}
/// Walk the operations in this region. The callback method is called for each
/// nested region, block or operation, depending on the callback provided.
/// The order in which regions, blocks and operations at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
/// is determined by 'Iterator'. The walk order for enclosing regions, blocks
/// and operations with respect to their nested ones is specified by 'Order'
/// (post-order by default). This method is invoked for skippable or
/// interruptible callbacks. A callback on a block or operation is allowed to
/// erase that block or operation if either:
/// * the walk is in post-order,
/// * or the walk is in pre-order and the walk is skipped after the erasure.
/// is determined by `Iterator`. The walk order for enclosing operations,
/// blocks or regions with respect to their nested ones is specified by
/// `Order` (post-order by default).
///
/// A callback on a operation or block is allowed to erase that operation or
/// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename ArgT = detail::first_argument<FnT>,
typename RetT = detail::walkResultType<FnT>>
std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
walk(FnT &&callback) {
for (auto &block : *this)
if (block.walk<Order, Iterator>(callback).wasInterrupted())
return WalkResult::interrupt();
return WalkResult::advance();
RetT walk(FnT &&callback) {
if constexpr (std::is_same<ArgT, Region *>::value &&
Order == WalkOrder::PreOrder) {
// Pre-order walk on regions: invoke the callback on this region.
if constexpr (std::is_same<RetT, void>::value) {
callback(this);
} else {
RetT result = callback(this);
if (result.wasSkipped())
return WalkResult::advance();
if (result.wasInterrupted())
return WalkResult::interrupt();
}
}
// Walk nested operations, blocks or regions.
for (auto &block : *this) {
if constexpr (std::is_same<RetT, void>::value) {
block.walk<Order, Iterator>(callback);
} else {
if (block.walk<Order, Iterator>(callback).wasInterrupted())
return WalkResult::interrupt();
}
}
if constexpr (std::is_same<ArgT, Region *>::value &&
Order == WalkOrder::PostOrder) {
// Post-order walk on regions: invoke the callback on this block.
return callback(this);
}
if constexpr (!std::is_same<RetT, void>::value)
return WalkResult::advance();
}
//===--------------------------------------------------------------------===//

View File

@ -463,7 +463,7 @@ BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
}
static bool regionOperatesOnMemrefValues(Region &region) {
auto checkBlock = [](Block *block) {
WalkResult result = region.walk([](Block *block) {
if (llvm::any_of(block->getArguments(), isMemref))
return WalkResult::interrupt();
for (Operation &op : *block) {
@ -473,18 +473,8 @@ static bool regionOperatesOnMemrefValues(Region &region) {
return WalkResult::interrupt();
}
return WalkResult::advance();
};
WalkResult result = region.walk(checkBlock);
if (result.wasInterrupted())
return true;
// Note: Block::walk/Region::walk visits only blocks that are nested under
// nested operations, but not direct children.
for (Block &block : region)
if (checkBlock(&block).wasInterrupted())
return true;
return false;
});
return result.wasInterrupted();
}
LogicalResult

View File

@ -17,7 +17,7 @@ func.func @structured_cfg() {
"use2"(%i) : (index) -> ()
}
"use3"(%i) : (index) -> ()
}
} {walk_blocks, walk_regions}
return
}
@ -88,6 +88,26 @@ func.func @structured_cfg() {
// CHECK: Visiting op 'func.func'
// CHECK: Visiting op 'builtin.module'
// CHECK-LABEL: Invoke block pre-order visits on blocks
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
// CHECK-LABEL: Invoke block post-order visits on blocks
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
// CHECK-LABEL: Invoke region pre-order visits on region
// CHECK: Visiting region 0 from operation 'scf.for'
// CHECK: Visiting region 0 from operation 'scf.if'
// CHECK: Visiting region 1 from operation 'scf.if'
// CHECK-LABEL: Invoke region post-order visits on region
// CHECK: Visiting region 0 from operation 'scf.if'
// CHECK: Visiting region 1 from operation 'scf.if'
// CHECK: Visiting region 0 from operation 'scf.for'
// CHECK-LABEL: Op pre-order erasures
// CHECK: Erasing op 'scf.for'
// CHECK: Erasing op 'func.return'

View File

@ -204,6 +204,60 @@ static void testNoSkipErasureCallbacks(Operation *op) {
cloned->erase();
}
/// Invoke region/block walks on regions/blocks.
static void testBlockAndRegionWalkers(Operation *op) {
auto blockPure = [](Block *block) {
llvm::outs() << "Visiting ";
printBlock(block);
llvm::outs() << "\n";
};
auto regionPure = [](Region *region) {
llvm::outs() << "Visiting ";
printRegion(region);
llvm::outs() << "\n";
};
llvm::outs() << "Invoke block pre-order visits on blocks\n";
op->walk([&](Operation *op) {
if (!op->hasAttr("walk_blocks"))
return;
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
block.walk<WalkOrder::PreOrder>(blockPure);
}
}
});
llvm::outs() << "Invoke block post-order visits on blocks\n";
op->walk([&](Operation *op) {
if (!op->hasAttr("walk_blocks"))
return;
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
block.walk<WalkOrder::PostOrder>(blockPure);
}
}
});
llvm::outs() << "Invoke region pre-order visits on region\n";
op->walk([&](Operation *op) {
if (!op->hasAttr("walk_regions"))
return;
for (Region &region : op->getRegions()) {
region.walk<WalkOrder::PreOrder>(regionPure);
}
});
llvm::outs() << "Invoke region post-order visits on region\n";
op->walk([&](Operation *op) {
if (!op->hasAttr("walk_regions"))
return;
for (Region &region : op->getRegions()) {
region.walk<WalkOrder::PostOrder>(regionPure);
}
});
}
namespace {
/// This pass exercises the different configurations of the IR visitors.
struct TestIRVisitorsPass
@ -215,6 +269,7 @@ struct TestIRVisitorsPass
void runOnOperation() override {
Operation *op = getOperation();
testPureCallbacks(op);
testBlockAndRegionWalkers(op);
testSkipErasureCallbacks(op);
testNoSkipErasureCallbacks(op);
}