[mlir][linalg] Extend elementwise (#124661)

Implements Linalg elemwise named-op following the proposal and
discussions in RFC:
  https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1
This commit is contained in:
Javed Absar 2025-02-21 10:51:21 +00:00 committed by GitHub
parent b9622e84b4
commit 6de5d1e46d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 721 additions and 0 deletions

View File

@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect {
}];
}
// Define the attribute enums matching elementwise op kind (e.g., add).
def ElementwiseKindAttr : EnumAttr<Linalg_Dialect,
ElementwiseKind, "elementwise_kind"> {
let assemblyFormat = "`<` $value `>`";
}
// Define the function attribute enums matching the OpDSL functions.
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
let assemblyFormat = "`<` $value `>`";

View File

@ -55,6 +55,65 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
// Join two I32EnumAttrCase lists. This joining takes care that the
// 'int enum values' in the combined list do not overlap. It does this
// by adding to each element of second list the offset '!size(a)'.
class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
list<I32EnumAttrCase> b> {
int aSize = !size(a);
list<I32EnumAttrCase> result =
!foldl(a, b, acc, var,
acc # [I32EnumAttrCase<var.symbol,
!add(var.value, aSize)
>]);
}
// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'.
// The flattening (via call to 'join') ensures no overlap in enum values.
class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
list<I32EnumAttrCase> result =
!foldl([]<I32EnumAttrCase>, l, acc, var,
JoinTwoI32EnumAttrCaseList<acc, var>.result);
}
// Define a unified `enum class : i32` for all element-wise op functions.
def ElementwiseKind :
I32EnumAttr<"ElementwiseKind",
"",
ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
BinaryFn.enumerants,
TernaryFn.enumerants]>.result
> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
// Define an `enum class : i32` that marks where each individual enum class
// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseKind.
def ElementwiseCaseLimits : I32EnumAttr<"ElementwiseCaseLimits", "", []> {
int last_unary = !size(UnaryFn.enumerants);
int last_binary = !add(last_unary, !size(BinaryFn.enumerants));
int last_ternary = !add(last_binary, !size(TernaryFn.enumerants));
let enumerants = [
I32EnumAttrCase<"LastUnary", last_unary>,
I32EnumAttrCase<"LastBinary", last_binary>,
I32EnumAttrCase<"LastTernary", last_ternary>];
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
// Define an `enum class : i32` to categorise arity elementwise ops.
def ElementwiseArityGroup : I32EnumAttr<"ElementwiseArityGroup", "", [
I32EnumAttrCase<"Unary", 1>,
I32EnumAttrCase<"Binary", 2>,
I32EnumAttrCase<"Ternary", 3>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast_signed", 0>,
I32EnumAttrCase<"cast_unsigned", 1>

View File

@ -538,6 +538,126 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// Op definition for ElementwiseOp
//===----------------------------------------------------------------------===//
def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
AttrSizedOperandSegments]> {
let summary = [{ Performs element-wise operation }];
let description = [{
The attribute `kind` describes arithmetic operation to perform. The
operation kind can be unary (e.g. max), binary (e.g. add) or ternary
(e.g. select).
By default, all indexing maps are identities. In the case of default
indexing map, all input and output shapes must match. The number of dims in
each of the identity maps is equal to the rank of the output type.
Affine-maps for operands and result are required to be provided by the user
when a transpose and/or broadcast is needed on any operand. When a map is not
provided, default identity maps are inferred for each operand.
Iterator-types are always all `parallel`.
Iterator-types are needed for constructing the underlying structured op.
The number of dims of the iterator-types are inferred from the rank of
the result type.
Example:
Defining a unary linalg.elemwise with default indexing-map:
```mlir
%exp = linalg.elemwise
kind=#linalg.elemwise_kind<exp>
ins(%x : tensor<4x16x8xf32>)
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
```
Defining a binary linalg.elemwise with user-defined indexing-map:
```mlir
%add = linalg.elemwise
kind=#linalg.elemwise_kind<add>
indexing_maps = [#transpose, #broadcast, #identity]
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
```
}];
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
ElementwiseKindAttr:$kind,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, ElementwiseOp::getRegionBuilder());
}]>
];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{
/// Get the arity enum corresponding to the kind of op, e.g. if arg is
/// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
static ElementwiseArityGroup getArityGroup(ElementwiseKind n);
/// Both user-specified and default indexing map will always depend on
/// the current Op instance.
static bool hasDynamicIndexingMaps() { return true; }
/// Implements the block region builder for the elementwiseOp. This is
/// called by the 'fillStructuredOpRegion'.
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}
/// Returns rank of the result tensor/memref. Useful for knowing
/// the dimensionality of the iteration space when others means
/// are not possible e.g. absence of user-provided indexing map.
unsigned getResultRank() {
Value output = getDpsInitOperand(0)->get();
ShapedType shapedType = llvm::cast<ShapedType>(output.getType());
return shapedType.getRank();
}
/// Returns N 'parallel' iterator types where N is rank of result.
SmallVector<utils::IteratorType> getIteratorTypesArray();
/// The default indexing maps are identities.
/// There will be N+1 such maps, where N is the arity of the Op.
static SmallVector<AffineMap>
getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims,
MLIRContext *context);
/// Destination passing style interface method.
::mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
// Generic methods.
std::string getLibraryCallName() {
return generateLibraryCallName(getOperation());
}
}];
}
//===----------------------------------------------------------------------===//
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//

