mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-28 16:06:07 +00:00

Summary: This diff adds support to allow `linalg.generic` and `linalg.indexed_generic` to take tensor input and output arguments. The subset of output tensor operand types must appear verbatim in the result types after an arrow. The parser, printer and verifier are extended to accomodate this behavior. The Linalg operations now support variadic ranked tensor return values. This extension exhibited issues with the current handling of NativeCall in RewriterGen.cpp. As a consequence, an explicit cast to `SmallVector<Value, 4>` is added in the proper place to support the new behavior (better suggestions are welcome). Relevant cleanups and name uniformization are applied. Relevant invalid and roundtrip test are added. Reviewers: mehdi_amini, rriddle, jpienaar, antiagainst, ftynse Subscribers: burmako, shauheen, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72022
248 lines
9.5 KiB
C++
248 lines
9.5 KiB
C++
//===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
|
|
//
|
|
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
|
|
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/EDSC/Builders.h"
|
|
#include "mlir/EDSC/Intrinsics.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/Support/Functional.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::edsc;
|
|
using namespace mlir::edsc::intrinsics;
|
|
using namespace mlir::edsc::ops;
|
|
|
|
static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
|
|
unsigned &pos) {
|
|
for (auto sidx : structuredIndices) {
|
|
for (auto expr : sidx.getExprs()) {
|
|
expr.walk([&pos](AffineExpr e) {
|
|
if (auto d = e.dyn_cast<AffineDimExpr>())
|
|
pos = std::max(pos, d.getPosition());
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
Operation *mlir::edsc::makeGenericLinalgOp(
|
|
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
|
|
ArrayRef<StructuredIndexed> outputs,
|
|
function_ref<void(ArrayRef<BlockArgument>)> regionBuilder,
|
|
ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
|
|
auto &builder = edsc::ScopedContext::getBuilder();
|
|
auto *ctx = builder.getContext();
|
|
unsigned nInputs = inputs.size();
|
|
unsigned nOutputs = outputs.size();
|
|
unsigned maxPos = 0;
|
|
getMaxDimIndex(inputs, maxPos);
|
|
getMaxDimIndex(outputs, maxPos);
|
|
// maxPos is 0 indexed, need to turn this into a count (i.e. +1)
|
|
unsigned nDims = maxPos + 1;
|
|
|
|
SmallVector<AffineMap, 4> maps;
|
|
maps.reserve(nInputs + nOutputs);
|
|
for (auto in : inputs)
|
|
maps.push_back(
|
|
AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs()));
|
|
for (auto out : outputs)
|
|
maps.push_back(
|
|
AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));
|
|
|
|
unsigned nViews = nInputs + nOutputs;
|
|
SmallVector<Value, 4> values;
|
|
values.reserve(nViews);
|
|
values.append(inputs.begin(), inputs.end());
|
|
values.append(outputs.begin(), outputs.end());
|
|
|
|
auto iteratorStrTypes = functional::map(toString, iteratorTypes);
|
|
// clang-format off
|
|
auto *op =
|
|
edsc::ScopedContext::getBuilder()
|
|
.create<linalg::GenericOp>(
|
|
edsc::ScopedContext::getLocation(),
|
|
ArrayRef<Type>{}, // TODO(ntv): support tensors
|
|
values,
|
|
IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
|
|
IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
|
|
builder.getAffineMapArrayAttr(maps),
|
|
builder.getStrArrayAttr(iteratorStrTypes),
|
|
StringAttr() /*doc*/,
|
|
FlatSymbolRefAttr() /*fun*/,
|
|
StringAttr() /*library_call*/
|
|
/* TODO: other attributes in op */
|
|
)
|
|
.getOperation();
|
|
// clang-format on
|
|
|
|
using namespace edsc;
|
|
SmallVector<Type, 4> blockTypes;
|
|
blockTypes.reserve(values.size());
|
|
for (auto it : llvm::enumerate(values))
|
|
blockTypes.push_back((it.index() < nViews)
|
|
? getElementTypeOrSelf(it.value())
|
|
: it.value()->getType());
|
|
|
|
assert(op->getRegions().front().empty());
|
|
op->getRegions().front().push_front(new Block);
|
|
OpBuilder bb(op->getRegions().front());
|
|
ScopedContext scope(bb, op->getLoc());
|
|
BlockHandle b;
|
|
auto handles = makeValueHandles(blockTypes);
|
|
BlockBuilder(&b, makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
|
|
[&] { regionBuilder(b.getBlock()->getArguments()); });
|
|
return op;
|
|
}
|
|
|
|
void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
|
|
using edsc::op::operator+;
|
|
using edsc::op::operator*;
|
|
assert(args.size() == 3 && "expected 3 block arguments");
|
|
ValueHandle a(args[0]), b(args[1]), c(args[2]);
|
|
linalg_yield((c + a * b).getValue());
|
|
}
|
|
|
|
Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
|
|
StructuredIndexed I,
|
|
StructuredIndexed O) {
|
|
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
|
|
edsc::IterType::Parallel);
|
|
auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
|
|
assert(args.size() == 2 && "expected 2 block arguments");
|
|
ValueHandle a(args[0]);
|
|
linalg_yield(unaryOp(a));
|
|
};
|
|
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
|
|
}
|
|
|
|
Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
|
|
StructuredIndexed O) {
|
|
;
|
|
using edsc::intrinsics::tanh;
|
|
UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); });
|
|
return linalg_pointwise(unOp, I, O);
|
|
}
|
|
|
|
/// Binary pointwise operation (with broadcast) entry point.
|
|
Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
|
|
StructuredIndexed I1,
|
|
StructuredIndexed I2,
|
|
StructuredIndexed O) {
|
|
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
|
|
edsc::IterType::Parallel);
|
|
auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
|
|
assert(args.size() == 3 && "expected 3 block arguments");
|
|
ValueHandle a(args[0]), b(args[1]);
|
|
linalg_yield(binaryOp(a, b));
|
|
};
|
|
return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
|
|
}
|
|
|
|
Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1,
|
|
StructuredIndexed I2,
|
|
StructuredIndexed O) {
|
|
using edsc::op::operator+;
|
|
BinaryPointwiseOpBuilder binOp(
|
|
[](ValueHandle a, ValueHandle b) -> Value { return a + b; });
|
|
return linalg_pointwise(binOp, I1, I2, O);
|
|
}
|
|
|
|
Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
|
|
StructuredIndexed I2,
|
|
StructuredIndexed O) {
|
|
BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value {
|
|
using edsc::intrinsics::select;
|
|
using edsc::op::operator>;
|
|
return select(a > b, a, b).getValue();
|
|
});
|
|
return linalg_pointwise(binOp, I1, I2, O);
|
|
}
|
|
|
|
Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
|
|
ValueHandle vC) {
|
|
// clang-format off
|
|
AffineExpr m, n, k;
|
|
bindDims(ScopedContext::getContext(), m, n, k);
|
|
StructuredIndexed A(vA), B(vB), C(vC);
|
|
return makeGenericLinalgOp(
|
|
{IterType::Parallel, IterType::Parallel, IterType::Reduction},
|
|
{A({m, k}), B({k, n})},
|
|
{C({m, n})},
|
|
macRegionBuilder);
|
|
// clang-format on
|
|
}
|
|
|
|
Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
|
|
ValueHandle vO,
|
|
ArrayRef<int> strides,
|
|
ArrayRef<int> dilations) {
|
|
MLIRContext *ctx = ScopedContext::getContext();
|
|
// TODO(ntv) some template magic to make everything rank-polymorphic.
|
|
assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
|
|
assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
|
|
|
|
// Some short names.
|
|
auto par = IterType::Parallel;
|
|
auto red = IterType::Reduction;
|
|
auto s = strides;
|
|
auto d = dilations;
|
|
|
|
AffineExpr b, f, h, w, kh, kw, c;
|
|
bindDims(ctx, b, f, h, w, kh, kw, c);
|
|
unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
|
|
StructuredIndexed I(vI), W(vW), O(vO);
|
|
// clang-format off
|
|
return makeGenericLinalgOp(
|
|
{par, par, par, par, red, red, red}, {
|
|
I({b,
|
|
// Roundtrip to flattened form to serve as canonicalization and ensure
|
|
// consistent ordering of subexpressions.
|
|
simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
|
|
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
|
|
c}),
|
|
W({kh, kw, c, f})}, {
|
|
O({b, h, w, f})},
|
|
macRegionBuilder);
|
|
// clang-format on
|
|
}
|
|
|
|
Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
|
|
ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
|
|
ArrayRef<int> strides, ArrayRef<int> dilations) {
|
|
MLIRContext *ctx = ScopedContext::getContext();
|
|
// TODO(ntv) some template magic to make everything rank-polymorphic.
|
|
assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
|
|
assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
|
|
|
|
// Some short names.
|
|
auto par = IterType::Parallel;
|
|
auto red = IterType::Reduction;
|
|
auto s = strides;
|
|
auto d = dilations;
|
|
|
|
// clang-format off
|
|
AffineExpr b, dm, c, h, w, kh, kw;
|
|
bindDims(ctx, b, dm, c, h, w, kh, kw);
|
|
unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
|
|
StructuredIndexed I(vI), W(vW), O(vO);
|
|
return makeGenericLinalgOp(
|
|
{par, par, par, par, par, red, red}, {
|
|
I({b,
|
|
// Roundtrip to flattened form to serve as canonicalization and ensure
|
|
// consistent ordering of subexpressions.
|
|
simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
|
|
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
|
|
c}),
|
|
W({kh, kw, c, dm})}, {
|
|
O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
|
|
macRegionBuilder);
|
|
// clang-format on
|
|
}
|