mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 03:46:07 +00:00
Revert "[mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (#104783)"
This reverts commit 03483737a7a2d72a257a5ab6ff01748ad9cf0f75 and 99c8557, which is a fix-up on top of the former. I'm reverting because this commit broke two tests: mlir/test/python/integration/dialects/linalg/opsrun.py mlir/test/python/integration/dialects/transform.py See https://lab.llvm.org/buildbot/#/builders/138/builds/4872 I'm not familiar with the tests, so I'm leaving it to the original author to either remove or adapt the broken tests, as discussed here: https://github.com/llvm/llvm-project/pull/104783#issuecomment-2406390905
This commit is contained in:
parent
72f339de45
commit
1276ce9e97
@ -684,16 +684,6 @@ def LinalgStructuredInterface
|
||||
return;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return true if the user has supplied an explicit indexing maps for this op.
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"hasUserDefinedMaps",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{ return false; }]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Linalg generalization hooks.
|
||||
//===------------------------------------------------------------------===//
|
||||
|
@ -1065,6 +1065,78 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: rhs
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: matmul
|
||||
cpp_class_name: MatmulOp
|
||||
doc: |-
|
||||
Performs a matrix multiplication of two 2D inputs.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
implements:
|
||||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: A
|
||||
kind: input_tensor
|
||||
type_var: T1
|
||||
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: B
|
||||
kind: input_tensor
|
||||
type_var: T2
|
||||
shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: C
|
||||
kind: output_tensor
|
||||
type_var: U
|
||||
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: cast
|
||||
kind: type_fn_attr
|
||||
default_fn: cast_signed
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
|
||||
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
|
||||
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- parallel
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
attr_name: cast
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: A
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
attr_name: cast
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: B
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: quantized_matmul
|
||||
cpp_class_name: QuantizedMatmulOp
|
||||
|
@ -535,140 +535,6 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op definition for MatmulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MatmulOp : LinalgStructuredBase_Op<"matmul", [
|
||||
AttrSizedOperandSegments,
|
||||
LinalgContractionOpInterface]> {
|
||||
|
||||
let summary = [{
|
||||
Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
|
||||
}];
|
||||
let description = [{
|
||||
Numeric casting is performed on the operands to the inner multiply,
|
||||
promoting them to the same data type as the accumulator/output.
|
||||
|
||||
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
|
||||
'indexing_maps' as shown below.This is a list attribute, so the list must include all
|
||||
the maps if specified.
|
||||
|
||||
Example Transpose:
|
||||
```
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
|
||||
outs(%arg2: memref<3x7xf32>)
|
||||
```
|
||||
|
||||
Example Broadcast:
|
||||
```
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
|
||||
outs(%arg2: memref<3x7xf32>)
|
||||
```
|
||||
|
||||
Example Broadcast and transpose:
|
||||
```
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
|
||||
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$inputs,
|
||||
Variadic<AnyShaped>:$outputs,
|
||||
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
|
||||
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
|
||||
);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
|
||||
let regions = (region AnyRegion:$region);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<
|
||||
(ins "ValueRange":$inputs, "ValueRange":$outputs,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
|
||||
attributes, MatmulOp::getRegionBuilder());
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
"ValueRange":$outputs,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
buildStructuredOp($_builder, $_state, resultTensorTypes,
|
||||
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
$_state.addOperands(operands);
|
||||
$_state.addAttributes(attributes);
|
||||
$_state.addTypes(resultTensorTypes);
|
||||
(void)$_state.addRegion();
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
"ValueRange":$outputs,
|
||||
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
$_state.addAttribute("cast", cast);
|
||||
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
|
||||
attributes, MatmulOp::getRegionBuilder());
|
||||
}]>
|
||||
|
||||
];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = structuredOpsBaseDecls # [{
|
||||
SmallVector<utils::IteratorType> getIteratorTypesArray();
|
||||
|
||||
/// Implements the block region builder.
|
||||
static void regionBuilder(ImplicitLocOpBuilder &b,
|
||||
Block &block, ArrayRef<NamedAttribute> attrs);
|
||||
|
||||
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
|
||||
SmallVector<AffineMap> getDefaultIndexingMaps();
|
||||
|
||||
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
|
||||
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
|
||||
|
||||
static std::function<void(ImplicitLocOpBuilder &,
|
||||
Block &, ArrayRef<NamedAttribute>)>
|
||||
getRegionBuilder() {
|
||||
return regionBuilder;
|
||||
}
|
||||
|
||||
::mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputsMutable();
|
||||
}
|
||||
|
||||
// Generic methods.
|
||||
static unsigned getNumRegionArgs();
|
||||
std::string getLibraryCallName();
|
||||
bool hasDynamicIndexingMaps();
|
||||
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
|
||||
/// user defined indexing maps are not equal to default map.
|
||||
bool hasUserDefinedMaps();
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Named Linalg ops, implemented as a declarative configurations of generic ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -15,20 +15,13 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetOperations.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
@ -1149,6 +1142,7 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
|
||||
|
||||
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||
|
||||
// Mixed tensor/buffer operands are not allowed.
|
||||
if (!linalgOp.hasPureTensorSemantics() &&
|
||||
!linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
|
||||
@ -1168,8 +1162,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
||||
<< ") to be equal to the number of input/output operands ("
|
||||
<< linalgOp->getNumOperands() << ")";
|
||||
|
||||
// Set this flag if this op has user defined maps. This is required to guard
|
||||
// the below error condition which assume default indexing maps.
|
||||
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
||||
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
|
||||
|
||||
@ -1186,13 +1178,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
||||
<< " dim(s) to match the number of loops";
|
||||
|
||||
int64_t rank = linalgOp.getRank(&opOperand);
|
||||
|
||||
if (indexingMap.getNumResults() != rank)
|
||||
return op->emitOpError("expected operand rank (")
|
||||
<< rank << ") to match the result rank of indexing_map #"
|
||||
<< opOperand.getOperandNumber() << " ("
|
||||
<< indexingMap.getNumResults() << ")";
|
||||
}
|
||||
|
||||
SmallVector<unsigned> redDims;
|
||||
linalgOp.getReductionDims(redDims);
|
||||
|
||||
@ -1202,8 +1194,9 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
||||
// Check if given shapes match to inferred shapes.
|
||||
SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
|
||||
SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
|
||||
// Verify only static cases since we can't get exact dimension sizes and
|
||||
// loop ranges for dynamic cases in this stage.
|
||||
|
||||
// Verify only static cases since we can't get exact dimension sizes and loop
|
||||
// ranges for dynamic cases in this stage.
|
||||
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
|
||||
for (int64_t &range : endLoopRangeValues)
|
||||
range -= 1;
|
||||
|
@ -27,7 +27,6 @@
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
@ -38,17 +37,12 @@
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetOperations.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
|
||||
using namespace mlir;
|
||||
@ -155,36 +149,15 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
||||
// iterator_types is an auto-generated method.
|
||||
}
|
||||
|
||||
/// Helper to create a typical indexing map for MatmulOp. Returns a list of
|
||||
/// AffineMap.
|
||||
static SmallVector<AffineMap, 3>
|
||||
getDefaultIndexingMapsForMatmul(MLIRContext *context) {
|
||||
AffineExpr d0, d1, d2;
|
||||
SmallVector<AffineMap, 3> indexingMaps;
|
||||
bindDims(context, d0, d1, d2);
|
||||
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
|
||||
indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
|
||||
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
|
||||
return indexingMaps;
|
||||
}
|
||||
|
||||
/// Wrapper to return the typical indexing map array attribute for MatmulOp.
|
||||
static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
|
||||
return llvm::map_to_vector(
|
||||
getDefaultIndexingMapsForMatmul(context),
|
||||
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
||||
}
|
||||
|
||||
/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
|
||||
/// The result types are derived automatically if `resultTensorTypes` is none.
|
||||
/// The body of the operation is filled using `regionBuilder`. All ods-gen
|
||||
/// created structured operations use the method to implement their builders.
|
||||
static void buildStructuredOp(
|
||||
OpBuilder &b, OperationState &state,
|
||||
std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
|
||||
ValueRange outputs, ArrayRef<NamedAttribute> attributes,
|
||||
RegionBuilderFn regionBuilder,
|
||||
std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
|
||||
static void buildStructuredOp(OpBuilder &b, OperationState &state,
|
||||
std::optional<TypeRange> resultTensorTypes,
|
||||
ValueRange inputs, ValueRange outputs,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
RegionBuilderFn regionBuilder) {
|
||||
// Derive the result types if needed.
|
||||
SmallVector<Type> derivedResultTypes =
|
||||
resultTensorTypes.value_or(TypeRange());
|
||||
@ -195,20 +168,6 @@ static void buildStructuredOp(
|
||||
state.addOperands(inputs);
|
||||
state.addOperands(outputs);
|
||||
state.addTypes(derivedResultTypes);
|
||||
|
||||
// Initialize indexingMaps, for MatmulOp.
|
||||
SmallVector<Attribute, 3> indexingMapsAttrVal;
|
||||
if (indexingMaps.has_value()) {
|
||||
for (mlir::AffineMap map : *indexingMaps) {
|
||||
// Convert each AffineMap to an AffineMapAttr
|
||||
indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
|
||||
}
|
||||
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
|
||||
} else {
|
||||
indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
|
||||
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
|
||||
}
|
||||
|
||||
state.addAttributes(attributes);
|
||||
state.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
@ -340,48 +299,11 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
||||
OperationState &result,
|
||||
unsigned numRegionArgs,
|
||||
RegionBuilderFn regionBuilder) {
|
||||
|
||||
SmallVector<Attribute, 3> indexingMapsAttr;
|
||||
Attribute mapAttr;
|
||||
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
|
||||
if (parser.parseEqual())
|
||||
return failure();
|
||||
|
||||
if (parser.parseLSquare())
|
||||
return failure();
|
||||
|
||||
do {
|
||||
if (parser.parseAttribute(mapAttr))
|
||||
return failure();
|
||||
if (!isa<AffineMapAttr>(mapAttr)) {
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"expected affine map attribute");
|
||||
}
|
||||
indexingMapsAttr.push_back(mapAttr);
|
||||
|
||||
if (parser.parseOptionalComma())
|
||||
break;
|
||||
} while (true);
|
||||
|
||||
if (parser.parseRSquare())
|
||||
return failure();
|
||||
}
|
||||
// Initialize indexingMaps, if not supplied explicitly.
|
||||
if (indexingMapsAttr.empty()) {
|
||||
indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
|
||||
}
|
||||
result.addAttribute("indexing_maps",
|
||||
parser.getBuilder().getArrayAttr(indexingMapsAttr));
|
||||
|
||||
// TODO: Enable when ods-gen supports captures.
|
||||
SmallVector<Type, 1> inputTypes, outputTypes;
|
||||
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
||||
return failure();
|
||||
|
||||
// Parse optional attributes.
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
|
||||
// TODO: consider merging results parsing into region parsing.
|
||||
// Need to wait for declarative assembly resolution to decide.
|
||||
SmallVector<Type, 1> outputTensorsTypes;
|
||||
@ -407,9 +329,13 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
||||
}
|
||||
|
||||
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
|
||||
ValueRange inputs, ValueRange outputs,
|
||||
ArrayRef<StringRef> elidedAttrs = {}) {
|
||||
p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
ValueRange inputs, ValueRange outputs) {
|
||||
p.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
/*elidedAttrs=*/{"operandSegmentSizes",
|
||||
// See generated code in
|
||||
// LinalgNamedStructuredOps.yamlgen.cpp.inc
|
||||
"linalg.memoized_indexing_maps"});
|
||||
|
||||
// Printing is shared with generic ops, except for the region and
|
||||
// attributes.
|
||||
@ -3456,168 +3382,3 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
|
||||
Location loc) {
|
||||
return arith::ConstantOp::materialize(builder, value, type, loc);
|
||||
}
|
||||
|
||||
/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
|
||||
/// defaultMap.
|
||||
static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
|
||||
auto explicitRange = explictMap.getResults();
|
||||
auto defaultRange = defaultMap.getResults();
|
||||
DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
|
||||
DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
|
||||
llvm::set_union(explicitSet, defaultSet);
|
||||
return explicitSet == defaultSet;
|
||||
}
|
||||
|
||||
/// Returns true if the \p explictMap is broadcasted with respect to the
|
||||
/// \p defaultMap.
|
||||
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
|
||||
return explictMap.getNumResults() < defaultMap.getNumResults();
|
||||
}
|
||||
|
||||
/// Verifies the broadcast and transpose semantic sepecified by the explicit
|
||||
/// indexing map for the MatmulOp \p op for each operand specified by \p
|
||||
/// opIndex.
|
||||
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
|
||||
unsigned opIndex) {
|
||||
SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
|
||||
SmallVector<AffineMap, 3> defaultIndexingMaps =
|
||||
matmulOp.getDefaultIndexingMaps();
|
||||
|
||||
auto opIndexingMap = opIndexingMaps[opIndex];
|
||||
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
|
||||
// Check general validity of indexing map results.
|
||||
if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
|
||||
return matmulOp->emitOpError()
|
||||
<< "Unexpected dim expression in map result.";
|
||||
|
||||
// Check if the requested broadcast is valid.
|
||||
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
|
||||
if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
|
||||
return matmulOp->emitOpError()
|
||||
<< "Invalid broadcast requested, should be (d2).";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatMulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
|
||||
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
|
||||
utils::IteratorType::parallel,
|
||||
utils::IteratorType::reduction};
|
||||
}
|
||||
|
||||
unsigned MatmulOp::getNumRegionArgs() { return 3; }
|
||||
|
||||
std::string MatmulOp::getLibraryCallName() {
|
||||
return generateLibraryCallName(getOperation());
|
||||
}
|
||||
|
||||
bool MatmulOp::hasDynamicIndexingMaps() { return true; }
|
||||
|
||||
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
|
||||
/// user defined indexing maps are not equal to default map.
|
||||
bool MatmulOp::hasUserDefinedMaps() {
|
||||
SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
|
||||
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
|
||||
return defaultMaps != explicitMaps;
|
||||
}
|
||||
|
||||
/// Implements the block region builder for the MatmulOp. This is called by
|
||||
/// 'fillStructuredOpRegion'.
|
||||
void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
assert(3 > 0 && block.getNumArguments() == 3 &&
|
||||
"MatmulOp regionBuilder expects 3 (>=0) args");
|
||||
RegionBuilderHelper helper(b, block);
|
||||
SmallVector<Value> yields;
|
||||
|
||||
TypeFn castVal = TypeFn::cast_signed;
|
||||
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
|
||||
return attr.getName() == "cast";
|
||||
});
|
||||
if (castIter != attrs.end()) {
|
||||
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
|
||||
castVal = attr.getValue();
|
||||
}
|
||||
|
||||
Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
|
||||
block.getArgument(0));
|
||||
Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
|
||||
block.getArgument(1));
|
||||
Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
|
||||
Value value4 =
|
||||
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
|
||||
yields.push_back(value4);
|
||||
helper.yieldOutputs(yields);
|
||||
}
|
||||
|
||||
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
|
||||
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
|
||||
MLIRContext *context = this->getContext();
|
||||
return getDefaultIndexingMapsForMatmul(context);
|
||||
}
|
||||
|
||||
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
|
||||
bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
|
||||
assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
|
||||
AffineExpr exp = bcastMap.getResult(0);
|
||||
// Invalid map if the common dimension of matmul not found.
|
||||
return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
|
||||
}
|
||||
|
||||
ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
|
||||
MatmulOp::getRegionBuilder());
|
||||
}
|
||||
void MatmulOp::print(OpAsmPrinter &p) {
|
||||
SmallVector<StringRef, 3> elidedAttrs = {
|
||||
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
|
||||
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
||||
elidedAttrs);
|
||||
|
||||
SmallVector<Attribute, 3> indexingMaps =
|
||||
getDefaultIndexingMapAttr(getContext());
|
||||
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
|
||||
p << " indexing_maps = [";
|
||||
llvm::interleaveComma(getIndexingMaps(), p,
|
||||
[&](Attribute attr) { p.printAttribute(attr); });
|
||||
p << "]";
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify the user defined indexing maps.
|
||||
LogicalResult MatmulOp::verify() {
|
||||
// Verification of pure matmul is handled by verifyStructuredOpInterface().
|
||||
if (!hasUserDefinedMaps())
|
||||
return success();
|
||||
|
||||
for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
|
||||
if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
||||
return memref::foldMemRefCast(*this);
|
||||
}
|
||||
void MatmulOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
if (hasPureTensorSemantics())
|
||||
return;
|
||||
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
||||
}
|
||||
|
||||
Speculation::Speculatability MatmulOp::getSpeculatability() {
|
||||
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
||||
}
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
@ -31,13 +31,6 @@ using namespace mlir::linalg;
|
||||
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
|
||||
linalg::MatmulOp matmulOp,
|
||||
bool transposeLHS) {
|
||||
// Check to not let go the matmul with extended semantic, through this
|
||||
// transform.
|
||||
if (matmulOp.hasUserDefinedMaps()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
matmulOp, "only matmul ops with non-extended semantics are supported");
|
||||
}
|
||||
|
||||
if (!bufferization::hasTensorSemantics(matmulOp))
|
||||
return rewriter.notifyMatchFailure(
|
||||
matmulOp, "only matmul ops with tensors are supported");
|
||||
|
@ -2071,11 +2071,6 @@ vectorizeScalableVectorPrecondition(Operation *op,
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check to not let go the matmul with extended semantic, through this
|
||||
// transform.
|
||||
if (linalgOp.hasUserDefinedMaps())
|
||||
return failure();
|
||||
|
||||
// Cond 4: Only the following ops are supported in the
|
||||
// presence of scalable vectors
|
||||
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
|
||||
|
@ -821,12 +821,6 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
|
||||
bool fail = true;
|
||||
// TODO: more robust detection of matmulOp, with transposes etc.
|
||||
if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
|
||||
// Check to not let go the matmul with extended semantic, through this
|
||||
// transform.
|
||||
if (linalgOp.hasUserDefinedMaps()) {
|
||||
return emitSilenceableError()
|
||||
<< "only matmul ops with non-extended semantics are supported";
|
||||
}
|
||||
Location loc = linalgOp.getLoc();
|
||||
// TODO: more robust computation of laneId, for now assume a single warp.
|
||||
Value laneId = rewriter.create<gpu::ThreadIdOp>(
|
||||
|
@ -383,6 +383,23 @@ def select(
|
||||
O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def matmul(
|
||||
A=TensorDef(T1, S.M, S.K),
|
||||
B=TensorDef(T2, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
|
||||
):
|
||||
"""Performs a matrix multiplication of two 2D inputs.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def quantized_matmul(
|
||||
A=TensorDef(T1, S.M, S.K),
|
||||
|
@ -29,34 +29,6 @@ func.func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>,
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-LABEL: func.func @matmul_bcast_a(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
|
||||
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) {
|
||||
// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
|
||||
// CHECK: %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
|
||||
// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
|
||||
// CHECK: linalg.yield %[[VAL_7]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
|
||||
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
|
||||
@ -919,86 +891,3 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
|
||||
|
||||
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: func.func @matmul_transpose_a_explicit(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
|
||||
|
||||
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
// CHECK: arith.mulf
|
||||
// CHECK: arith.addf
|
||||
|
||||
func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2, d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
|
||||
outs(%arg2: memref<3x7xf32>)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-LABEL: func.func @matmul_transpose_b_explicit(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
|
||||
|
||||
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
// CHECK: arith.mulf
|
||||
// CHECK: arith.addf
|
||||
|
||||
func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
|
||||
outs(%arg2: memref<3x7xf32>)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
|
||||
|
||||
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
// CHECK: arith.mulf
|
||||
// CHECK: arith.addf
|
||||
|
||||
func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2, d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
|
||||
outs(%arg2: memref<3x7xf32>)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -361,165 +361,6 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
|
||||
// expected-error @+1 {{expected attribute value}}
|
||||
linalg.matmul indexing_maps = [
|
||||
,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)
|
||||
outs(%arg2 :memref<2x4xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_matmul_dim_a(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) {
|
||||
// expected-error @+1 {{Unexpected dim expression in map result}}
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_matmul_dim_b(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) {
|
||||
// expected-error @+1 {{Unexpected dim expression in map result}}
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_transpose_a_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
|
||||
// expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 1}}
|
||||
%0 = linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2, d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
|
||||
outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
|
||||
return %0: tensor<4x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_transpose_b_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
|
||||
// expected-error @+1 {{inferred input/output operand #1 has shape's dimension #1 to be 1, but found 64}}
|
||||
%0 = linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
|
||||
outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
|
||||
return %0: tensor<4x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_bcast_a(%arg0: memref<3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
|
||||
// expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}}
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
|
||||
// expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}}
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<3x5xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_bcast_a_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
|
||||
// expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #0 (1)}}
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_bcast_b_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
|
||||
// expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #1 (1)}}
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
|
||||
// expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 5, but found 7}}
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2, d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_matmul_bcast_b_transpose_a_wrong_dim(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
|
||||
// expected-error @+1 {{'linalg.matmul' op Unexpected dim expression in map result.}}
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
|
||||
// expected-error @+2 {{custom op 'indexing_maps' is unknown (tried 'func.indexing_maps' as well)}}
|
||||
linalg.matmul ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) outs(%init : tensor<4x64xf32>)
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
|
||||
// expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
|
||||
linalg.conv_2d_nhwc_hwcf
|
||||
|
@ -1201,249 +1201,6 @@ func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %a
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @matmul_transpose_a_explicit
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
|
||||
func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2, d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
|
||||
outs(%arg2: memref<3x7xf32>)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
|
||||
outs(%arg2: memref<3x7xf32>)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: func.func @matmul_transpose_b_explicit(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
|
||||
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2, d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
|
||||
outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
|
||||
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @matmul_bcast_a
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @matmul_bcast_a_dim1
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @matmul_bcast_b
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: func.func @matmul_bcast_a_b(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
|
||||
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @matmul_bcast_b_dim1
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func.func @dynamic_matmul_bcast_a(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: func.func @dynamic_matmul_bcast_a(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: memref<?x?xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: memref<?x?xf32>) {
|
||||
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: func.func @matmul_bcast_a_transpose_b(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
|
||||
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
|
||||
linalg.matmul indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d2, d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
]
|
||||
ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: func.func @matmul_bcast_b_transpose_a(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
|
||||
// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @matmul_transpose_b
|
||||
// CHECK: linalg.matmul_transpose_b
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
|
||||
|
@ -84,6 +84,81 @@ def testNamedStructuredOpCustomForm():
|
||||
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
|
||||
@run
|
||||
def testNamedStructuredOpGenericForm():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@func.FuncOp.from_py_func(
|
||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
|
||||
)
|
||||
def named_form(lhs, rhs):
|
||||
init_result = tensor.empty([4, 8], f32)
|
||||
# CHECK: "linalg.matmul"(%{{.*}})
|
||||
# CHECK-SAME: cast = #linalg.type_fn<cast_signed>
|
||||
# CHECK-SAME: operandSegmentSizes = array<i32: 2, 1>
|
||||
# CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
|
||||
# CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32
|
||||
# CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
|
||||
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
|
||||
# CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return linalg.matmul(lhs, rhs, outs=[init_result])
|
||||
|
||||
module.operation.print(print_generic_op_form=True)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
|
||||
@run
|
||||
def testNamedStructuredAsGenericOp():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@func.FuncOp.from_py_func(
|
||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
|
||||
)
|
||||
def generic_form(lhs, rhs):
|
||||
init_result = tensor.EmptyOp([4, 8], f32)
|
||||
# CHECK: linalg.generic
|
||||
return linalg.matmul(
|
||||
lhs, rhs, outs=[init_result.result], emit_generic=True
|
||||
)
|
||||
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOpResultFromOtherOp
|
||||
@run
|
||||
def testOpResultFromOtherOp():
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@func.FuncOp.from_py_func(
|
||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
|
||||
)
|
||||
def pass_an_op_directly(arg0, arg1):
|
||||
one = arith.ConstantOp(F32Type.get(), 1.0)
|
||||
# CHECK: %[[LHS:.*]] = linalg.fill
|
||||
lhs = linalg.fill(one, outs=[arg0])
|
||||
# CHECK: %[[RHS:.*]] = linalg.fill
|
||||
rhs = linalg.fill(one, outs=[arg1])
|
||||
# CHECK: %[[INIT:.*]] = tensor.empty
|
||||
init = tensor.EmptyOp([4, 8], f32)
|
||||
# CHECK: linalg.matmul
|
||||
# CHECK: ins(%[[LHS]], %[[RHS]]
|
||||
# CHECK: outs(%[[INIT]]
|
||||
return linalg.matmul(lhs, rhs, outs=init)
|
||||
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testIdentityRegionOps
|
||||
@run
|
||||
def testIdentityRegionOps():
|
||||
|
@ -681,11 +681,7 @@ ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
|
||||
{0}::getNumRegionArgs(), {0}::getRegionBuilder());
|
||||
}
|
||||
void {0}::print(OpAsmPrinter &p) {{
|
||||
SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes",
|
||||
"linalg.memoized_indexing_maps",
|
||||
"indexing_maps"};
|
||||
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
||||
elidedAttrs);
|
||||
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
|
||||
}
|
||||
)FMT";
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user