[mlir][transform] Support results on ForeachOp

Handles can be yielded from the ForeachOp.

Differential Revision: https://reviews.llvm.org/D130640
This commit is contained in:
Matthias Springer 2022-07-27 17:58:24 +02:00
parent d5a3cc1d88
commit c1e6caac70
3 changed files with 90 additions and 2 deletions

View File

@ -118,12 +118,17 @@ def ForeachOp : TransformDialectOp<"foreach",
the entire sequence fails immediately leaving the payload IR in potentially
invalid state, i.e., this operation offers no transformation rollback
capabilities.
This op generates as many handles as the terminating YieldOp has operands.
For each result, the payload ops of the corresponding YieldOp operand are
merged and mapped to the same resulting handle.
}];
let arguments = (ins PDL_Operation:$target);
let results = (outs);
let results = (outs Variadic<PDL_Operation>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat = "$target $body attr-dict";
let assemblyFormat = "$target (`->` type($results)^)? $body attr-dict";
let hasVerifier = 1;
let extraClassDeclaration = [{
/// Allow the dialect prefix to be omitted.
@ -132,6 +137,8 @@ def ForeachOp : TransformDialectOp<"foreach",
BlockArgument getIterationVariable() {
return getBody().front().getArgument(0);
}
transform::YieldOp getYieldOp();
}];
}

View File

@ -281,18 +281,32 @@ DiagnosedSilenceableFailure
transform::ForeachOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
for (Operation *op : payloadOps) {
auto scope = state.make_region_scope(getBody());
if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
return DiagnosedSilenceableFailure::definiteFailure();
// Execute loop body.
for (Operation &transform : getBody().front().without_terminator()) {
DiagnosedSilenceableFailure result = state.applyTransform(
cast<transform::TransformOpInterface>(transform));
if (!result.succeeded())
return result;
}
// Append yielded payload ops to result list (if any).
for (unsigned i = 0; i < getNumResults(); ++i) {
ArrayRef<Operation *> yieldedOps =
state.getPayloadOps(getYieldOp().getOperand(i));
resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
}
}
for (unsigned i = 0; i < getNumResults(); ++i)
results.set(getResult(i).cast<OpResult>(), resultOps[i]);
return DiagnosedSilenceableFailure::success();
}
@ -306,6 +320,9 @@ void transform::ForeachOp::getEffects(
} else {
onlyReadsHandle(getTarget(), effects);
}
for (Value result : getResults())
producesHandle(result, effects);
}
void transform::ForeachOp::getSuccessorRegions(
@ -331,6 +348,21 @@ transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
return getOperation()->getOperands();
}
transform::YieldOp transform::ForeachOp::getYieldOp() {
return cast<transform::YieldOp>(getBody().front().getTerminator());
}
LogicalResult transform::ForeachOp::verify() {
auto yieldOp = getYieldOp();
if (getNumResults() != yieldOp.getNumOperands())
return emitOpError() << "expects the same number of results as the "
"terminator has operands";
for (Value v : yieldOp.getOperands())
if (!v.getType().isa<pdl::OperationType>())
return yieldOp->emitOpError("expects only PDL_Operation operands");
return success();
}
//===----------------------------------------------------------------------===//
// GetClosestIsolatedParentOp
//===----------------------------------------------------------------------===//

View File

@ -627,3 +627,52 @@ transform.with_pdl_patterns {
}
}
}
// -----
func.func @bar() {
scf.execute_region {
// expected-remark @below {{transform applied}}
%0 = arith.constant 0 : i32
scf.yield
}
scf.execute_region {
// expected-remark @below {{transform applied}}
%1 = arith.constant 1 : i32
// expected-remark @below {{transform applied}}
%2 = arith.constant 2 : i32
scf.yield
}
return
}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @const : benefit(1) {
%r = pdl.types
%0 = pdl.operation "arith.constant" -> (%r : !pdl.range<type>)
pdl.rewrite %0 with "transform.dialect"
}
pdl.pattern @execute_region : benefit(1) {
%r = pdl.types
%0 = pdl.operation "scf.execute_region" -> (%r : !pdl.range<type>)
pdl.rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%f = pdl_match @execute_region in %arg1
%results = transform.foreach %f -> !pdl.operation {
^bb2(%arg2: !pdl.operation):
%g = transform.pdl_match @const in %arg2
transform.yield %g : !pdl.operation
}
// expected-remark @below {{3}}
transform.test_print_number_of_associated_payload_ir_ops %results
transform.test_print_remark_at_operand %results, "transform applied"
}
}