Nicolas Vasilache 2140a973f2 [mlir][Linalg] Extend generic ops to allow tensors
Summary:
    This diff adds support to allow `linalg.generic` and
    `linalg.indexed_generic` to take tensor input and output
    arguments.

    The subset of output tensor operand types must appear
    verbatim in the result types after an arrow. The parser,
    printer and verifier are extended to accomodate this
    behavior.

    The Linalg operations now support variadic ranked tensor
    return values. This extension exhibited issues with the
    current handling of NativeCall in RewriterGen.cpp. As a
    consequence, an explicit cast to `SmallVector<Value, 4>`
    is added in the proper place to support the new behavior
    (better suggestions are welcome).

    Relevant cleanups and name uniformization are applied.

    Relevant invalid and roundtrip test are added.

    Reviewers: mehdi_amini, rriddle, jpienaar, antiagainst, ftynse

    Subscribers: burmako, shauheen, llvm-commits

    Tags: #llvm

    Differential Revision: https://reviews.llvm.org/D72022
2020-01-02 13:54:57 -05:00

248 lines
9.5 KiB
C++

//===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/Support/Functional.h"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::edsc::ops;
static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
unsigned &pos) {
for (auto sidx : structuredIndices) {
for (auto expr : sidx.getExprs()) {
expr.walk([&pos](AffineExpr e) {
if (auto d = e.dyn_cast<AffineDimExpr>())
pos = std::max(pos, d.getPosition());
});
}
}
}
Operation *mlir::edsc::makeGenericLinalgOp(
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
ArrayRef<StructuredIndexed> outputs,
function_ref<void(ArrayRef<BlockArgument>)> regionBuilder,
ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
auto &builder = edsc::ScopedContext::getBuilder();
auto *ctx = builder.getContext();
unsigned nInputs = inputs.size();
unsigned nOutputs = outputs.size();
unsigned maxPos = 0;
getMaxDimIndex(inputs, maxPos);
getMaxDimIndex(outputs, maxPos);
// maxPos is 0 indexed, need to turn this into a count (i.e. +1)
unsigned nDims = maxPos + 1;
SmallVector<AffineMap, 4> maps;
maps.reserve(nInputs + nOutputs);
for (auto in : inputs)
maps.push_back(
AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs()));
for (auto out : outputs)
maps.push_back(
AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));
unsigned nViews = nInputs + nOutputs;
SmallVector<Value, 4> values;
values.reserve(nViews);
values.append(inputs.begin(), inputs.end());
values.append(outputs.begin(), outputs.end());
auto iteratorStrTypes = functional::map(toString, iteratorTypes);
// clang-format off
auto *op =
edsc::ScopedContext::getBuilder()
.create<linalg::GenericOp>(
edsc::ScopedContext::getLocation(),
ArrayRef<Type>{}, // TODO(ntv): support tensors
values,
IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
builder.getAffineMapArrayAttr(maps),
builder.getStrArrayAttr(iteratorStrTypes),
StringAttr() /*doc*/,
FlatSymbolRefAttr() /*fun*/,
StringAttr() /*library_call*/
/* TODO: other attributes in op */
)
.getOperation();
// clang-format on
using namespace edsc;
SmallVector<Type, 4> blockTypes;
blockTypes.reserve(values.size());
for (auto it : llvm::enumerate(values))
blockTypes.push_back((it.index() < nViews)
? getElementTypeOrSelf(it.value())
: it.value()->getType());
assert(op->getRegions().front().empty());
op->getRegions().front().push_front(new Block);
OpBuilder bb(op->getRegions().front());
ScopedContext scope(bb, op->getLoc());
BlockHandle b;
auto handles = makeValueHandles(blockTypes);
BlockBuilder(&b, makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
[&] { regionBuilder(b.getBlock()->getArguments()); });
return op;
}
void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
using edsc::op::operator+;
using edsc::op::operator*;
assert(args.size() == 3 && "expected 3 block arguments");
ValueHandle a(args[0]), b(args[1]), c(args[2]);
linalg_yield((c + a * b).getValue());
}
Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
StructuredIndexed I,
StructuredIndexed O) {
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
edsc::IterType::Parallel);
auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
assert(args.size() == 2 && "expected 2 block arguments");
ValueHandle a(args[0]);
linalg_yield(unaryOp(a));
};
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
}
Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
StructuredIndexed O) {
;
using edsc::intrinsics::tanh;
UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); });
return linalg_pointwise(unOp, I, O);
}
/// Binary pointwise operation (with broadcast) entry point.
Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O) {
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
edsc::IterType::Parallel);
auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
assert(args.size() == 3 && "expected 3 block arguments");
ValueHandle a(args[0]), b(args[1]);
linalg_yield(binaryOp(a, b));
};
return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
}
Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O) {
using edsc::op::operator+;
BinaryPointwiseOpBuilder binOp(
[](ValueHandle a, ValueHandle b) -> Value { return a + b; });
return linalg_pointwise(binOp, I1, I2, O);
}
Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O) {
BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value {
using edsc::intrinsics::select;
using edsc::op::operator>;
return select(a > b, a, b).getValue();
});
return linalg_pointwise(binOp, I1, I2, O);
}
Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
ValueHandle vC) {
// clang-format off
AffineExpr m, n, k;
bindDims(ScopedContext::getContext(), m, n, k);
StructuredIndexed A(vA), B(vB), C(vC);
return makeGenericLinalgOp(
{IterType::Parallel, IterType::Parallel, IterType::Reduction},
{A({m, k}), B({k, n})},
{C({m, n})},
macRegionBuilder);
// clang-format on
}
Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
ValueHandle vO,
ArrayRef<int> strides,
ArrayRef<int> dilations) {
MLIRContext *ctx = ScopedContext::getContext();
// TODO(ntv) some template magic to make everything rank-polymorphic.
assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
// Some short names.
auto par = IterType::Parallel;
auto red = IterType::Reduction;
auto s = strides;
auto d = dilations;
AffineExpr b, f, h, w, kh, kw, c;
bindDims(ctx, b, f, h, w, kh, kw, c);
unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
StructuredIndexed I(vI), W(vW), O(vO);
// clang-format off
return makeGenericLinalgOp(
{par, par, par, par, red, red, red}, {
I({b,
// Roundtrip to flattened form to serve as canonicalization and ensure
// consistent ordering of subexpressions.
simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
c}),
W({kh, kw, c, f})}, {
O({b, h, w, f})},
macRegionBuilder);
// clang-format on
}
Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
ArrayRef<int> strides, ArrayRef<int> dilations) {
MLIRContext *ctx = ScopedContext::getContext();
// TODO(ntv) some template magic to make everything rank-polymorphic.
assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
// Some short names.
auto par = IterType::Parallel;
auto red = IterType::Reduction;
auto s = strides;
auto d = dilations;
// clang-format off
AffineExpr b, dm, c, h, w, kh, kw;
bindDims(ctx, b, dm, c, h, w, kh, kw);
unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
StructuredIndexed I(vI), W(vW), O(vO);
return makeGenericLinalgOp(
{par, par, par, par, par, red, red}, {
I({b,
// Roundtrip to flattened form to serve as canonicalization and ensure
// consistent ordering of subexpressions.
simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
c}),
W({kh, kw, c, dm})}, {
O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
macRegionBuilder);
// clang-format on
}