[mlir][transforms] Process RegionBranchOp with empty region (#123895)

This PR adds process for RegionBranchOp with empty region, such as
'else' region of `scf.if`. Fixes #123246.
This commit is contained in:
Longsheng Mou 2025-02-11 14:43:15 +08:00 committed by GitHub
parent 0d8d354b0c
commit be354cf381
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 2 deletions

View File

@ -375,6 +375,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
SmallVector<Value> arguments(region.front().getArguments());
BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
liveArgs[&region] = regionLiveArgs;
@ -420,6 +422,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
auto markNonForwardedReturnValues =
[&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
Operation *terminator = region.front().getTerminator();
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
@ -499,6 +503,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Recompute `resultsToKeep` and `argsToKeep` based on
// `terminatorOperandsToKeep`.
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
Operation *terminator = region.front().getTerminator();
for (const RegionSuccessor &successor : getSuccessors(&region)) {
Region *successorRegion = successor.getSuccessor();
@ -547,6 +553,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Update the terminator operands that need to be kept.
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
updateOperandsOrTerminatorOperandsToKeep(
terminatorOperandsToKeep[region.back().getTerminator()],
resultsToKeep, argsToKeep, &region);
@ -611,8 +619,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Do (2.a) and (2.b).
for (Region &region : regionBranchOp->getRegions()) {
assert(!region.empty() && "expected a non-empty region in an op "
"implementing `RegionBranchOpInterface`");
if (region.empty())
continue;
BitVector argsToRemove = argsToKeep[&region].flip();
cl.blocks.push_back({&region.front(), argsToRemove});
collectNonLiveValues(nonLiveSet, region.front().getArguments(),
@ -621,6 +629,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Do (2.c).
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
Operation *terminator = region.front().getTerminator();
cl.operands.push_back(
{terminator, terminatorOperandsToKeep[terminator].flip()});

View File

@ -408,6 +408,22 @@ func.func @main(%arg3 : i32, %arg4 : i1) {
// -----
// The scf.if operation represents an if-then-else construct for conditionally
// executing two regions of code. The 'the' region has exactly 1 block, and
// the 'else' region may have 0 or 1 block. This case is to ensure 'else' region
// with 0 block not crash.
// CHECK-LABEL: func.func @clean_region_branch_op_with_empty_region
func.func @clean_region_branch_op_with_empty_region(%arg0: i1, %arg1: memref<f32>) {
%cst = arith.constant 1.000000e+00 : f32
scf.if %arg0 {
memref.store %cst, %arg1[] : memref<f32>
}
return
}
// -----
#map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
func.func @kernel(%arg0: memref<18xf32>) {
%c1 = arith.constant 1 : index