View File

@ -4058,6 +4058,233 @@ Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
//===----------------------------------------------------------------------===//
// ElementwiseOp
//===----------------------------------------------------------------------===//
//
namespace {
struct ArityGroupAndKind {
// The enum class {Unary, Binary, Ternary, ..}
ElementwiseArityGroup arityGroup;
// The kind (e.g. `exp` or `add`) belonging to the arity group.
union Kind {
UnaryFn unaryFn;
BinaryFn binaryFn;
TernaryFn ternaryFn;
} kind;
};
unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
return static_cast<unsigned>(arityGroup);
}
} // namespace
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
constexpr int lastBinary =
static_cast<int>(ElementwiseCaseLimits::LastBinary);
constexpr int lastTernary =
static_cast<int>(ElementwiseCaseLimits::LastTernary);
int val = static_cast<int>(kind);
ArityGroupAndKind result;
if (val < lastUnary) {
result.arityGroup = ElementwiseArityGroup::Unary;
result.kind.unaryFn = static_cast<UnaryFn>(val);
return result;
}
if (val < lastBinary) {
result.arityGroup = ElementwiseArityGroup::Binary;
result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
return result;
}
if (val >= lastTernary) {
llvm_unreachable("unhandled ElementwiseFn");
}
result.arityGroup = ElementwiseArityGroup::Ternary;
result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
return result;
}
SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
auto rank = getResultRank();
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
}
SmallVector<AffineMap>
ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
MLIRContext *context) {
auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
return SmallVector<AffineMap>(numMaps, map);
}
ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
// Expect e.g. `kind = #linalg.elemwise_kind<add>`
Attribute attr;
mlir::linalg::ElementwiseKind elemwiseKindVal;
if (parser.parseKeyword("kind") || parser.parseEqual())
return failure();
if (succeeded(parser.parseAttribute(attr))) {
auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
if (!elemwiseKindAttr)
return parser.emitError(parser.getCurrentLocation(),
"expected ElementwiseKind attribute");
elemwiseKindVal = elemwiseKindAttr.getValue();
} else {
return parser.emitError(parser.getCurrentLocation(),
"expected operation 'kind' attribute");
}
result.addAttribute(
"kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
// Parse optional `indexing_maps`
SmallVector<Attribute, 3> indexingMapsAttr;
Attribute mapAttr;
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
if (parser.parseEqual())
return failure();
if (parser.parseLSquare())
return failure();
do {
if (parser.parseAttribute(mapAttr))
return failure();
if (!isa<AffineMapAttr>(mapAttr))
return parser.emitError(parser.getCurrentLocation(),
"expected affine map attribute");
indexingMapsAttr.push_back(mapAttr);
if (parser.parseOptionalComma())
break;
} while (true);
if (parser.parseRSquare())
return failure();
}
// At this stage of parsing the only way to infer number of region
// args is through op kind, as input output tensors are not parsed yet.
auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
int numRegionArgs =
getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
if (parseNamedStructuredOp(parser, result, numRegionArgs,
ElementwiseOp::getRegionBuilder())) {
return parser.emitError(parser.getCurrentLocation(),
"unable to parse elemwise op");
}
// Initialize indexingMaps, if not supplied explicitly.
if (indexingMapsAttr.empty()) {
// We need to infer the numDims of the indexing maps from the output
// type which is already parsed by now.
auto resultType = result.operands[result.operands.size() - 1].getType();
auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
if (!shapedType)
return parser.emitError(parser.getCurrentLocation(),
"return type needs to be shaped type");
auto numDims = shapedType.getRank();
indexingMapsAttr = llvm::map_to_vector(
ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
parser.getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
}
result.addAttribute("indexing_maps",
parser.getBuilder().getArrayAttr(indexingMapsAttr));
return success();
}
void ElementwiseOp::print(OpAsmPrinter &p) {
p << " kind=";
p.printAttribute(getKindAttr());
SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
"indexing_maps"};
unsigned arity =
getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
unsigned numDims = getResultRank();
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
p << " indexing_maps = [";
llvm::interleaveComma(getIndexingMaps(), p,
[&](Attribute attr) { p.printAttribute(attr); });
p << "]";
}
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);
}
LogicalResult ElementwiseOp::verify() {
// All necessary checks are done either by
// - EnumAttr (e.g. unknown operation kind)
// - verifyStructuredOpInterface (incorrect map, sizes).
return success();
}
/// Implements the block region builder for the ElementwiseOp. This is called by
/// 'fillStructuredOpRegion'.
void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
ArrayRef<NamedAttribute> attrs) {
ElementwiseKind elemwiseKind;
for (auto attr : attrs) {
if (attr.getName() == b.getStringAttr("kind")) {
auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
assert(kindAttr && "op kind attribute incorrectly set");
elemwiseKind = kindAttr.getValue();
break;
}
}
ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
auto arityGroup = groupAndKind.arityGroup;
auto kind = groupAndKind.kind;
unsigned numBlockArgs = getArityGroupAsUInt(arityGroup) + 1 /*output*/;
assert(block.getNumArguments() == numBlockArgs &&
"Elementwise regionBuilder number of block args mismatch");
RegionBuilderHelper helper(b, block);
SmallVector<Value> yields;
Value result;
if (arityGroup == ElementwiseArityGroup::Unary) {
result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
} else if (arityGroup == ElementwiseArityGroup::Binary) {
result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
block.getArgument(1));
} else if (arityGroup == ElementwiseArityGroup::Ternary) {
result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
block.getArgument(1), block.getArgument(2));
} else
assert(false && "found unhandled category in elemwise");
yields.push_back(result);
helper.yieldOutputs(yields);
}
LogicalResult ElementwiseOp::fold(FoldAdaptor,
SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
void ElementwiseOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (hasPureTensorSemantics())
return;
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
Speculation::Speculatability ElementwiseOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
//===----------------------------------------------------------------------===//
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,165 @@
// RUN: mlir-opt %s -linalg-generalize-named-ops -split-input-file | FileCheck %s
// CHECK: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
//
// CHECK: @unary_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[IDENTITY]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[A]]
// CHECK-SAME: outs(%[[B]]
//
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32)
// CHECK: %[[EXP:.+]] = math.exp %[[A_ARG]] : f32
// CHECK: linalg.yield %[[EXP]] : f32
//
func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<exp>
ins(%A : tensor<8x16x32xf32>)
outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
}
// -----
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
//
// CHECK: @unary_transpose_broadcast_tanh(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[A]]
// CHECK-SAME: outs(%[[B]]
//
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32)
// CHECK: %[[TANH:.+]] = math.tanh %[[A_ARG]] : f32
// CHECK: linalg.yield %[[TANH]] : f32
//
func.func @unary_transpose_broadcast_tanh(%A : tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<tanh>
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
ins(%A : tensor<32x16xf32>)
outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
}
// -----
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
//
// CHECK: @binary_div_on_memrefs(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>)
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[A]], %[[B]]
// CHECK-SAME: outs(%[[C]]
//
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
// CHECK: %[[DIV:.+]] = arith.divf %[[A_ARG]], %[[B_ARG]] : f32
// CHECK: linalg.yield %[[DIV]] : f32
//
func.func @binary_div_on_memrefs(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
linalg.elementwise
kind=#linalg.elementwise_kind<div>
ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>)
outs(%C: memref<16x8xf32>)
return
}
// -----
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
//
// CHECK: @binary_mul_on_tensors(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[A]], %[[B]]
// CHECK-SAME: outs(%[[C]]
//
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
// CHECK: %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : f32
// CHECK: linalg.yield %[[MUL]] : f32
//
func.func @binary_mul_on_tensors(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<mul>
ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
return %r : tensor<16x8xf32>
}
// -----
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)>
//
// CHECK: @binary_transpose_a(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[IDENTITY]], #[[IDENTITY]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[A]], %[[B]]
// CHECK-SAME: outs(%[[C]]
//
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
// CHECK: %[[SUB:.+]] = arith.subf %[[A_ARG]], %[[B_ARG]] : f32
// CHECK: linalg.yield %[[SUB]] : f32
//
func.func @binary_transpose_a(%A : tensor<8x16xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<sub>
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>]
ins(%A, %B: tensor<8x16xf32>, tensor<16x8xf32>)
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
return %r : tensor<16x8xf32>
}
// -----
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-DAG: #[[BROADCAST:.+]] = affine_map<(d0, d1) -> (d0)>
//
// CHECK: @binary_transpose_a_broadcast_b(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16xf32>, %[[C:.+]]: tensor<16x8xf32>)
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[BROADCAST]], #[[IDENTITY]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[A]], %[[B]]
// CHECK-SAME: outs(%[[C]]
//
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
// CHECK: %[[ADD:.+]] = arith.addf %[[A_ARG]], %[[B_ARG]] : f32
// CHECK: linalg.yield %[[ADD]] : f32
//
func.func @binary_transpose_a_broadcast_b(%A : tensor<8x16xf32>, %B: tensor<16xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<add>
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>]
ins(%A, %B: tensor<8x16xf32>, tensor<16xf32>)
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
return %r : tensor<16x8xf32>
}
// -----
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
//
// CHECK: @ternary(%[[A:.+]]: tensor<32x16xi1>, %[[B:.+]]: tensor<8x16x32xf32>, %[[C:.+]]: tensor<8x16x32xf32>, %[[D:.+]]: tensor<8x16x32xf32>)
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]], #[[IDENTITY]], #[[IDENTITY]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
//
// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]]
// CHECK-SAME: outs(%[[D]]
//
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i1, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32, %[[D_ARG:.+]]: f32)
// CHECK: %[[SELECTED:.+]] = arith.select %[[A_ARG]], %[[B_ARG]], %[[C_ARG]] : f32
// CHECK: linalg.yield %[[SELECTED]] : f32
//
func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8x16x32xf32>, %D : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<select>
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
ins(%A, %B, %C : tensor<32x16xi1>, tensor<8x16x32xf32>, tensor<8x16x32xf32>)
outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
}

