mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-28 17:56:06 +00:00
306 lines
11 KiB
C++
306 lines
11 KiB
C++
//===- Transforms.cpp - Implementation of the linalg Transformations ------===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
//
|
|
// This file implements analyses and transformations for the linalg dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "linalg3/Transforms.h"
|
|
#include "linalg2/Intrinsics.h"
|
|
#include "linalg3/Ops.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/StandardTypes.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::edsc;
|
|
using namespace mlir::edsc::intrinsics;
|
|
using namespace linalg;
|
|
using namespace linalg::intrinsics;
|
|
|
|
void linalg::composeSliceOps(mlir::Function *f) {
|
|
f->walk<SliceOp>([](SliceOp sliceOp) {
|
|
auto *sliceResult = sliceOp.getResult();
|
|
auto viewOp = emitAndReturnFullyComposedView(sliceResult);
|
|
sliceResult->replaceAllUsesWith(viewOp.getResult());
|
|
sliceOp.erase();
|
|
});
|
|
}
|
|
|
|
void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) {
|
|
f->walk([](Operation *op) {
|
|
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
|
|
matmulOp.writeAsFinerGrainTensorContraction();
|
|
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
|
|
matvecOp.writeAsFinerGrainTensorContraction();
|
|
} else {
|
|
return;
|
|
}
|
|
op->erase();
|
|
});
|
|
}
|
|
|
|
// Folding eagerly is necessary to abide by affine.for static step requirement.
|
|
// Returns nullptr if folding is not trivially feasible.
|
|
static Value *tryFold(AffineMap map, SmallVector<Value *, 4> operands) {
|
|
assert(map.getNumResults() == 1 && "single result map expected");
|
|
auto expr = map.getResult(0);
|
|
if (auto dim = expr.dyn_cast<AffineDimExpr>())
|
|
return operands[dim.getPosition()];
|
|
if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
|
|
return operands[map.getNumDims() + sym.getPosition()];
|
|
if (auto cst = expr.dyn_cast<AffineConstantExpr>())
|
|
return constant_index(cst.getValue());
|
|
return nullptr;
|
|
}
|
|
|
|
Value *linalg::makeFoldedComposedAffineApply(AffineMap map,
|
|
ArrayRef<Value *> operandsRef) {
|
|
SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
|
|
fullyComposeAffineMapAndOperands(&map, &operands);
|
|
if (auto *v = tryFold(map, operands)) {
|
|
return v;
|
|
}
|
|
auto *b = ScopedContext::getBuilder();
|
|
auto loc = ScopedContext::getLocation();
|
|
return b->create<AffineApplyOp>(loc, map, operands).getResult();
|
|
}
|
|
|
|
linalg::RangeParts::RangeParts(unsigned reserved) {
|
|
mins.reserve(reserved);
|
|
maxes.reserve(reserved);
|
|
steps.reserve(reserved);
|
|
}
|
|
|
|
static SmallVector<Value *, 4>
|
|
extractFromRanges(ArrayRef<Value *> ranges,
|
|
std::function<Value *(RangeOp)> extract) {
|
|
SmallVector<Value *, 4> res;
|
|
res.reserve(ranges.size());
|
|
for (auto *v : ranges) {
|
|
auto r = cast<RangeOp>(v->getDefiningOp());
|
|
res.push_back(extract(r));
|
|
}
|
|
return res;
|
|
}
|
|
|
|
linalg::RangeParts::RangeParts(ArrayRef<Value *> ranges)
|
|
: mins(extractFromRanges(ranges, [](RangeOp r) { return r.getMin(); })),
|
|
maxes(extractFromRanges(ranges, [](RangeOp r) { return r.getMax(); })),
|
|
steps(extractFromRanges(ranges, [](RangeOp r) { return r.getStep(); })) {}
|
|
|
|
SmallVector<Value *, 4> linalg::RangeParts::makeRanges() {
|
|
SmallVector<Value *, 4> res;
|
|
res.reserve(mins.size());
|
|
for (auto z : llvm::zip(mins, maxes, steps)) {
|
|
res.push_back(range(std::get<0>(z), std::get<1>(z), std::get<2>(z)));
|
|
}
|
|
return res;
|
|
}
|
|
|
|
static RangeParts makeGenericRangeParts(AffineMap map,
|
|
ArrayRef<Value *> ranges) {
|
|
assert(map.getNumInputs() == ranges.size());
|
|
unsigned numDims = map.getNumDims();
|
|
assert(map.getNumSymbols() == 0);
|
|
|
|
RangeParts res(map.getNumResults());
|
|
RangeParts rangeParts(ranges);
|
|
for (auto expr : map.getResults()) {
|
|
AffineMap map = AffineMap::get(numDims, 0, expr);
|
|
res.mins.push_back(makeFoldedComposedAffineApply(map, rangeParts.mins));
|
|
res.maxes.push_back(makeFoldedComposedAffineApply(map, rangeParts.maxes));
|
|
res.steps.push_back(makeFoldedComposedAffineApply(map, rangeParts.steps));
|
|
}
|
|
return res;
|
|
}
|
|
|
|
SmallVector<Value *, 4> makeGenericRanges(AffineMap map,
|
|
ArrayRef<Value *> ranges) {
|
|
return makeGenericRangeParts(map, ranges).makeRanges();
|
|
}
|
|
|
|
SmallVector<Value *, 4>
|
|
linalg::makeGenericLoopRanges(AffineMap operandRangesToLoopMaps,
|
|
ArrayRef<Value *> ranges,
|
|
ArrayRef<Value *> tileSizes) {
|
|
RangeParts res = makeGenericRangeParts(operandRangesToLoopMaps, ranges);
|
|
if (tileSizes.empty())
|
|
return res.makeRanges();
|
|
SmallVector<Value *, 4> tiledSteps;
|
|
for (auto z : llvm::zip(res.steps, tileSizes)) {
|
|
auto *step = std::get<0>(z);
|
|
auto tileSize = std::get<1>(z);
|
|
auto stepValue = cast<ConstantIndexOp>(step->getDefiningOp()).getValue();
|
|
auto tileSizeValue =
|
|
cast<ConstantIndexOp>(tileSize->getDefiningOp()).getValue();
|
|
assert(stepValue > 0);
|
|
tiledSteps.push_back(constant_index(stepValue * tileSizeValue));
|
|
}
|
|
res.steps = tiledSteps;
|
|
return res.makeRanges();
|
|
}
|
|
|
|
template <class ContractionOp>
|
|
static SmallVector<mlir::AffineForOp, 4>
|
|
writeContractionAsLoops(ContractionOp contraction) {
|
|
OpBuilder builder(contraction.getOperation());
|
|
ScopedContext scope(builder, contraction.getLoc());
|
|
auto allRanges = getRanges(contraction);
|
|
auto loopRanges =
|
|
makeGenericLoopRanges(operandRangesToLoopsMap(contraction), allRanges);
|
|
|
|
SmallVector<IndexHandle, 4> parallelIvs(contraction.getNumParallelDims());
|
|
SmallVector<IndexHandle, 4> reductionIvs(contraction.getNumReductionDims());
|
|
auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs);
|
|
auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs);
|
|
assert(loopRanges.size() == pivs.size() + rivs.size());
|
|
|
|
// clang-format off
|
|
using linalg::common::LoopNestRangeBuilder;
|
|
ArrayRef<Value *> ranges(loopRanges);
|
|
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&]{
|
|
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))(
|
|
[&contraction, ¶llelIvs, &reductionIvs] {
|
|
SmallVector<mlir::Value *, 4> parallel(
|
|
parallelIvs.begin(), parallelIvs.end());
|
|
SmallVector<mlir::Value *, 4> reduction(
|
|
reductionIvs.begin(), reductionIvs.end());
|
|
contraction.emitScalarImplementation(parallel, reduction);
|
|
});
|
|
});
|
|
// clang-format on
|
|
|
|
// Return the AffineForOp for better compositionality (e.g. tiling).
|
|
SmallVector<mlir::AffineForOp, 4> loops;
|
|
loops.reserve(pivs.size() + rivs.size());
|
|
for (auto iv : parallelIvs)
|
|
loops.push_back(getForInductionVarOwner(iv.getValue()));
|
|
for (auto iv : reductionIvs)
|
|
loops.push_back(getForInductionVarOwner(iv.getValue()));
|
|
|
|
return loops;
|
|
}
|
|
|
|
llvm::Optional<SmallVector<mlir::AffineForOp, 4>>
|
|
linalg::writeAsLoops(Operation *op) {
|
|
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
|
|
return writeContractionAsLoops(matmulOp);
|
|
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
|
|
return writeContractionAsLoops(matvecOp);
|
|
} else if (auto dotOp = dyn_cast<linalg::DotOp>(op)) {
|
|
return writeContractionAsLoops(dotOp);
|
|
}
|
|
return llvm::None;
|
|
}
|
|
|
|
void linalg::lowerToLoops(mlir::Function *f) {
|
|
f->walk([](Operation *op) {
|
|
if (writeAsLoops(op))
|
|
op->erase();
|
|
});
|
|
}
|
|
|
|
/// Emits and returns the standard load and store ops from the view indexings.
|
|
/// If the indexing is of index type, use it as an index to the load/store.
|
|
/// If the indexing is a range, use range.min + indexing as an index to the
|
|
/// load/store.
|
|
template <typename LoadOrStoreOp>
|
|
static SmallVector<Value *, 8>
|
|
emitAndReturnLoadStoreOperands(LoadOrStoreOp loadOrStoreOp, ViewOp viewOp) {
|
|
unsigned storeDim = 0;
|
|
SmallVector<Value *, 8> operands;
|
|
for (auto *indexing : viewOp.getIndexings()) {
|
|
if (indexing->getType().isa<IndexType>()) {
|
|
operands.push_back(indexing);
|
|
continue;
|
|
}
|
|
RangeOp range = cast<RangeOp>(indexing->getDefiningOp());
|
|
ValueHandle min(range.getMin());
|
|
Value *storeIndex = *(loadOrStoreOp.getIndices().begin() + storeDim++);
|
|
using edsc::op::operator+;
|
|
operands.push_back(min + ValueHandle(storeIndex));
|
|
}
|
|
return operands;
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Rewriting linalg::LoadOp and linalg::StoreOp to mlir::LoadOp and
|
|
/// mlir::StoreOp requires finding the proper indexing in the supporting MemRef.
|
|
/// This is most easily achieved by calling emitAndReturnFullyComposedView to
|
|
/// fold away all the SliceOp.
|
|
template <typename LoadOrStoreOpTy>
|
|
struct Rewriter : public OpRewritePattern<LoadOrStoreOpTy> {
|
|
using OpRewritePattern<LoadOrStoreOpTy>::OpRewritePattern;
|
|
|
|
/// Performs the rewrite.
|
|
PatternMatchResult matchAndRewrite(LoadOrStoreOpTy op,
|
|
PatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
struct LowerLinalgLoadStorePass
|
|
: public FunctionPass<LowerLinalgLoadStorePass> {
|
|
void runOnFunction() {
|
|
OwningRewritePatternList patterns;
|
|
auto *context = &getContext();
|
|
patterns.push_back(llvm::make_unique<Rewriter<linalg::LoadOp>>(context));
|
|
patterns.push_back(llvm::make_unique<Rewriter<linalg::StoreOp>>(context));
|
|
applyPatternsGreedily(getFunction(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
template <>
|
|
PatternMatchResult
|
|
Rewriter<linalg::LoadOp>::matchAndRewrite(linalg::LoadOp load,
|
|
PatternRewriter &rewriter) const {
|
|
SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp());
|
|
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
|
|
: cast<ViewOp>(load.getView()->getDefiningOp());
|
|
OpBuilder builder(load);
|
|
ScopedContext scope(builder, load.getLoc());
|
|
auto *memRef = view.getSupportingMemRef();
|
|
auto operands = emitAndReturnLoadStoreOperands(load, view);
|
|
rewriter.replaceOpWithNewOp<mlir::LoadOp>(load, memRef, operands);
|
|
return matchSuccess();
|
|
}
|
|
|
|
template <>
|
|
PatternMatchResult
|
|
Rewriter<linalg::StoreOp>::matchAndRewrite(linalg::StoreOp store,
|
|
PatternRewriter &rewriter) const {
|
|
SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp());
|
|
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
|
|
: cast<ViewOp>(store.getView()->getDefiningOp());
|
|
OpBuilder builder(store);
|
|
ScopedContext scope(builder, store.getLoc());
|
|
auto *valueToStore = store.getValueToStore();
|
|
auto *memRef = view.getSupportingMemRef();
|
|
auto operands = emitAndReturnLoadStoreOperands(store, view);
|
|
rewriter.replaceOpWithNewOp<mlir::StoreOp>(store, valueToStore, memRef,
|
|
operands);
|
|
return matchSuccess();
|
|
}
|
|
} // namespace
|
|
|
|
FunctionPassBase *linalg::createLowerLinalgLoadStorePass() {
|
|
return new LowerLinalgLoadStorePass();
|
|
}
|