mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 01:56:06 +00:00
[mlir][linalg] Extend elementwise (#124661)
Implements Linalg elemwise named-op following the proposal and discussions in RFC: https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1
This commit is contained in:
parent
b9622e84b4
commit
6de5d1e46d
@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect {
|
||||
}];
|
||||
}
|
||||
|
||||
// Define the attribute enums matching elementwise op kind (e.g., add).
|
||||
def ElementwiseKindAttr : EnumAttr<Linalg_Dialect,
|
||||
ElementwiseKind, "elementwise_kind"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
// Define the function attribute enums matching the OpDSL functions.
|
||||
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
|
@ -55,6 +55,65 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
}
|
||||
|
||||
// Join two I32EnumAttrCase lists. This joining takes care that the
|
||||
// 'int enum values' in the combined list do not overlap. It does this
|
||||
// by adding to each element of second list the offset '!size(a)'.
|
||||
class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
|
||||
list<I32EnumAttrCase> b> {
|
||||
int aSize = !size(a);
|
||||
list<I32EnumAttrCase> result =
|
||||
!foldl(a, b, acc, var,
|
||||
acc # [I32EnumAttrCase<var.symbol,
|
||||
!add(var.value, aSize)
|
||||
>]);
|
||||
}
|
||||
|
||||
// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'.
|
||||
// The flattening (via call to 'join') ensures no overlap in enum values.
|
||||
class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
|
||||
list<I32EnumAttrCase> result =
|
||||
!foldl([]<I32EnumAttrCase>, l, acc, var,
|
||||
JoinTwoI32EnumAttrCaseList<acc, var>.result);
|
||||
}
|
||||
|
||||
// Define a unified `enum class : i32` for all element-wise op functions.
|
||||
def ElementwiseKind :
|
||||
I32EnumAttr<"ElementwiseKind",
|
||||
"",
|
||||
ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
|
||||
BinaryFn.enumerants,
|
||||
TernaryFn.enumerants]>.result
|
||||
> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
}
|
||||
|
||||
// Define an `enum class : i32` that marks where each individual enum class
|
||||
// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseKind.
|
||||
def ElementwiseCaseLimits : I32EnumAttr<"ElementwiseCaseLimits", "", []> {
|
||||
int last_unary = !size(UnaryFn.enumerants);
|
||||
int last_binary = !add(last_unary, !size(BinaryFn.enumerants));
|
||||
int last_ternary = !add(last_binary, !size(TernaryFn.enumerants));
|
||||
|
||||
let enumerants = [
|
||||
I32EnumAttrCase<"LastUnary", last_unary>,
|
||||
I32EnumAttrCase<"LastBinary", last_binary>,
|
||||
I32EnumAttrCase<"LastTernary", last_ternary>];
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
}
|
||||
|
||||
// Define an `enum class : i32` to categorise arity elementwise ops.
|
||||
def ElementwiseArityGroup : I32EnumAttr<"ElementwiseArityGroup", "", [
|
||||
I32EnumAttrCase<"Unary", 1>,
|
||||
I32EnumAttrCase<"Binary", 2>,
|
||||
I32EnumAttrCase<"Ternary", 3>
|
||||
]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
}
|
||||
|
||||
def TypeFn : I32EnumAttr<"TypeFn", "", [
|
||||
I32EnumAttrCase<"cast_signed", 0>,
|
||||
I32EnumAttrCase<"cast_unsigned", 1>
|
||||
|
@ -538,6 +538,126 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op definition for ElementwiseOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
|
||||
AttrSizedOperandSegments]> {
|
||||
let summary = [{ Performs element-wise operation }];
|
||||
let description = [{
|
||||
The attribute `kind` describes arithmetic operation to perform. The
|
||||
operation kind can be unary (e.g. max), binary (e.g. add) or ternary
|
||||
(e.g. select).
|
||||
|
||||
By default, all indexing maps are identities. In the case of default
|
||||
indexing map, all input and output shapes must match. The number of dims in
|
||||
each of the identity maps is equal to the rank of the output type.
|
||||
|
||||
Affine-maps for operands and result are required to be provided by the user
|
||||
when a transpose and/or broadcast is needed on any operand. When a map is not
|
||||
provided, default identity maps are inferred for each operand.
|
||||
|
||||
Iterator-types are always all `parallel`.
|
||||
Iterator-types are needed for constructing the underlying structured op.
|
||||
|
||||
The number of dims of the iterator-types are inferred from the rank of
|
||||
the result type.
|
||||
|
||||
Example:
|
||||
|
||||
Defining a unary linalg.elemwise with default indexing-map:
|
||||
```mlir
|
||||
%exp = linalg.elemwise
|
||||
kind=#linalg.elemwise_kind<exp>
|
||||
ins(%x : tensor<4x16x8xf32>)
|
||||
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
|
||||
```
|
||||
|
||||
Defining a binary linalg.elemwise with user-defined indexing-map:
|
||||
```mlir
|
||||
%add = linalg.elemwise
|
||||
kind=#linalg.elemwise_kind<add>
|
||||
indexing_maps = [#transpose, #broadcast, #identity]
|
||||
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
|
||||
outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$inputs,
|
||||
Variadic<AnyShaped>:$outputs,
|
||||
ElementwiseKindAttr:$kind,
|
||||
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
|
||||
);
|
||||
|
||||
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
|
||||
let regions = (region AnyRegion:$region);
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<
|
||||
(ins "ValueRange":$inputs, "ValueRange":$outputs,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
|
||||
attributes, ElementwiseOp::getRegionBuilder());
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = structuredOpsBaseDecls # [{
|
||||
/// Get the arity enum corresponding to the kind of op, e.g. if arg is
|
||||
/// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
|
||||
static ElementwiseArityGroup getArityGroup(ElementwiseKind n);
|
||||
|
||||
/// Both user-specified and default indexing map will always depend on
|
||||
/// the current Op instance.
|
||||
static bool hasDynamicIndexingMaps() { return true; }
|
||||
|
||||
/// Implements the block region builder for the elementwiseOp. This is
|
||||
/// called by the 'fillStructuredOpRegion'.
|
||||
static void regionBuilder(ImplicitLocOpBuilder &b,
|
||||
Block &block, ArrayRef<NamedAttribute> attrs);
|
||||
|
||||
static std::function<void(ImplicitLocOpBuilder &,
|
||||
Block &, ArrayRef<NamedAttribute>)>
|
||||
getRegionBuilder() {
|
||||
return regionBuilder;
|
||||
}
|
||||
|
||||
/// Returns rank of the result tensor/memref. Useful for knowing
|
||||
/// the dimensionality of the iteration space when others means
|
||||
/// are not possible e.g. absence of user-provided indexing map.
|
||||
unsigned getResultRank() {
|
||||
Value output = getDpsInitOperand(0)->get();
|
||||
ShapedType shapedType = llvm::cast<ShapedType>(output.getType());
|
||||
return shapedType.getRank();
|
||||
}
|
||||
|
||||
/// Returns N 'parallel' iterator types where N is rank of result.
|
||||
SmallVector<utils::IteratorType> getIteratorTypesArray();
|
||||
|
||||
/// The default indexing maps are identities.
|
||||
/// There will be N+1 such maps, where N is the arity of the Op.
|
||||
static SmallVector<AffineMap>
|
||||
getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims,
|
||||
MLIRContext *context);
|
||||
|
||||
/// Destination passing style interface method.
|
||||
::mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputsMutable();
|
||||
}
|
||||
|
||||
// Generic methods.
|
||||
std::string getLibraryCallName() {
|
||||
return generateLibraryCallName(getOperation());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op definition for MatmulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -4058,6 +4058,233 @@ Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
|
||||
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ElementwiseOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
namespace {
|
||||
struct ArityGroupAndKind {
|
||||
// The enum class {Unary, Binary, Ternary, ..}
|
||||
ElementwiseArityGroup arityGroup;
|
||||
|
||||
// The kind (e.g. `exp` or `add`) belonging to the arity group.
|
||||
union Kind {
|
||||
UnaryFn unaryFn;
|
||||
BinaryFn binaryFn;
|
||||
TernaryFn ternaryFn;
|
||||
} kind;
|
||||
};
|
||||
|
||||
unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
|
||||
return static_cast<unsigned>(arityGroup);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
|
||||
constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
|
||||
constexpr int lastBinary =
|
||||
static_cast<int>(ElementwiseCaseLimits::LastBinary);
|
||||
constexpr int lastTernary =
|
||||
static_cast<int>(ElementwiseCaseLimits::LastTernary);
|
||||
|
||||
int val = static_cast<int>(kind);
|
||||
ArityGroupAndKind result;
|
||||
|
||||
if (val < lastUnary) {
|
||||
result.arityGroup = ElementwiseArityGroup::Unary;
|
||||
result.kind.unaryFn = static_cast<UnaryFn>(val);
|
||||
return result;
|
||||
}
|
||||
if (val < lastBinary) {
|
||||
result.arityGroup = ElementwiseArityGroup::Binary;
|
||||
result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
|
||||
return result;
|
||||
}
|
||||
if (val >= lastTernary) {
|
||||
llvm_unreachable("unhandled ElementwiseFn");
|
||||
}
|
||||
result.arityGroup = ElementwiseArityGroup::Ternary;
|
||||
result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
|
||||
return result;
|
||||
}
|
||||
|
||||
SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
|
||||
auto rank = getResultRank();
|
||||
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
||||
}
|
||||
|
||||
SmallVector<AffineMap>
|
||||
ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
|
||||
MLIRContext *context) {
|
||||
auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
|
||||
return SmallVector<AffineMap>(numMaps, map);
|
||||
}
|
||||
|
||||
ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
// Expect e.g. `kind = #linalg.elemwise_kind<add>`
|
||||
Attribute attr;
|
||||
mlir::linalg::ElementwiseKind elemwiseKindVal;
|
||||
if (parser.parseKeyword("kind") || parser.parseEqual())
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseAttribute(attr))) {
|
||||
auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
|
||||
if (!elemwiseKindAttr)
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"expected ElementwiseKind attribute");
|
||||
elemwiseKindVal = elemwiseKindAttr.getValue();
|
||||
} else {
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"expected operation 'kind' attribute");
|
||||
}
|
||||
result.addAttribute(
|
||||
"kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
|
||||
|
||||
// Parse optional `indexing_maps`
|
||||
SmallVector<Attribute, 3> indexingMapsAttr;
|
||||
Attribute mapAttr;
|
||||
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
|
||||
if (parser.parseEqual())
|
||||
return failure();
|
||||
if (parser.parseLSquare())
|
||||
return failure();
|
||||
do {
|
||||
if (parser.parseAttribute(mapAttr))
|
||||
return failure();
|
||||
if (!isa<AffineMapAttr>(mapAttr))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"expected affine map attribute");
|
||||
indexingMapsAttr.push_back(mapAttr);
|
||||
if (parser.parseOptionalComma())
|
||||
break;
|
||||
} while (true);
|
||||
if (parser.parseRSquare())
|
||||
return failure();
|
||||
}
|
||||
// At this stage of parsing the only way to infer number of region
|
||||
// args is through op kind, as input output tensors are not parsed yet.
|
||||
auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
|
||||
int numRegionArgs =
|
||||
getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
|
||||
if (parseNamedStructuredOp(parser, result, numRegionArgs,
|
||||
ElementwiseOp::getRegionBuilder())) {
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"unable to parse elemwise op");
|
||||
}
|
||||
|
||||
// Initialize indexingMaps, if not supplied explicitly.
|
||||
if (indexingMapsAttr.empty()) {
|
||||
// We need to infer the numDims of the indexing maps from the output
|
||||
// type which is already parsed by now.
|
||||
auto resultType = result.operands[result.operands.size() - 1].getType();
|
||||
auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
|
||||
if (!shapedType)
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"return type needs to be shaped type");
|
||||
auto numDims = shapedType.getRank();
|
||||
indexingMapsAttr = llvm::map_to_vector(
|
||||
ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
|
||||
parser.getContext()),
|
||||
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
||||
}
|
||||
|
||||
result.addAttribute("indexing_maps",
|
||||
parser.getBuilder().getArrayAttr(indexingMapsAttr));
|
||||
return success();
|
||||
}
|
||||
|
||||
void ElementwiseOp::print(OpAsmPrinter &p) {
|
||||
p << " kind=";
|
||||
p.printAttribute(getKindAttr());
|
||||
SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
|
||||
"indexing_maps"};
|
||||
unsigned arity =
|
||||
getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
|
||||
unsigned numDims = getResultRank();
|
||||
|
||||
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
|
||||
ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
|
||||
getContext()),
|
||||
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
||||
|
||||
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
|
||||
p << " indexing_maps = [";
|
||||
llvm::interleaveComma(getIndexingMaps(), p,
|
||||
[&](Attribute attr) { p.printAttribute(attr); });
|
||||
p << "]";
|
||||
}
|
||||
|
||||
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
||||
elidedAttrs);
|
||||
}
|
||||
|
||||
LogicalResult ElementwiseOp::verify() {
|
||||
// All necessary checks are done either by
|
||||
// - EnumAttr (e.g. unknown operation kind)
|
||||
// - verifyStructuredOpInterface (incorrect map, sizes).
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Implements the block region builder for the ElementwiseOp. This is called by
|
||||
/// 'fillStructuredOpRegion'.
|
||||
void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
ElementwiseKind elemwiseKind;
|
||||
for (auto attr : attrs) {
|
||||
if (attr.getName() == b.getStringAttr("kind")) {
|
||||
auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
|
||||
assert(kindAttr && "op kind attribute incorrectly set");
|
||||
elemwiseKind = kindAttr.getValue();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
|
||||
auto arityGroup = groupAndKind.arityGroup;
|
||||
auto kind = groupAndKind.kind;
|
||||
unsigned numBlockArgs = getArityGroupAsUInt(arityGroup) + 1 /*output*/;
|
||||
assert(block.getNumArguments() == numBlockArgs &&
|
||||
"Elementwise regionBuilder number of block args mismatch");
|
||||
|
||||
RegionBuilderHelper helper(b, block);
|
||||
SmallVector<Value> yields;
|
||||
Value result;
|
||||
|
||||
if (arityGroup == ElementwiseArityGroup::Unary) {
|
||||
result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
|
||||
|
||||
} else if (arityGroup == ElementwiseArityGroup::Binary) {
|
||||
result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
|
||||
block.getArgument(1));
|
||||
|
||||
} else if (arityGroup == ElementwiseArityGroup::Ternary) {
|
||||
result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
|
||||
block.getArgument(1), block.getArgument(2));
|
||||
|
||||
} else
|
||||
assert(false && "found unhandled category in elemwise");
|
||||
|
||||
yields.push_back(result);
|
||||
helper.yieldOutputs(yields);
|
||||
}
|
||||
|
||||
LogicalResult ElementwiseOp::fold(FoldAdaptor,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return memref::foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
void ElementwiseOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
if (hasPureTensorSemantics())
|
||||
return;
|
||||
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
||||
}
|
||||
|
||||
Speculation::Speculatability ElementwiseOp::getSpeculatability() {
|
||||
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PackOp/UnPackOp Common
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
165
mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir
Normal file
165
mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir
Normal file
@ -0,0 +1,165 @@
|
||||
// RUN: mlir-opt %s -linalg-generalize-named-ops -split-input-file | FileCheck %s
|
||||
// CHECK: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
//
|
||||
// CHECK: @unary_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[IDENTITY]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[A]]
|
||||
// CHECK-SAME: outs(%[[B]]
|
||||
//
|
||||
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32)
|
||||
// CHECK: %[[EXP:.+]] = math.exp %[[A_ARG]] : f32
|
||||
// CHECK: linalg.yield %[[EXP]] : f32
|
||||
//
|
||||
func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
|
||||
%r = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<exp>
|
||||
ins(%A : tensor<8x16x32xf32>)
|
||||
outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
|
||||
return %r : tensor<8x16x32xf32>
|
||||
}
|
||||
// -----
|
||||
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
//
|
||||
// CHECK: @unary_transpose_broadcast_tanh(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[A]]
|
||||
// CHECK-SAME: outs(%[[B]]
|
||||
//
|
||||
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32)
|
||||
// CHECK: %[[TANH:.+]] = math.tanh %[[A_ARG]] : f32
|
||||
// CHECK: linalg.yield %[[TANH]] : f32
|
||||
//
|
||||
func.func @unary_transpose_broadcast_tanh(%A : tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
|
||||
%r = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<tanh>
|
||||
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
|
||||
ins(%A : tensor<32x16xf32>)
|
||||
outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
|
||||
return %r : tensor<8x16x32xf32>
|
||||
}
|
||||
// -----
|
||||
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
//
|
||||
// CHECK: @binary_div_on_memrefs(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>)
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[A]], %[[B]]
|
||||
// CHECK-SAME: outs(%[[C]]
|
||||
//
|
||||
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
|
||||
// CHECK: %[[DIV:.+]] = arith.divf %[[A_ARG]], %[[B_ARG]] : f32
|
||||
// CHECK: linalg.yield %[[DIV]] : f32
|
||||
//
|
||||
func.func @binary_div_on_memrefs(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
|
||||
linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<div>
|
||||
ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>)
|
||||
outs(%C: memref<16x8xf32>)
|
||||
return
|
||||
}
|
||||
// -----
|
||||
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
//
|
||||
// CHECK: @binary_mul_on_tensors(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[A]], %[[B]]
|
||||
// CHECK-SAME: outs(%[[C]]
|
||||
//
|
||||
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
|
||||
// CHECK: %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : f32
|
||||
// CHECK: linalg.yield %[[MUL]] : f32
|
||||
//
|
||||
func.func @binary_mul_on_tensors(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
|
||||
%r = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<mul>
|
||||
ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
|
||||
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
|
||||
return %r : tensor<16x8xf32>
|
||||
}
|
||||
// -----
|
||||
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)>
|
||||
//
|
||||
// CHECK: @binary_transpose_a(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[IDENTITY]], #[[IDENTITY]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[A]], %[[B]]
|
||||
// CHECK-SAME: outs(%[[C]]
|
||||
//
|
||||
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
|
||||
// CHECK: %[[SUB:.+]] = arith.subf %[[A_ARG]], %[[B_ARG]] : f32
|
||||
// CHECK: linalg.yield %[[SUB]] : f32
|
||||
//
|
||||
func.func @binary_transpose_a(%A : tensor<8x16xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
|
||||
%r = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<sub>
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>]
|
||||
ins(%A, %B: tensor<8x16xf32>, tensor<16x8xf32>)
|
||||
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
|
||||
return %r : tensor<16x8xf32>
|
||||
}
|
||||
// -----
|
||||
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)>
|
||||
// CHECK-DAG: #[[BROADCAST:.+]] = affine_map<(d0, d1) -> (d0)>
|
||||
//
|
||||
// CHECK: @binary_transpose_a_broadcast_b(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16xf32>, %[[C:.+]]: tensor<16x8xf32>)
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[BROADCAST]], #[[IDENTITY]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[A]], %[[B]]
|
||||
// CHECK-SAME: outs(%[[C]]
|
||||
//
|
||||
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
|
||||
// CHECK: %[[ADD:.+]] = arith.addf %[[A_ARG]], %[[B_ARG]] : f32
|
||||
// CHECK: linalg.yield %[[ADD]] : f32
|
||||
//
|
||||
func.func @binary_transpose_a_broadcast_b(%A : tensor<8x16xf32>, %B: tensor<16xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
|
||||
%r = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<add>
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
|
||||
affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>]
|
||||
ins(%A, %B: tensor<8x16xf32>, tensor<16xf32>)
|
||||
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
|
||||
return %r : tensor<16x8xf32>
|
||||
}
|
||||
// -----
|
||||
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
//
|
||||
// CHECK: @ternary(%[[A:.+]]: tensor<32x16xi1>, %[[B:.+]]: tensor<8x16x32xf32>, %[[C:.+]]: tensor<8x16x32xf32>, %[[D:.+]]: tensor<8x16x32xf32>)
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]], #[[IDENTITY]], #[[IDENTITY]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
||||
//
|
||||
// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]]
|
||||
// CHECK-SAME: outs(%[[D]]
|
||||
//
|
||||
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i1, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32, %[[D_ARG:.+]]: f32)
|
||||
// CHECK: %[[SELECTED:.+]] = arith.select %[[A_ARG]], %[[B_ARG]], %[[C_ARG]] : f32
|
||||
// CHECK: linalg.yield %[[SELECTED]] : f32
|
||||
//
|
||||
func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8x16x32xf32>, %D : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
|
||||
%r = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<select>
|
||||
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
|
||||
ins(%A, %B, %C : tensor<32x16xi1>, tensor<8x16x32xf32>, tensor<8x16x32xf32>)
|
||||
outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
|
||||
return %r : tensor<8x16x32xf32>
|
||||
}
|
54
mlir/test/Dialect/Linalg/elementwise/invalid.mlir
Normal file
54
mlir/test/Dialect/Linalg/elementwise/invalid.mlir
Normal file
@ -0,0 +1,54 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
|
||||
func.func @misspelt_op_div(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
|
||||
// expected-error@+3 {{expected ::mlir::linalg::ElementwiseKind to be one of: exp, log, abs, ceil, floor}}
|
||||
// expected-error@+2 {{failed to parse ElementwiseKindAttr parameter}}
|
||||
// expected-error@+1 {{custom op 'linalg.elementwise' expected operation 'kind' attribute}}
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<dive> ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func.func @missing_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
|
||||
// expected-error@+1 {{'linalg.elementwise' op expected the number of indexing_map (2) to be equal to the number of input/output operands (3)}}
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map] ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func.func @identity_map_when_transpose_expected(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
|
||||
// expected-error@+1 {{'linalg.elementwise' op inferred input/output operand #1 has shape's dimension #0 to be 8, but found 16}}
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map, #map] ins(%A, %B: memref<8x16xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#map1 = affine_map<(d0, d1) -> (d0)>
|
||||
func.func @incorrect_result_rank(%A : memref<8x16xf32>, %B: memref<8x16xf32>, %C: memref<8xf32>) {
|
||||
// expected-error@+1 {{'linalg.elementwise' op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map, #map1] ins(%A, %B: memref<8x16xf32>, memref<8x16xf32>) outs(%C: memref<8xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unary_too_many_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>, %C: memref<8x16x32xf32>) {
|
||||
// expected-error@+3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 2 args, got 3}}
|
||||
// expected-error@+2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<exp> ins(%A, %B : memref<8x16x32xf32>, memref<8x16x32xf32>) outs(%C: memref<8x16x32xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @binary_too_few_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>) {
|
||||
// expected-error@+3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 3 args, got 2}}
|
||||
// expected-error@+2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<add> ins(%A : memref<8x16x32xf32>) outs(%B: memref<8x16x32xf32>)
|
||||
return
|
||||
}
|
90
mlir/test/Dialect/Linalg/elementwise/round-trip.mlir
Normal file
90
mlir/test/Dialect/Linalg/elementwise/round-trip.mlir
Normal file
@ -0,0 +1,90 @@
|
||||
// RUN: mlir-opt %s -split-input-file | FileCheck %s
|
||||
//
|
||||
// Note - the functions are named @{unary|binary}_{identity|transpose|broadcast|transpose_a|...}_{exp|mul|div|..}
|
||||
|
||||
// CHECK: @unary_identity_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
|
||||
// CHECK: %{{.*}} = linalg.elementwise kind=#linalg.elementwise_kind<exp>
|
||||
// CHECK-SAME ins(%[[A:.+]] : tensor<8x16x32xf32>) outs(%[[B:.+]] : tensor<8x16x32xf32>)
|
||||
//
|
||||
func.func @unary_identity_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
|
||||
%r = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<exp>
|
||||
ins(%A : tensor<8x16x32xf32>)
|
||||
outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
|
||||
return %r : tensor<8x16x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
//
|
||||
// CHECK: @unary_projection_tanh(%[[A:.+]]: tensor<?x16xf32>,
|
||||
// CHECK-SAME: %[[B:.+]]: tensor<8x16x?xf32>) -> tensor<8x16x?xf32> {
|
||||
// CHECK: {{.*}} = linalg.elementwise kind=#linalg.elementwise_kind<tanh>
|
||||
// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]]
|
||||
// CHECK-SAME: ins(%[[A]] : tensor<?x16xf32>) outs(%[[B]] : tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
|
||||
//
|
||||
func.func @unary_projection_tanh(%A: tensor<?x16xf32>,
|
||||
%B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32> {
|
||||
%r = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<tanh>
|
||||
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
|
||||
ins(%A : tensor<?x16xf32>)
|
||||
outs(%B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
|
||||
return %r : tensor<8x16x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: @binary_identity_div(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>,
|
||||
// CHECK-SAME: %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
|
||||
// CHECK: {{.*}} = linalg.elementwise
|
||||
// CHECK-SAME: kind=#linalg.elementwise_kind<div>
|
||||
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
|
||||
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
|
||||
//
|
||||
func.func @binary_identity_div(%A: tensor<16x8xf32>, %B: tensor<16x8xf32>,
|
||||
%C: tensor<16x8xf32>) -> tensor<16x8xf32> {
|
||||
%r = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<div>
|
||||
ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
|
||||
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
|
||||
return %r : tensor<16x8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: @binary_identity_mul_5Di(%[[A]]: tensor<1x2x3x4x5xi32>,
|
||||
// CHECK-SAME: %[[B:.+]]: tensor<1x2x3x4x5xi32>,
|
||||
// CHECK-SAME: %[[C:.+]]: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
|
||||
// CHECK: {{.*}} = linalg.elementwise
|
||||
// CHECK-SAME: kind=#linalg.elementwise_kind<mul>
|
||||
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
|
||||
// CHECK-SAME: outs(%[[C]] : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
|
||||
//
|
||||
func.func @binary_identity_mul_5Di(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
|
||||
%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
|
||||
%r = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<mul>
|
||||
ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
|
||||
outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
|
||||
return %r : tensor<1x2x3x4x5xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: @redundant_maps
|
||||
// CHECK-NOT: indexing_maps
|
||||
//
|
||||
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
|
||||
func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
|
||||
%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
|
||||
%r = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<mul>
|
||||
indexing_maps = [#map, #map, #map]
|
||||
ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
|
||||
outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
|
||||
return %r : tensor<1x2x3x4x5xi32>
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user