mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 04:16:10 +00:00
[mlir][transform] Support results on ForeachOp
Handles can be yielded from the ForeachOp. Differential Revision: https://reviews.llvm.org/D130640
This commit is contained in:
parent
d5a3cc1d88
commit
c1e6caac70
@ -118,12 +118,17 @@ def ForeachOp : TransformDialectOp<"foreach",
|
||||
the entire sequence fails immediately leaving the payload IR in potentially
|
||||
invalid state, i.e., this operation offers no transformation rollback
|
||||
capabilities.
|
||||
|
||||
This op generates as many handles as the terminating YieldOp has operands.
|
||||
For each result, the payload ops of the corresponding YieldOp operand are
|
||||
merged and mapped to the same resulting handle.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target);
|
||||
let results = (outs);
|
||||
let results = (outs Variadic<PDL_Operation>:$results);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
let assemblyFormat = "$target $body attr-dict";
|
||||
let assemblyFormat = "$target (`->` type($results)^)? $body attr-dict";
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Allow the dialect prefix to be omitted.
|
||||
@ -132,6 +137,8 @@ def ForeachOp : TransformDialectOp<"foreach",
|
||||
BlockArgument getIterationVariable() {
|
||||
return getBody().front().getArgument(0);
|
||||
}
|
||||
|
||||
transform::YieldOp getYieldOp();
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -281,18 +281,32 @@ DiagnosedSilenceableFailure
|
||||
transform::ForeachOp::apply(transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
|
||||
SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
|
||||
|
||||
for (Operation *op : payloadOps) {
|
||||
auto scope = state.make_region_scope(getBody());
|
||||
if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
|
||||
// Execute loop body.
|
||||
for (Operation &transform : getBody().front().without_terminator()) {
|
||||
DiagnosedSilenceableFailure result = state.applyTransform(
|
||||
cast<transform::TransformOpInterface>(transform));
|
||||
if (!result.succeeded())
|
||||
return result;
|
||||
}
|
||||
|
||||
// Append yielded payload ops to result list (if any).
|
||||
for (unsigned i = 0; i < getNumResults(); ++i) {
|
||||
ArrayRef<Operation *> yieldedOps =
|
||||
state.getPayloadOps(getYieldOp().getOperand(i));
|
||||
resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < getNumResults(); ++i)
|
||||
results.set(getResult(i).cast<OpResult>(), resultOps[i]);
|
||||
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
@ -306,6 +320,9 @@ void transform::ForeachOp::getEffects(
|
||||
} else {
|
||||
onlyReadsHandle(getTarget(), effects);
|
||||
}
|
||||
|
||||
for (Value result : getResults())
|
||||
producesHandle(result, effects);
|
||||
}
|
||||
|
||||
void transform::ForeachOp::getSuccessorRegions(
|
||||
@ -331,6 +348,21 @@ transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
|
||||
return getOperation()->getOperands();
|
||||
}
|
||||
|
||||
transform::YieldOp transform::ForeachOp::getYieldOp() {
|
||||
return cast<transform::YieldOp>(getBody().front().getTerminator());
|
||||
}
|
||||
|
||||
LogicalResult transform::ForeachOp::verify() {
|
||||
auto yieldOp = getYieldOp();
|
||||
if (getNumResults() != yieldOp.getNumOperands())
|
||||
return emitOpError() << "expects the same number of results as the "
|
||||
"terminator has operands";
|
||||
for (Value v : yieldOp.getOperands())
|
||||
if (!v.getType().isa<pdl::OperationType>())
|
||||
return yieldOp->emitOpError("expects only PDL_Operation operands");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetClosestIsolatedParentOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -627,3 +627,52 @@ transform.with_pdl_patterns {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @bar() {
|
||||
scf.execute_region {
|
||||
// expected-remark @below {{transform applied}}
|
||||
%0 = arith.constant 0 : i32
|
||||
scf.yield
|
||||
}
|
||||
|
||||
scf.execute_region {
|
||||
// expected-remark @below {{transform applied}}
|
||||
%1 = arith.constant 1 : i32
|
||||
// expected-remark @below {{transform applied}}
|
||||
%2 = arith.constant 2 : i32
|
||||
scf.yield
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @const : benefit(1) {
|
||||
%r = pdl.types
|
||||
%0 = pdl.operation "arith.constant" -> (%r : !pdl.range<type>)
|
||||
pdl.rewrite %0 with "transform.dialect"
|
||||
}
|
||||
|
||||
pdl.pattern @execute_region : benefit(1) {
|
||||
%r = pdl.types
|
||||
%0 = pdl.operation "scf.execute_region" -> (%r : !pdl.range<type>)
|
||||
pdl.rewrite %0 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%f = pdl_match @execute_region in %arg1
|
||||
%results = transform.foreach %f -> !pdl.operation {
|
||||
^bb2(%arg2: !pdl.operation):
|
||||
%g = transform.pdl_match @const in %arg2
|
||||
transform.yield %g : !pdl.operation
|
||||
}
|
||||
|
||||
// expected-remark @below {{3}}
|
||||
transform.test_print_number_of_associated_payload_ir_ops %results
|
||||
transform.test_print_remark_at_operand %results, "transform applied"
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user