Reland "[mlir] Use a type for representing branch points in RegionBranchOpInterface"

This reverts commit b26bb30b467b996c9786e3bd426c07684d84d406.
This commit is contained in:
Markus Böck 2023-08-30 09:22:34 +02:00
parent 82e851a407
commit 4dd744ac9c
23 changed files with 258 additions and 241 deletions

View File

@ -3467,10 +3467,10 @@ void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control.
void fir::IfOp::getSuccessorRegions(
std::optional<unsigned> index,
mlir::RegionBranchPoint point,
llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (index) {
if (!point.isParent()) {
regions.push_back(mlir::RegionSuccessor(getResults()));
return;
}

View File

@ -353,8 +353,8 @@ protected:
/// any effect on the lattice that isn't already expressed by the interface
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
@ -382,7 +382,7 @@ private:
/// of the branch operation itself.
void visitRegionBranchOperation(ProgramPoint point,
RegionBranchOpInterface branch,
std::optional<unsigned> regionNo,
RegionBranchPoint branchPoint,
AbstractDenseLattice *before);
/// Visit an operation for which the data flow is described by the
@ -472,9 +472,8 @@ public:
/// nullptr`. The behavior can be further refined for specific pairs of "from"
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
std::optional<unsigned> regionTo, const LatticeT &after,
LatticeT *before) {
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
@ -508,8 +507,8 @@ protected:
static_cast<LatticeT *>(before));
}
void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, std::optional<unsigned> regionForm,
std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),

View File

@ -243,7 +243,7 @@ private:
/// regions or the parent operation itself, and set either the argument or
/// parent result lattices.
void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
std::optional<unsigned> successorIndex,
RegionBranchPoint successor,
ArrayRef<AbstractSparseLattice *> lattices);
};

View File

@ -190,6 +190,68 @@ private:
ValueRange inputs;
};
/// This class represents a point being branched from in the methods of the
/// `RegionBranchOpInterface`.
/// One can branch from one of two kinds of places:
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
/// * A region within the parent operation.
class RegionBranchPoint {
public:
/// Returns an instance of `RegionBranchPoint` representing the parent
/// operation.
static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
/// Creates a `RegionBranchPoint` that branches from the given region.
/// The pointer must not be null.
RegionBranchPoint(Region *region) : maybeRegion(region) {
assert(region && "Region must not be null");
}
RegionBranchPoint(Region &region) : RegionBranchPoint(&region) {}
/// Explicitly stops users from constructing with `nullptr`.
RegionBranchPoint(std::nullptr_t) = delete;
/// Constructs a `RegionBranchPoint` from the the target of a
/// `RegionSuccessor` instance.
RegionBranchPoint(RegionSuccessor successor) {
if (successor.isParent())
maybeRegion = nullptr;
else
maybeRegion = successor.getSuccessor();
}
/// Assigns a region being branched from.
RegionBranchPoint &operator=(Region &region) {
maybeRegion = &region;
return *this;
}
/// Returns true if branching from the parent op.
bool isParent() const { return maybeRegion == nullptr; }
/// Returns the region if branching from a region.
/// A null pointer otherwise.
Region *getRegionOrNull() const { return maybeRegion; }
/// Returns true if the two branch points are equal.
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return lhs.maybeRegion == rhs.maybeRegion;
}
private:
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
/// Internal encoding. Uses nullptr for representing branching from the parent
/// op and the region being branched from otherwise.
Region *maybeRegion;
};
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return !(lhs == rhs);
}
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.

View File

