2019-12-13 13:26:00 -08:00
|
|
|
//===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
2019-12-13 16:35:49 -08:00
|
|
|
#include "mlir/EDSC/Builders.h"
|
2019-12-13 13:26:00 -08:00
|
|
|
#include "mlir/EDSC/Intrinsics.h"
|
|
|
|
#include "mlir/IR/AffineExpr.h"
|
2019-12-13 16:35:49 -08:00
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/Support/Functional.h"
|
2019-12-13 13:26:00 -08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::edsc;
|
2019-12-13 16:35:49 -08:00
|
|
|
using namespace mlir::edsc::intrinsics;
|
|
|
|
|
|
|
|
static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
|
|
|
|
unsigned &pos) {
|
|
|
|
for (auto sidx : structuredIndices) {
|
|
|
|
for (auto expr : sidx.getExprs()) {
|
|
|
|
expr.walk([&pos](AffineExpr e) {
|
|
|
|
if (auto d = e.dyn_cast<AffineDimExpr>())
|
|
|
|
pos = std::max(pos, d.getPosition());
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2019-12-13 13:26:00 -08:00
|
|
|
|
|
|
|
Operation *mlir::edsc::makeLinalgGenericOp(
|
2019-12-13 16:35:49 -08:00
|
|
|
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
|
|
|
|
ArrayRef<StructuredIndexed> outputs,
|
|
|
|
decltype(defaultRegionBuilder) regionBuilder, ArrayRef<Value *> otherValues,
|
|
|
|
ArrayRef<Attribute> otherAttributes) {
|
2019-12-13 13:26:00 -08:00
|
|
|
auto &builder = edsc::ScopedContext::getBuilder();
|
|
|
|
auto *ctx = builder.getContext();
|
2019-12-13 16:35:49 -08:00
|
|
|
unsigned nInputs = inputs.size();
|
|
|
|
unsigned nOutputs = outputs.size();
|
|
|
|
unsigned rank = 0;
|
|
|
|
getMaxDimIndex(inputs, rank);
|
|
|
|
getMaxDimIndex(outputs, rank);
|
2019-12-13 13:26:00 -08:00
|
|
|
|
|
|
|
SmallVector<AffineMap, 4> maps;
|
2019-12-13 16:35:49 -08:00
|
|
|
maps.reserve(nInputs + nOutputs);
|
|
|
|
for (auto in : inputs)
|
|
|
|
maps.push_back(
|
|
|
|
AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, in.getExprs()));
|
|
|
|
for (auto out : outputs)
|
|
|
|
maps.push_back(
|
|
|
|
AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, out.getExprs()));
|
2019-12-13 13:26:00 -08:00
|
|
|
|
2019-12-13 16:35:49 -08:00
|
|
|
unsigned nViews = nInputs + nOutputs;
|
|
|
|
SmallVector<Value *, 4> values;
|
|
|
|
values.reserve(nViews);
|
|
|
|
values.append(inputs.begin(), inputs.end());
|
|
|
|
values.append(outputs.begin(), outputs.end());
|
2019-12-13 13:26:00 -08:00
|
|
|
|
2019-12-13 16:35:49 -08:00
|
|
|
auto iteratorStrTypes = functional::map(toString, iteratorTypes);
|
|
|
|
// clang-format off
|
2019-12-13 13:26:00 -08:00
|
|
|
auto *op =
|
|
|
|
edsc::ScopedContext::getBuilder()
|
|
|
|
.create<linalg::GenericOp>(
|
2019-12-13 16:35:49 -08:00
|
|
|
edsc::ScopedContext::getLocation(),
|
|
|
|
values,
|
|
|
|
IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
|
|
|
|
IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
|
2019-12-13 13:26:00 -08:00
|
|
|
builder.getAffineMapArrayAttr(maps),
|
2019-12-13 16:35:49 -08:00
|
|
|
builder.getStrArrayAttr(iteratorStrTypes),
|
|
|
|
StringAttr() /*doc*/,
|
|
|
|
FlatSymbolRefAttr() /*fun*/,
|
|
|
|
StringAttr() /*library_call*/
|
|
|
|
/* TODO: other attributes in op */
|
2019-12-13 13:26:00 -08:00
|
|
|
)
|
|
|
|
.getOperation();
|
2019-12-13 16:35:49 -08:00
|
|
|
// clang-format on
|
2019-12-13 13:26:00 -08:00
|
|
|
|
|
|
|
using namespace edsc;
|
|
|
|
SmallVector<Type, 4> blockTypes;
|
2019-12-13 16:35:49 -08:00
|
|
|
blockTypes.reserve(values.size());
|
|
|
|
for (auto it : llvm::enumerate(values))
|
|
|
|
blockTypes.push_back((it.index() < nViews)
|
|
|
|
? getElementTypeOrSelf(it.value())
|
|
|
|
: it.value()->getType());
|
2019-12-13 13:26:00 -08:00
|
|
|
|
|
|
|
assert(op->getRegions().front().empty());
|
|
|
|
op->getRegions().front().push_front(new Block);
|
|
|
|
OpBuilder bb(op->getRegions().front());
|
|
|
|
ScopedContext scope(bb, op->getLoc());
|
|
|
|
BlockHandle b;
|
|
|
|
auto handles = makeValueHandles(blockTypes);
|
|
|
|
BlockBuilder(&b, makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
|
|
|
|
[&] { regionBuilder(b.getBlock()->getArguments()); });
|
|
|
|
return op;
|
|
|
|
}
|
2019-12-13 16:35:49 -08:00
|
|
|
|
|
|
|
using linalg_yield = OperationBuilder<linalg::YieldOp>;
|
|
|
|
|
|
|
|
Operation *mlir::edsc::linalg_matmul(ValueHandle vA, ValueHandle vB,
|
|
|
|
ValueHandle vC) {
|
|
|
|
// clang-format off
|
|
|
|
AffineExpr m, n, k;
|
|
|
|
bindDims(ScopedContext::getContext(), m, n, k);
|
|
|
|
StructuredIndexed A(vA), B(vB), C(vC);
|
|
|
|
return makeLinalgGenericOp(
|
|
|
|
{IterType::Parallel, IterType::Parallel, IterType::Reduction},
|
|
|
|
{A({m, n}), B({k, n})},
|
|
|
|
{C({m, n})},
|
|
|
|
[](ArrayRef<BlockArgument *> args) {
|
|
|
|
using edsc::op::operator*;
|
|
|
|
using edsc::op::operator+;
|
|
|
|
ValueHandle a(args[0]), b(args[1]), c(args[2]);
|
|
|
|
linalg_yield((c + a * b).getValue());
|
|
|
|
});
|
|
|
|
// clang-format on
|
|
|
|
}
|