[mlir][Interfaces][NFC] Better documentation for RegionBranchOpInterface (#66920)

Update outdated documentation and add an example.
This commit is contained in:
Matthias Springer 2023-09-21 18:17:14 +02:00 committed by GitHub
parent 991cb14715
commit d56537a516
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 108 additions and 92 deletions

View File

@ -117,27 +117,58 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let description = [{ let description = [{
This interface provides information for region operations that contain This interface provides information for region operations that exhibit
branching behavior between held regions, i.e. this interface allows for branching behavior between held regions. I.e., this interface allows for
expressing control flow information for region holding operations. expressing control flow information for region holding operations.
This interface is meant to model well-defined cases of control-flow of This interface is meant to model well-defined cases of control-flow and
value propagation, where what occurs along control-flow edges is assumed to value propagation, where what occurs along control-flow edges is assumed to
be side-effect free. For example, corresponding successor operands and be side-effect free.
successor block arguments may have different types. In such cases,
`areTypesCompatible` can be implemented to compare types along control-flow A "region branch point" indicates a point from which a branch originates. It
edges. By default, type equality is used. can indicate either a region of this op or `RegionBranchPoint::parent()`. In
the latter case, the branch originates from outside of the op, i.e., when
first executing this op.
A "region successor" indicates the target of a branch. It can indicate
either a region of this op or this op. In the former case, the region
successor is a region pointer and a range of block arguments to which the
"successor operands" are forwarded to. In the latter case, the control flow
leaves this op and the region successor is a range of results of this op to
which the successor operands are forwarded to.
By default, successor operands and successor block arguments/successor
results must have the same type. `areTypesCompatible` can be implemented to
allow non-equal types.
Example:
```
%r = scf.for %iv = %lb to %ub step %step iter_args(%a = %b)
-> tensor<5xf32> {
...
scf.yield %c : tensor<5xf32>
}
```
`scf.for` has one region. The region has two region successors: the region
itself and the `scf.for` op. %b is an entry successor operand. %c is a
successor operand. %a is a successor block argument. %r is a successor
result.
}]; }];
let cppNamespace = "::mlir"; let cppNamespace = "::mlir";
let methods = [ let methods = [
InterfaceMethod<[{ InterfaceMethod<[{
Returns the operands of this operation used as the entry arguments when Returns the operands of this operation that are forwarded to the region
branching from `point`, which was specified as a successor of successor's block arguments or this operation's results when branching
this operation by `getEntrySuccessorRegions`, or the operands forwarded to `point`. `point` is guaranteed to be among the successors that are
to the operation's results when it branches back to itself. These operands returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
should correspond 1-1 with the successor inputs specified in
`getEntrySuccessorRegions`. Example: In the above example, this method returns the operand %b of the
`scf.for` op, regardless of the value of `point`. I.e., this op always
forwards the same operands, regardless of whether the loop has 0 or more
iterations.
}], }],
"::mlir::OperandRange", "getEntrySuccessorOperands", "::mlir::OperandRange", "getEntrySuccessorOperands",
(ins "::mlir::RegionBranchPoint":$point), [{}], (ins "::mlir::RegionBranchPoint":$point), [{}],
@ -147,32 +178,47 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
}] }]
>, >,
InterfaceMethod<[{ InterfaceMethod<[{
Returns the viable region successors that are branched to when first Returns the potential region successors when first executing the op.
executing the op.
Unlike `getSuccessorRegions`, this method also passes along the
constant operands of this op. Based on these, different region
successors can be determined.
`operands` contains an entry for every operand of the implementing
op with a null attribute if the operand has no constant value or
the corresponding attribute if it is a constant.
By default, simply dispatches to `getSuccessorRegions`. Unlike `getSuccessorRegions`, this method also passes along the
constant operands of this op. Based on these, the implementation may
filter out certain successors. By default, simply dispatches to
`getSuccessorRegions`. `operands` contains an entry for every
operand of this op, with a null attribute if the operand has no constant
value.
Note: The control flow does not necessarily have to enter any region of
this op.
Example: In the above example, this method may return two region
region successors: the single region of the `scf.for` op and the
`scf.for` operation (that implements this interface). If %lb, %ub, %step
are constants and it can be determined the loop does not have any
iterations, this method may choose to return only this operation.
Similarly, if it can be determined that the loop has at least one
iteration, this method may choose to return only the region of the loop.
}], }],
"void", "getEntrySuccessorRegions", "void", "getEntrySuccessorRegions",
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands, (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
[{}], [{ /*defaultImplementation=*/[{
$_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions); $_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
}] }]
>, >,
InterfaceMethod<[{ InterfaceMethod<[{
Returns the viable successors of `point`. These are the regions that may Returns the potential region successors when branching from `point`.
be selected during the flow of control. The parent operation, may These are the regions that may be selected during the flow of control.
specify itself as successor, which indicates that the control flow may
not enter any region at all. This method allows for describing which When `point = RegionBranchPoint::parent()`, this method returns the
regions may be executed when entering an operation, and which regions region successors when entering the operation. Otherwise, this method
are executed after having executed another region of the parent op. The returns the successor regions when branching from the region indicated
successor region must be non-empty. by `point`.
Example: In the above example, this method returns the region of the
`scf.for` and this operation for either region branch point (`parent`
and the region of the `scf.for`). An implementation may choose to filter
out region successors when it is statically known (e.g., by examining
the operands of this op) that those successors are not branched to.
}], }],
"void", "getSuccessorRegions", "void", "getSuccessorRegions",
(ins "::mlir::RegionBranchPoint":$point, (ins "::mlir::RegionBranchPoint":$point,
@ -183,12 +229,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
times this operation will invoke the attached regions (assuming the times this operation will invoke the attached regions (assuming the
regions yield normally, i.e. do not abort or invoke an infinite loop). regions yield normally, i.e. do not abort or invoke an infinite loop).
The minimum number of invocations is at least 0. If the maximum number The minimum number of invocations is at least 0. If the maximum number
of invocations cannot be statically determined, then it will not have a of invocations cannot be statically determined, then it will be set to
value (i.e., it is set to `std::nullopt`). `InvocationBounds::getUnknown()`.
`operands` is a set of optional attributes that either correspond to This method also passes along the constant operands of this op.
constant values for each operand of this operation or null if that `operands` contains an entry for every operand of this op, with a null
operand is not a constant. attribute if the operand has no constant value.
This method may be called speculatively on operations where the provided This method may be called speculatively on operations where the provided
operands are not necessarily the same as the operation's current operands are not necessarily the same as the operation's current
@ -199,8 +245,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands, (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::InvocationBounds> &" "::llvm::SmallVectorImpl<::mlir::InvocationBounds> &"
:$invocationBounds), [{}], :$invocationBounds), [{}],
[{ invocationBounds.append($_op->getNumRegions(), /*defaultImplementation=*/[{
::mlir::InvocationBounds::getUnknown()); }] invocationBounds.append($_op->getNumRegions(),
::mlir::InvocationBounds::getUnknown());
}]
>, >,
InterfaceMethod<[{ InterfaceMethod<[{
This method is called to compare types along control-flow edges. By This method is called to compare types along control-flow edges. By
@ -208,7 +256,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
}], }],
"bool", "areTypesCompatible", "bool", "areTypesCompatible",
(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}], (ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
[{ return lhs == rhs; }] /*defaultImplementation=*/[{ return lhs == rhs; }]
>, >,
]; ];
@ -235,7 +283,7 @@ def RegionBranchTerminatorOpInterface :
OpInterface<"RegionBranchTerminatorOpInterface"> { OpInterface<"RegionBranchTerminatorOpInterface"> {
let description = [{ let description = [{
This interface provides information for branching terminator operations This interface provides information for branching terminator operations
in the presence of a parent RegionBranchOpInterface implementation. It in the presence of a parent `RegionBranchOpInterface` implementation. It
specifies which operands are passed to which successor region. specifies which operands are passed to which successor region.
}]; }];
let cppNamespace = "::mlir"; let cppNamespace = "::mlir";
@ -243,26 +291,26 @@ def RegionBranchTerminatorOpInterface :
let methods = [ let methods = [
InterfaceMethod<[{ InterfaceMethod<[{
Returns a mutable range of operands that are semantically "returned" by Returns a mutable range of operands that are semantically "returned" by
passing them to the region successor given by `point`. passing them to the region successor indicated by `point`.
}], }],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands", "::mlir::MutableOperandRange", "getMutableSuccessorOperands",
(ins "::mlir::RegionBranchPoint":$point) (ins "::mlir::RegionBranchPoint":$point)
>, >,
InterfaceMethod<[{ InterfaceMethod<[{
Returns the viable region successors that are branched to after this Returns the potential region successors that are branched to after this
terminator based on the given constant operands. terminator based on the given constant operands.
`operands` contains an entry for every operand of the This method also passes along the constant operands of this op.
implementing op with a null attribute if the operand has no constant `operands` contains an entry for every operand of this op, with a null
value or the corresponding attribute if it is a constant. attribute if the operand has no constant value.
Default implementation simply dispatches to the parent The default implementation simply dispatches to the parent
`RegionBranchOpInterface`'s `getSuccessorRegions` implementation. `RegionBranchOpInterface`'s `getSuccessorRegions` implementation.
}], }],
"void", "getSuccessorRegions", "void", "getSuccessorRegions",
(ins "::llvm::ArrayRef<::mlir::Attribute>":$operands, (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}], "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
[{ /*defaultImplementation=*/[{
::mlir::Operation *op = $_op; ::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp()) ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
.getSuccessorRegions(op->getParentRegion(), regions); .getSuccessorRegions(op->getParentRegion(), regions);

View File

@ -2375,10 +2375,6 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<AffineForEmptyLoopFolder>(context); results.add<AffineForEmptyLoopFolder>(context);
} }
/// Return operands used when entering the region at 'index'. These operands
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable. AffineForOp only has one region, so zero is the only
/// valid value for `index`.
OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert((point.isParent() || point == getRegion()) && "invalid region point"); assert((point.isParent() || point == getRegion()) && "invalid region point");
@ -2387,11 +2383,6 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getInits(); return getInits();
} }
/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void AffineForOp::getSuccessorRegions( void AffineForOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) { RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
assert((point.isParent() || point == getRegion()) && "expected loop region"); assert((point.isParent() || point == getRegion()) && "expected loop region");

View File

@ -260,11 +260,6 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
} }
/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ExecuteRegionOp::getSuccessorRegions( void ExecuteRegionOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) { RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// If the predecessor is the ExecuteRegionOp, branch into the body. // If the predecessor is the ExecuteRegionOp, branch into the body.
@ -543,18 +538,10 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
return dyn_cast_or_null<ForOp>(containingOp); return dyn_cast_or_null<ForOp>(containingOp);
} }
/// Return operands used when entering the region at 'index'. These operands
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable.
OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getInitArgs(); return getInitArgs();
} }
/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ForOp::getSuccessorRegions(RegionBranchPoint point, void ForOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) { SmallVectorImpl<RegionSuccessor> &regions) {
// Both the operation itself and the region may be branching into the body or // Both the operation itself and the region may be branching into the body or
@ -1999,11 +1986,6 @@ void IfOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs()); p.printOptionalAttrDict((*this)->getAttrs());
} }
/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void IfOp::getSuccessorRegions(RegionBranchPoint point, void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) { SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation. // The `then` and the `else` region branch back to the parent operation.
@ -3162,13 +3144,6 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
} }
OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBefore() &&
"WhileOp is expected to branch only to the first region");
return getInits();
}
ConditionOp WhileOp::getConditionOp() { ConditionOp WhileOp::getConditionOp() {
return cast<ConditionOp>(getBeforeBody()->getTerminator()); return cast<ConditionOp>(getBeforeBody()->getTerminator());
} }
@ -3189,6 +3164,12 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() {
return getBeforeArguments(); return getBeforeArguments();
} }
OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBefore() &&
"WhileOp is expected to branch only to the first region");
return getInits();
}
void WhileOp::getSuccessorRegions(RegionBranchPoint point, void WhileOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) { SmallVectorImpl<RegionSuccessor> &regions) {
// The parent op always branches to the condition region. // The parent op always branches to the condition region.

View File

@ -102,11 +102,8 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
} }
/// Verify that types match along all region control flow edges originating from /// Verify that types match along all region control flow edges originating from
/// `sourceNo` (region # if source is a region, std::nullopt if source is parent /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
/// op). `getInputsTypesForRegion` is a function that returns the types of the /// types of the inputs that flow to a successor region.
/// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
/// the exact type match verification is not necessary (e.g., if the Op verifies
/// the match itself).
static LogicalResult static LogicalResult
verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
function_ref<FailureOr<TypeRange>(RegionBranchPoint)> function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
@ -150,8 +147,8 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op); auto regionInterface = cast<RegionBranchOpInterface>(op);
auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange { auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
return regionInterface.getEntrySuccessorOperands(regionNo).getTypes(); return regionInterface.getEntrySuccessorOperands(point).getTypes();
}; };
// Verify types along control flow edges originating from the parent. // Verify types along control flow edges originating from the parent.
@ -190,11 +187,10 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
continue; continue;
auto inputTypesForRegion = auto inputTypesForRegion =
[&](RegionBranchPoint succRegionNo) -> FailureOr<TypeRange> { [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
std::optional<OperandRange> regionReturnOperands; std::optional<OperandRange> regionReturnOperands;
for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
auto terminatorOperands = auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
regionReturnOp.getSuccessorOperands(succRegionNo);
if (!regionReturnOperands) { if (!regionReturnOperands) {
regionReturnOperands = terminatorOperands; regionReturnOperands = terminatorOperands;
@ -206,7 +202,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
if (!areTypesCompatible(regionReturnOperands->getTypes(), if (!areTypesCompatible(regionReturnOperands->getTypes(),
terminatorOperands.getTypes())) { terminatorOperands.getTypes())) {
InFlightDiagnostic diag = op->emitOpError(" along control flow edge"); InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
return printRegionEdgeName(diag, region, succRegionNo) return printRegionEdgeName(diag, region, point)
<< " operands mismatch between return-like terminators"; << " operands mismatch between return-like terminators";
} }
} }