mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 13:26:09 +00:00
The refactor had a bug where the fused loop was inserted in an incorrect location. This patch fixes the bug and relands the original PR https://github.com/llvm/llvm-project/pull/94391. This patch refactors code related to LoopFuseSiblingOp transform in attempt to reduce duplicate common code. The aim is to refactor as much as possible to a functions on LoopLikeOpInterfaces, but this is still a work in progress. A full refactor will require more additions to the LoopLikeOpInterface. In addition, scf.parallel fusion support has been added.
This commit is contained in:
parent
c156d42185
commit
edbc0e30a9
@ -303,7 +303,8 @@ def ForallOp : SCF_Op<"forall", [
|
||||
DeclareOpInterfaceMethods<LoopLikeOpInterface,
|
||||
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
|
||||
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
|
||||
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
|
||||
"replaceWithAdditionalYields", "promoteIfSingleIteration",
|
||||
"yieldTiledValuesAndReplace"]>,
|
||||
RecursiveMemoryEffects,
|
||||
SingleBlockImplicitTerminator<"scf::InParallelOp">,
|
||||
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||
|
@ -181,6 +181,16 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
|
||||
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
|
||||
scf::ForOp root);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fusion related helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Check structural compatibility between two loops such as iteration space
|
||||
/// and dominance.
|
||||
bool checkFusionStructuralLegality(LoopLikeOpInterface target,
|
||||
LoopLikeOpInterface source,
|
||||
Diagnostic &diag);
|
||||
|
||||
/// Given two scf.forall loops, `target` and `source`, fuses `target` into
|
||||
/// `source`. Assumes that the given loops are siblings and are independent of
|
||||
/// each other.
|
||||
@ -202,6 +212,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
|
||||
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
|
||||
RewriterBase &rewriter);
|
||||
|
||||
/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
|
||||
/// `source`. Assumes that the given loops are siblings and are independent of
|
||||
/// each other.
|
||||
///
|
||||
/// This function does not perform any legality checks and simply fuses the
|
||||
/// loops. The caller is responsible for ensuring that the loops are legal to
|
||||
/// fuse.
|
||||
scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
|
||||
scf::ParallelOp source,
|
||||
RewriterBase &rewriter);
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
|
||||
|
@ -90,4 +90,24 @@ struct JamBlockGatherer {
|
||||
/// Include the generated interface declarations.
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
/// A function that rewrites `target`'s terminator as a teminator obtained by
|
||||
/// fusing `source` into `target`.
|
||||
using FuseTerminatorFn =
|
||||
function_ref<void(RewriterBase &rewriter, LoopLikeOpInterface source,
|
||||
LoopLikeOpInterface &target, IRMapping mapping)>;
|
||||
|
||||
/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to
|
||||
/// `target`. The `NewYieldValuesFn` callback is used to pass to the
|
||||
/// `replaceWithAdditionalYields` interface method to replace the loop with a
|
||||
/// new loop with (possibly) additional yields, while the `FuseTerminatorFn`
|
||||
/// callback is repsonsible for updating the fused loop terminator.
|
||||
LoopLikeOpInterface createFused(LoopLikeOpInterface target,
|
||||
LoopLikeOpInterface source,
|
||||
RewriterBase &rewriter,
|
||||
NewYieldValuesFn newYieldValuesFn,
|
||||
FuseTerminatorFn fuseTerminatorFn);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
|
||||
|
@ -618,6 +618,44 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
|
||||
|
||||
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
|
||||
|
||||
FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
|
||||
RewriterBase &rewriter, ValueRange newInitOperands,
|
||||
bool replaceInitOperandUsesInLoop,
|
||||
const NewYieldValuesFn &newYieldValuesFn) {
|
||||
// Create a new loop before the existing one, with the extra operands.
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(getOperation());
|
||||
SmallVector<Value> inits(getOutputs());
|
||||
llvm::append_range(inits, newInitOperands);
|
||||
scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
|
||||
getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
|
||||
inits, getMapping(),
|
||||
/*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
|
||||
|
||||
// Move the loop body to the new op.
|
||||
rewriter.mergeBlocks(getBody(), newLoop.getBody(),
|
||||
newLoop.getBody()->getArguments().take_front(
|
||||
getBody()->getNumArguments()));
|
||||
|
||||
if (replaceInitOperandUsesInLoop) {
|
||||
// Replace all uses of `newInitOperands` with the corresponding basic block
|
||||
// arguments.
|
||||
for (auto &&[newOperand, oldOperand] :
|
||||
llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back(
|
||||
newInitOperands.size()))) {
|
||||
rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) {
|
||||
Operation *user = use.getOwner();
|
||||
return newLoop->isProperAncestor(user);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the old loop.
|
||||
rewriter.replaceOp(getOperation(),
|
||||
newLoop->getResults().take_front(getNumResults()));
|
||||
return cast<LoopLikeOpInterface>(newLoop.getOperation());
|
||||
}
|
||||
|
||||
/// Promotes the loop body of a forallOp to its containing block if it can be
|
||||
/// determined that the loop has a single iteration.
|
||||
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
|
||||
|
@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp,
|
||||
return 1;
|
||||
};
|
||||
|
||||
std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
|
||||
std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
|
||||
std::optional<int64_t> ubConstant =
|
||||
getConstantIntValue(forOp.getUpperBound());
|
||||
std::optional<int64_t> lbConstant =
|
||||
getConstantIntValue(forOp.getLowerBound());
|
||||
DenseMap<Operation *, unsigned> opCycles;
|
||||
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
|
||||
for (Operation &op : forOp.getBody()->getOperations()) {
|
||||
@ -447,113 +449,6 @@ void transform::TakeAssumedBranchOp::getEffects(
|
||||
// LoopFuseSiblingOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Check if `target` and `source` are siblings, in the context that `target`
|
||||
/// is being fused into `source`.
|
||||
///
|
||||
/// This is a simple check that just checks if both operations are in the same
|
||||
/// block and some checks to ensure that the fused IR does not violate
|
||||
/// dominance.
|
||||
static DiagnosedSilenceableFailure isOpSibling(Operation *target,
|
||||
Operation *source) {
|
||||
// Check if both operations are same.
|
||||
if (target == source)
|
||||
return emitSilenceableFailure(source)
|
||||
<< "target and source need to be different loops";
|
||||
|
||||
// Check if both operations are in the same block.
|
||||
if (target->getBlock() != source->getBlock())
|
||||
return emitSilenceableFailure(source)
|
||||
<< "target and source are not in the same block";
|
||||
|
||||
// Check if fusion will violate dominance.
|
||||
DominanceInfo domInfo(source);
|
||||
if (target->isBeforeInBlock(source)) {
|
||||
// Since `target` is before `source`, all users of results of `target`
|
||||
// need to be dominated by `source`.
|
||||
for (Operation *user : target->getUsers()) {
|
||||
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
|
||||
return emitSilenceableFailure(target)
|
||||
<< "user of results of target should be properly dominated by "
|
||||
"source";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Since `target` is after `source`, all values used by `target` need
|
||||
// to dominate `source`.
|
||||
|
||||
// Check if operands of `target` are dominated by `source`.
|
||||
for (Value operand : target->getOperands()) {
|
||||
Operation *operandOp = operand.getDefiningOp();
|
||||
// Operands without defining operations are block arguments. When `target`
|
||||
// and `source` occur in the same block, these operands dominate `source`.
|
||||
if (!operandOp)
|
||||
continue;
|
||||
|
||||
// Operand's defining operation should properly dominate `source`.
|
||||
if (!domInfo.properlyDominates(operandOp, source,
|
||||
/*enclosingOpOk=*/false))
|
||||
return emitSilenceableFailure(target)
|
||||
<< "operands of target should be properly dominated by source";
|
||||
}
|
||||
|
||||
// Check if values used by `target` are dominated by `source`.
|
||||
bool failed = false;
|
||||
OpOperand *failedValue = nullptr;
|
||||
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
|
||||
Operation *operandOp = operand->get().getDefiningOp();
|
||||
if (operandOp && !domInfo.properlyDominates(operandOp, source,
|
||||
/*enclosingOpOk=*/false)) {
|
||||
// `operand` is not an argument of an enclosing block and the defining
|
||||
// op of `operand` is outside `target` but does not dominate `source`.
|
||||
failed = true;
|
||||
failedValue = operand;
|
||||
}
|
||||
});
|
||||
|
||||
if (failed)
|
||||
return emitSilenceableFailure(failedValue->getOwner())
|
||||
<< "values used inside regions of target should be properly "
|
||||
"dominated by source";
|
||||
}
|
||||
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
/// Check if `target` scf.forall can be fused into `source` scf.forall.
|
||||
///
|
||||
/// This simply checks if both loops have the same bounds, steps and mapping.
|
||||
/// No attempt is made at checking that the side effects of `target` and
|
||||
/// `source` are independent of each other.
|
||||
static bool isForallWithIdenticalConfiguration(Operation *target,
|
||||
Operation *source) {
|
||||
auto targetOp = dyn_cast<scf::ForallOp>(target);
|
||||
auto sourceOp = dyn_cast<scf::ForallOp>(source);
|
||||
if (!targetOp || !sourceOp)
|
||||
return false;
|
||||
|
||||
return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
|
||||
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
|
||||
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
|
||||
targetOp.getMapping() == sourceOp.getMapping();
|
||||
}
|
||||
|
||||
/// Check if `target` scf.for can be fused into `source` scf.for.
|
||||
///
|
||||
/// This simply checks if both loops have the same bounds and steps. No attempt
|
||||
/// is made at checking that the side effects of `target` and `source` are
|
||||
/// independent of each other.
|
||||
static bool isForWithIdenticalConfiguration(Operation *target,
|
||||
Operation *source) {
|
||||
auto targetOp = dyn_cast<scf::ForOp>(target);
|
||||
auto sourceOp = dyn_cast<scf::ForOp>(source);
|
||||
if (!targetOp || !sourceOp)
|
||||
return false;
|
||||
|
||||
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
|
||||
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
|
||||
targetOp.getStep() == sourceOp.getStep();
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
@ -569,25 +464,32 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
|
||||
<< "source handle (got " << llvm::range_size(sourceOps) << ")";
|
||||
}
|
||||
|
||||
Operation *target = *targetOps.begin();
|
||||
Operation *source = *sourceOps.begin();
|
||||
auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
|
||||
auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
|
||||
if (!target || !source)
|
||||
return emitSilenceableFailure(target->getLoc())
|
||||
<< "target or source is not a loop op";
|
||||
|
||||
// Check if the target and source are siblings.
|
||||
DiagnosedSilenceableFailure diag = isOpSibling(target, source);
|
||||
if (!diag.succeeded())
|
||||
return diag;
|
||||
// Check if loops can be fused
|
||||
Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error);
|
||||
if (!mlir::checkFusionStructuralLegality(target, source, diag))
|
||||
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
|
||||
|
||||
Operation *fusedLoop;
|
||||
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
|
||||
if (isForWithIdenticalConfiguration(target, source)) {
|
||||
// TODO: Support fusion for loop-like ops besides scf.for, scf.forall
|
||||
// and scf.parallel.
|
||||
if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
|
||||
fusedLoop = fuseIndependentSiblingForLoops(
|
||||
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
|
||||
} else if (isForallWithIdenticalConfiguration(target, source)) {
|
||||
} else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
|
||||
fusedLoop = fuseIndependentSiblingForallLoops(
|
||||
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
|
||||
} else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
|
||||
fusedLoop = fuseIndependentSiblingParallelLoops(
|
||||
cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
|
||||
} else
|
||||
return emitSilenceableFailure(target->getLoc())
|
||||
<< "operations cannot be fused";
|
||||
<< "unsupported loop type for fusion";
|
||||
|
||||
assert(fusedLoop && "failed to fuse operations");
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
@ -37,24 +38,6 @@ static bool hasNestedParallelOp(ParallelOp ploop) {
|
||||
return walkResult.wasInterrupted();
|
||||
}
|
||||
|
||||
/// Verify equal iteration spaces.
|
||||
static bool equalIterationSpaces(ParallelOp firstPloop,
|
||||
ParallelOp secondPloop) {
|
||||
if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
|
||||
return false;
|
||||
|
||||
auto matchOperands = [&](const OperandRange &lhs,
|
||||
const OperandRange &rhs) -> bool {
|
||||
// TODO: Extend this to support aliases and equal constants.
|
||||
return std::equal(lhs.begin(), lhs.end(), rhs.begin());
|
||||
};
|
||||
return matchOperands(firstPloop.getLowerBound(),
|
||||
secondPloop.getLowerBound()) &&
|
||||
matchOperands(firstPloop.getUpperBound(),
|
||||
secondPloop.getUpperBound()) &&
|
||||
matchOperands(firstPloop.getStep(), secondPloop.getStep());
|
||||
}
|
||||
|
||||
/// Checks if the parallel loops have mixed access to the same buffers. Returns
|
||||
/// `true` if the first parallel loop writes to the same indices that the second
|
||||
/// loop reads.
|
||||
@ -153,9 +136,10 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
|
||||
const IRMapping &firstToSecondPloopIndices,
|
||||
llvm::function_ref<bool(Value, Value)> mayAlias) {
|
||||
Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
|
||||
return !hasNestedParallelOp(firstPloop) &&
|
||||
!hasNestedParallelOp(secondPloop) &&
|
||||
equalIterationSpaces(firstPloop, secondPloop) &&
|
||||
checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
|
||||
succeeded(verifyDependencies(firstPloop, secondPloop,
|
||||
firstToSecondPloopIndices, mayAlias));
|
||||
}
|
||||
@ -174,61 +158,9 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
|
||||
mayAlias))
|
||||
return;
|
||||
|
||||
DominanceInfo dom;
|
||||
// We are fusing first loop into second, make sure there are no users of the
|
||||
// first loop results between loops.
|
||||
for (Operation *user : firstPloop->getUsers())
|
||||
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
|
||||
return;
|
||||
|
||||
ValueRange inits1 = firstPloop.getInitVals();
|
||||
ValueRange inits2 = secondPloop.getInitVals();
|
||||
|
||||
SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
|
||||
newInitVars.append(inits2.begin(), inits2.end());
|
||||
|
||||
IRRewriter b(builder);
|
||||
b.setInsertionPoint(secondPloop);
|
||||
auto newSecondPloop = b.create<ParallelOp>(
|
||||
secondPloop.getLoc(), secondPloop.getLowerBound(),
|
||||
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
|
||||
|
||||
Block *newBlock = newSecondPloop.getBody();
|
||||
auto term1 = cast<ReduceOp>(block1->getTerminator());
|
||||
auto term2 = cast<ReduceOp>(block2->getTerminator());
|
||||
|
||||
b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
|
||||
newBlock->getArguments());
|
||||
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
|
||||
newBlock->getArguments());
|
||||
|
||||
ValueRange results = newSecondPloop.getResults();
|
||||
if (!results.empty()) {
|
||||
b.setInsertionPointToEnd(newBlock);
|
||||
|
||||
ValueRange reduceArgs1 = term1.getOperands();
|
||||
ValueRange reduceArgs2 = term2.getOperands();
|
||||
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
|
||||
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
|
||||
|
||||
auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
|
||||
|
||||
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
|
||||
term1.getReductions(), term2.getReductions()))) {
|
||||
Block &oldRedBlock = reg.front();
|
||||
Block &newRedBlock = newReduceOp.getReductions()[i].front();
|
||||
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
|
||||
newRedBlock.getArguments());
|
||||
}
|
||||
|
||||
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
|
||||
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
|
||||
}
|
||||
term1->erase();
|
||||
term2->erase();
|
||||
firstPloop.erase();
|
||||
secondPloop.erase();
|
||||
secondPloop = newSecondPloop;
|
||||
IRRewriter rewriter(builder);
|
||||
secondPloop = mlir::fuseIndependentSiblingParallelLoops(
|
||||
firstPloop, secondPloop, rewriter);
|
||||
}
|
||||
|
||||
void mlir::scf::naivelyFuseParallelOps(
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
@ -1262,54 +1263,131 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
|
||||
return tileLoops;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fusion related helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Check if `target` and `source` are siblings, in the context that `target`
|
||||
/// is being fused into `source`.
|
||||
///
|
||||
/// This is a simple check that just checks if both operations are in the same
|
||||
/// block and some checks to ensure that the fused IR does not violate
|
||||
/// dominance.
|
||||
static bool isOpSibling(Operation *target, Operation *source,
|
||||
Diagnostic &diag) {
|
||||
// Check if both operations are same.
|
||||
if (target == source) {
|
||||
diag << "target and source need to be different loops";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if both operations are in the same block.
|
||||
if (target->getBlock() != source->getBlock()) {
|
||||
diag << "target and source are not in the same block";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if fusion will violate dominance.
|
||||
DominanceInfo domInfo(source);
|
||||
if (target->isBeforeInBlock(source)) {
|
||||
// Since `target` is before `source`, all users of results of `target`
|
||||
// need to be dominated by `source`.
|
||||
for (Operation *user : target->getUsers()) {
|
||||
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
|
||||
diag << "user of results of target should "
|
||||
"be properly dominated by source";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Since `target` is after `source`, all values used by `target` need
|
||||
// to dominate `source`.
|
||||
|
||||
// Check if operands of `target` are dominated by `source`.
|
||||
for (Value operand : target->getOperands()) {
|
||||
Operation *operandOp = operand.getDefiningOp();
|
||||
// Operands without defining operations are block arguments. When `target`
|
||||
// and `source` occur in the same block, these operands dominate `source`.
|
||||
if (!operandOp)
|
||||
continue;
|
||||
|
||||
// Operand's defining operation should properly dominate `source`.
|
||||
if (!domInfo.properlyDominates(operandOp, source,
|
||||
/*enclosingOpOk=*/false)) {
|
||||
diag << "operands of target should be properly dominated by source";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if values used by `target` are dominated by `source`.
|
||||
bool failed = false;
|
||||
OpOperand *failedValue = nullptr;
|
||||
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
|
||||
Operation *operandOp = operand->get().getDefiningOp();
|
||||
if (operandOp && !domInfo.properlyDominates(operandOp, source,
|
||||
/*enclosingOpOk=*/false)) {
|
||||
// `operand` is not an argument of an enclosing block and the defining
|
||||
// op of `operand` is outside `target` but does not dominate `source`.
|
||||
failed = true;
|
||||
failedValue = operand;
|
||||
}
|
||||
});
|
||||
|
||||
if (failed) {
|
||||
diag << "values used inside regions of target should be properly "
|
||||
"dominated by source";
|
||||
diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
|
||||
LoopLikeOpInterface source,
|
||||
Diagnostic &diag) {
|
||||
if (target->getName() != source->getName()) {
|
||||
diag << "target and source must be same loop type";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool iterSpaceEq =
|
||||
target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
|
||||
target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
|
||||
target.getLoopSteps() == source.getLoopSteps();
|
||||
// TODO: Decouple checks on concrete loop types and move this function
|
||||
// somewhere for general utility for `LoopLikeOpInterface`
|
||||
if (auto forAllTarget = dyn_cast<scf::ForallOp>(*target))
|
||||
iterSpaceEq = iterSpaceEq && forAllTarget.getMapping() ==
|
||||
cast<scf::ForallOp>(*source).getMapping();
|
||||
if (!iterSpaceEq) {
|
||||
diag << "target and source iteration spaces must be equal";
|
||||
return false;
|
||||
}
|
||||
return isOpSibling(target, source, diag);
|
||||
}
|
||||
|
||||
scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
|
||||
scf::ForallOp source,
|
||||
RewriterBase &rewriter) {
|
||||
unsigned numTargetOuts = target.getNumResults();
|
||||
unsigned numSourceOuts = source.getNumResults();
|
||||
|
||||
// Create fused shared_outs.
|
||||
SmallVector<Value> fusedOuts;
|
||||
llvm::append_range(fusedOuts, target.getOutputs());
|
||||
llvm::append_range(fusedOuts, source.getOutputs());
|
||||
|
||||
// Create a new scf.forall op after the source loop.
|
||||
rewriter.setInsertionPointAfter(source);
|
||||
scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
|
||||
source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
|
||||
source.getMixedStep(), fusedOuts, source.getMapping());
|
||||
|
||||
// Map control operands.
|
||||
IRMapping mapping;
|
||||
mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
|
||||
mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
|
||||
|
||||
// Map shared outs.
|
||||
mapping.map(target.getRegionIterArgs(),
|
||||
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
|
||||
mapping.map(source.getRegionIterArgs(),
|
||||
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
|
||||
|
||||
// Append everything except the terminator into the fused operation.
|
||||
rewriter.setInsertionPointToStart(fusedLoop.getBody());
|
||||
for (Operation &op : target.getBody()->without_terminator())
|
||||
rewriter.clone(op, mapping);
|
||||
for (Operation &op : source.getBody()->without_terminator())
|
||||
rewriter.clone(op, mapping);
|
||||
|
||||
// Fuse the old terminator in_parallel ops into the new one.
|
||||
scf::InParallelOp targetTerm = target.getTerminator();
|
||||
scf::InParallelOp sourceTerm = source.getTerminator();
|
||||
scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
|
||||
rewriter.setInsertionPointToStart(fusedTerm.getBody());
|
||||
for (Operation &op : targetTerm.getYieldingOps())
|
||||
rewriter.clone(op, mapping);
|
||||
for (Operation &op : sourceTerm.getYieldingOps())
|
||||
rewriter.clone(op, mapping);
|
||||
|
||||
// Replace old loops by substituting their uses by results of the fused loop.
|
||||
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
|
||||
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
|
||||
scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused(
|
||||
target, source, rewriter,
|
||||
[&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
|
||||
// `ForallOp` does not have yields, rather an `InParallelOp` terminator.
|
||||
return ValueRange{};
|
||||
},
|
||||
[&](RewriterBase &b, LoopLikeOpInterface source,
|
||||
LoopLikeOpInterface &target, IRMapping mapping) {
|
||||
auto sourceForall = cast<scf::ForallOp>(source);
|
||||
auto targetForall = cast<scf::ForallOp>(target);
|
||||
scf::InParallelOp fusedTerm = targetForall.getTerminator();
|
||||
b.setInsertionPointToEnd(fusedTerm.getBody());
|
||||
for (Operation &op : sourceForall.getTerminator().getYieldingOps())
|
||||
b.clone(op, mapping);
|
||||
}));
|
||||
rewriter.replaceOp(source,
|
||||
fusedLoop.getResults().take_back(source.getNumResults()));
|
||||
|
||||
return fusedLoop;
|
||||
}
|
||||
@ -1317,49 +1395,74 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
|
||||
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
|
||||
scf::ForOp source,
|
||||
RewriterBase &rewriter) {
|
||||
unsigned numTargetOuts = target.getNumResults();
|
||||
unsigned numSourceOuts = source.getNumResults();
|
||||
scf::ForOp fusedLoop = cast<scf::ForOp>(createFused(
|
||||
target, source, rewriter,
|
||||
[&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
|
||||
return source.getYieldedValues();
|
||||
},
|
||||
[&](RewriterBase &b, LoopLikeOpInterface source,
|
||||
LoopLikeOpInterface &target, IRMapping mapping) {
|
||||
auto targetFor = cast<scf::ForOp>(target);
|
||||
auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping);
|
||||
b.replaceOp(targetFor.getBody()->getTerminator(), newTerm);
|
||||
}));
|
||||
rewriter.replaceOp(source,
|
||||
fusedLoop.getResults().take_back(source.getNumResults()));
|
||||
return fusedLoop;
|
||||
}
|
||||
|
||||
// Create fused init_args, with target's init_args before source's init_args.
|
||||
SmallVector<Value> fusedInitArgs;
|
||||
llvm::append_range(fusedInitArgs, target.getInitArgs());
|
||||
llvm::append_range(fusedInitArgs, source.getInitArgs());
|
||||
// TODO: Finish refactoring this a la the above, but likely requires additional
|
||||
// interface methods.
|
||||
scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
|
||||
scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Block *block1 = target.getBody();
|
||||
Block *block2 = source.getBody();
|
||||
auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
|
||||
auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
|
||||
|
||||
// Create a new scf.for op after the source loop (with scf.yield terminator
|
||||
// (without arguments) only in case its init_args is empty).
|
||||
rewriter.setInsertionPointAfter(source);
|
||||
scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
|
||||
source.getLoc(), source.getLowerBound(), source.getUpperBound(),
|
||||
source.getStep(), fusedInitArgs);
|
||||
ValueRange inits1 = target.getInitVals();
|
||||
ValueRange inits2 = source.getInitVals();
|
||||
|
||||
// Map original induction variables and operands to those of the fused loop.
|
||||
IRMapping mapping;
|
||||
mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
|
||||
mapping.map(target.getRegionIterArgs(),
|
||||
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
|
||||
mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
|
||||
mapping.map(source.getRegionIterArgs(),
|
||||
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
|
||||
SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
|
||||
newInitVars.append(inits2.begin(), inits2.end());
|
||||
|
||||
// Merge target's body into the new (fused) for loop and then source's body.
|
||||
rewriter.setInsertionPointToStart(fusedLoop.getBody());
|
||||
for (Operation &op : target.getBody()->without_terminator())
|
||||
rewriter.clone(op, mapping);
|
||||
for (Operation &op : source.getBody()->without_terminator())
|
||||
rewriter.clone(op, mapping);
|
||||
rewriter.setInsertionPoint(source);
|
||||
auto fusedLoop = rewriter.create<scf::ParallelOp>(
|
||||
rewriter.getFusedLoc(target.getLoc(), source.getLoc()),
|
||||
source.getLowerBound(), source.getUpperBound(), source.getStep(),
|
||||
newInitVars);
|
||||
Block *newBlock = fusedLoop.getBody();
|
||||
rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(),
|
||||
newBlock->getArguments());
|
||||
rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(),
|
||||
newBlock->getArguments());
|
||||
|
||||
// Build fused yield results by appropriately mapping original yield operands.
|
||||
SmallVector<Value> yieldResults;
|
||||
for (Value operand : target.getBody()->getTerminator()->getOperands())
|
||||
yieldResults.push_back(mapping.lookupOrDefault(operand));
|
||||
for (Value operand : source.getBody()->getTerminator()->getOperands())
|
||||
yieldResults.push_back(mapping.lookupOrDefault(operand));
|
||||
if (!yieldResults.empty())
|
||||
rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
|
||||
ValueRange results = fusedLoop.getResults();
|
||||
if (!results.empty()) {
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
|
||||
// Replace old loops by substituting their uses by results of the fused loop.
|
||||
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
|
||||
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
|
||||
ValueRange reduceArgs1 = term1.getOperands();
|
||||
ValueRange reduceArgs2 = term2.getOperands();
|
||||
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
|
||||
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
|
||||
|
||||
auto newReduceOp = rewriter.create<scf::ReduceOp>(
|
||||
rewriter.getFusedLoc(term1.getLoc(), term2.getLoc()), newReduceArgs);
|
||||
|
||||
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
|
||||
term1.getReductions(), term2.getReductions()))) {
|
||||
Block &oldRedBlock = reg.front();
|
||||
Block &newRedBlock = newReduceOp.getReductions()[i].front();
|
||||
rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock,
|
||||
newRedBlock.begin(),
|
||||
newRedBlock.getArguments());
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(target, results.take_front(inits1.size()));
|
||||
rewriter.replaceOp(source, results.take_back(inits2.size()));
|
||||
rewriter.eraseOp(term1);
|
||||
rewriter.eraseOp(term2);
|
||||
|
||||
return fusedLoop;
|
||||
}
|
||||
|
@ -8,6 +8,8 @@
|
||||
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h"
|
||||
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
@ -113,3 +115,60 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target,
|
||||
LoopLikeOpInterface source,
|
||||
RewriterBase &rewriter,
|
||||
NewYieldValuesFn newYieldValuesFn,
|
||||
FuseTerminatorFn fuseTerminatorFn) {
|
||||
auto targetIterArgs = target.getRegionIterArgs();
|
||||
std::optional<SmallVector<Value>> targetInductionVar =
|
||||
target.getLoopInductionVars();
|
||||
SmallVector<Value> targetYieldOperands(target.getYieldedValues());
|
||||
auto sourceIterArgs = source.getRegionIterArgs();
|
||||
std::optional<SmallVector<Value>> sourceInductionVar =
|
||||
*source.getLoopInductionVars();
|
||||
SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
|
||||
auto sourceRegion = source.getLoopRegions().front();
|
||||
|
||||
FailureOr<LoopLikeOpInterface> maybeFusedLoop =
|
||||
target.replaceWithAdditionalYields(rewriter, source.getInits(),
|
||||
/*replaceInitOperandUsesInLoop=*/false,
|
||||
newYieldValuesFn);
|
||||
if (failed(maybeFusedLoop))
|
||||
llvm_unreachable("failed to replace loop");
|
||||
LoopLikeOpInterface fusedLoop = *maybeFusedLoop;
|
||||
// Since the target op is rewritten at the original's location, we move it to
|
||||
// the soure op's location.
|
||||
rewriter.moveOpBefore(fusedLoop, source);
|
||||
|
||||
// Map control operands.
|
||||
IRMapping mapping;
|
||||
std::optional<SmallVector<Value>> fusedInductionVar =
|
||||
fusedLoop.getLoopInductionVars();
|
||||
if (fusedInductionVar) {
|
||||
if (!targetInductionVar || !sourceInductionVar)
|
||||
llvm_unreachable(
|
||||
"expected target and source loops to have induction vars");
|
||||
mapping.map(*targetInductionVar, *fusedInductionVar);
|
||||
mapping.map(*sourceInductionVar, *fusedInductionVar);
|
||||
}
|
||||
mapping.map(targetIterArgs,
|
||||
fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
|
||||
mapping.map(targetYieldOperands,
|
||||
fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
|
||||
mapping.map(sourceIterArgs,
|
||||
fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
|
||||
mapping.map(sourceYieldOperands,
|
||||
fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
|
||||
// Append everything except the terminator into the fused operation.
|
||||
rewriter.setInsertionPoint(
|
||||
fusedLoop.getLoopRegions().front()->front().getTerminator());
|
||||
for (Operation &op : sourceRegion->front().without_terminator())
|
||||
rewriter.clone(op, mapping);
|
||||
|
||||
// TODO: Replace with corresponding interface method if added
|
||||
fuseTerminatorFn(rewriter, source, fusedLoop, mapping);
|
||||
|
||||
return fusedLoop;
|
||||
}
|
||||
|
@ -47,6 +47,169 @@ module attributes {transform.with_named_sequence} {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fuse_two_parallel
|
||||
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
|
||||
func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
|
||||
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c1fp = arith.constant 1.0 : f32
|
||||
// CHECK: [[SUM:%.*]] = memref.alloc()
|
||||
%sum = memref.alloc() : memref<2x2xf32>
|
||||
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
|
||||
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
|
||||
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
|
||||
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK-NOT: scf.parallel
|
||||
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
|
||||
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: scf.reduce
|
||||
// CHECK: }
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
|
||||
%sum_elem = arith.addf %B_elem, %c1fp : f32
|
||||
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
|
||||
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
|
||||
%product_elem = arith.mulf %sum_elem, %A_elem : f32
|
||||
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
// CHECK: memref.dealloc [[SUM]]
|
||||
memref.dealloc %sum : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fuse_two_parallel_reverse
|
||||
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
|
||||
func.func @fuse_two_parallel_reverse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
|
||||
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c1fp = arith.constant 1.0 : f32
|
||||
// CHECK: [[SUM:%.*]] = memref.alloc()
|
||||
%sum = memref.alloc() : memref<2x2xf32>
|
||||
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
|
||||
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
|
||||
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
|
||||
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK-NOT: scf.parallel
|
||||
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
|
||||
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: scf.reduce
|
||||
// CHECK: }
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
|
||||
%sum_elem = arith.addf %B_elem, %c1fp : f32
|
||||
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
|
||||
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
|
||||
%product_elem = arith.mulf %sum_elem, %A_elem : f32
|
||||
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
// CHECK: memref.dealloc [[SUM]]
|
||||
memref.dealloc %sum : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%fused = transform.loop.fuse_sibling %parallel#1 into %parallel#0 : (!transform.any_op,!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fuse_reductions_two
|
||||
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
|
||||
func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
|
||||
// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
|
||||
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
|
||||
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
|
||||
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
|
||||
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
|
||||
// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
|
||||
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
|
||||
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK: scf.reduce.return %[[R]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
|
||||
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK: scf.reduce.return %[[R]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%init1 = arith.constant 1.0 : f32
|
||||
%init2 = arith.constant 2.0 : f32
|
||||
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
|
||||
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
|
||||
scf.reduce(%A_elem : f32) {
|
||||
^bb0(%lhs: f32, %rhs: f32):
|
||||
%1 = arith.addf %lhs, %rhs : f32
|
||||
scf.reduce.return %1 : f32
|
||||
}
|
||||
}
|
||||
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
|
||||
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
|
||||
scf.reduce(%B_elem : f32) {
|
||||
^bb0(%lhs: f32, %rhs: f32):
|
||||
%1 = arith.mulf %lhs, %rhs : f32
|
||||
scf.reduce.return %1 : f32
|
||||
}
|
||||
}
|
||||
return %res1, %res2 : f32, f32
|
||||
}
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
|
||||
func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
|
||||
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
|
||||
@ -208,6 +371,62 @@ module attributes {transform.with_named_sequence} {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 32)
|
||||
#map = affine_map<(d0) -> (d0 * 32)>
|
||||
#map1 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
module {
|
||||
// CHECK: func.func @loop_sibling_fusion(%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}
|
||||
func.func @loop_sibling_fusion(%arg0: tensor<128xf32>, %arg1: tensor<128x128xf16>, %arg2: tensor<128x64xf32>, %arg3: tensor<128x128xf32>) -> (tensor<128xf32>, tensor<128x128xf16>) {
|
||||
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<128x128xf16>
|
||||
// CHECK-NEXT: %[[RESULTS:.*]]:2 = scf.forall (%[[I:.*]]) in (4) shared_outs(%[[S1:.*]] = %[[ARG0]], %[[S2:.*]] = %[[ARG1]]) -> (tensor<128xf32>, tensor<128x128xf16>) {
|
||||
// CHECK-NEXT: %[[IDX:.*]] = affine.apply #[[$MAP]](%[[I]])
|
||||
// CHECK-NEXT: %[[SLICE0:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32>
|
||||
// CHECK-NEXT: %[[SLICE1:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32>
|
||||
// CHECK-NEXT: %[[SLICE2:.*]] = tensor.extract_slice %[[EMPTY]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16>
|
||||
// CHECK-NEXT: %[[GENERIC:.*]] = linalg.generic {{.*}} ins(%[[SLICE1]] : {{.*}}) outs(%[[SLICE2]] : {{.*}})
|
||||
// CHECK: scf.forall.in_parallel {
|
||||
// CHECK-NEXT: tensor.parallel_insert_slice %[[SLICE0]] into %[[S1]][%[[IDX]]] [32] [1] : tensor<32xf32> into tensor<128xf32>
|
||||
// CHECK-NEXT: tensor.parallel_insert_slice %[[GENERIC]] into %[[S2]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } {mapping = [#gpu.warp<linear_dim_0>]}
|
||||
// CHECK-NEXT: return %[[RESULTS]]#0, %[[RESULTS]]#1
|
||||
%0 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg0) -> (tensor<128xf32>) {
|
||||
%3 = affine.apply #map(%arg4)
|
||||
%extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32>
|
||||
scf.forall.in_parallel {
|
||||
tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [32] [1] : tensor<32xf32> into tensor<128xf32>
|
||||
}
|
||||
} {mapping = [#gpu.warp<linear_dim_0>]}
|
||||
%1 = tensor.empty() : tensor<128x128xf16>
|
||||
%2 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg1) -> (tensor<128x128xf16>) {
|
||||
%3 = affine.apply #map(%arg4)
|
||||
%extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32>
|
||||
%extracted_slice_0 = tensor.extract_slice %1[%3, 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16>
|
||||
%4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<32x128xf32>) outs(%extracted_slice_0 : tensor<32x128xf16>) {
|
||||
^bb0(%in: f32, %out: f16):
|
||||
%5 = arith.truncf %in : f32 to f16
|
||||
linalg.yield %5 : f16
|
||||
} -> tensor<32x128xf16>
|
||||
scf.forall.in_parallel {
|
||||
tensor.parallel_insert_slice %4 into %arg5[%3, 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16>
|
||||
}
|
||||
} {mapping = [#gpu.warp<linear_dim_0>]}
|
||||
return %0, %2 : tensor<128xf32>, tensor<128x128xf16>
|
||||
}
|
||||
}
|
||||
|
||||
module attributes { transform.with_named_sequence } {
|
||||
transform.named_sequence @__transform_main(%root: !transform.any_op) {
|
||||
%loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
|
||||
%loop1, %loop2 = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%loop3 = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
|
||||
@ -282,8 +501,9 @@ func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>,
|
||||
%6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
|
||||
scf.yield %6 : tensor<128xf32>
|
||||
}
|
||||
%dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
|
||||
// expected-error @below {{values used inside regions of target should be properly dominated by source}}
|
||||
%dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
|
||||
// expected-note @below {{see operation}}
|
||||
%dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
|
||||
%dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
|
||||
%dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
|
||||
@ -328,6 +548,74 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @non_matching_iteration_spaces_err(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c1fp = arith.constant 1.0 : f32
|
||||
%sum = memref.alloc() : memref<2x2xf32>
|
||||
// expected-error @below {{target and source iteration spaces must be equal}}
|
||||
scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
|
||||
%B_elem = memref.load %B[%i, %c0] : memref<2x2xf32>
|
||||
%sum_elem = arith.addf %B_elem, %c1fp : f32
|
||||
memref.store %sum_elem, %sum[%i, %c0] : memref<2x2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
|
||||
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
|
||||
%product_elem = arith.mulf %sum_elem, %A_elem : f32
|
||||
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
memref.dealloc %sum : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @non_matching_loop_types_err(%A: memref<2xf32>, %B: memref<2xf32>) {
|
||||
%c2 = arith.constant 2 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c1fp = arith.constant 1.0 : f32
|
||||
%sum = memref.alloc() : memref<2xf32>
|
||||
// expected-error @below {{target and source must be same loop type}}
|
||||
scf.for %i = %c0 to %c2 step %c1 {
|
||||
%B_elem = memref.load %B[%i] : memref<2xf32>
|
||||
%sum_elem = arith.addf %B_elem, %c1fp : f32
|
||||
memref.store %sum_elem, %sum[%i] : memref<2xf32>
|
||||
}
|
||||
scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
|
||||
%sum_elem = memref.load %sum[%i] : memref<2xf32>
|
||||
%A_elem = memref.load %A[%i] : memref<2xf32>
|
||||
%product_elem = arith.mulf %sum_elem, %A_elem : f32
|
||||
memref.store %product_elem, %B[%i] : memref<2xf32>
|
||||
scf.reduce
|
||||
}
|
||||
memref.dealloc %sum : memref<2xf32>
|
||||
return
|
||||
}
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%1 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%fused = transform.loop.fuse_sibling %0 into %1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
|
||||
|
Loading…
x
Reference in New Issue
Block a user