mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 18:36:05 +00:00
Reland "[mlir] Use a type for representing branch points in RegionBranchOpInterface
"
This reverts commit b26bb30b467b996c9786e3bd426c07684d84d406.
This commit is contained in:
parent
82e851a407
commit
4dd744ac9c
@ -3467,10 +3467,10 @@ void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
|
||||
/// return the successor regions. These are the regions that may be selected
|
||||
/// during the flow of control.
|
||||
void fir::IfOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index,
|
||||
mlir::RegionBranchPoint point,
|
||||
llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) {
|
||||
// The `then` and the `else` region branch back to the parent operation.
|
||||
if (index) {
|
||||
if (!point.isParent()) {
|
||||
regions.push_back(mlir::RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
|
@ -353,8 +353,8 @@ protected:
|
||||
/// any effect on the lattice that isn't already expressed by the interface
|
||||
/// itself.
|
||||
virtual void visitRegionBranchControlFlowTransfer(
|
||||
RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
|
||||
std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
|
||||
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
|
||||
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
|
||||
AbstractDenseLattice *before) {
|
||||
meet(before, after);
|
||||
}
|
||||
@ -382,7 +382,7 @@ private:
|
||||
/// of the branch operation itself.
|
||||
void visitRegionBranchOperation(ProgramPoint point,
|
||||
RegionBranchOpInterface branch,
|
||||
std::optional<unsigned> regionNo,
|
||||
RegionBranchPoint branchPoint,
|
||||
AbstractDenseLattice *before);
|
||||
|
||||
/// Visit an operation for which the data flow is described by the
|
||||
@ -472,9 +472,8 @@ public:
|
||||
/// nullptr`. The behavior can be further refined for specific pairs of "from"
|
||||
/// and "to" regions.
|
||||
virtual void visitRegionBranchControlFlowTransfer(
|
||||
RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
|
||||
std::optional<unsigned> regionTo, const LatticeT &after,
|
||||
LatticeT *before) {
|
||||
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
|
||||
RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
|
||||
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
|
||||
branch, regionFrom, regionTo, after, before);
|
||||
}
|
||||
@ -508,8 +507,8 @@ protected:
|
||||
static_cast<LatticeT *>(before));
|
||||
}
|
||||
void visitRegionBranchControlFlowTransfer(
|
||||
RegionBranchOpInterface branch, std::optional<unsigned> regionForm,
|
||||
std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
|
||||
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
|
||||
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
|
||||
AbstractDenseLattice *before) final {
|
||||
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
|
||||
static_cast<const LatticeT &>(after),
|
||||
|
@ -243,7 +243,7 @@ private:
|
||||
/// regions or the parent operation itself, and set either the argument or
|
||||
/// parent result lattices.
|
||||
void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
|
||||
std::optional<unsigned> successorIndex,
|
||||
RegionBranchPoint successor,
|
||||
ArrayRef<AbstractSparseLattice *> lattices);
|
||||
};
|
||||
|
||||
|
@ -190,6 +190,68 @@ private:
|
||||
ValueRange inputs;
|
||||
};
|
||||
|
||||
/// This class represents a point being branched from in the methods of the
|
||||
/// `RegionBranchOpInterface`.
|
||||
/// One can branch from one of two kinds of places:
|
||||
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
|
||||
/// * A region within the parent operation.
|
||||
class RegionBranchPoint {
|
||||
public:
|
||||
/// Returns an instance of `RegionBranchPoint` representing the parent
|
||||
/// operation.
|
||||
static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
|
||||
|
||||
/// Creates a `RegionBranchPoint` that branches from the given region.
|
||||
/// The pointer must not be null.
|
||||
RegionBranchPoint(Region *region) : maybeRegion(region) {
|
||||
assert(region && "Region must not be null");
|
||||
}
|
||||
|
||||
RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {}
|
||||
|
||||
/// Explicitly stops users from constructing with `nullptr`.
|
||||
RegionBranchPoint(std::nullptr_t) = delete;
|
||||
|
||||
/// Constructs a `RegionBranchPoint` from the the target of a
|
||||
/// `RegionSuccessor` instance.
|
||||
RegionBranchPoint(RegionSuccessor successor) {
|
||||
if (successor.isParent())
|
||||
maybeRegion = nullptr;
|
||||
else
|
||||
maybeRegion = successor.getSuccessor();
|
||||
}
|
||||
|
||||
/// Assigns a region being branched from.
|
||||
RegionBranchPoint &operator=(Region ®ion) {
|
||||
maybeRegion = ®ion;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns true if branching from the parent op.
|
||||
bool isParent() const { return maybeRegion == nullptr; }
|
||||
|
||||
/// Returns the region if branching from a region.
|
||||
/// A null pointer otherwise.
|
||||
Region *getRegionOrNull() const { return maybeRegion; }
|
||||
|
||||
/// Returns true if the two branch points are equal.
|
||||
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
|
||||
return lhs.maybeRegion == rhs.maybeRegion;
|
||||
}
|
||||
|
||||
private:
|
||||
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
|
||||
constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
|
||||
|
||||
/// Internal encoding. Uses nullptr for representing branching from the parent
|
||||
/// op and the region being branched from otherwise.
|
||||
Region *maybeRegion;
|
||||
};
|
||||
|
||||
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
/// This class represents upper and lower bounds on the number of times a region
|
||||
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
|
||||
/// zero, but the upper bound may not be known.
|
||||
|
@ -133,14 +133,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
|
||||
let methods = [
|
||||
InterfaceMethod<[{
|
||||
Returns the operands of this operation used as the entry arguments when
|
||||
entering the region at `index`, which was specified as a successor of
|
||||
branching from `point`, which was specified as a successor of
|
||||
this operation by `getEntrySuccessorRegions`, or the operands forwarded
|
||||
to the operation's results when it branches back to itself. These operands
|
||||
should correspond 1-1 with the successor inputs specified in
|
||||
`getEntrySuccessorRegions`.
|
||||
}],
|
||||
"::mlir::OperandRange", "getEntrySuccessorOperands",
|
||||
(ins "::std::optional<unsigned>":$index), [{}],
|
||||
(ins "::mlir::RegionBranchPoint":$point), [{}],
|
||||
/*defaultImplementation=*/[{
|
||||
auto operandEnd = this->getOperation()->operand_end();
|
||||
return ::mlir::OperandRange(operandEnd, operandEnd);
|
||||
@ -162,22 +162,20 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
|
||||
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
|
||||
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
|
||||
[{}], [{
|
||||
$_op.getSuccessorRegions(std::nullopt, regions);
|
||||
$_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns the viable successors of a region at `index`, or the possible
|
||||
successors when branching from the parent op if `index` is None. These
|
||||
are the regions that may be selected during the flow of control. The
|
||||
parent operation, i.e. a null `index`, may specify itself as successor,
|
||||
which indicates that the control flow may not enter any region at all.
|
||||
This method allows for describing which regions may be executed when
|
||||
entering an operation, and which regions are executed after having
|
||||
executed another region of the parent op. The successor region must be
|
||||
non-empty.
|
||||
Returns the viable successors of `point`. These are the regions that may
|
||||
be selected during the flow of control. The parent operation, may
|
||||
specify itself as successor, which indicates that the control flow may
|
||||
not enter any region at all. This method allows for describing which
|
||||
regions may be executed when entering an operation, and which regions
|
||||
are executed after having executed another region of the parent op. The
|
||||
successor region must be non-empty.
|
||||
}],
|
||||
"void", "getSuccessorRegions",
|
||||
(ins "::std::optional<unsigned>":$index,
|
||||
(ins "::mlir::RegionBranchPoint":$point,
|
||||
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
@ -245,12 +243,10 @@ def RegionBranchTerminatorOpInterface :
|
||||
let methods = [
|
||||
InterfaceMethod<[{
|
||||
Returns a mutable range of operands that are semantically "returned" by
|
||||
passing them to the region successor given by `index`. If `index` is
|
||||
None, this function returns the operands that are passed as a result to
|
||||
the parent operation.
|
||||
passing them to the region successor given by `point`.
|
||||
}],
|
||||
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
|
||||
(ins "::std::optional<unsigned>":$index)
|
||||
(ins "::mlir::RegionBranchPoint":$point)
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns the viable region successors that are branched to after this
|
||||
@ -269,8 +265,7 @@ def RegionBranchTerminatorOpInterface :
|
||||
[{
|
||||
::mlir::Operation *op = $_op;
|
||||
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
|
||||
.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
|
||||
regions);
|
||||
.getSuccessorRegions(op->getParentRegion(), regions);
|
||||
}]
|
||||
>,
|
||||
];
|
||||
@ -290,8 +285,8 @@ def RegionBranchTerminatorOpInterface :
|
||||
// them to the region successor given by `index`. If `index` is None, this
|
||||
// function returns the operands that are passed as a result to the parent
|
||||
// operation.
|
||||
::mlir::OperandRange getSuccessorOperands(std::optional<unsigned> index) {
|
||||
return getMutableSuccessorOperands(index);
|
||||
::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) {
|
||||
return getMutableSuccessorOperands(point);
|
||||
}
|
||||
}];
|
||||
}
|
||||
@ -309,7 +304,7 @@ def ReturnLike : TraitList<[
|
||||
/*extraOpDeclaration=*/"",
|
||||
/*extraOpDefinition=*/[{
|
||||
::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
|
||||
::std::optional<unsigned> index) {
|
||||
::mlir::RegionBranchPoint point) {
|
||||
return ::mlir::MutableOperandRange(*this);
|
||||
}
|
||||
}]
|
||||
|
@ -45,9 +45,9 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
|
||||
// this region predecessor that correspond to the input values of `region`. If
|
||||
// an index could not be found, std::nullopt is returned instead.
|
||||
auto getOperandIndexIfPred =
|
||||
[&](std::optional<unsigned> predIndex) -> std::optional<unsigned> {
|
||||
[&](RegionBranchPoint pred) -> std::optional<unsigned> {
|
||||
SmallVector<RegionSuccessor, 2> successors;
|
||||
branch.getSuccessorRegions(predIndex, successors);
|
||||
branch.getSuccessorRegions(pred, successors);
|
||||
for (RegionSuccessor &successor : successors) {
|
||||
if (successor.getSuccessor() != region)
|
||||
continue;
|
||||
@ -75,28 +75,27 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
|
||||
};
|
||||
|
||||
// Check branches from the parent operation.
|
||||
std::optional<unsigned> regionIndex;
|
||||
if (region) {
|
||||
// Determine the actual region number from the passed region.
|
||||
regionIndex = region->getRegionNumber();
|
||||
}
|
||||
auto branchPoint = RegionBranchPoint::parent();
|
||||
if (region)
|
||||
branchPoint = region;
|
||||
|
||||
if (std::optional<unsigned> operandIndex =
|
||||
getOperandIndexIfPred(/*predIndex=*/std::nullopt)) {
|
||||
getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) {
|
||||
collectUnderlyingAddressValues(
|
||||
branch.getEntrySuccessorOperands(regionIndex)[*operandIndex], maxDepth,
|
||||
branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
|
||||
visited, output);
|
||||
}
|
||||
// Check branches from each child region.
|
||||
Operation *op = branch.getOperation();
|
||||
for (int i = 0, e = op->getNumRegions(); i != e; ++i) {
|
||||
if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(i)) {
|
||||
for (Block &block : op->getRegion(i)) {
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
|
||||
for (Block &block : region) {
|
||||
// Try to determine possible region-branch successor operands for the
|
||||
// current region.
|
||||
if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
|
||||
block.getTerminator())) {
|
||||
collectUnderlyingAddressValues(
|
||||
term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth,
|
||||
term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
|
||||
visited, output);
|
||||
} else if (block.getNumSuccessors()) {
|
||||
// Otherwise, if this terminator may exit the region we can't make
|
||||
|
@ -312,7 +312,8 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
|
||||
|
||||
// Special cases where control flow may dictate data flow.
|
||||
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
|
||||
return visitRegionBranchOperation(op, branch, std::nullopt, before);
|
||||
return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
|
||||
before);
|
||||
if (auto call = dyn_cast<CallOpInterface>(op))
|
||||
return visitCallOperation(call, before);
|
||||
|
||||
@ -368,8 +369,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
|
||||
// If this block is exiting from an operation with region-based control
|
||||
// flow, propagate the lattice back along the control flow edge.
|
||||
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
|
||||
visitRegionBranchOperation(block, branch,
|
||||
block->getParent()->getRegionNumber(), before);
|
||||
visitRegionBranchOperation(block, branch, block->getParent(), before);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -396,13 +396,13 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
|
||||
|
||||
void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
|
||||
ProgramPoint point, RegionBranchOpInterface branch,
|
||||
std::optional<unsigned> regionNo, AbstractDenseLattice *before) {
|
||||
RegionBranchPoint branchPoint, AbstractDenseLattice *before) {
|
||||
|
||||
// The successors of the operation may be either the first operation of the
|
||||
// entry block of each possible successor region, or the next operation when
|
||||
// the branch is a successor of itself.
|
||||
SmallVector<RegionSuccessor> successors;
|
||||
branch.getSuccessorRegions(regionNo, successors);
|
||||
branch.getSuccessorRegions(branchPoint, successors);
|
||||
for (const RegionSuccessor &successor : successors) {
|
||||
const AbstractDenseLattice *after;
|
||||
if (successor.isParent() || successor.getSuccessor()->empty()) {
|
||||
@ -423,10 +423,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
|
||||
else
|
||||
after = getLatticeFor(point, &successorBlock->front());
|
||||
}
|
||||
std::optional<unsigned> successorNo =
|
||||
successor.isParent() ? std::optional<unsigned>()
|
||||
: successor.getSuccessor()->getRegionNumber();
|
||||
visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after,
|
||||
|
||||
visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after,
|
||||
before);
|
||||
}
|
||||
}
|
||||
|
@ -99,7 +99,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
|
||||
// The results of a region branch operation are determined by control-flow.
|
||||
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
|
||||
return visitRegionSuccessors({branch}, branch,
|
||||
/*successorIndex=*/std::nullopt,
|
||||
/*successor=*/RegionBranchPoint::parent(),
|
||||
resultLattices);
|
||||
}
|
||||
|
||||
@ -167,8 +167,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
|
||||
|
||||
// Check if the lattices can be determined from region control flow.
|
||||
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
|
||||
return visitRegionSuccessors(
|
||||
block, branch, block->getParent()->getRegionNumber(), argLattices);
|
||||
return visitRegionSuccessors(block, branch, block->getParent(),
|
||||
argLattices);
|
||||
}
|
||||
|
||||
// Otherwise, we can't reason about the data-flow.
|
||||
@ -212,8 +212,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
|
||||
|
||||
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
|
||||
ProgramPoint point, RegionBranchOpInterface branch,
|
||||
std::optional<unsigned> successorIndex,
|
||||
ArrayRef<AbstractSparseLattice *> lattices) {
|
||||
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
|
||||
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
|
||||
assert(predecessors->allPredecessorsKnown() &&
|
||||
"unexpected unresolved region successors");
|
||||
@ -224,11 +223,11 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
|
||||
|
||||
// Check if the predecessor is the parent op.
|
||||
if (op == branch) {
|
||||
operands = branch.getEntrySuccessorOperands(successorIndex);
|
||||
operands = branch.getEntrySuccessorOperands(successor);
|
||||
// Otherwise, try to deduce the operands from a region return-like op.
|
||||
} else if (auto regionTerminator =
|
||||
dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
|
||||
operands = regionTerminator.getSuccessorOperands(successorIndex);
|
||||
operands = regionTerminator.getSuccessorOperands(successor);
|
||||
}
|
||||
|
||||
if (!operands) {
|
||||
@ -501,10 +500,7 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
|
||||
BitVector unaccounted(op->getNumOperands(), true);
|
||||
|
||||
for (RegionSuccessor &successor : successors) {
|
||||
Region *region = successor.getSuccessor();
|
||||
OperandRange operands =
|
||||
region ? branch.getEntrySuccessorOperands(region->getRegionNumber())
|
||||
: branch.getEntrySuccessorOperands({});
|
||||
OperandRange operands = branch.getEntrySuccessorOperands(successor);
|
||||
MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
|
||||
ValueRange inputs = successor.getSuccessorInputs();
|
||||
for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
|
||||
@ -538,9 +534,7 @@ void AbstractSparseBackwardDataFlowAnalysis::
|
||||
|
||||
for (const RegionSuccessor &successor : successors) {
|
||||
ValueRange inputs = successor.getSuccessorInputs();
|
||||
Region *region = successor.getSuccessor();
|
||||
OperandRange operands = terminator.getSuccessorOperands(
|
||||
region ? region->getRegionNumber() : std::optional<unsigned>{});
|
||||
OperandRange operands = terminator.getSuccessorOperands(successor);
|
||||
MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
|
||||
for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
|
||||
meet(getLatticeElement(opOperand.get()),
|
||||
|
@ -2379,9 +2379,9 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
/// correspond to the loop iterator operands, i.e., those excluding the
|
||||
/// induction variable. AffineForOp only has one region, so zero is the only
|
||||
/// valid value for `index`.
|
||||
OperandRange
|
||||
AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
|
||||
assert((!index || *index == 0) && "invalid region index");
|
||||
OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
||||
assert((point.isParent() || point == getLoopBody()) &&
|
||||
"invalid region point");
|
||||
|
||||
// The initial operands map to the loop arguments after the induction
|
||||
// variable or are forwarded to the results when the trip count is zero.
|
||||
@ -2394,14 +2394,15 @@ AffineForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
|
||||
/// correspond to a constant value for each operand, or null if that operand is
|
||||
/// not a constant.
|
||||
void AffineForOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
assert((!index.has_value() || index.value() == 0) && "expected loop region");
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
assert((point.isParent() || point == getLoopBody()) &&
|
||||
"expected loop region");
|
||||
// The loop may typically branch back to its body or to the parent operation.
|
||||
// If the predecessor is the parent op and the trip count is known to be at
|
||||
// least one, branch into the body using the iterator arguments. And in cases
|
||||
// we know the trip count is zero, it can only branch back to its parent.
|
||||
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
|
||||
if (!index.has_value() && tripCount.has_value()) {
|
||||
if (point.isParent() && tripCount.has_value()) {
|
||||
if (tripCount.value() > 0) {
|
||||
regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
|
||||
return;
|
||||
@ -2414,7 +2415,7 @@ void AffineForOp::getSuccessorRegions(
|
||||
|
||||
// From the loop body, if the trip count is one, we can only branch back to
|
||||
// the parent.
|
||||
if (index && tripCount && *tripCount == 1) {
|
||||
if (!point.isParent() && tripCount && *tripCount == 1) {
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
@ -2859,10 +2860,10 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
|
||||
/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
|
||||
/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
|
||||
void AffineIfOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// If the predecessor is an AffineIfOp, then branching into both `then` and
|
||||
// `else` region is valid.
|
||||
if (!index.has_value()) {
|
||||
if (point.isParent()) {
|
||||
regions.reserve(2);
|
||||
regions.push_back(
|
||||
RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
|
||||
|
@ -38,9 +38,8 @@ void AsyncDialect::initialize() {
|
||||
|
||||
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
|
||||
|
||||
OperandRange
|
||||
ExecuteOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
|
||||
assert(index && *index == 0 && "invalid region index");
|
||||
OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
||||
assert(point == getBodyRegion() && "invalid region index");
|
||||
return getBodyOperands();
|
||||
}
|
||||
|
||||
@ -53,11 +52,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
|
||||
return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
|
||||
}
|
||||
|
||||
void ExecuteOp::getSuccessorRegions(std::optional<unsigned> index,
|
||||
void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// The `body` region branch back to the parent operation.
|
||||
if (index) {
|
||||
assert(*index == 0 && "invalid region index");
|
||||
if (point == getBodyRegion()) {
|
||||
regions.push_back(RegionSuccessor(getBodyResults()));
|
||||
return;
|
||||
}
|
||||
|
@ -372,7 +372,7 @@ private:
|
||||
// parent operation. In this case, we have to introduce an additional clone
|
||||
// for buffer that is passed to the argument.
|
||||
SmallVector<RegionSuccessor, 2> successorRegions;
|
||||
regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
|
||||
regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
|
||||
successorRegions);
|
||||
auto *it =
|
||||
llvm::find_if(successorRegions, [&](RegionSuccessor &successorRegion) {
|
||||
@ -383,8 +383,7 @@ private:
|
||||
|
||||
// Determine the actual operand to introduce a clone for and rewire the
|
||||
// operand to point to the clone instead.
|
||||
auto operands =
|
||||
regionInterface.getEntrySuccessorOperands(argRegion->getRegionNumber());
|
||||
auto operands = regionInterface.getEntrySuccessorOperands(argRegion);
|
||||
size_t operandIndex =
|
||||
llvm::find(it->getSuccessorInputs(), blockArg).getIndex() +
|
||||
operands.getBeginOperandIndex();
|
||||
@ -432,8 +431,7 @@ private:
|
||||
// Query the regionInterface to get all successor regions of the current
|
||||
// one.
|
||||
SmallVector<RegionSuccessor, 2> successorRegions;
|
||||
regionInterface.getSuccessorRegions(region.getRegionNumber(),
|
||||
successorRegions);
|
||||
regionInterface.getSuccessorRegions(region, successorRegions);
|
||||
// Try to find a matching region successor.
|
||||
RegionSuccessor *regionSuccessor =
|
||||
llvm::find_if(successorRegions, regionPredicate);
|
||||
@ -445,10 +443,6 @@ private:
|
||||
llvm::find(regionSuccessor->getSuccessorInputs(), argValue)
|
||||
.getIndex();
|
||||
|
||||
std::optional<unsigned> successorRegionNumber;
|
||||
if (Region *successorRegion = regionSuccessor->getSuccessor())
|
||||
successorRegionNumber = successorRegion->getRegionNumber();
|
||||
|
||||
// Iterate over all immediate terminator operations to introduce
|
||||
// new buffer allocations. Thereby, the appropriate terminator operand
|
||||
// will be adjusted to point to the newly allocated buffer instead.
|
||||
@ -456,8 +450,7 @@ private:
|
||||
®ion, [&](RegionBranchTerminatorOpInterface terminator) {
|
||||
// Get the actual mutable operands for this terminator op.
|
||||
auto terminatorOperands =
|
||||
terminator.getMutableSuccessorOperands(
|
||||
successorRegionNumber);
|
||||
terminator.getMutableSuccessorOperands(*regionSuccessor);
|
||||
// Extract the source value from the current terminator.
|
||||
// This conversion needs to exist on a separate line due to a
|
||||
// bug in GCC conversion analysis.
|
||||
|
@ -123,7 +123,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
|
||||
return true;
|
||||
// Recurses into all region successors.
|
||||
SmallVector<RegionSuccessor, 2> successors;
|
||||
regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
|
||||
regionInterface.getSuccessorRegions(current, successors);
|
||||
for (RegionSuccessor ®ionEntry : successors)
|
||||
if (recurse(regionEntry.getSuccessor()))
|
||||
return true;
|
||||
@ -132,7 +132,8 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
|
||||
|
||||
// Start with all entry regions and test whether they induce a loop.
|
||||
SmallVector<RegionSuccessor, 2> successorRegions;
|
||||
regionInterface.getSuccessorRegions(/*index=*/std::nullopt, successorRegions);
|
||||
regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
|
||||
successorRegions);
|
||||
for (RegionSuccessor ®ionEntry : successorRegions) {
|
||||
if (recurse(regionEntry.getSuccessor()))
|
||||
return true;
|
||||
|
@ -100,16 +100,13 @@ void BufferViewFlowAnalysis::build(Operation *op) {
|
||||
// Query the RegionBranchOpInterface to find potential successor regions.
|
||||
// Extract all entry regions and wire all initial entry successor inputs.
|
||||
SmallVector<RegionSuccessor, 2> entrySuccessors;
|
||||
regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
|
||||
regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
|
||||
entrySuccessors);
|
||||
for (RegionSuccessor &entrySuccessor : entrySuccessors) {
|
||||
// Wire the entry region's successor arguments with the initial
|
||||
// successor inputs.
|
||||
registerDependencies(
|
||||
regionInterface.getEntrySuccessorOperands(
|
||||
entrySuccessor.isParent()
|
||||
? std::optional<unsigned>()
|
||||
: entrySuccessor.getSuccessor()->getRegionNumber()),
|
||||
regionInterface.getEntrySuccessorOperands(entrySuccessor),
|
||||
entrySuccessor.getSuccessorInputs());
|
||||
}
|
||||
|
||||
@ -118,20 +115,15 @@ void BufferViewFlowAnalysis::build(Operation *op) {
|
||||
// Iterate over all successor region entries that are reachable from the
|
||||
// current region.
|
||||
SmallVector<RegionSuccessor, 2> successorRegions;
|
||||
regionInterface.getSuccessorRegions(region.getRegionNumber(),
|
||||
successorRegions);
|
||||
regionInterface.getSuccessorRegions(region, successorRegions);
|
||||
for (RegionSuccessor &successorRegion : successorRegions) {
|
||||
// Determine the current region index (if any).
|
||||
std::optional<unsigned> regionIndex;
|
||||
Region *regionSuccessor = successorRegion.getSuccessor();
|
||||
if (regionSuccessor)
|
||||
regionIndex = regionSuccessor->getRegionNumber();
|
||||
// Iterate over all immediate terminator operations and wire the
|
||||
// successor inputs with the successor operands of each terminator.
|
||||
for (Block &block : region)
|
||||
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
|
||||
block.getTerminator()))
|
||||
registerDependencies(terminator.getSuccessorOperands(regionIndex),
|
||||
registerDependencies(
|
||||
terminator.getSuccessorOperands(successorRegion),
|
||||
successorRegion.getSuccessorInputs());
|
||||
}
|
||||
}
|
||||
|
@ -455,8 +455,8 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
}
|
||||
|
||||
void AllocaScopeOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (index) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (!point.isParent()) {
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
|
@ -266,9 +266,9 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
/// correspond to a constant value for each operand, or null if that operand is
|
||||
/// not a constant.
|
||||
void ExecuteRegionOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// If the predecessor is the ExecuteRegionOp, branch into the body.
|
||||
if (!index) {
|
||||
if (point.isParent()) {
|
||||
regions.push_back(RegionSuccessor(&getRegion()));
|
||||
return;
|
||||
}
|
||||
@ -282,8 +282,8 @@ void ExecuteRegionOp::getSuccessorRegions(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MutableOperandRange
|
||||
ConditionOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
|
||||
assert((!index || index == getParentOp().getAfter().getRegionNumber()) &&
|
||||
ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
|
||||
assert((point.isParent() || point == getParentOp().getAfter()) &&
|
||||
"condition op can only exit the loop or branch to the after"
|
||||
"region");
|
||||
// Pass all operands except the condition to the successor region.
|
||||
@ -553,7 +553,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
|
||||
/// Return operands used when entering the region at 'index'. These operands
|
||||
/// correspond to the loop iterator operands, i.e., those excluding the
|
||||
/// induction variable.
|
||||
OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
|
||||
OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
||||
return getInitArgs();
|
||||
}
|
||||
|
||||
@ -562,7 +562,7 @@ OperandRange ForOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
|
||||
/// during the flow of control. `operands` is a set of optional attributes that
|
||||
/// correspond to a constant value for each operand, or null if that operand is
|
||||
/// not a constant.
|
||||
void ForOp::getSuccessorRegions(std::optional<unsigned> index,
|
||||
void ForOp::getSuccessorRegions(RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// Both the operation itself and the region may be branching into the body or
|
||||
// back into the operation itself. It is possible for loop not to enter the
|
||||
@ -1731,7 +1731,7 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
/// during the flow of control. `operands` is a set of optional attributes that
|
||||
/// correspond to a constant value for each operand, or null if that operand is
|
||||
/// not a constant.
|
||||
void ForallOp::getSuccessorRegions(std::optional<unsigned> index,
|
||||
void ForallOp::getSuccessorRegions(RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// Both the operation itself and the region may be branching into the body or
|
||||
// back into the operation itself. It is possible for loop not to enter the
|
||||
@ -2011,10 +2011,10 @@ void IfOp::print(OpAsmPrinter &p) {
|
||||
/// during the flow of control. `operands` is a set of optional attributes that
|
||||
/// correspond to a constant value for each operand, or null if that operand is
|
||||
/// not a constant.
|
||||
void IfOp::getSuccessorRegions(std::optional<unsigned> index,
|
||||
void IfOp::getSuccessorRegions(RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// The `then` and the `else` region branch back to the parent operation.
|
||||
if (index) {
|
||||
if (!point.isParent()) {
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
@ -3042,7 +3042,7 @@ void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
/// correspond to a constant value for each operand, or null if that operand is
|
||||
/// not a constant.
|
||||
void ParallelOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// Both the operation itself and the region may be branching into the body or
|
||||
// back into the operation itself. It is possible for loop not to enter the
|
||||
// body.
|
||||
@ -3169,8 +3169,8 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
|
||||
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
|
||||
}
|
||||
|
||||
OperandRange WhileOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
|
||||
assert(index && *index == 0 &&
|
||||
OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
||||
assert(point == getBefore() &&
|
||||
"WhileOp is expected to branch only to the first region");
|
||||
|
||||
return getInits();
|
||||
@ -3192,17 +3192,18 @@ Block::BlockArgListType WhileOp::getAfterArguments() {
|
||||
return getAfterBody()->getArguments();
|
||||
}
|
||||
|
||||
void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
|
||||
void WhileOp::getSuccessorRegions(RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// The parent op always branches to the condition region.
|
||||
if (!index) {
|
||||
if (point.isParent()) {
|
||||
regions.emplace_back(&getBefore(), getBefore().getArguments());
|
||||
return;
|
||||
}
|
||||
|
||||
assert(*index < 2 && "there are only two regions in a WhileOp");
|
||||
assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
|
||||
"there are only two regions in a WhileOp");
|
||||
// The body region always branches back to the condition region.
|
||||
if (*index == 1) {
|
||||
if (point == getAfter()) {
|
||||
regions.emplace_back(&getBefore(), getBefore().getArguments());
|
||||
return;
|
||||
}
|
||||
@ -4023,10 +4024,9 @@ Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
|
||||
}
|
||||
|
||||
void IndexSwitchOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index,
|
||||
SmallVectorImpl<RegionSuccessor> &successors) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
|
||||
// All regions branch back to the parent op.
|
||||
if (index) {
|
||||
if (!point.isParent()) {
|
||||
successors.emplace_back(getResults());
|
||||
return;
|
||||
}
|
||||
|
@ -335,11 +335,11 @@ void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
|
||||
// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
|
||||
void AssumingOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// AssumingOp has unconditional control flow into the region and back to the
|
||||
// parent, so return the correct RegionSuccessor purely based on the index
|
||||
// being None or 0.
|
||||
if (index) {
|
||||
if (!point.isParent()) {
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
|
@ -86,23 +86,25 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
|
||||
// AlternativesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
|
||||
std::optional<unsigned> index) {
|
||||
if (index && getOperation()->getNumOperands() == 1)
|
||||
OperandRange
|
||||
transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
||||
if (!point.isParent() && getOperation()->getNumOperands() == 1)
|
||||
return getOperation()->getOperands();
|
||||
return OperandRange(getOperation()->operand_end(),
|
||||
getOperation()->operand_end());
|
||||
}
|
||||
|
||||
void transform::AlternativesOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
for (Region &alternative : llvm::drop_begin(
|
||||
getAlternatives(), index.has_value() ? *index + 1 : 0)) {
|
||||
getAlternatives(),
|
||||
point.isParent() ? 0
|
||||
: point.getRegionOrNull()->getRegionNumber() + 1)) {
|
||||
regions.emplace_back(&alternative, !getOperands().empty()
|
||||
? alternative.getArguments()
|
||||
: Block::BlockArgListType());
|
||||
}
|
||||
if (index.has_value())
|
||||
if (!point.isParent())
|
||||
regions.emplace_back(getOperation()->getResults());
|
||||
}
|
||||
|
||||
@ -1159,24 +1161,24 @@ void transform::ForeachOp::getEffects(
|
||||
}
|
||||
|
||||
void transform::ForeachOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
Region *bodyRegion = &getBody();
|
||||
if (!index) {
|
||||
if (point.isParent()) {
|
||||
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
|
||||
return;
|
||||
}
|
||||
|
||||
// Branch back to the region or the parent.
|
||||
assert(*index == 0 && "unexpected region index");
|
||||
assert(point == getBody() && "unexpected region index");
|
||||
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
|
||||
regions.emplace_back();
|
||||
}
|
||||
|
||||
OperandRange
|
||||
transform::ForeachOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
|
||||
transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
||||
// The iteration variable op handle is mapped to a subset (one op to be
|
||||
// precise) of the payload ops of the ForeachOp operand.
|
||||
assert(index && *index == 0 && "unexpected region index");
|
||||
assert(point == getBody() && "unexpected region index");
|
||||
return getOperation()->getOperands();
|
||||
}
|
||||
|
||||
@ -2178,9 +2180,9 @@ void transform::SequenceOp::getEffects(
|
||||
getPotentialTopLevelEffects(effects);
|
||||
}
|
||||
|
||||
OperandRange transform::SequenceOp::getEntrySuccessorOperands(
|
||||
std::optional<unsigned> index) {
|
||||
assert(index && *index == 0 && "unexpected region index");
|
||||
OperandRange
|
||||
transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
||||
assert(point == getBody() && "unexpected region index");
|
||||
if (getOperation()->getNumOperands() > 0)
|
||||
return getOperation()->getOperands();
|
||||
return OperandRange(getOperation()->operand_end(),
|
||||
@ -2188,8 +2190,8 @@ OperandRange transform::SequenceOp::getEntrySuccessorOperands(
|
||||
}
|
||||
|
||||
void transform::SequenceOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (!index) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (point.isParent()) {
|
||||
Region *bodyRegion = &getBody();
|
||||
regions.emplace_back(bodyRegion, getNumOperands() != 0
|
||||
? bodyRegion->getArguments()
|
||||
@ -2197,7 +2199,7 @@ void transform::SequenceOp::getSuccessorRegions(
|
||||
return;
|
||||
}
|
||||
|
||||
assert(*index == 0 && "unexpected region index");
|
||||
assert(point == getBody() && "unexpected region index");
|
||||
regions.emplace_back(getOperation()->getResults());
|
||||
}
|
||||
|
||||
|
@ -5821,8 +5821,8 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
|
||||
}
|
||||
|
||||
void WarpExecuteOnLane0Op::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (index) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (!point.isParent()) {
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
|
@ -84,18 +84,18 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
|
||||
// RegionBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static InFlightDiagnostic &
|
||||
printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
|
||||
std::optional<unsigned> succRegionNo) {
|
||||
static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
|
||||
RegionBranchPoint sourceNo,
|
||||
RegionBranchPoint succRegionNo) {
|
||||
diag << "from ";
|
||||
if (sourceNo)
|
||||
diag << "Region #" << sourceNo.value();
|
||||
if (Region *region = sourceNo.getRegionOrNull())
|
||||
diag << "Region #" << region->getRegionNumber();
|
||||
else
|
||||
diag << "parent operands";
|
||||
|
||||
diag << " to ";
|
||||
if (succRegionNo)
|
||||
diag << "Region #" << succRegionNo.value();
|
||||
if (Region *region = succRegionNo.getRegionOrNull())
|
||||
diag << "Region #" << region->getRegionNumber();
|
||||
else
|
||||
diag << "parent results";
|
||||
return diag;
|
||||
@ -107,28 +107,24 @@ printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
|
||||
/// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
|
||||
/// the exact type match verification is not necessary (e.g., if the Op verifies
|
||||
/// the match itself).
|
||||
static LogicalResult verifyTypesAlongAllEdges(
|
||||
Operation *op, std::optional<unsigned> sourceNo,
|
||||
function_ref<FailureOr<TypeRange>(std::optional<unsigned>)>
|
||||
static LogicalResult
|
||||
verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
|
||||
function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
|
||||
getInputsTypesForRegion) {
|
||||
auto regionInterface = cast<RegionBranchOpInterface>(op);
|
||||
|
||||
SmallVector<RegionSuccessor, 2> successors;
|
||||
regionInterface.getSuccessorRegions(sourceNo, successors);
|
||||
regionInterface.getSuccessorRegions(sourcePoint, successors);
|
||||
|
||||
for (RegionSuccessor &succ : successors) {
|
||||
std::optional<unsigned> succRegionNo;
|
||||
if (!succ.isParent())
|
||||
succRegionNo = succ.getSuccessor()->getRegionNumber();
|
||||
|
||||
FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
|
||||
FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
|
||||
if (failed(sourceTypes))
|
||||
return failure();
|
||||
|
||||
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
|
||||
if (sourceTypes->size() != succInputsTypes.size()) {
|
||||
InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
|
||||
return printRegionEdgeName(diag, sourceNo, succRegionNo)
|
||||
return printRegionEdgeName(diag, sourcePoint, succ)
|
||||
<< ": source has " << sourceTypes->size()
|
||||
<< " operands, but target successor needs "
|
||||
<< succInputsTypes.size();
|
||||
@ -140,7 +136,7 @@ static LogicalResult verifyTypesAlongAllEdges(
|
||||
Type inputType = std::get<1>(typesIdx.value());
|
||||
if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
|
||||
InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
|
||||
return printRegionEdgeName(diag, sourceNo, succRegionNo)
|
||||
return printRegionEdgeName(diag, sourcePoint, succ)
|
||||
<< ": source type #" << typesIdx.index() << " " << sourceType
|
||||
<< " should match input type #" << typesIdx.index() << " "
|
||||
<< inputType;
|
||||
@ -154,13 +150,13 @@ static LogicalResult verifyTypesAlongAllEdges(
|
||||
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
|
||||
auto regionInterface = cast<RegionBranchOpInterface>(op);
|
||||
|
||||
auto inputTypesFromParent =
|
||||
[&](std::optional<unsigned> regionNo) -> TypeRange {
|
||||
auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange {
|
||||
return regionInterface.getEntrySuccessorOperands(regionNo).getTypes();
|
||||
};
|
||||
|
||||
// Verify types along control flow edges originating from the parent.
|
||||
if (failed(verifyTypesAlongAllEdges(op, std::nullopt, inputTypesFromParent)))
|
||||
if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
|
||||
inputTypesFromParent)))
|
||||
return failure();
|
||||
|
||||
auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
|
||||
@ -176,8 +172,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
|
||||
};
|
||||
|
||||
// Verify types along control flow edges originating from each region.
|
||||
for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
|
||||
Region ®ion = op->getRegion(regionNo);
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
|
||||
// Since there can be multiple terminators implementing the
|
||||
// `RegionBranchTerminatorOpInterface`, all should have the same operand
|
||||
@ -195,7 +190,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
|
||||
continue;
|
||||
|
||||
auto inputTypesForRegion =
|
||||
[&](std::optional<unsigned> succRegionNo) -> FailureOr<TypeRange> {
|
||||
[&](RegionBranchPoint succRegionNo) -> FailureOr<TypeRange> {
|
||||
std::optional<OperandRange> regionReturnOperands;
|
||||
for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
|
||||
auto terminatorOperands =
|
||||
@ -211,7 +206,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
|
||||
if (!areTypesCompatible(regionReturnOperands->getTypes(),
|
||||
terminatorOperands.getTypes())) {
|
||||
InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
|
||||
return printRegionEdgeName(diag, regionNo, succRegionNo)
|
||||
return printRegionEdgeName(diag, region, succRegionNo)
|
||||
<< " operands mismatch between return-like terminators";
|
||||
}
|
||||
}
|
||||
@ -220,7 +215,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
|
||||
return TypeRange(regionReturnOperands->getTypes());
|
||||
};
|
||||
|
||||
if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesForRegion)))
|
||||
if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
@ -237,24 +232,24 @@ static bool isRegionReachable(Region *begin, Region *r) {
|
||||
visited[begin->getRegionNumber()] = true;
|
||||
|
||||
// Retrieve all successors of the region and enqueue them in the worklist.
|
||||
SmallVector<unsigned> worklist;
|
||||
auto enqueueAllSuccessors = [&](unsigned index) {
|
||||
SmallVector<Region *> worklist;
|
||||
auto enqueueAllSuccessors = [&](Region *region) {
|
||||
SmallVector<RegionSuccessor> successors;
|
||||
op.getSuccessorRegions(index, successors);
|
||||
op.getSuccessorRegions(region, successors);
|
||||
for (RegionSuccessor successor : successors)
|
||||
if (!successor.isParent())
|
||||
worklist.push_back(successor.getSuccessor()->getRegionNumber());
|
||||
worklist.push_back(successor.getSuccessor());
|
||||
};
|
||||
enqueueAllSuccessors(begin->getRegionNumber());
|
||||
enqueueAllSuccessors(begin);
|
||||
|
||||
// Process all regions in the worklist via DFS.
|
||||
while (!worklist.empty()) {
|
||||
unsigned nextRegion = worklist.pop_back_val();
|
||||
if (nextRegion == r->getRegionNumber())
|
||||
Region *nextRegion = worklist.pop_back_val();
|
||||
if (nextRegion == r)
|
||||
return true;
|
||||
if (visited[nextRegion])
|
||||
if (visited[nextRegion->getRegionNumber()])
|
||||
continue;
|
||||
visited[nextRegion] = true;
|
||||
visited[nextRegion->getRegionNumber()] = true;
|
||||
enqueueAllSuccessors(nextRegion);
|
||||
}
|
||||
|
||||
|
@ -316,15 +316,11 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
|
||||
// Return the successors of `region` if the latter is not null. Else return
|
||||
// the successors of `regionBranchOp`.
|
||||
auto getSuccessors = [&](Region *region = nullptr) {
|
||||
std::optional<unsigned> index =
|
||||
region ? std::optional(region->getRegionNumber()) : std::nullopt;
|
||||
auto point = region ? region : RegionBranchPoint::parent();
|
||||
SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(),
|
||||
nullptr);
|
||||
SmallVector<RegionSuccessor> successors;
|
||||
if (!index)
|
||||
regionBranchOp.getEntrySuccessorRegions(operandAttributes, successors);
|
||||
else
|
||||
regionBranchOp.getSuccessorRegions(index, successors);
|
||||
regionBranchOp.getSuccessorRegions(point, successors);
|
||||
return successors;
|
||||
};
|
||||
|
||||
@ -333,14 +329,10 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
|
||||
// forwarded to `successor`.
|
||||
auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
|
||||
Operation *terminator = nullptr) {
|
||||
Region *successorRegion = successor.getSuccessor();
|
||||
std::optional<unsigned> index =
|
||||
successorRegion ? std::optional(successorRegion->getRegionNumber())
|
||||
: std::nullopt;
|
||||
OperandRange operands =
|
||||
terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
|
||||
.getSuccessorOperands(index)
|
||||
: regionBranchOp.getEntrySuccessorOperands(index);
|
||||
.getSuccessorOperands(successor)
|
||||
: regionBranchOp.getEntrySuccessorOperands(successor);
|
||||
SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
|
||||
return opOperands;
|
||||
};
|
||||
|
@ -60,8 +60,8 @@ public:
|
||||
NextAccess *before) override;
|
||||
|
||||
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
|
||||
std::optional<unsigned> regionFrom,
|
||||
std::optional<unsigned> regionTo,
|
||||
RegionBranchPoint regionFrom,
|
||||
RegionBranchPoint regionTo,
|
||||
const NextAccess &after,
|
||||
NextAccess *before) override;
|
||||
|
||||
@ -124,15 +124,15 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
|
||||
}
|
||||
|
||||
void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
|
||||
RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
|
||||
std::optional<unsigned> regionTo, const NextAccess &after,
|
||||
NextAccess *before) {
|
||||
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
|
||||
RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
|
||||
auto testStoreWithARegion =
|
||||
dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
|
||||
|
||||
if (testStoreWithARegion &&
|
||||
((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
|
||||
(!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
|
||||
((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
|
||||
(regionFrom.isParent() &&
|
||||
testStoreWithARegion.getStoreBeforeRegion()))) {
|
||||
visitOperation(branch, static_cast<const NextAccess &>(after),
|
||||
static_cast<NextAccess *>(before));
|
||||
} else {
|
||||
@ -219,7 +219,7 @@ struct TestNextAccessPass
|
||||
|
||||
SmallVector<Attribute> entryPointNextAccess;
|
||||
SmallVector<RegionSuccessor> regionSuccessors;
|
||||
iface.getSuccessorRegions(std::nullopt, regionSuccessors);
|
||||
iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
|
||||
for (const RegionSuccessor &successor : regionSuccessors) {
|
||||
if (!successor.getSuccessor() || successor.getSuccessor()->empty())
|
||||
continue;
|
||||
|
@ -931,17 +931,17 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
OperandRange
|
||||
RegionIfOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
|
||||
assert(index && *index < 2 && "invalid region index");
|
||||
OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
||||
assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
|
||||
"invalid region index");
|
||||
return getOperands();
|
||||
}
|
||||
|
||||
void RegionIfOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// We always branch to the join region.
|
||||
if (index.has_value()) {
|
||||
if (index.value() < 2)
|
||||
if (!point.isParent()) {
|
||||
if (point != getJoinRegion())
|
||||
regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
|
||||
else
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
@ -964,11 +964,11 @@ void RegionIfOp::getRegionInvocationBounds(
|
||||
// AnyCondOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AnyCondOp::getSuccessorRegions(std::optional<unsigned> index,
|
||||
void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// The parent op branches into the only region, and the region branches back
|
||||
// to the parent op.
|
||||
if (!index)
|
||||
if (point.isParent())
|
||||
regions.emplace_back(&getRegion());
|
||||
else
|
||||
regions.emplace_back(getResults());
|
||||
@ -985,17 +985,16 @@ void AnyCondOp::getRegionInvocationBounds(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LoopBlockOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
regions.emplace_back(&getBody(), getBody().getArguments());
|
||||
if (!index)
|
||||
if (point.isParent())
|
||||
return;
|
||||
|
||||
regions.emplace_back((*this)->getResults());
|
||||
}
|
||||
|
||||
OperandRange
|
||||
LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
|
||||
assert(index == 0);
|
||||
OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
||||
assert(point == getBody());
|
||||
return getInitMutable();
|
||||
}
|
||||
|
||||
@ -1003,10 +1002,9 @@ LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
|
||||
// LoopBlockTerminatorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MutableOperandRange LoopBlockTerminatorOp::getMutableSuccessorOperands(
|
||||
std::optional<unsigned> index) {
|
||||
assert(!index || index == 0);
|
||||
if (!index)
|
||||
MutableOperandRange
|
||||
LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
|
||||
if (point.isParent())
|
||||
return getExitArgMutable();
|
||||
return getNextIterArgMutable();
|
||||
}
|
||||
@ -1313,13 +1311,12 @@ MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
|
||||
}
|
||||
|
||||
void TestStoreWithARegion::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (!index) {
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (point.isParent())
|
||||
regions.emplace_back(&getBody(), getBody().front().getArguments());
|
||||
} else {
|
||||
else
|
||||
regions.emplace_back();
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
|
||||
|
@ -37,7 +37,7 @@ struct MutuallyExclusiveRegionsOp
|
||||
}
|
||||
|
||||
// Regions have no successors.
|
||||
void getSuccessorRegions(std::optional<unsigned> index,
|
||||
void getSuccessorRegions(RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {}
|
||||
};
|
||||
|
||||
@ -51,14 +51,13 @@ struct LoopRegionsOp
|
||||
|
||||
static StringRef getOperationName() { return "cftest.loop_regions_op"; }
|
||||
|
||||
void getSuccessorRegions(std::optional<unsigned> index,
|
||||
void getSuccessorRegions(RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (index) {
|
||||
if (*index == 1)
|
||||
if (Region *region = point.getRegionOrNull()) {
|
||||
if (point == (*this)->getRegion(1))
|
||||
// This region also branches back to the parent.
|
||||
regions.push_back(RegionSuccessor());
|
||||
regions.push_back(
|
||||
RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
|
||||
regions.push_back(RegionSuccessor(region));
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -74,11 +73,11 @@ struct DoubleLoopRegionsOp
|
||||
return "cftest.double_loop_regions_op";
|
||||
}
|
||||
|
||||
void getSuccessorRegions(std::optional<unsigned> index,
|
||||
void getSuccessorRegions(RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (index.has_value()) {
|
||||
if (Region *region = point.getRegionOrNull()) {
|
||||
regions.push_back(RegionSuccessor());
|
||||
regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index)));
|
||||
regions.push_back(RegionSuccessor(region));
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -92,9 +91,9 @@ struct SequentialRegionsOp
|
||||
static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
|
||||
|
||||
// Region 0 has Region 1 as a successor.
|
||||
void getSuccessorRegions(std::optional<unsigned> index,
|
||||
void getSuccessorRegions(RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (index == 0u) {
|
||||
if (point == (*this)->getRegion(0)) {
|
||||
Operation *thisOp = this->getOperation();
|
||||
regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user