View File

@ -0,0 +1,54 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @misspelt_op_div(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
// expected-error@+3 {{expected ::mlir::linalg::ElementwiseKind to be one of: exp, log, abs, ceil, floor}}
// expected-error@+2 {{failed to parse ElementwiseKindAttr parameter}}
// expected-error@+1 {{custom op 'linalg.elementwise' expected operation 'kind' attribute}}
linalg.elementwise kind=#linalg.elementwise_kind<dive> ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
return
}
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @missing_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
// expected-error@+1 {{'linalg.elementwise' op expected the number of indexing_map (2) to be equal to the number of input/output operands (3)}}
linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map] ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
return
}
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @identity_map_when_transpose_expected(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
// expected-error@+1 {{'linalg.elementwise' op inferred input/output operand #1 has shape's dimension #0 to be 8, but found 16}}
linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map, #map] ins(%A, %B: memref<8x16xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
return
}
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
func.func @incorrect_result_rank(%A : memref<8x16xf32>, %B: memref<8x16xf32>, %C: memref<8xf32>) {
// expected-error@+1 {{'linalg.elementwise' op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map, #map1] ins(%A, %B: memref<8x16xf32>, memref<8x16xf32>) outs(%C: memref<8xf32>)
return
}
// -----
func.func @unary_too_many_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>, %C: memref<8x16x32xf32>) {
// expected-error@+3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 2 args, got 3}}
// expected-error@+2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
linalg.elementwise kind=#linalg.elementwise_kind<exp> ins(%A, %B : memref<8x16x32xf32>, memref<8x16x32xf32>) outs(%C: memref<8x16x32xf32>)
return
}
// -----
func.func @binary_too_few_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>) {
// expected-error@+3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 3 args, got 2}}
// expected-error@+2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
linalg.elementwise kind=#linalg.elementwise_kind<add> ins(%A : memref<8x16x32xf32>) outs(%B: memref<8x16x32xf32>)
return
}

