mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-29 03:16:07 +00:00
Add a layer of EDSC for linalg.GenericOp
This will be evolved into a simple programming model for custom ops and custom layers in followup CLs. This CL also deletes the obsolete tablegen's reference-impl.td that was using EDSCs. PiperOrigin-RevId: 285459545
This commit is contained in:
parent
b030e4a4ec
commit
7923abd357
@ -1,5 +1,4 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(EDSC)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
43
mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
Normal file
43
mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
Normal file
@ -0,0 +1,43 @@
|
||||
//===- Builders.h - MLIR Declarative Linalg Builders ------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Provides intuitive composable interfaces for building structured MLIR
|
||||
// snippets in a declarative fashion.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
|
||||
#define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
||||
namespace mlir {
|
||||
class BlockArgument;
|
||||
namespace edsc {
|
||||
|
||||
inline void defaultRegionBuilder(ArrayRef<BlockArgument *> args) {}
|
||||
|
||||
/// EDSC entry point to build linalg.generic operations programmatically.
|
||||
Operation *makeLinalgGenericOp(
|
||||
ArrayRef<AffineExpr> indices, ArrayRef<ArrayRef<AffineExpr>> mapExpressions,
|
||||
ArrayRef<Value *> inputViews, ArrayRef<Value *> outputViews,
|
||||
ArrayRef<StringRef> iteratorTypes,
|
||||
decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder);
|
||||
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
|
@ -529,6 +529,18 @@ ValueHandle operator>(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs);
|
||||
|
||||
} // namespace op
|
||||
|
||||
/// Entry point to build multiple ValueHandle from a `Container` of Value* or
|
||||
/// Type.
|
||||
template <typename Container>
|
||||
inline SmallVector<ValueHandle, 8> makeValueHandles(Container values) {
|
||||
SmallVector<ValueHandle, 8> res;
|
||||
res.reserve(values.size());
|
||||
for (auto v : values)
|
||||
res.push_back(ValueHandle(v));
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -1,3 +0,0 @@
|
||||
set(LLVM_TARGET_DEFINITIONS "${MLIR_SOURCE_DIR}/test/mlir-tblgen/reference-impl.td")
|
||||
mlir_tablegen("reference-impl.inc" -gen-reference-implementations)
|
||||
add_public_tablegen_target(MLIRReferenceImplementationTestGen)
|
@ -67,8 +67,10 @@ inline SmallVector<IndexHandle, 8> makeIndexHandles(unsigned rank) {
|
||||
return SmallVector<IndexHandle, 8>(rank);
|
||||
}
|
||||
|
||||
/// Entry point to build multiple ValueHandle* from a mutable list `ivs` of T.
|
||||
template <typename T>
|
||||
inline SmallVector<ValueHandle *, 8>
|
||||
makeIndexHandlePointers(MutableArrayRef<IndexHandle> ivs) {
|
||||
makeHandlePointers(MutableArrayRef<T> ivs) {
|
||||
SmallVector<ValueHandle *, 8> pivs;
|
||||
pivs.reserve(ivs.size());
|
||||
for (auto &iv : ivs) {
|
||||
|
@ -286,6 +286,23 @@ bool getFlattenedAffineExprs(
|
||||
bool getFlattenedAffineExprs(
|
||||
IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs);
|
||||
|
||||
namespace detail {
|
||||
template <int N> void bindDims(MLIRContext *ctx) {}
|
||||
|
||||
template <int N, typename AffineExprTy, typename... AffineExprTy2>
|
||||
void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &... exprs) {
|
||||
e = getAffineDimExpr(N, ctx);
|
||||
bindDims<N + 1, AffineExprTy2 &...>(ctx, exprs...);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/// Bind a list of AffineExpr references to DimExpr at positions:
|
||||
/// [0 .. sizeof...(exprs)]
|
||||
template <typename... AffineExprTy>
|
||||
void bindDims(MLIRContext *ctx, AffineExprTy &... exprs) {
|
||||
detail::bindDims<0>(ctx, exprs...);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
|
@ -21,8 +21,8 @@
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
||||
#include "mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h"
|
||||
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
@ -270,7 +270,7 @@ PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
||||
VectorView vectorView(transfer.vector());
|
||||
SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
|
||||
SmallVector<ValueHandle *, 8> pivs =
|
||||
makeIndexHandlePointers(MutableArrayRef<IndexHandle>(ivs));
|
||||
makeHandlePointers(MutableArrayRef<IndexHandle>(ivs));
|
||||
coalesceCopy(transfer, &pivs, &vectorView);
|
||||
|
||||
auto lbs = vectorView.getLbs();
|
||||
@ -332,7 +332,8 @@ PatternMatchResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
||||
ValueHandle vectorValue(transfer.vector());
|
||||
VectorView vectorView(transfer.vector());
|
||||
SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
|
||||
SmallVector<ValueHandle *, 8> pivs = makeIndexHandlePointers(ivs);
|
||||
SmallVector<ValueHandle *, 8> pivs =
|
||||
makeHandlePointers(MutableArrayRef<IndexHandle>(ivs));
|
||||
coalesceCopy(transfer, &pivs, &vectorView);
|
||||
|
||||
auto lbs = vectorView.getLbs();
|
||||
|
@ -1,6 +1,7 @@
|
||||
add_llvm_library(MLIRLinalg
|
||||
LinalgRegistration.cpp
|
||||
Analysis/DependenceAnalysis.cpp
|
||||
EDSC/Builders.cpp
|
||||
IR/LinalgOps.cpp
|
||||
IR/LinalgTypes.cpp
|
||||
Transforms/Fusion.cpp
|
||||
@ -20,6 +21,7 @@ add_dependencies(MLIRLinalg
|
||||
|
||||
MLIRAffineOps
|
||||
MLIRAnalysis
|
||||
MLIREDSC
|
||||
MLIRLinalgOpsIncGen
|
||||
MLIRLinalgLibraryOpsIncGen
|
||||
MLIRLinalgTransformPatternsIncGen
|
||||
|
72
mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
Normal file
72
mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
Normal file
@ -0,0 +1,72 @@
|
||||
//===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
|
||||
Operation *mlir::edsc::makeLinalgGenericOp(
|
||||
ArrayRef<AffineExpr> indices, ArrayRef<ArrayRef<AffineExpr>> mapExpressions,
|
||||
ArrayRef<Value *> inputViews, ArrayRef<Value *> outputViews,
|
||||
ArrayRef<StringRef> iteratorTypes,
|
||||
decltype(defaultRegionBuilder) regionBuilder) {
|
||||
auto &builder = edsc::ScopedContext::getBuilder();
|
||||
auto *ctx = builder.getContext();
|
||||
|
||||
SmallVector<AffineMap, 4> maps;
|
||||
maps.reserve(mapExpressions.size());
|
||||
for (auto exprs : mapExpressions)
|
||||
maps.push_back(AffineMap::get(indices.size(), 0, exprs));
|
||||
|
||||
SmallVector<Value *, 4> views;
|
||||
views.reserve(inputViews.size() + outputViews.size());
|
||||
views.append(inputViews.begin(), inputViews.end());
|
||||
views.append(outputViews.begin(), outputViews.end());
|
||||
|
||||
auto *op =
|
||||
edsc::ScopedContext::getBuilder()
|
||||
.create<linalg::GenericOp>(
|
||||
edsc::ScopedContext::getLocation(), views,
|
||||
IntegerAttr::get(IntegerType::get(64, ctx), inputViews.size()),
|
||||
IntegerAttr::get(IntegerType::get(64, ctx), outputViews.size()),
|
||||
builder.getAffineMapArrayAttr(maps),
|
||||
builder.getStrArrayAttr(iteratorTypes), StringAttr() /*doc*/,
|
||||
FlatSymbolRefAttr() /*fun*/, StringAttr() /*library_call*/
|
||||
)
|
||||
.getOperation();
|
||||
|
||||
using namespace edsc;
|
||||
SmallVector<Type, 4> blockTypes;
|
||||
blockTypes.reserve(views.size());
|
||||
for (auto *v : views)
|
||||
blockTypes.push_back(getElementTypeOrSelf(v));
|
||||
|
||||
assert(op->getRegions().front().empty());
|
||||
op->getRegions().front().push_front(new Block);
|
||||
OpBuilder bb(op->getRegions().front());
|
||||
ScopedContext scope(bb, op->getLoc());
|
||||
BlockHandle b;
|
||||
auto handles = makeValueHandles(blockTypes);
|
||||
BlockBuilder(&b, makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
|
||||
[&] { regionBuilder(b.getBlock()->getArguments()); });
|
||||
return op;
|
||||
}
|
@ -430,7 +430,8 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
|
||||
auto nRed = linalgOp.getNumReductionLoops();
|
||||
auto nWin = linalgOp.getNumWindowLoops();
|
||||
SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin);
|
||||
SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs);
|
||||
SmallVector<ValueHandle *, 4> allPIvs =
|
||||
makeHandlePointers(MutableArrayRef<IndexHandle>(allIvs));
|
||||
auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(),
|
||||
invertedMap, getViewSizes(linalgOp));
|
||||
assert(loopRanges.size() == allIvs.size());
|
||||
|
@ -356,7 +356,7 @@ llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
||||
// 3. Create the tiled loops.
|
||||
LinalgOp res = op;
|
||||
SmallVector<IndexHandle, 4> ivs(loopRanges.size());
|
||||
auto pivs = makeIndexHandlePointers(ivs);
|
||||
auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs));
|
||||
LoopNestRangeBuilder(pivs, loopRanges)([&] {
|
||||
auto b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
|
@ -13,7 +13,6 @@ add_llvm_library(MLIREDSC
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/EDSC
|
||||
)
|
||||
add_dependencies(MLIREDSC MLIRReferenceImplementationTestGen)
|
||||
target_link_libraries(MLIREDSC
|
||||
PUBLIC
|
||||
MLIRAffineOps
|
||||
@ -30,7 +29,6 @@ add_llvm_library(MLIREDSCInterface
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/EDSC
|
||||
)
|
||||
add_dependencies(MLIREDSCInterface MLIRIR)
|
||||
add_dependencies(MLIREDSC MLIRReferenceImplementationTestGen)
|
||||
target_link_libraries(MLIREDSC
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
|
@ -9,6 +9,7 @@ target_link_libraries(mlir-edsc-builder-api-test
|
||||
MLIRAffineOps
|
||||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRLinalg
|
||||
MLIRLoopOps
|
||||
MLIRStandardOps
|
||||
MLIRTransforms
|
||||
@ -20,6 +21,7 @@ target_include_directories(mlir-edsc-builder-api-test PRIVATE ..)
|
||||
|
||||
whole_archive_link(mlir-edsc-builder-api-test
|
||||
MLIRAffineOps
|
||||
MLIRLinalg
|
||||
MLIRLoopOps
|
||||
MLIRStandardOps
|
||||
MLIRTransforms
|
||||
|
@ -18,10 +18,13 @@
|
||||
// RUN: mlir-edsc-builder-api-test | FileCheck %s
|
||||
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
@ -30,6 +33,7 @@
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
@ -806,6 +810,48 @@ TEST_FUNC(affine_if_op) {
|
||||
f.erase();
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @linalg_matmul
|
||||
// CHECK: linalg.generic
|
||||
/// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
|
||||
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
|
||||
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
|
||||
// CHECK: linalg.yield %[[a4]] : f32
|
||||
// CHECK: }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
|
||||
// clang-format on
|
||||
TEST_FUNC(linalg_matmul) {
|
||||
using namespace edsc;
|
||||
using namespace edsc::intrinsics;
|
||||
using namespace edsc::op;
|
||||
using linalg_yield = OperationBuilder<linalg::YieldOp>;
|
||||
|
||||
auto f32Type = FloatType::getF32(&globalContext());
|
||||
auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
|
||||
auto f =
|
||||
makeFunction("linalg_matmul", {}, {memrefType, memrefType, memrefType});
|
||||
|
||||
// clang-format off
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
Value *A(f.getArgument(0)), *B(f.getArgument(1)), *C(f.getArgument(2));
|
||||
AffineExpr m, n, k;
|
||||
bindDims(f.getContext(), m, n, k);
|
||||
makeLinalgGenericOp(
|
||||
{m, n, k},
|
||||
{{m, n}, {k, n}, {m, n}},
|
||||
{A, B},
|
||||
{C},
|
||||
{"parallel", "parallel", "reduction"},
|
||||
[](ArrayRef<BlockArgument *> args) {
|
||||
ValueHandle a(args[0]), b(args[1]), c(args[2]);
|
||||
linalg_yield((c + a * b).getValue());
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
int main() {
|
||||
RUN_TESTS();
|
||||
return 0;
|
||||
|
@ -1,25 +0,0 @@
|
||||
// RUN: mlir-tblgen -gen-reference-implementations -I %S/../../include %s | FileCheck %s
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def X_Dialect : Dialect {
|
||||
let name = "x";
|
||||
}
|
||||
class X_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<X_Dialect, mnemonic, traits>;
|
||||
|
||||
def X_AddOp : X_Op<"add">,
|
||||
Arguments<(ins AnyTensor:$A, AnyTensor:$B)>,
|
||||
Results<(outs AnyTensor: $C)> {
|
||||
// TODO: extract referenceImplementation to Op.
|
||||
code referenceImplementation = [{
|
||||
auto ivs = IndexedLinalgValuemakeIndexHandles(view_A.rank());
|
||||
auto pivs = IndexedLinalgValuemakeIndexHandlePointers(ivs);
|
||||
IndexedValue A(arg_A), B(arg_B), C(arg_C);
|
||||
AffineLoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())({
|
||||
C(ivs) = A(ivs) + B(ivs)
|
||||
});
|
||||
}];
|
||||
}
|
||||
|
||||
// CHECK: printRefImplementation
|
Loading…
x
Reference in New Issue
Block a user