2019-12-13 13:26:00 -08:00
|
|
|
//===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
|
|
|
|
//
|
2019-12-23 09:35:36 -08:00
|
|
|
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2019-12-13 13:26:00 -08:00
|
|
|
//
|
2019-12-23 09:35:36 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-12-13 13:26:00 -08:00
|
|
|
|
|
|
|
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
|
2019-12-16 13:32:02 -08:00
|
|
|
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
|
2019-12-13 13:26:00 -08:00
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
2019-12-13 16:35:49 -08:00
|
|
|
#include "mlir/EDSC/Builders.h"
|
2019-12-13 13:26:00 -08:00
|
|
|
#include "mlir/EDSC/Intrinsics.h"
|
|
|
|
#include "mlir/IR/AffineExpr.h"
|
2019-12-13 16:35:49 -08:00
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/Support/Functional.h"
|
2019-12-13 13:26:00 -08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::edsc;
|
2019-12-13 16:35:49 -08:00
|
|
|
using namespace mlir::edsc::intrinsics;
|
2019-12-16 13:32:02 -08:00
|
|
|
using namespace mlir::edsc::ops;
|
[mlir][EDSC] Refactor dependencies involving EDSCs.
Summary: This diff removes the dependency of LinalgOps and VectorOps on EDSCs.
Reviewers: jpienaar, ftynse
Reviewed By: ftynse
Subscribers: merge_guards_bot, mgorny, mehdi_amini, rriddle, burmako, shauheen, antiagainst, csigg, arpith-jacob, mgester, lucyrfox, herhut, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72481
2020-01-15 09:28:12 -05:00
|
|
|
using namespace mlir::linalg;
|
|
|
|
using namespace mlir::loop;
|
|
|
|
|
|
|
|
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv,
|
|
|
|
ValueHandle range) {
|
|
|
|
assert(range.getType() && "expected !linalg.range type");
|
|
|
|
assert(range.getValue().getDefiningOp() &&
|
|
|
|
"need operations to extract range parts");
|
|
|
|
auto rangeOp = cast<RangeOp>(range.getValue().getDefiningOp());
|
|
|
|
auto lb = rangeOp.min();
|
|
|
|
auto ub = rangeOp.max();
|
|
|
|
auto step = rangeOp.step();
|
|
|
|
auto forOp = OperationHandle::createOp<ForOp>(lb, ub, step);
|
|
|
|
*iv = ValueHandle(forOp.getInductionVar());
|
|
|
|
auto *body = forOp.getBody();
|
|
|
|
enter(body, /*prev=*/1);
|
|
|
|
}
|
|
|
|
|
|
|
|
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv,
|
|
|
|
SubViewOp::Range range) {
|
|
|
|
auto forOp =
|
|
|
|
OperationHandle::createOp<ForOp>(range.offset, range.size, range.stride);
|
|
|
|
*iv = ValueHandle(forOp.getInductionVar());
|
|
|
|
auto *body = forOp.getBody();
|
|
|
|
enter(body, /*prev=*/1);
|
|
|
|
}
|
|
|
|
|
|
|
|
ValueHandle mlir::edsc::LoopRangeBuilder::
|
|
|
|
operator()(std::function<void(void)> fun) {
|
|
|
|
if (fun)
|
|
|
|
fun();
|
|
|
|
exit();
|
|
|
|
return ValueHandle::null();
|
|
|
|
}
|
|
|
|
|
|
|
|
mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
|
|
|
ArrayRef<ValueHandle *> ivs, ArrayRef<SubViewOp::Range> ranges) {
|
|
|
|
loops.reserve(ranges.size());
|
|
|
|
for (unsigned i = 0, e = ranges.size(); i < e; ++i) {
|
|
|
|
loops.emplace_back(ivs[i], ranges[i]);
|
|
|
|
}
|
|
|
|
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
|
|
|
|
}
|
|
|
|
|
|
|
|
mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
|
|
|
ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> ranges) {
|
|
|
|
loops.reserve(ranges.size());
|
|
|
|
for (unsigned i = 0, e = ranges.size(); i < e; ++i) {
|
|
|
|
loops.emplace_back(ivs[i], ranges[i]);
|
|
|
|
}
|
|
|
|
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
|
|
|
|
}
|
|
|
|
|
|
|
|
mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
|
|
|
ArrayRef<ValueHandle *> ivs, ArrayRef<Value> ranges)
|
|
|
|
: LoopNestRangeBuilder(
|
|
|
|
ivs, SmallVector<ValueHandle, 4>(ranges.begin(), ranges.end())) {}
|
|
|
|
|
|
|
|
ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::
|
|
|
|
operator()(std::function<void(void)> fun) {
|
|
|
|
if (fun)
|
|
|
|
fun();
|
|
|
|
for (auto &lit : reverse(loops)) {
|
|
|
|
lit({});
|
|
|
|
}
|
|
|
|
return ValueHandle::null();
|
|
|
|
}
|
|
|
|
|
2020-01-15 11:12:53 -05:00
|
|
|
namespace mlir {
|
|
|
|
namespace edsc {
|
[mlir][EDSC] Refactor dependencies involving EDSCs.
Summary: This diff removes the dependency of LinalgOps and VectorOps on EDSCs.
Reviewers: jpienaar, ftynse
Reviewed By: ftynse
Subscribers: merge_guards_bot, mgorny, mehdi_amini, rriddle, burmako, shauheen, antiagainst, csigg, arpith-jacob, mgester, lucyrfox, herhut, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72481
2020-01-15 09:28:12 -05:00
|
|
|
template <>
|
2020-01-15 11:12:53 -05:00
|
|
|
GenericLoopNestRangeBuilder<loop::ForOp>::GenericLoopNestRangeBuilder(
|
|
|
|
ArrayRef<edsc::ValueHandle *> ivs, ArrayRef<Value> ranges) {
|
[mlir][EDSC] Refactor dependencies involving EDSCs.
Summary: This diff removes the dependency of LinalgOps and VectorOps on EDSCs.
Reviewers: jpienaar, ftynse
Reviewed By: ftynse
Subscribers: merge_guards_bot, mgorny, mehdi_amini, rriddle, burmako, shauheen, antiagainst, csigg, arpith-jacob, mgester, lucyrfox, herhut, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72481
2020-01-15 09:28:12 -05:00
|
|
|
builder = std::make_unique<LoopNestRangeBuilder>(ivs, ranges);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
2020-01-15 11:12:53 -05:00
|
|
|
GenericLoopNestRangeBuilder<AffineForOp>::GenericLoopNestRangeBuilder(
|
|
|
|
ArrayRef<ValueHandle *> ivs, ArrayRef<Value> ranges) {
|
[mlir][EDSC] Refactor dependencies involving EDSCs.
Summary: This diff removes the dependency of LinalgOps and VectorOps on EDSCs.
Reviewers: jpienaar, ftynse
Reviewed By: ftynse
Subscribers: merge_guards_bot, mgorny, mehdi_amini, rriddle, burmako, shauheen, antiagainst, csigg, arpith-jacob, mgester, lucyrfox, herhut, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72481
2020-01-15 09:28:12 -05:00
|
|
|
SmallVector<ValueHandle, 4> lbs;
|
|
|
|
SmallVector<ValueHandle, 4> ubs;
|
|
|
|
SmallVector<int64_t, 4> steps;
|
|
|
|
for (Value range : ranges) {
|
|
|
|
assert(range.getType() && "expected linalg.range type");
|
|
|
|
assert(range.getDefiningOp() && "need operations to extract range parts");
|
|
|
|
RangeOp rangeOp = cast<RangeOp>(range.getDefiningOp());
|
|
|
|
lbs.emplace_back(ValueHandle(rangeOp.min()));
|
|
|
|
ubs.emplace_back(ValueHandle(rangeOp.max()));
|
|
|
|
steps.emplace_back(ValueHandle(rangeOp.step()));
|
|
|
|
}
|
|
|
|
builder = std::make_unique<AffineLoopNestBuilder>(ivs, lbs, ubs, steps);
|
|
|
|
}
|
2020-01-15 11:12:53 -05:00
|
|
|
} // namespace edsc
|
|
|
|
} // namespace mlir
|
2019-12-13 16:35:49 -08:00
|
|
|
|
|
|
|
static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
|
|
|
|
unsigned &pos) {
|
|
|
|
for (auto sidx : structuredIndices) {
|
|
|
|
for (auto expr : sidx.getExprs()) {
|
|
|
|
expr.walk([&pos](AffineExpr e) {
|
|
|
|
if (auto d = e.dyn_cast<AffineDimExpr>())
|
|
|
|
pos = std::max(pos, d.getPosition());
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2019-12-13 13:26:00 -08:00
|
|
|
|
2020-01-02 09:14:23 -05:00
|
|
|
Operation *mlir::edsc::makeGenericLinalgOp(
|
2019-12-13 16:35:49 -08:00
|
|
|
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
|
|
|
|
ArrayRef<StructuredIndexed> outputs,
|
2019-12-23 14:45:01 -08:00
|
|
|
function_ref<void(ArrayRef<BlockArgument>)> regionBuilder,
|
|
|
|
ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
|
2019-12-13 13:26:00 -08:00
|
|
|
auto &builder = edsc::ScopedContext::getBuilder();
|
|
|
|
auto *ctx = builder.getContext();
|
2019-12-13 16:35:49 -08:00
|
|
|
unsigned nInputs = inputs.size();
|
|
|
|
unsigned nOutputs = outputs.size();
|
2019-12-16 13:32:02 -08:00
|
|
|
unsigned maxPos = 0;
|
|
|
|
getMaxDimIndex(inputs, maxPos);
|
|
|
|
getMaxDimIndex(outputs, maxPos);
|
|
|
|
// maxPos is 0 indexed, need to turn this into a count (i.e. +1)
|
|
|
|
unsigned nDims = maxPos + 1;
|
2019-12-13 13:26:00 -08:00
|
|
|
|
|
|
|
SmallVector<AffineMap, 4> maps;
|
2019-12-13 16:35:49 -08:00
|
|
|
maps.reserve(nInputs + nOutputs);
|
|
|
|
for (auto in : inputs)
|
|
|
|
maps.push_back(
|
2019-12-16 13:32:02 -08:00
|
|
|
AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs()));
|
2019-12-13 16:35:49 -08:00
|
|
|
for (auto out : outputs)
|
|
|
|
maps.push_back(
|
2019-12-16 13:32:02 -08:00
|
|
|
AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));
|
2019-12-13 13:26:00 -08:00
|
|
|
|
2019-12-13 16:35:49 -08:00
|
|
|
unsigned nViews = nInputs + nOutputs;
|
2019-12-23 14:45:01 -08:00
|
|
|
SmallVector<Value, 4> values;
|
2019-12-13 16:35:49 -08:00
|
|
|
values.reserve(nViews);
|
|
|
|
values.append(inputs.begin(), inputs.end());
|
|
|
|
values.append(outputs.begin(), outputs.end());
|
2019-12-13 13:26:00 -08:00
|
|
|
|
2019-12-13 16:35:49 -08:00
|
|
|
auto iteratorStrTypes = functional::map(toString, iteratorTypes);
|
|
|
|
// clang-format off
|
2019-12-13 13:26:00 -08:00
|
|
|
auto *op =
|
|
|
|
edsc::ScopedContext::getBuilder()
|
|
|
|
.create<linalg::GenericOp>(
|
2019-12-13 16:35:49 -08:00
|
|
|
edsc::ScopedContext::getLocation(),
|
2019-12-31 09:28:18 -05:00
|
|
|
ArrayRef<Type>{}, // TODO(ntv): support tensors
|
2019-12-13 16:35:49 -08:00
|
|
|
values,
|
|
|
|
IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
|
|
|
|
IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
|
2019-12-13 13:26:00 -08:00
|
|
|
builder.getAffineMapArrayAttr(maps),
|
2019-12-13 16:35:49 -08:00
|
|
|
builder.getStrArrayAttr(iteratorStrTypes),
|
|
|
|
StringAttr() /*doc*/,
|
|
|
|
FlatSymbolRefAttr() /*fun*/,
|
|
|
|
StringAttr() /*library_call*/
|
|
|
|
/* TODO: other attributes in op */
|
2019-12-13 13:26:00 -08:00
|
|
|
)
|
|
|
|
.getOperation();
|
2019-12-13 16:35:49 -08:00
|
|
|
// clang-format on
|
2019-12-13 13:26:00 -08:00
|
|
|
|
|
|
|
using namespace edsc;
|
|
|
|
SmallVector<Type, 4> blockTypes;
|
2019-12-13 16:35:49 -08:00
|
|
|
blockTypes.reserve(values.size());
|
|
|
|
for (auto it : llvm::enumerate(values))
|
|
|
|
blockTypes.push_back((it.index() < nViews)
|
|
|
|
? getElementTypeOrSelf(it.value())
|
2020-01-11 08:54:04 -08:00
|
|
|
: it.value().getType());
|
2019-12-13 13:26:00 -08:00
|
|
|
|
|
|
|
assert(op->getRegions().front().empty());
|
|
|
|
op->getRegions().front().push_front(new Block);
|
|
|
|
OpBuilder bb(op->getRegions().front());
|
|
|
|
ScopedContext scope(bb, op->getLoc());
|
|
|
|
BlockHandle b;
|
|
|
|
auto handles = makeValueHandles(blockTypes);
|
|
|
|
BlockBuilder(&b, makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
|
|
|
|
[&] { regionBuilder(b.getBlock()->getArguments()); });
|
|
|
|
return op;
|
|
|
|
}
|
2019-12-13 16:35:49 -08:00
|
|
|
|
2019-12-23 14:45:01 -08:00
|
|
|
void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
|
2019-12-16 13:32:02 -08:00
|
|
|
using edsc::op::operator+;
|
|
|
|
using edsc::op::operator*;
|
|
|
|
assert(args.size() == 3 && "expected 3 block arguments");
|
|
|
|
ValueHandle a(args[0]), b(args[1]), c(args[2]);
|
|
|
|
linalg_yield((c + a * b).getValue());
|
|
|
|
}
|
|
|
|
|
|
|
|
Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
|
|
|
|
StructuredIndexed I,
|
|
|
|
StructuredIndexed O) {
|
|
|
|
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
|
|
|
|
edsc::IterType::Parallel);
|
2019-12-23 14:45:01 -08:00
|
|
|
auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
|
2019-12-16 13:32:02 -08:00
|
|
|
assert(args.size() == 2 && "expected 2 block arguments");
|
|
|
|
ValueHandle a(args[0]);
|
|
|
|
linalg_yield(unaryOp(a));
|
|
|
|
};
|
2020-01-02 09:14:23 -05:00
|
|
|
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
|
2019-12-16 13:32:02 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
|
|
|
|
StructuredIndexed O) {
|
|
|
|
;
|
|
|
|
using edsc::intrinsics::tanh;
|
2019-12-23 14:45:01 -08:00
|
|
|
UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); });
|
2019-12-16 13:32:02 -08:00
|
|
|
return linalg_pointwise(unOp, I, O);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Binary pointwise operation (with broadcast) entry point.
|
|
|
|
Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
|
|
|
|
StructuredIndexed I1,
|
|
|
|
StructuredIndexed I2,
|
|
|
|
StructuredIndexed O) {
|
|
|
|
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
|
|
|
|
edsc::IterType::Parallel);
|
2019-12-23 14:45:01 -08:00
|
|
|
auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
|
2019-12-16 13:32:02 -08:00
|
|
|
assert(args.size() == 3 && "expected 3 block arguments");
|
|
|
|
ValueHandle a(args[0]), b(args[1]);
|
|
|
|
linalg_yield(binaryOp(a, b));
|
|
|
|
};
|
2020-01-02 09:14:23 -05:00
|
|
|
return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
|
2019-12-16 13:32:02 -08:00
|
|
|
}
|
2019-12-13 16:35:49 -08:00
|
|
|
|
2019-12-16 13:32:02 -08:00
|
|
|
Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1,
|
|
|
|
StructuredIndexed I2,
|
|
|
|
StructuredIndexed O) {
|
|
|
|
using edsc::op::operator+;
|
|
|
|
BinaryPointwiseOpBuilder binOp(
|
2019-12-23 14:45:01 -08:00
|
|
|
[](ValueHandle a, ValueHandle b) -> Value { return a + b; });
|
2019-12-16 13:32:02 -08:00
|
|
|
return linalg_pointwise(binOp, I1, I2, O);
|
|
|
|
}
|
|
|
|
|
|
|
|
Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
|
|
|
|
StructuredIndexed I2,
|
|
|
|
StructuredIndexed O) {
|
2019-12-23 14:45:01 -08:00
|
|
|
BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value {
|
2019-12-16 13:32:02 -08:00
|
|
|
using edsc::intrinsics::select;
|
|
|
|
using edsc::op::operator>;
|
|
|
|
return select(a > b, a, b).getValue();
|
|
|
|
});
|
|
|
|
return linalg_pointwise(binOp, I1, I2, O);
|
|
|
|
}
|
|
|
|
|
|
|
|
Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
|
|
|
|
ValueHandle vC) {
|
2019-12-13 16:35:49 -08:00
|
|
|
// clang-format off
|
|
|
|
AffineExpr m, n, k;
|
|
|
|
bindDims(ScopedContext::getContext(), m, n, k);
|
|
|
|
StructuredIndexed A(vA), B(vB), C(vC);
|
2020-01-02 09:14:23 -05:00
|
|
|
return makeGenericLinalgOp(
|
2019-12-13 16:35:49 -08:00
|
|
|
{IterType::Parallel, IterType::Parallel, IterType::Reduction},
|
2019-12-16 13:32:02 -08:00
|
|
|
{A({m, k}), B({k, n})},
|
2019-12-13 16:35:49 -08:00
|
|
|
{C({m, n})},
|
2019-12-16 13:32:02 -08:00
|
|
|
macRegionBuilder);
|
|
|
|
// clang-format on
|
|
|
|
}
|
|
|
|
|
|
|
|
Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
|
|
|
|
ValueHandle vO,
|
|
|
|
ArrayRef<int> strides,
|
|
|
|
ArrayRef<int> dilations) {
|
|
|
|
MLIRContext *ctx = ScopedContext::getContext();
|
|
|
|
// TODO(ntv) some template magic to make everything rank-polymorphic.
|
|
|
|
assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
|
|
|
|
assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
|
|
|
|
|
|
|
|
// Some short names.
|
|
|
|
auto par = IterType::Parallel;
|
|
|
|
auto red = IterType::Reduction;
|
|
|
|
auto s = strides;
|
|
|
|
auto d = dilations;
|
|
|
|
|
|
|
|
AffineExpr b, f, h, w, kh, kw, c;
|
|
|
|
bindDims(ctx, b, f, h, w, kh, kw, c);
|
|
|
|
unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
|
|
|
|
StructuredIndexed I(vI), W(vW), O(vO);
|
|
|
|
// clang-format off
|
2020-01-02 09:14:23 -05:00
|
|
|
return makeGenericLinalgOp(
|
2019-12-16 13:32:02 -08:00
|
|
|
{par, par, par, par, red, red, red}, {
|
|
|
|
I({b,
|
|
|
|
// Roundtrip to flattened form to serve as canonicalization and ensure
|
|
|
|
// consistent ordering of subexpressions.
|
|
|
|
simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
|
|
|
|
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
|
|
|
|
c}),
|
|
|
|
W({kh, kw, c, f})}, {
|
|
|
|
O({b, h, w, f})},
|
|
|
|
macRegionBuilder);
|
|
|
|
// clang-format on
|
|
|
|
}
|
|
|
|
|
|
|
|
Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
|
|
|
|
ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
|
|
|
|
ArrayRef<int> strides, ArrayRef<int> dilations) {
|
|
|
|
MLIRContext *ctx = ScopedContext::getContext();
|
|
|
|
// TODO(ntv) some template magic to make everything rank-polymorphic.
|
|
|
|
assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
|
|
|
|
assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
|
|
|
|
|
|
|
|
// Some short names.
|
|
|
|
auto par = IterType::Parallel;
|
|
|
|
auto red = IterType::Reduction;
|
|
|
|
auto s = strides;
|
|
|
|
auto d = dilations;
|
|
|
|
|
|
|
|
// clang-format off
|
|
|
|
AffineExpr b, dm, c, h, w, kh, kw;
|
|
|
|
bindDims(ctx, b, dm, c, h, w, kh, kw);
|
|
|
|
unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
|
|
|
|
StructuredIndexed I(vI), W(vW), O(vO);
|
2020-01-02 09:14:23 -05:00
|
|
|
return makeGenericLinalgOp(
|
2019-12-16 13:32:02 -08:00
|
|
|
{par, par, par, par, par, red, red}, {
|
|
|
|
I({b,
|
|
|
|
// Roundtrip to flattened form to serve as canonicalization and ensure
|
|
|
|
// consistent ordering of subexpressions.
|
|
|
|
simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
|
|
|
|
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
|
|
|
|
c}),
|
|
|
|
W({kh, kw, c, dm})}, {
|
|
|
|
O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
|
|
|
|
macRegionBuilder);
|
2019-12-13 16:35:49 -08:00
|
|
|
// clang-format on
|
|
|
|
}
|