2019-06-09 16:17:59 -07:00

306 lines
11 KiB
C++

//===- Transforms.cpp - Implementation of the linalg Transformations ------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements analyses and transformations for the linalg dialect.
//
//===----------------------------------------------------------------------===//
#include "linalg3/Transforms.h"
#include "linalg2/Intrinsics.h"
#include "linalg3/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace linalg;
using namespace linalg::intrinsics;
void linalg::composeSliceOps(mlir::Function *f) {
f->walk<SliceOp>([](SliceOp sliceOp) {
auto *sliceResult = sliceOp.getResult();
auto viewOp = emitAndReturnFullyComposedView(sliceResult);
sliceResult->replaceAllUsesWith(viewOp.getResult());
sliceOp.erase();
});
}
void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) {
f->walk([](Operation *op) {
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
matmulOp.writeAsFinerGrainTensorContraction();
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
matvecOp.writeAsFinerGrainTensorContraction();
} else {
return;
}
op->erase();
});
}
// Folding eagerly is necessary to abide by affine.for static step requirement.
// Returns nullptr if folding is not trivially feasible.
static Value *tryFold(AffineMap map, SmallVector<Value *, 4> operands) {
assert(map.getNumResults() == 1 && "single result map expected");
auto expr = map.getResult(0);
if (auto dim = expr.dyn_cast<AffineDimExpr>())
return operands[dim.getPosition()];
if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
return operands[map.getNumDims() + sym.getPosition()];
if (auto cst = expr.dyn_cast<AffineConstantExpr>())
return constant_index(cst.getValue());
return nullptr;
}
Value *linalg::makeFoldedComposedAffineApply(AffineMap map,
ArrayRef<Value *> operandsRef) {
SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
fullyComposeAffineMapAndOperands(&map, &operands);
if (auto *v = tryFold(map, operands)) {
return v;
}
auto *b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
return b->create<AffineApplyOp>(loc, map, operands).getResult();
}
linalg::RangeParts::RangeParts(unsigned reserved) {
mins.reserve(reserved);
maxes.reserve(reserved);
steps.reserve(reserved);
}
static SmallVector<Value *, 4>
extractFromRanges(ArrayRef<Value *> ranges,
std::function<Value *(RangeOp)> extract) {
SmallVector<Value *, 4> res;
res.reserve(ranges.size());
for (auto *v : ranges) {
auto r = cast<RangeOp>(v->getDefiningOp());
res.push_back(extract(r));
}
return res;
}
linalg::RangeParts::RangeParts(ArrayRef<Value *> ranges)
: mins(extractFromRanges(ranges, [](RangeOp r) { return r.getMin(); })),
maxes(extractFromRanges(ranges, [](RangeOp r) { return r.getMax(); })),
steps(extractFromRanges(ranges, [](RangeOp r) { return r.getStep(); })) {}
SmallVector<Value *, 4> linalg::RangeParts::makeRanges() {
SmallVector<Value *, 4> res;
res.reserve(mins.size());
for (auto z : llvm::zip(mins, maxes, steps)) {
res.push_back(range(std::get<0>(z), std::get<1>(z), std::get<2>(z)));
}
return res;
}
static RangeParts makeGenericRangeParts(AffineMap map,
ArrayRef<Value *> ranges) {
assert(map.getNumInputs() == ranges.size());
unsigned numDims = map.getNumDims();
assert(map.getNumSymbols() == 0);
RangeParts res(map.getNumResults());
RangeParts rangeParts(ranges);
for (auto expr : map.getResults()) {
AffineMap map = AffineMap::get(numDims, 0, expr);
res.mins.push_back(makeFoldedComposedAffineApply(map, rangeParts.mins));
res.maxes.push_back(makeFoldedComposedAffineApply(map, rangeParts.maxes));
res.steps.push_back(makeFoldedComposedAffineApply(map, rangeParts.steps));
}
return res;
}
SmallVector<Value *, 4> makeGenericRanges(AffineMap map,
ArrayRef<Value *> ranges) {
return makeGenericRangeParts(map, ranges).makeRanges();
}
SmallVector<Value *, 4>
linalg::makeGenericLoopRanges(AffineMap operandRangesToLoopMaps,
ArrayRef<Value *> ranges,
ArrayRef<Value *> tileSizes) {
RangeParts res = makeGenericRangeParts(operandRangesToLoopMaps, ranges);
if (tileSizes.empty())
return res.makeRanges();
SmallVector<Value *, 4> tiledSteps;
for (auto z : llvm::zip(res.steps, tileSizes)) {
auto *step = std::get<0>(z);
auto tileSize = std::get<1>(z);
auto stepValue = cast<ConstantIndexOp>(step->getDefiningOp()).getValue();
auto tileSizeValue =
cast<ConstantIndexOp>(tileSize->getDefiningOp()).getValue();
assert(stepValue > 0);
tiledSteps.push_back(constant_index(stepValue * tileSizeValue));
}
res.steps = tiledSteps;
return res.makeRanges();
}
template <class ContractionOp>
static SmallVector<mlir::AffineForOp, 4>
writeContractionAsLoops(ContractionOp contraction) {
OpBuilder builder(contraction.getOperation());
ScopedContext scope(builder, contraction.getLoc());
auto allRanges = getRanges(contraction);
auto loopRanges =
makeGenericLoopRanges(operandRangesToLoopsMap(contraction), allRanges);
SmallVector<IndexHandle, 4> parallelIvs(contraction.getNumParallelDims());
SmallVector<IndexHandle, 4> reductionIvs(contraction.getNumReductionDims());
auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs);
auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs);
assert(loopRanges.size() == pivs.size() + rivs.size());
// clang-format off
using linalg::common::LoopNestRangeBuilder;
ArrayRef<Value *> ranges(loopRanges);
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&]{
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))(
[&contraction, &parallelIvs, &reductionIvs] {
SmallVector<mlir::Value *, 4> parallel(
parallelIvs.begin(), parallelIvs.end());
SmallVector<mlir::Value *, 4> reduction(
reductionIvs.begin(), reductionIvs.end());
contraction.emitScalarImplementation(parallel, reduction);
});
});
// clang-format on
// Return the AffineForOp for better compositionality (e.g. tiling).
SmallVector<mlir::AffineForOp, 4> loops;
loops.reserve(pivs.size() + rivs.size());
for (auto iv : parallelIvs)
loops.push_back(getForInductionVarOwner(iv.getValue()));
for (auto iv : reductionIvs)
loops.push_back(getForInductionVarOwner(iv.getValue()));
return loops;
}
llvm::Optional<SmallVector<mlir::AffineForOp, 4>>
linalg::writeAsLoops(Operation *op) {
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
return writeContractionAsLoops(matmulOp);
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
return writeContractionAsLoops(matvecOp);
} else if (auto dotOp = dyn_cast<linalg::DotOp>(op)) {
return writeContractionAsLoops(dotOp);
}
return llvm::None;
}
void linalg::lowerToLoops(mlir::Function *f) {
f->walk([](Operation *op) {
if (writeAsLoops(op))
op->erase();
});
}
/// Emits and returns the standard load and store ops from the view indexings.
/// If the indexing is of index type, use it as an index to the load/store.
/// If the indexing is a range, use range.min + indexing as an index to the
/// load/store.
template <typename LoadOrStoreOp>
static SmallVector<Value *, 8>
emitAndReturnLoadStoreOperands(LoadOrStoreOp loadOrStoreOp, ViewOp viewOp) {
unsigned storeDim = 0;
SmallVector<Value *, 8> operands;
for (auto *indexing : viewOp.getIndexings()) {
if (indexing->getType().isa<IndexType>()) {
operands.push_back(indexing);
continue;
}
RangeOp range = cast<RangeOp>(indexing->getDefiningOp());
ValueHandle min(range.getMin());
Value *storeIndex = *(loadOrStoreOp.getIndices().begin() + storeDim++);
using edsc::op::operator+;
operands.push_back(min + ValueHandle(storeIndex));
}
return operands;
}
namespace {
/// Rewriting linalg::LoadOp and linalg::StoreOp to mlir::LoadOp and
/// mlir::StoreOp requires finding the proper indexing in the supporting MemRef.
/// This is most easily achieved by calling emitAndReturnFullyComposedView to
/// fold away all the SliceOp.
template <typename LoadOrStoreOpTy>
struct Rewriter : public OpRewritePattern<LoadOrStoreOpTy> {
using OpRewritePattern<LoadOrStoreOpTy>::OpRewritePattern;
/// Performs the rewrite.
PatternMatchResult matchAndRewrite(LoadOrStoreOpTy op,
PatternRewriter &rewriter) const override;
};
struct LowerLinalgLoadStorePass
: public FunctionPass<LowerLinalgLoadStorePass> {
void runOnFunction() {
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<Rewriter<linalg::LoadOp>>(context));
patterns.push_back(llvm::make_unique<Rewriter<linalg::StoreOp>>(context));
applyPatternsGreedily(getFunction(), std::move(patterns));
}
};
template <>
PatternMatchResult
Rewriter<linalg::LoadOp>::matchAndRewrite(linalg::LoadOp load,
PatternRewriter &rewriter) const {
SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp());
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
: cast<ViewOp>(load.getView()->getDefiningOp());
OpBuilder builder(load);
ScopedContext scope(builder, load.getLoc());
auto *memRef = view.getSupportingMemRef();
auto operands = emitAndReturnLoadStoreOperands(load, view);
rewriter.replaceOpWithNewOp<mlir::LoadOp>(load, memRef, operands);
return matchSuccess();
}
template <>
PatternMatchResult
Rewriter<linalg::StoreOp>::matchAndRewrite(linalg::StoreOp store,
PatternRewriter &rewriter) const {
SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp());
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
: cast<ViewOp>(store.getView()->getDefiningOp());
OpBuilder builder(store);
ScopedContext scope(builder, store.getLoc());
auto *valueToStore = store.getValueToStore();
auto *memRef = view.getSupportingMemRef();
auto operands = emitAndReturnLoadStoreOperands(store, view);
rewriter.replaceOpWithNewOp<mlir::StoreOp>(store, valueToStore, memRef,
operands);
return matchSuccess();
}
} // namespace
FunctionPassBase *linalg::createLowerLinalgLoadStorePass() {
return new LowerLinalgLoadStorePass();
}