View File

@ -0,0 +1,90 @@
// RUN: mlir-opt %s -split-input-file | FileCheck %s
//
// Note - the functions are named @{unary|binary}_{identity|transpose|broadcast|transpose_a|...}_{exp|mul|div|..}
// CHECK: @unary_identity_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
// CHECK: %{{.*}} = linalg.elementwise kind=#linalg.elementwise_kind<exp>
// CHECK-SAME ins(%[[A:.+]] : tensor<8x16x32xf32>) outs(%[[B:.+]] : tensor<8x16x32xf32>)
//
func.func @unary_identity_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<exp>
ins(%A : tensor<8x16x32xf32>)
outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
}
// -----
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
//
// CHECK: @unary_projection_tanh(%[[A:.+]]: tensor<?x16xf32>,
// CHECK-SAME: %[[B:.+]]: tensor<8x16x?xf32>) -> tensor<8x16x?xf32> {
// CHECK: {{.*}} = linalg.elementwise kind=#linalg.elementwise_kind<tanh>
// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]] : tensor<?x16xf32>) outs(%[[B]] : tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
//
func.func @unary_projection_tanh(%A: tensor<?x16xf32>,
%B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<tanh>
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
ins(%A : tensor<?x16xf32>)
outs(%B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
return %r : tensor<8x16x?xf32>
}
// -----
// CHECK: @binary_identity_div(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>,
// CHECK-SAME: %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK: {{.*}} = linalg.elementwise
// CHECK-SAME: kind=#linalg.elementwise_kind<div>
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
//
func.func @binary_identity_div(%A: tensor<16x8xf32>, %B: tensor<16x8xf32>,
%C: tensor<16x8xf32>) -> tensor<16x8xf32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<div>
ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
return %r : tensor<16x8xf32>
}
// -----
// CHECK: @binary_identity_mul_5Di(%[[A]]: tensor<1x2x3x4x5xi32>,
// CHECK-SAME: %[[B:.+]]: tensor<1x2x3x4x5xi32>,
// CHECK-SAME: %[[C:.+]]: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
// CHECK: {{.*}} = linalg.elementwise
// CHECK-SAME: kind=#linalg.elementwise_kind<mul>
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
// CHECK-SAME: outs(%[[C]] : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
//
func.func @binary_identity_mul_5Di(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<mul>
ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
return %r : tensor<1x2x3x4x5xi32>
}
// -----
// CHECK: @redundant_maps
// CHECK-NOT: indexing_maps
//
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<mul>
indexing_maps = [#map, #map, #map]
ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
return %r : tensor<1x2x3x4x5xi32>
}