@ -133,14 +133,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let methods = [
InterfaceMethod<[{
Returns the operands of this operation used as the entry arguments when
entering the region at `index`, which was specified as a successor of
branching from `point`, which was specified as a successor of
this operation by `getEntrySuccessorRegions`, or the operands forwarded
to the operation's results when it branches back to itself. These operands
should correspond 1-1 with the successor inputs specified in
`getEntrySuccessorRegions`.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
(ins "::std::optional<unsigned>":$index), [{}],
(ins "::mlir::RegionBranchPoint":$point), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
@ -162,22 +162,20 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
[{}], [{
$_op.getSuccessorRegions(std::nullopt, regions);
$_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
}]
>,
InterfaceMethod<[{
Returns the viable successors of a region at `index`, or the possible
successors when branching from the parent op if `index` is None. These
are the regions that may be selected during the flow of control. The
parent operation, i.e. a null `index`, may specify itself as successor,
which indicates that the control flow may not enter any region at all.
This method allows for describing which regions may be executed when
entering an operation, and which regions are executed after having
executed another region of the parent op. The successor region must be
non-empty.
Returns the viable successors of `point`. These are the regions that may
be selected during the flow of control. The parent operation, may
specify itself as successor, which indicates that the control flow may
not enter any region at all. This method allows for describing which
regions may be executed when entering an operation, and which regions
are executed after having executed another region of the parent op. The
successor region must be non-empty.
}],
"void", "getSuccessorRegions",
(ins "::std::optional<unsigned>":$index,
(ins "::mlir::RegionBranchPoint":$point,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
InterfaceMethod<[{
@ -245,12 +243,10 @@ def RegionBranchTerminatorOpInterface :
let methods = [
InterfaceMethod<[{
Returns a mutable range of operands that are semantically "returned" by
passing them to the region successor given by `index`. If `index` is
None, this function returns the operands that are passed as a result to
the parent operation.
passing them to the region successor given by `point`.
}],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
(ins "::std::optional<unsigned>":$index)
(ins "::mlir::RegionBranchPoint":$point)
>,
InterfaceMethod<[{
Returns the viable region successors that are branched to after this
@ -269,8 +265,7 @@ def RegionBranchTerminatorOpInterface :
[{
::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
regions);
.getSuccessorRegions(op->getParentRegion(), regions);
}]
>,
];
@ -290,8 +285,8 @@ def RegionBranchTerminatorOpInterface :
// them to the region successor given by `index`. If `index` is None, this
// function returns the operands that are passed as a result to the parent
// operation.
::mlir::OperandRange getSuccessorOperands(std::optional<unsigned> index) {
return getMutableSuccessorOperands(index);
::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) {
return getMutableSuccessorOperands(point);
}
}];
}
@ -309,7 +304,7 @@ def ReturnLike : TraitList<[
/*extraOpDeclaration=*/"",
/*extraOpDefinition=*/[{
::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
::std::optional<unsigned> index) {
::mlir::RegionBranchPoint point) {
return ::mlir::MutableOperandRange(*this);
}
}]

View File

@ -45,9 +45,9 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
// this region predecessor that correspond to the input values of `region`. If
// an index could not be found, std::nullopt is returned instead.
auto getOperandIndexIfPred =
[&](std::optional<unsigned> predIndex) -> std::optional<unsigned> {
[&](RegionBranchPoint pred) -> std::optional<unsigned> {
SmallVector<RegionSuccessor, 2> successors;
branch.getSuccessorRegions(predIndex, successors);
branch.getSuccessorRegions(pred, successors);
for (RegionSuccessor &successor : successors) {
if (successor.getSuccessor() != region)
continue;
@ -75,28 +75,27 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
};
// Check branches from the parent operation.
std::optional<unsigned> regionIndex;
if (region) {
// Determine the actual region number from the passed region.
regionIndex = region->getRegionNumber();
}
auto branchPoint = RegionBranchPoint::parent();
if (region)
branchPoint = region;
if (std::optional<unsigned> operandIndex =
getOperandIndexIfPred(/*predIndex=*/std::nullopt)) {
getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) {
collectUnderlyingAddressValues(
branch.getEntrySuccessorOperands(regionIndex)[*operandIndex], maxDepth,
branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
visited, output);
}
// Check branches from each child region.
Operation *op = branch.getOperation();
for (int i = 0, e = op->getNumRegions(); i != e; ++i) {
if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(i)) {
for (Block &block : op->getRegion(i)) {
for (Region &region : op->getRegions()) {
if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
for (Block &block : region) {
// Try to determine possible region-branch successor operands for the
// current region.
if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator())) {
collectUnderlyingAddressValues(
term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth,
term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
visited, output);
} else if (block.getNumSuccessors()) {
// Otherwise, if this terminator may exit the region we can't make

View File

@ -312,7 +312,8 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
// Special cases where control flow may dictate data flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
return visitRegionBranchOperation(op, branch, std::nullopt, before);
return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
before);
if (auto call = dyn_cast<CallOpInterface>(op))
return visitCallOperation(call, before);
@ -368,8 +369,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// If this block is exiting from an operation with region-based control
// flow, propagate the lattice back along the control flow edge.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
visitRegionBranchOperation(block, branch,
block->getParent()->getRegionNumber(), before);
visitRegionBranchOperation(block, branch, block->getParent(), before);
return;
}
@ -396,13 +396,13 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
ProgramPoint point, RegionBranchOpInterface branch,
std::optional<unsigned> regionNo, AbstractDenseLattice *before) {
RegionBranchPoint branchPoint, AbstractDenseLattice *before) {
// The successors of the operation may be either the first operation of the
// entry block of each possible successor region, or the next operation when
// the branch is a successor of itself.
SmallVector<RegionSuccessor> successors;
branch.getSuccessorRegions(regionNo, successors);
branch.getSuccessorRegions(branchPoint, successors);
for (const RegionSuccessor &successor : successors) {
const AbstractDenseLattice *after;
if (successor.isParent() || successor.getSuccessor()->empty()) {
@ -423,10 +423,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
else
after = getLatticeFor(point, &successorBlock->front());
}
std::optional<unsigned> successorNo =
successor.isParent() ? std::optional<unsigned>()
: successor.getSuccessor()->getRegionNumber();
visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after,
visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after,
before);
}
}

View File

@ -99,7 +99,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// The results of a region branch operation are determined by control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
return visitRegionSuccessors({branch}, branch,
/*successorIndex=*/std::nullopt,
/*successor=*/RegionBranchPoint::parent(),
resultLattices);
}
@ -167,8 +167,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
return visitRegionSuccessors(
block, branch, block->getParent()->getRegionNumber(), argLattices);
return visitRegionSuccessors(block, branch, block->getParent(),
argLattices);
}
// Otherwise, we can't reason about the data-flow.
@ -212,8 +212,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint point, RegionBranchOpInterface branch,
std::optional<unsigned> successorIndex,
ArrayRef<AbstractSparseLattice *> lattices) {
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
@ -224,11 +223,11 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
// Check if the predecessor is the parent op.
if (op == branch) {
operands = branch.getEntrySuccessorOperands(successorIndex);
operands = branch.getEntrySuccessorOperands(successor);
// Otherwise, try to deduce the operands from a region return-like op.
} else if (auto regionTerminator =
dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
operands = regionTerminator.getSuccessorOperands(successorIndex);
operands = regionTerminator.getSuccessorOperands(successor);
}
if (!operands) {
@ -501,10 +500,7 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
BitVector unaccounted(op->getNumOperands(), true);
for (RegionSuccessor &successor : successors) {
Region *region = successor.getSuccessor();
OperandRange operands =
region ? branch.getEntrySuccessorOperands(region->getRegionNumber())
: branch.getEntrySuccessorOperands({});
OperandRange operands = branch.getEntrySuccessorOperands(successor);
MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
ValueRange inputs = successor.getSuccessorInputs();
for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
@ -538,9 +534,7 @@ void AbstractSparseBackwardDataFlowAnalysis::
for (const RegionSuccessor &successor : successors) {
ValueRange inputs = successor.getSuccessorInputs();
Region *region = successor.getSuccessor();
OperandRange operands = terminator.getSuccessorOperands(
region ? region->getRegionNumber() : std::optional<unsigned>{});
OperandRange operands = terminator.getSuccessorOperands(successor);
MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
meet(getLatticeElement(opOperand.get()),

View File

@ -2379,9 +2379,9 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable. AffineForOp only has one region, so zero is the only
/// valid value for `index`.
OperandRange
AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert((!index || *index == 0) && "invalid region index");
OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert((point.isParent() || point == getLoopBody()) &&
"invalid region point");
// The initial operands map to the loop arguments after the induction
// variable or are forwarded to the results when the trip count is zero.
@ -2394,14 +2394,15 @@ AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void AffineForOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
assert((!index.has_value() || index.value() == 0) && "expected loop region");
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
assert((point.isParent() || point == getLoopBody()) &&
"expected loop region");
// The loop may typically branch back to its body or to the parent operation.
// If the predecessor is the parent op and the trip count is known to be at
// least one, branch into the body using the iterator arguments. And in cases
// we know the trip count is zero, it can only branch back to its parent.
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
if (!index.has_value() && tripCount.has_value()) {
if (point.isParent() && tripCount.has_value()) {
if (tripCount.value() > 0) {
regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
return;
@ -2414,7 +2415,7 @@ void AffineForOp::getSuccessorRegions(
// From the loop body, if the trip count is one, we can only branch back to
// the parent.
if (index && tripCount && *tripCount == 1) {
if (!point.isParent() && tripCount && *tripCount == 1) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
@ -2859,10 +2860,10 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
void AffineIfOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// If the predecessor is an AffineIfOp, then branching into both `then` and
// `else` region is valid.
if (!index.has_value()) {
if (point.isParent()) {
regions.reserve(2);
regions.push_back(
RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));

View File

@ -38,9 +38,8 @@ void AsyncDialect::initialize() {
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
OperandRange
ExecuteOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert(index && *index == 0 && "invalid region index");
OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBodyRegion() && "invalid region index");
return getBodyOperands();
}
@ -53,11 +52,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
}
void ExecuteOp::getSuccessorRegions(std::optional<unsigned> index,
void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `body` region branch back to the parent operation.
if (index) {
assert(*index == 0 && "invalid region index");
if (point == getBodyRegion()) {
regions.push_back(RegionSuccessor(getBodyResults()));
return;
}

View File

@ -372,7 +372,7 @@ private:
// parent operation. In this case, we have to introduce an additional clone
// for buffer that is passed to the argument.
SmallVector<RegionSuccessor, 2> successorRegions;
regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
successorRegions);
auto *it =
llvm::find_if(successorRegions, [&](RegionSuccessor &successorRegion) {
@ -383,8 +383,7 @@ private:
// Determine the actual operand to introduce a clone for and rewire the
// operand to point to the clone instead.
auto operands =
regionInterface.getEntrySuccessorOperands(argRegion->getRegionNumber());
auto operands = regionInterface.getEntrySuccessorOperands(argRegion);
size_t operandIndex =
llvm::find(it->getSuccessorInputs(), blockArg).getIndex() +
operands.getBeginOperandIndex();
@ -432,8 +431,7 @@ private:
// Query the regionInterface to get all successor regions of the current
// one.
SmallVector<RegionSuccessor, 2> successorRegions;
regionInterface.getSuccessorRegions(region.getRegionNumber(),
successorRegions);
regionInterface.getSuccessorRegions(region, successorRegions);
// Try to find a matching region successor.
RegionSuccessor *regionSuccessor =
llvm::find_if(successorRegions, regionPredicate);
@ -445,10 +443,6 @@ private:
llvm::find(regionSuccessor->getSuccessorInputs(), argValue)
.getIndex();
std::optional<unsigned> successorRegionNumber;
if (Region *successorRegion = regionSuccessor->getSuccessor())
successorRegionNumber = successorRegion->getRegionNumber();
// Iterate over all immediate terminator operations to introduce
// new buffer allocations. Thereby, the appropriate terminator operand
// will be adjusted to point to the newly allocated buffer instead.
@ -456,8 +450,7 @@ private:
&region, [&](RegionBranchTerminatorOpInterface terminator) {
// Get the actual mutable operands for this terminator op.
auto terminatorOperands =
terminator.getMutableSuccessorOperands(
successorRegionNumber);
terminator.getMutableSuccessorOperands(*regionSuccessor);
// Extract the source value from the current terminator.
// This conversion needs to exist on a separate line due to a
// bug in GCC conversion analysis.

View File

@ -123,7 +123,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
return true;
// Recurses into all region successors.
SmallVector<RegionSuccessor, 2> successors;
regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
regionInterface.getSuccessorRegions(current, successors);
for (RegionSuccessor &regionEntry : successors)
if (recurse(regionEntry.getSuccessor()))
return true;
@ -132,7 +132,8 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
// Start with all entry regions and test whether they induce a loop.
SmallVector<RegionSuccessor, 2> successorRegions;
regionInterface.getSuccessorRegions(/*index=*/std::nullopt, successorRegions);
regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
successorRegions);
for (RegionSuccessor &regionEntry : successorRegions) {
if (recurse(regionEntry.getSuccessor()))
return true;

View File

@ -100,16 +100,13 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Query the RegionBranchOpInterface to find potential successor regions.
// Extract all entry regions and wire all initial entry successor inputs.
SmallVector<RegionSuccessor, 2> entrySuccessors;
regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
entrySuccessors);
for (RegionSuccessor &entrySuccessor : entrySuccessors) {
// Wire the entry region's successor arguments with the initial
// successor inputs.
registerDependencies(
regionInterface.getEntrySuccessorOperands(
entrySuccessor.isParent()
? std::optional<unsigned>()
: entrySuccessor.getSuccessor()->getRegionNumber()),
regionInterface.getEntrySuccessorOperands(entrySuccessor),
entrySuccessor.getSuccessorInputs());
}
@ -118,20 +115,15 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Iterate over all successor region entries that are reachable from the
// current region.
SmallVector<RegionSuccessor, 2> successorRegions;
regionInterface.getSuccessorRegions(region.getRegionNumber(),
successorRegions);
regionInterface.getSuccessorRegions(region, successorRegions);
for (RegionSuccessor &successorRegion : successorRegions) {
// Determine the current region index (if any).
std::optional<unsigned> regionIndex;
Region *regionSuccessor = successorRegion.getSuccessor();
if (regionSuccessor)
regionIndex = regionSuccessor->getRegionNumber();
// Iterate over all immediate terminator operations and wire the
// successor inputs with the successor operands of each terminator.
for (Block &block : region)
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator()))
registerDependencies(terminator.getSuccessorOperands(regionIndex),
registerDependencies(
terminator.getSuccessorOperands(successorRegion),
successorRegion.getSuccessorInputs());
}
}

View File

@ -455,8 +455,8 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
}
void AllocaScopeOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
if (index) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}

View File

@ -266,9 +266,9 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ExecuteRegionOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// If the predecessor is the ExecuteRegionOp, branch into the body.
if (!index) {
if (point.isParent()) {
regions.push_back(RegionSuccessor(&getRegion()));
return;
}
@ -282,8 +282,8 @@ void ExecuteRegionOp::getSuccessorRegions(
//===----------------------------------------------------------------------===//
MutableOperandRange
ConditionOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
assert((!index || index == getParentOp().getAfter().getRegionNumber()) &&
ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
assert((point.isParent() || point == getParentOp().getAfter()) &&
"condition op can only exit the loop or branch to the after"
"region");
// Pass all operands except the condition to the successor region.
@ -553,7 +553,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
/// Return operands used when entering the region at 'index'. These operands
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable.
OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getInitArgs();
}
@ -562,7 +562,7 @@ OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ForOp::getSuccessorRegions(std::optional<unsigned> index,
void ForOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
@ -1731,7 +1731,7 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ForallOp::getSuccessorRegions(std::optional<unsigned> index,
void ForallOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
@ -2011,10 +2011,10 @@ void IfOp::print(OpAsmPrinter &p) {
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void IfOp::getSuccessorRegions(std::optional<unsigned> index,
void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (index) {
if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
@ -3042,7 +3042,7 @@ void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ParallelOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
// body.
@ -3169,8 +3169,8 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
}
OperandRange WhileOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert(index && *index == 0 &&
OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBefore() &&
"WhileOp is expected to branch only to the first region");
return getInits();
@ -3192,17 +3192,18 @@ Block::BlockArgListType WhileOp::getAfterArguments() {
return getAfterBody()->getArguments();
}
void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
void WhileOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The parent op always branches to the condition region.
if (!index) {
if (point.isParent()) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
assert(*index < 2 && "there are only two regions in a WhileOp");
assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
"there are only two regions in a WhileOp");
// The body region always branches back to the condition region.
if (*index == 1) {
if (point == getAfter()) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
@ -4023,10 +4024,9 @@ Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
}
void IndexSwitchOp::getSuccessorRegions(
std::optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> &successors) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
// All regions branch back to the parent op.
if (index) {
if (!point.isParent()) {
successors.emplace_back(getResults());
return;
}

View File

@ -335,11 +335,11 @@ void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
void AssumingOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// AssumingOp has unconditional control flow into the region and back to the
// parent, so return the correct RegionSuccessor purely based on the index
// being None or 0.
if (index) {
if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}

View File

@ -86,23 +86,25 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
// AlternativesOp
//===----------------------------------------------------------------------===//
OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
std::optional<unsigned> index) {
if (index && getOperation()->getNumOperands() == 1)
OperandRange
transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
if (!point.isParent() && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
}
void transform::AlternativesOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
for (Region &alternative : llvm::drop_begin(
getAlternatives(), index.has_value() ? *index + 1 : 0)) {
getAlternatives(),
point.isParent() ? 0
: point.getRegionOrNull()->getRegionNumber() + 1)) {
regions.emplace_back(&alternative, !getOperands().empty()
? alternative.getArguments()
: Block::BlockArgListType());
}
if (index.has_value())
if (!point.isParent())
regions.emplace_back(getOperation()->getResults());
}
@ -1159,24 +1161,24 @@ void transform::ForeachOp::getEffects(
}
void transform::ForeachOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
Region *bodyRegion = &getBody();
if (!index) {
if (point.isParent()) {
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
return;
}
// Branch back to the region or the parent.
assert(*index == 0 && "unexpected region index");
assert(point == getBody() && "unexpected region index");
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
regions.emplace_back();
}
OperandRange
transform::ForeachOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
// The iteration variable op handle is mapped to a subset (one op to be
// precise) of the payload ops of the ForeachOp operand.
assert(index && *index == 0 && "unexpected region index");
assert(point == getBody() && "unexpected region index");
return getOperation()->getOperands();
}
@ -2178,9 +2180,9 @@ void transform::SequenceOp::getEffects(
getPotentialTopLevelEffects(effects);
}
OperandRange transform::SequenceOp::getEntrySuccessorOperands(
std::optional<unsigned> index) {
assert(index && *index == 0 && "unexpected region index");
OperandRange
transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBody() && "unexpected region index");
if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
@ -2188,8 +2190,8 @@ OperandRange transform::SequenceOp::getEntrySuccessorOperands(
}
void transform::SequenceOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
if (!index) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (point.isParent()) {
Region *bodyRegion = &getBody();
regions.emplace_back(bodyRegion, getNumOperands() != 0
? bodyRegion->getArguments()
@ -2197,7 +2199,7 @@ void transform::SequenceOp::getSuccessorRegions(
return;
}
assert(*index == 0 && "unexpected region index");
assert(point == getBody() && "unexpected region index");
regions.emplace_back(getOperation()->getResults());
}

View File

@ -5821,8 +5821,8 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
}
void WarpExecuteOnLane0Op::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
if (index) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}

View File

@ -84,18 +84,18 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
static InFlightDiagnostic &
printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
std::optional<unsigned> succRegionNo) {
static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
RegionBranchPoint sourceNo,
RegionBranchPoint succRegionNo) {
diag << "from ";
if (sourceNo)
diag << "Region #" << sourceNo.value();
if (Region *region = sourceNo.getRegionOrNull())
diag << "Region #" << region->getRegionNumber();
else
diag << "parent operands";
diag << " to ";
if (succRegionNo)
diag << "Region #" << succRegionNo.value();
if (Region *region = succRegionNo.getRegionOrNull())
diag << "Region #" << region->getRegionNumber();
else
diag << "parent results";
return diag;
@ -107,28 +107,24 @@ printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
/// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
/// the exact type match verification is not necessary (e.g., if the Op verifies
/// the match itself).
static LogicalResult verifyTypesAlongAllEdges(
Operation *op, std::optional<unsigned> sourceNo,
function_ref<FailureOr<TypeRange>(std::optional<unsigned>)>
static LogicalResult
verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
getInputsTypesForRegion) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
SmallVector<RegionSuccessor, 2> successors;
regionInterface.getSuccessorRegions(sourceNo, successors);
regionInterface.getSuccessorRegions(sourcePoint, successors);
for (RegionSuccessor &succ : successors) {
std::optional<unsigned> succRegionNo;
if (!succ.isParent())
succRegionNo = succ.getSuccessor()->getRegionNumber();
FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
if (failed(sourceTypes))
return failure();
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
if (sourceTypes->size() != succInputsTypes.size()) {
InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
return printRegionEdgeName(diag, sourceNo, succRegionNo)
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source has " << sourceTypes->size()
<< " operands, but target successor needs "
<< succInputsTypes.size();
@ -140,7 +136,7 @@ static LogicalResult verifyTypesAlongAllEdges(
Type inputType = std::get<1>(typesIdx.value());
if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
return printRegionEdgeName(diag, sourceNo, succRegionNo)
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source type #" << typesIdx.index() << " " << sourceType
<< " should match input type #" << typesIdx.index() << " "
<< inputType;
@ -154,13 +150,13 @@ static LogicalResult verifyTypesAlongAllEdges(
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
auto inputTypesFromParent =
[&](std::optional<unsigned> regionNo) -> TypeRange {
auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange {
return regionInterface.getEntrySuccessorOperands(regionNo).getTypes();
};
// Verify types along control flow edges originating from the parent.
if (failed(verifyTypesAlongAllEdges(op, std::nullopt, inputTypesFromParent)))
if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
inputTypesFromParent)))
return failure();
auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
@ -176,8 +172,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
};
// Verify types along control flow edges originating from each region.
for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
Region &region = op->getRegion(regionNo);
for (Region &region : op->getRegions()) {
// Since there can be multiple terminators implementing the
// `RegionBranchTerminatorOpInterface`, all should have the same operand
@ -195,7 +190,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
continue;
auto inputTypesForRegion =
[&](std::optional<unsigned> succRegionNo) -> FailureOr<TypeRange> {
[&](RegionBranchPoint succRegionNo) -> FailureOr<TypeRange> {
std::optional<OperandRange> regionReturnOperands;
for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
auto terminatorOperands =
@ -211,7 +206,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
if (!areTypesCompatible(regionReturnOperands->getTypes(),
terminatorOperands.getTypes())) {
InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
return printRegionEdgeName(diag, regionNo, succRegionNo)
return printRegionEdgeName(diag, region, succRegionNo)
<< " operands mismatch between return-like terminators";
}
}
@ -220,7 +215,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
return TypeRange(regionReturnOperands->getTypes());
};
if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesForRegion)))
if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
return failure();
}
@ -237,24 +232,24 @@ static bool isRegionReachable(Region *begin, Region *r) {
visited[begin->getRegionNumber()] = true;
// Retrieve all successors of the region and enqueue them in the worklist.
SmallVector<unsigned> worklist;
auto enqueueAllSuccessors = [&](unsigned index) {
SmallVector<Region *> worklist;
auto enqueueAllSuccessors = [&](Region *region) {
SmallVector<RegionSuccessor> successors;
op.getSuccessorRegions(index, successors);
op.getSuccessorRegions(region, successors);
for (RegionSuccessor successor : successors)
if (!successor.isParent())
worklist.push_back(successor.getSuccessor()->getRegionNumber());
worklist.push_back(successor.getSuccessor());
};
enqueueAllSuccessors(begin->getRegionNumber());
enqueueAllSuccessors(begin);
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
unsigned nextRegion = worklist.pop_back_val();
if (nextRegion == r->getRegionNumber())
Region *nextRegion = worklist.pop_back_val();
if (nextRegion == r)
return true;
if (visited[nextRegion])
if (visited[nextRegion->getRegionNumber()])
continue;
visited[nextRegion] = true;
visited[nextRegion->getRegionNumber()] = true;
enqueueAllSuccessors(nextRegion);
}

View File

@ -316,15 +316,11 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Return the successors of `region` if the latter is not null. Else return
// the successors of `regionBranchOp`.
auto getSuccessors = [&](Region *region = nullptr) {
std::optional<unsigned> index =
region ? std::optional(region->getRegionNumber()) : std::nullopt;
auto point = region ? region : RegionBranchPoint::parent();
SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(),
nullptr);
SmallVector<RegionSuccessor> successors;
if (!index)
regionBranchOp.getEntrySuccessorRegions(operandAttributes, successors);
else
regionBranchOp.getSuccessorRegions(index, successors);
regionBranchOp.getSuccessorRegions(point, successors);
return successors;
};
@ -333,14 +329,10 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// forwarded to `successor`.
auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
Operation *terminator = nullptr) {
Region *successorRegion = successor.getSuccessor();
std::optional<unsigned> index =
successorRegion ? std::optional(successorRegion->getRegionNumber())
: std::nullopt;
OperandRange operands =
terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
.getSuccessorOperands(index)
: regionBranchOp.getEntrySuccessorOperands(index);
.getSuccessorOperands(successor)
: regionBranchOp.getEntrySuccessorOperands(successor);
SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
return opOperands;
};

View File

@ -60,8 +60,8 @@ public:
NextAccess *before) override;
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
std::optional<unsigned> regionFrom,
std::optional<unsigned> regionTo,
RegionBranchPoint regionFrom,
RegionBranchPoint regionTo,
const NextAccess &after,
NextAccess *before) override;
@ -124,15 +124,15 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
}
void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
std::optional<unsigned> regionTo, const NextAccess &after,
NextAccess *before) {
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
auto testStoreWithARegion =
dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
if (testStoreWithARegion &&
((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
(!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
(regionFrom.isParent() &&
testStoreWithARegion.getStoreBeforeRegion()))) {
visitOperation(branch, static_cast<const NextAccess &>(after),
static_cast<NextAccess *>(before));
} else {
@ -219,7 +219,7 @@ struct TestNextAccessPass
SmallVector<Attribute> entryPointNextAccess;
SmallVector<RegionSuccessor> regionSuccessors;
iface.getSuccessorRegions(std::nullopt, regionSuccessors);
iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
for (const RegionSuccessor &successor : regionSuccessors) {
if (!successor.getSuccessor() || successor.getSuccessor()->empty())
continue;

View File

@ -931,17 +931,17 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getCurrentLocation(), result.operands);
}
OperandRange
RegionIfOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert(index && *index < 2 && "invalid region index");
OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
"invalid region index");
return getOperands();
}
void RegionIfOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// We always branch to the join region.
if (index.has_value()) {
if (index.value() < 2)
if (!point.isParent()) {
if (point != getJoinRegion())
regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
else
regions.push_back(RegionSuccessor(getResults()));
@ -964,11 +964,11 @@ void RegionIfOp::getRegionInvocationBounds(
// AnyCondOp
//===----------------------------------------------------------------------===//
void AnyCondOp::getSuccessorRegions(std::optional<unsigned> index,
void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The parent op branches into the only region, and the region branches back
// to the parent op.
if (!index)
if (point.isParent())
regions.emplace_back(&getRegion());
else
regions.emplace_back(getResults());
@ -985,17 +985,16 @@ void AnyCondOp::getRegionInvocationBounds(
//===----------------------------------------------------------------------===//
void LoopBlockOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
regions.emplace_back(&getBody(), getBody().getArguments());
if (!index)
if (point.isParent())
return;
regions.emplace_back((*this)->getResults());
}
OperandRange
LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert(index == 0);
OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBody());
return getInitMutable();
}
@ -1003,10 +1002,9 @@ LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
// LoopBlockTerminatorOp
//===----------------------------------------------------------------------===//
MutableOperandRange LoopBlockTerminatorOp::getMutableSuccessorOperands(
std::optional<unsigned> index) {
assert(!index || index == 0);
if (!index)
MutableOperandRange
LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
if (point.isParent())
return getExitArgMutable();
return getNextIterArgMutable();
}
@ -1313,13 +1311,12 @@ MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
}
void TestStoreWithARegion::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
if (!index) {
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (point.isParent())
regions.emplace_back(&getBody(), getBody().front().getArguments());
} else {
else
regions.emplace_back();
}
}
LogicalResult
TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,

View File

@ -37,7 +37,7 @@ struct MutuallyExclusiveRegionsOp
}
// Regions have no successors.
void getSuccessorRegions(std::optional<unsigned> index,
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {}
};
@ -51,14 +51,13 @@ struct LoopRegionsOp
static StringRef getOperationName() { return "cftest.loop_regions_op"; }
void getSuccessorRegions(std::optional<unsigned> index,
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
if (index) {
if (*index == 1)
if (Region *region = point.getRegionOrNull()) {
if (point == (*this)->getRegion(1))
// This region also branches back to the parent.
regions.push_back(RegionSuccessor());
regions.push_back(
RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
regions.push_back(RegionSuccessor(region));
}
}
};
@ -74,11 +73,11 @@ struct DoubleLoopRegionsOp
return "cftest.double_loop_regions_op";
}
void getSuccessorRegions(std::optional<unsigned> index,
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
if (index.has_value()) {
if (Region *region = point.getRegionOrNull()) {
regions.push_back(RegionSuccessor());
regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index)));
regions.push_back(RegionSuccessor(region));
}
}
};
@ -92,9 +91,9 @@ struct SequentialRegionsOp
static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
// Region 0 has Region 1 as a successor.
void getSuccessorRegions(std::optional<unsigned> index,
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
if (index == 0u) {
if (point == (*this)->getRegion(0)) {
Operation *thisOp = this->getOperation();
regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
}