Change the printing/parsing behavior for Attributes used in declarative assembly format

The new form of printing attribute in the declarative assembly is eliding the `#dialect.mnemonic` prefix to only keep the `<....>` part.

Differential Revision: https://reviews.llvm.org/D113873
This commit is contained in:
Mehdi Amini 2021-12-08 01:24:51 +00:00
parent 63cd1842a7
commit ee0908703d
32 changed files with 574 additions and 170 deletions

View File

@ -31,6 +31,7 @@ def ArmSVE_Dialect : Dialect {
vector operations, including a scalable vector type and intrinsics for
some Arm SVE instructions.
}];
let useDefaultTypePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
@ -66,20 +67,6 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
"Type":$elementType
);
let printer = [{
$_printer << "<";
for (int64_t dim : getShape())
$_printer << dim << 'x';
$_printer << getElementType() << '>';
}];
let parser = [{
VectorType vector;
if ($_parser.parseType(vector))
return Type();
return get($_ctxt, vector.getShape(), vector.getElementType());
}];
let extraClassDeclaration = [{
bool hasStaticShape() const {
return llvm::none_of(getShape(), ShapedType::isDynamic);

View File

@ -64,19 +64,19 @@ struct FieldParser<
AttributeT>> {
static FailureOr<AttributeT> parse(AsmParser &parser) {
AttributeT value;
if (parser.parseAttribute(value))
if (parser.parseCustomAttributeWithFallback(value))
return failure();
return value;
}
};
/// Parse a type.
/// Parse an attribute.
template <typename TypeT>
struct FieldParser<
TypeT, std::enable_if_t<std::is_base_of<Type, TypeT>::value, TypeT>> {
static FailureOr<TypeT> parse(AsmParser &parser) {
TypeT value;
if (parser.parseType(value))
if (parser.parseCustomTypeWithFallback(value))
return failure();
return value;
}

View File

@ -2984,6 +2984,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
AttrOrTypeDef<"Type", name, traits, baseCppClass> {
// Make it possible to use such type as parameters for other types.
string cppType = dialect.cppNamespace # "::" # cppClassName;
// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #

View File

@ -50,6 +50,36 @@ public:
virtual void printType(Type type);
virtual void printAttribute(Attribute attr);
/// Trait to check if `AttrType` provides a `print` method.
template <typename AttrOrType>
using has_print_method =
decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
template <typename AttrOrType>
using detect_has_print_method =
llvm::is_detected<has_print_method, AttrOrType>;
/// Print the provided attribute in the context of an operation custom
/// printer/parser: this will invoke directly the print method on the
/// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
template <typename AttrOrType,
std::enable_if_t<detect_has_print_method<AttrOrType>::value>
*sfinae = nullptr>
void printStrippedAttrOrType(AttrOrType attrOrType) {
if (succeeded(printAlias(attrOrType)))
return;
attrOrType.print(*this);
}
/// SFINAE for printing the provided attribute in the context of an operation
/// custom printer in the case where the attribute does not define a print
/// method.
template <typename AttrOrType,
std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
*sfinae = nullptr>
void printStrippedAttrOrType(AttrOrType attrOrType) {
*this << attrOrType;
}
/// Print the given attribute without its type. The corresponding parser must
/// provide a valid type for the attribute.
virtual void printAttributeWithoutType(Attribute attr);
@ -102,6 +132,14 @@ private:
AsmPrinter(const AsmPrinter &) = delete;
void operator=(const AsmPrinter &) = delete;
/// Print the alias for the given attribute, return failure if no alias could
/// be printed.
virtual LogicalResult printAlias(Attribute attr);
/// Print the alias for the given type, return failure if no alias could
/// be printed.
virtual LogicalResult printAlias(Type type);
/// The internal implementation of the printer.
Impl *impl;
};
@ -608,6 +646,13 @@ public:
/// Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
/// Parse a custom attribute with the provided callback, unless the next
/// token is `#`, in which case the generic parser is invoked.
virtual ParseResult parseCustomAttributeWithFallback(
Attribute &result, Type type,
function_ref<ParseResult(Attribute &result, Type type)>
parseAttribute) = 0;
/// Parse an attribute of a specific kind and type.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, Type type = {}) {
@ -639,9 +684,9 @@ public:
return parseAttribute(result, Type(), attrName, attrs);
}
/// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified
/// name.
/// Parse an arbitrary attribute of a given type and populate it in `result`.
/// This also adds the attribute to the specified attribute list with the
/// specified name.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
NamedAttrList &attrs) {
@ -661,6 +706,82 @@ public:
return success();
}
/// Trait to check if `AttrType` provides a `parse` method.
template <typename AttrType>
using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
std::declval<Type>()));
template <typename AttrType>
using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
/// Parse a custom attribute of a given type unless the next token is `#`, in
/// which case the generic parser is invoked. The parsed attribute is
/// populated in `result` and also added to the specified attribute list with
/// the specified name.
template <typename AttrType>
std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType &result, Type type,
StringRef attrName, NamedAttrList &attrs) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of attribute.
Attribute attr;
if (parseCustomAttributeWithFallback(
attr, type, [&](Attribute &result, Type type) -> ParseResult {
result = AttrType::parse(*this, type);
if (!result)
return failure();
return success();
}))
return failure();
// Check for the right kind of attribute.
result = attr.dyn_cast<AttrType>();
if (!result)
return emitError(loc, "invalid kind of attribute specified");
attrs.append(attrName, result);
return success();
}
/// SFINAE parsing method for Attribute that don't implement a parse method.
template <typename AttrType>
std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType &result, Type type,
StringRef attrName, NamedAttrList &attrs) {
return parseAttribute(result, type, attrName, attrs);
}
/// Parse a custom attribute of a given type unless the next token is `#`, in
/// which case the generic parser is invoked. The parsed attribute is
/// populated in `result`.
template <typename AttrType>
std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType &result) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of attribute.
Attribute attr;
if (parseCustomAttributeWithFallback(
attr, {}, [&](Attribute &result, Type type) -> ParseResult {
result = AttrType::parse(*this, type);
return success(!!result);
}))
return failure();
// Check for the right kind of attribute.
result = attr.dyn_cast<AttrType>();
if (!result)
return emitError(loc, "invalid kind of attribute specified");
return success();
}
/// SFINAE parsing method for Attribute that don't implement a parse method.
template <typename AttrType>
std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType &result) {
return parseAttribute(result);
}
/// Parse an arbitrary optional attribute of a given type and return it in
/// result.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
@ -740,6 +861,11 @@ public:
/// Parse a type.
virtual ParseResult parseType(Type &result) = 0;
/// Parse a custom type with the provided callback, unless the next
/// token is `#`, in which case the generic parser is invoked.
virtual ParseResult parseCustomTypeWithFallback(
Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
/// Parse an optional type.
virtual OptionalParseResult parseOptionalType(Type &result) = 0;
@ -753,7 +879,7 @@ public:
if (parseType(type))
return failure();
// Check for the right kind of attribute.
// Check for the right kind of type.
result = type.dyn_cast<TypeT>();
if (!result)
return emitError(loc, "invalid kind of type specified");
@ -761,6 +887,44 @@ public:
return success();
}
/// Trait to check if `TypeT` provides a `parse` method.
template <typename TypeT>
using type_has_parse_method =
decltype(TypeT::parse(std::declval<AsmParser &>()));
template <typename TypeT>
using detect_type_has_parse_method =
llvm::is_detected<type_has_parse_method, TypeT>;
/// Parse a custom Type of a given type unless the next token is `#`, in
/// which case the generic parser is invoked. The parsed Type is
/// populated in `result`.
template <typename TypeT>
std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
parseCustomTypeWithFallback(TypeT &result) {
llvm::SMLoc loc = getCurrentLocation();
// Parse any kind of Type.
Type type;
if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
result = TypeT::parse(*this);
return success(!!result);
}))
return failure();
// Check for the right kind of Type.
result = type.dyn_cast<TypeT>();
if (!result)
return emitError(loc, "invalid kind of Type specified");
return success();
}
/// SFINAE parsing method for Type that don't implement a parse method.
template <typename TypeT>
std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
parseCustomTypeWithFallback(TypeT &result) {
return parseType(result);
}
/// Parse a type list.
ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
do {
@ -792,7 +956,7 @@ public:
if (parseColonType(type))
return failure();
// Check for the right kind of attribute.
// Check for the right kind of type.
result = type.dyn_cast<TypeType>();
if (!result)
return emitError(loc, "invalid kind of type specified");

View File

@ -53,21 +53,21 @@ void ArmSVEDialect::initialize() {
// ScalableVectorType
//===----------------------------------------------------------------------===//
Type ArmSVEDialect::parseType(DialectAsmParser &parser) const {
llvm::SMLoc typeLoc = parser.getCurrentLocation();
{
Type genType;
auto parseResult = generatedTypeParser(parser, "vector", genType);
if (parseResult.hasValue())
return genType;
}
parser.emitError(typeLoc, "unknown type in ArmSVE dialect");
return Type();
void ScalableVectorType::print(AsmPrinter &printer) const {
printer << "<";
for (int64_t dim : getShape())
printer << dim << 'x';
printer << getElementType() << '>';
}
void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
if (failed(generatedTypePrinter(type, os)))
llvm_unreachable("unexpected 'arm_sve' type kind");
Type ScalableVectorType::parse(AsmParser &parser) {
SmallVector<int64_t> dims;
Type eltType;
if (parser.parseLess() ||
parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
parser.parseType(eltType) || parser.parseGreater())
return {};
return ScalableVectorType::get(eltType.getContext(), dims, eltType);
}
//===----------------------------------------------------------------------===//

View File

@ -170,7 +170,7 @@ static constexpr const CombiningKind combiningKindsList[] = {
};
void CombiningKindAttr::print(AsmPrinter &printer) const {
printer << "kind<";
printer << "<";
auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
return bitEnumContains(this->getKind(), kind);
});
@ -215,10 +215,12 @@ Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
void VectorDialect::printAttribute(Attribute attr,
DialectAsmPrinter &os) const {
if (auto ck = attr.dyn_cast<CombiningKindAttr>())
if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
os << "kind";
ck.print(os);
else
llvm_unreachable("Unknown attribute type");
return;
}
llvm_unreachable("Unknown attribute type");
}
//===----------------------------------------------------------------------===//

View File

@ -1188,7 +1188,7 @@ private:
/// Ex:
/// ```
/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
/// %1 = vector.multi_reduction #vector.kind<add>, %0 [1]
/// %1 = vector.multi_reduction add, %0 [1]
/// : vector<8x32x16xf32> to vector<8x16xf32>
/// ```
/// Gets converted to:
@ -1198,7 +1198,7 @@ private:
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
/// kind = add} %0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
struct MultiReduceToContract
@ -1247,7 +1247,7 @@ struct MultiReduceToContract
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
/// kind = add} %0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
/// Gets converted to:
@ -1257,7 +1257,7 @@ struct MultiReduceToContract
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
/// kind = add} %arg0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
struct CombineContractTranspose
@ -1304,7 +1304,7 @@ struct CombineContractTranspose
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
/// kind = add} %0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
/// Gets converted to:
@ -1314,7 +1314,7 @@ struct CombineContractTranspose
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
/// kind = add} %arg0, %arg1, %cst_f0
/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
struct CombineContractBroadcast

View File

@ -474,6 +474,14 @@ private:
void printAttributeWithoutType(Attribute attr) override {
printAttribute(attr);
}
LogicalResult printAlias(Attribute attr) override {
initializer.visit(attr);
return success();
}
LogicalResult printAlias(Type type) override {
initializer.visit(type);
return success();
}
/// Print the given set of attributes with names not included within
/// 'elidedAttrs'.
@ -1252,8 +1260,16 @@ public:
void printAttribute(Attribute attr,
AttrTypeElision typeElision = AttrTypeElision::Never);
/// Print the alias for the given attribute, return failure if no alias could
/// be printed.
LogicalResult printAlias(Attribute attr);
void printType(Type type);
/// Print the alias for the given type, return failure if no alias could
/// be printed.
LogicalResult printAlias(Type type);
/// Print the given location to the stream. If `allowAlias` is true, this
/// allows for the internal location to use an attribute alias.
void printLocation(LocationAttr loc, bool allowAlias = false);
@ -1594,6 +1610,14 @@ static void printElidedElementsAttr(raw_ostream &os) {
os << R"(opaque<"_", "0xDEADBEEF">)";
}
LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
}
LogicalResult AsmPrinter::Impl::printAlias(Type type) {
return success(state && succeeded(state->getAliasState().getAlias(type, os)));
}
void AsmPrinter::Impl::printAttribute(Attribute attr,
AttrTypeElision typeElision) {
if (!attr) {
@ -1602,7 +1626,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
}
// Try to print an alias for this attribute.
if (state && succeeded(state->getAliasState().getAlias(attr, os)))
if (succeeded(printAlias(attr)))
return;
if (!isa<BuiltinDialect>(attr.getDialect()))
@ -2104,6 +2128,16 @@ void AsmPrinter::printAttribute(Attribute attr) {
impl->printAttribute(attr);
}
LogicalResult AsmPrinter::printAlias(Attribute attr) {
assert(impl && "expected AsmPrinter::printAlias to be overriden");
return impl->printAlias(attr);
}
LogicalResult AsmPrinter::printAlias(Type type) {
assert(impl && "expected AsmPrinter::printAlias to be overriden");
return impl->printAlias(type);
}
void AsmPrinter::printAttributeWithoutType(Attribute attr) {
assert(impl &&
"expected AsmPrinter::printAttributeWithoutType to be overriden");

View File

@ -374,6 +374,7 @@ BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
//===----------------------------------------------------------------------===//
// BoolAttr
//===----------------------------------------------------------------------===//
bool BoolAttr::getValue() const {
auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);

View File

@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/BitVector.h"
@ -633,7 +634,7 @@ bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
return true;
// Allow custom dialect attributes.
if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect()))
if (!isa<BuiltinDialect>(memorySpace.getDialect()))
return true;
return false;

View File

@ -343,6 +343,29 @@ public:
return success(static_cast<bool>(result));
}
/// Parse a custom attribute with the provided callback, unless the next
/// token is `#`, in which case the generic parser is invoked.
ParseResult parseCustomAttributeWithFallback(
Attribute &result, Type type,
function_ref<ParseResult(Attribute &result, Type type)> parseAttribute)
override {
if (parser.getToken().isNot(Token::hash_identifier))
return parseAttribute(result, type);
result = parser.parseAttribute(type);
return success(static_cast<bool>(result));
}
/// Parse a custom attribute with the provided callback, unless the next
/// token is `#`, in which case the generic parser is invoked.
ParseResult parseCustomTypeWithFallback(
Type &result,
function_ref<ParseResult(Type &result)> parseType) override {
if (parser.getToken().isNot(Token::exclamation_identifier))
return parseType(result);
result = parser.parseType();
return success(static_cast<bool>(result));
}
OptionalParseResult parseOptionalAttribute(Attribute &result,
Type type) override {
return parser.parseOptionalAttribute(result, type);

View File

@ -3,7 +3,7 @@
func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
// CHECK: arm_sve.sdot {{.*}}: <16xi8> to <4xi32
%0 = arm_sve.sdot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
@ -12,7 +12,7 @@ func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
// CHECK: arm_sve.smmla {{.*}}: <16xi8> to <4xi3
%0 = arm_sve.smmla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
@ -21,7 +21,7 @@ func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
// CHECK: arm_sve.udot {{.*}}: <16xi8> to <4xi32
%0 = arm_sve.udot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
@ -30,7 +30,7 @@ func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
// CHECK: arm_sve.ummla {{.*}}: <16xi8> to <4xi3
%0 = arm_sve.ummla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>

View File

@ -16,7 +16,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
%0 = arith.addf %arg0, %arg0 : f32
// CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value<f32>
%1 = async.runtime.create: !async.value<f32>
// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value<f32>
// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : <f32>
async.runtime.store %0, %1: !async.value<f32>
// CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value<f32>
async.runtime.set_available %1: !async.value<f32>
@ -32,9 +32,9 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
// CHECK: ^[[BRANCH_OK]]:
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : !async.value<f32>
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : <f32>
// CHECK: %[[RETURNED:.*]] = arith.mulf %[[ARG]], %[[LOADED]] : f32
// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : !async.value<f32>
// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
@ -84,8 +84,8 @@ func @simple_caller() -> f32 {
// CHECK: cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]]
// CHECK: ^[[BRANCH_VALUE_OK]]:
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : !async.value<f32>
// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : !async.value<f32>
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : <f32>
// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]
@ -133,7 +133,7 @@ func @double_caller() -> f32 {
// CHECK: cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]]
// CHECK: ^[[BRANCH_VALUE_OK_1]]:
// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : !async.value<f32>
// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : <f32>
// CHECK: %[[RETURNED_TO_CALLER_2:.*]]:2 = call @simple_callee(%[[LOADED_1]]) : (f32) -> (!async.token, !async.value<f32>)
// CHECK: %[[SAVED_2:.*]] = async.coro.save %[[HDL]]
// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_2]]#0, %[[HDL]]
@ -150,8 +150,8 @@ func @double_caller() -> f32 {
// CHECK: cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]]
// CHECK: ^[[BRANCH_VALUE_OK_2]]:
// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : !async.value<f32>
// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : !async.value<f32>
// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : <f32>
// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : <f32>
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: br ^[[CLEANUP]]

View File

@ -245,7 +245,7 @@ func @execute_and_return_f32() -> f32 {
}
// CHECK: async.runtime.await %[[RET]]#1 : !async.value<f32>
// CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : !async.value<f32>
// CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : <f32>
%0 = async.await %result : !async.value<f32>
// CHECK: return %[[VALUE]]
@ -323,7 +323,7 @@ func @async_value_operands() {
// // Load from the async.value argument after error checking.
// CHECK: ^[[CONTINUATION:.*]]:
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : !async.value<f32
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : <f32
// CHECK: arith.addf %[[LOADED]], %[[LOADED]] : f32
// CHECK: async.runtime.set_available %[[TOKEN]]

View File

@ -129,16 +129,16 @@ func @resume(%arg0: !async.coro.handle) {
// CHECK-LABEL: @store
func @store(%arg0: f32, %arg1: !async.value<f32>) {
// CHECK: async.runtime.store %arg0, %arg1 : !async.value<f32>
async.runtime.store %arg0, %arg1 : !async.value<f32>
// CHECK: async.runtime.store %arg0, %arg1 : <f32>
async.runtime.store %arg0, %arg1 : <f32>
return
}
// CHECK-LABEL: @load
func @load(%arg0: !async.value<f32>) -> f32 {
// CHECK: %0 = async.runtime.load %arg0 : !async.value<f32>
// CHECK: %0 = async.runtime.load %arg0 : <f32>
// CHECK: return %0 : f32
%0 = async.runtime.load %arg0 : !async.value<f32>
%0 = async.runtime.load %arg0 : <f32>
return %0 : f32
}

View File

@ -6,7 +6,7 @@
func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [0] : vector<1584xf32> to f32
// CHECK: vector.multi_reduction <add>, %{{.*}} [0] : vector<1584xf32> to f32
// CHECK: arith.addf %{{.*}}, %{{.*}} : f32
linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
outs(%C: memref<f32>)
@ -19,7 +19,7 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32
func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
// CHECK: vector.multi_reduction <add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32>
linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
outs(%C: memref<1584xf32>)
@ -31,7 +31,7 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: me
// CHECK-LABEL: contraction_matmul
func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
outs(%C: memref<1584x1584xf32>)
@ -43,7 +43,7 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %
// CHECK-LABEL: contraction_batch_matmul
func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
// CHECK: vector.multi_reduction <add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
linalg.batch_matmul
ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
@ -71,7 +71,7 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
// CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
linalg.generic #matmul_trait
@ -105,7 +105,7 @@ func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
// CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
// CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
linalg.generic #matmul_transpose_out_trait
@ -139,7 +139,7 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32>
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
// CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32>
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
// CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
@ -160,7 +160,7 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32>
linalg.matmul
ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
@ -523,7 +523,7 @@ func @matmul_tensors(
// linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
// convert it to a 2D contract.
// CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
// CHECK: %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32>
// CHECK: %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
%0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
@ -744,7 +744,7 @@ func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
// CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
// CHECK: math.exp {{.*}} : vector<4x16x8xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
// CHECK: addf {{.*}} : vector<4x16xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
// CHECK: return {{.*}} : tensor<4x16xf32>
@ -779,7 +779,7 @@ func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: ten
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
// CHECK: addf {{.*}} : vector<2x3x4x5xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
// CHECK: vector.multi_reduction <add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
// CHECK: addf {{.*}} : vector<2x5xf32>
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
// CHECK: return {{.*}} : tensor<5x2xf32>
@ -808,7 +808,7 @@ func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind<maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: %[[R:.+]] = vector.multi_reduction <maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%ident = arith.constant -3.40282e+38 : f32
@ -833,7 +833,7 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
// CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind<minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: %[[R:.+]] = vector.multi_reduction <minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: arith.minf %[[R]], %[[CMAXF]] : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%maxf32 = arith.constant 3.40282e+38 : f32
@ -857,7 +857,7 @@ func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
// CHECK: vector.multi_reduction #vector.kind<mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: mulf {{.*}} : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%ident = arith.constant 1.0 : f32
@ -881,7 +881,7 @@ func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
// CHECK: vector.multi_reduction #vector.kind<or>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.multi_reduction <or>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
%ident = arith.constant false
%init = linalg.init_tensor [4] : tensor<4xi1>
@ -904,7 +904,7 @@ func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
// CHECK: vector.multi_reduction #vector.kind<and>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.multi_reduction <and>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
%ident = arith.constant true
%init = linalg.init_tensor [4] : tensor<4xi1>
@ -927,7 +927,7 @@ func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
// CHECK: vector.multi_reduction #vector.kind<xor>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.multi_reduction <xor>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
%ident = arith.constant false
%init = linalg.init_tensor [4] : tensor<4xi1>
@ -979,7 +979,7 @@ func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) ->
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32>
// CHECK: subf {{.*}} : vector<4x4xf32>
// CHECK: math.exp {{.*}} : vector<4x4xf32>
// CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
// CHECK: vector.multi_reduction <add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
// CHECK: addf {{.*}} : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
%c0 = arith.constant 0.0 : f32
@ -1019,7 +1019,7 @@ func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
// CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
// CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[r]] [0]
// CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]] [0]
// CHECK-SAME: : vector<32xf32> to f32
// CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<f32>

View File

@ -1027,7 +1027,7 @@ func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v
// CHECK-LABEL: func @vector_multi_reduction_single_parallel(
// CHECK-SAME: %[[v:.*]]: vector<2xf32>
func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
%0 = vector.multi_reduction <mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
// CHECK: return %[[v]] : vector<2xf32>
return %0 : vector<2xf32>

View File

@ -3,7 +3,7 @@
// -----
func @broadcast_to_scalar(%arg0: f32) -> f32 {
// expected-error@+1 {{'vector.broadcast' op result #0 must be vector of any type values, but got 'f32'}}
// expected-error@+1 {{custom op 'vector.broadcast' invalid kind of type specified}}
%0 = vector.broadcast %arg0 : f32 to f32
}
@ -1022,7 +1022,7 @@ func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
// -----
func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
// expected-error@+1 {{must be vector of any type values}}
// expected-error@+1 {{'vector.bitcast' invalid kind of type specified}}
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32
}

View File

@ -685,9 +685,9 @@ func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
// CHECK-LABEL: @multi_reduction
func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 {
%1 = vector.multi_reduction #vector.kind<add>, %0 [1, 3] :
%1 = vector.multi_reduction <add>, %0 [1, 3] :
vector<4x8x16x32xf32> to vector<4x16xf32>
%2 = vector.multi_reduction #vector.kind<add>, %1 [0, 1] :
%2 = vector.multi_reduction <add>, %1 [0, 1] :
vector<4x16xf32> to f32
return %2 : f32
}

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
%0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: func @vector_multi_reduction
@ -18,7 +18,7 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: return %[[RESULT_VEC]]
func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
%0 = vector.multi_reduction <mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
return %0 : f32
}
// CHECK-LABEL: func @vector_multi_reduction_to_scalar
@ -30,7 +30,7 @@ func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
// CHECK: return %[[RES]]
func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
%0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
return %0 : vector<2x3xi32>
}
// CHECK-LABEL: func @vector_reduction_inner
@ -66,7 +66,7 @@ func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
%0 = vector.multi_reduction <add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
return %0 : vector<2x5xf32>
}
@ -78,7 +78,7 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x
// CHECK: return %[[RESULT]]
func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
%0 = vector.multi_reduction <mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
return %0 : vector<2x4xf32>
}
// CHECK-LABEL: func @vector_multi_reduction_ordering

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s
func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
%0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
@ -18,7 +18,7 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
%0 = vector.multi_reduction #vector.kind<minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
%0 = vector.multi_reduction <minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
@ -35,7 +35,7 @@ func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
%0 = vector.multi_reduction #vector.kind<maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
%0 = vector.multi_reduction <maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
@ -52,7 +52,7 @@ func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
%0 = vector.multi_reduction #vector.kind<and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
%0 = vector.multi_reduction <and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
return %0 : vector<2xi32>
}
@ -69,7 +69,7 @@ func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
%0 = vector.multi_reduction #vector.kind<or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
%0 = vector.multi_reduction <or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
return %0 : vector<2xi32>
}
@ -86,7 +86,7 @@ func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
%0 = vector.multi_reduction #vector.kind<xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
%0 = vector.multi_reduction <xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
return %0 : vector<2xi32>
}
@ -104,7 +104,7 @@ func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
%0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
return %0 : vector<2x3xi32>
}

View File

@ -12,7 +12,7 @@
func @multidimreduction_contract(
%arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> {
%0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
%1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
%1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
return %1 : vector<8x16xf32>
}
@ -30,7 +30,7 @@ func @multidimreduction_contract(
func @multidimreduction_contract_int(
%arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> {
%0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32>
%1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
%1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
return %1 : vector<8x16xi32>
}

View File

@ -14,7 +14,7 @@
#define TEST_ATTRDEFS
// To get the test dialect definition.
include "TestOps.td"
include "TestDialect.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/SubElementInterfaces.td"
@ -121,6 +121,29 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
);
}
// A more complex parameterized attribute with multiple level of nesting.
def CompoundNestedInner : Test_Attr<"CompoundNestedInner"> {
let mnemonic = "cmpnd_nested_inner";
// List of type parameters.
let parameters = (
ins
"int":$some_int,
CompoundAttrA:$cmpdA
);
let assemblyFormat = "`<` $some_int $cmpdA `>`";
}
def CompoundNestedOuter : Test_Attr<"CompoundNestedOuter"> {
let mnemonic = "cmpnd_nested_outer";
// List of type parameters.
let parameters = (
ins
CompoundNestedInner:$inner
);
let assemblyFormat = "`<` `i` $inner `>`";
}
def TestParamOne : AttrParameter<"int64_t", ""> {}
def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> {

View File

@ -0,0 +1,46 @@
//===-- TestDialect.td - Test dialect definition -----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef TEST_DIALECT
#define TEST_DIALECT
include "mlir/IR/OpBase.td"
def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "::test";
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
let hasCanonicalizer = 1;
let hasConstantMaterializer = 1;
let hasOperationAttrVerify = 1;
let hasRegionArgAttrVerify = 1;
let hasRegionResultAttrVerify = 1;
let hasOperationInterfaceFallback = 1;
let hasNonDefaultDestructor = 1;
let useDefaultAttributePrinterParser = 1;
let dependentDialects = ["::mlir::DLTIDialect"];
let extraClassDeclaration = [{
void registerAttributes();
void registerTypes();
// Provides a custom printing/parsing for some operations.
::llvm::Optional<ParseOpHook>
getParseOperationHook(::llvm::StringRef opName) const override;
::llvm::unique_function<void(::mlir::Operation *,
::mlir::OpAsmPrinter &printer)>
getOperationPrinter(::mlir::Operation *op) const override;
private:
// Storage for a custom fallback interface.
void *fallbackEffectOpInterfaces;
}];
}
#endif // TEST_DIALECT

View File

@ -9,6 +9,7 @@
#ifndef TEST_OPS
#define TEST_OPS
include "TestDialect.td"
include "mlir/Dialect/DLTI/DLTIBase.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
@ -23,40 +24,11 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "TestInterfaces.td"
def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "::test";
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
let hasCanonicalizer = 1;
let hasConstantMaterializer = 1;
let hasOperationAttrVerify = 1;
let hasRegionArgAttrVerify = 1;
let hasRegionResultAttrVerify = 1;
let hasOperationInterfaceFallback = 1;
let hasNonDefaultDestructor = 1;
let useDefaultAttributePrinterParser = 1;
let dependentDialects = ["::mlir::DLTIDialect"];
let extraClassDeclaration = [{
void registerAttributes();
void registerTypes();
// Provides a custom printing/parsing for some operations.
::llvm::Optional<ParseOpHook>
getParseOperationHook(::llvm::StringRef opName) const override;
::llvm::unique_function<void(::mlir::Operation *,
::mlir::OpAsmPrinter &printer)>
getOperationPrinter(::mlir::Operation *op) const override;
private:
// Storage for a custom fallback interface.
void *fallbackEffectOpInterfaces;
}];
}
// Include the attribute definitions.
include "TestAttrDefs.td"
// Include the type definitions.
include "TestTypeDefs.td"
class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
@ -1933,6 +1905,16 @@ def FormatNestedAttr : TEST_Op<"format_nested_attr"> {
let assemblyFormat = "$nested attr-dict-with-keyword";
}
def FormatNestedCompoundAttr : TEST_Op<"format_cpmd_nested_attr"> {
let arguments = (ins CompoundNestedOuter:$nested);
let assemblyFormat = "`nested` $nested attr-dict-with-keyword";
}
def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> {
let arguments = (ins CompoundNestedOuterType:$nested);
let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword";
}
//===----------------------------------------------------------------------===//
// Custom Directives

View File

@ -14,8 +14,9 @@
#define TEST_TYPEDEFS
// To get the test dialect def.
include "TestOps.td"
include "TestDialect.td"
include "TestAttrDefs.td"
include "TestInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
@ -49,6 +50,29 @@ def CompoundTypeA : Test_Type<"CompoundA"> {
}];
}
// A more complex and nested parameterized type.
def CompoundNestedInnerType : Test_Type<"CompoundNestedInner"> {
let mnemonic = "cmpnd_inner";
// List of type parameters.
let parameters = (
ins
"int":$some_int,
CompoundTypeA:$cmpdA
);
let assemblyFormat = "`<` $some_int $cmpdA `>`";
}
def CompoundNestedOuterType : Test_Type<"CompoundNestedOuter"> {
let mnemonic = "cmpnd_nested_outer";
// List of type parameters.
let parameters = (
ins
CompoundNestedInnerType:$inner
);
let assemblyFormat = "`<` `i` $inner `>`";
}
// An example of how one could implement a standard integer.
def IntegerType : Test_Type<"TestInteger"> {
let mnemonic = "int";

View File

@ -25,6 +25,7 @@
#include "mlir/Interfaces/DataLayoutInterfaces.h"
namespace test {
class TestAttrWithFormatAttr;
/// FieldInfo represents a field in the StructType data type. It is used as a
/// parameter in TestTypeDefs.td.
@ -63,13 +64,13 @@ struct FieldParser<test::CustomParam> {
return test::CustomParam{value.getValue()};
}
};
} // end namespace mlir
inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
const test::CustomParam &param) {
test::CustomParam param) {
return printer << param.value;
}
} // end namespace mlir
#include "TestTypeInterfaces.h.inc"
#define GET_TYPEDEF_CLASSES

View File

@ -61,7 +61,7 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
// ATTR: printer << ' ' << "hello";
// ATTR: printer << ' ' << "=";
// ATTR: printer << ' ';
// ATTR: printer << getValue();
// ATTR: printer.printStrippedAttrOrType(getValue());
// ATTR: printer << ",";
// ATTR: printer << ' ';
// ATTR: ::printAttrParamA(printer, getComplex());
@ -154,10 +154,10 @@ def AttrB : TestAttr<"TestB"> {
// ATTR: void TestFAttr::print(::mlir::AsmPrinter &printer) const {
// ATTR: printer << ' ';
// ATTR: printer << getV0();
// ATTR: printer.printStrippedAttrOrType(getV0());
// ATTR: printer << ",";
// ATTR: printer << ' ';
// ATTR: printer << getV1();
// ATTR: printer.printStrippedAttrOrType(getV1());
// ATTR: }
def AttrC : TestAttr<"TestF"> {
@ -213,7 +213,7 @@ def AttrC : TestAttr<"TestF"> {
// TYPE: printer << ' ' << "bob";
// TYPE: printer << ' ' << "bar";
// TYPE: printer << ' ';
// TYPE: printer << getValue();
// TYPE: printer.printStrippedAttrOrType(getValue());
// TYPE: printer << ' ' << "complex";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
@ -361,21 +361,21 @@ def TypeB : TestType<"TestD"> {
// TYPE: printer << "v0";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
// TYPE: printer << getV0();
// TYPE: printer.printStrippedAttrOrType(getV0());
// TYPE: printer << ",";
// TYPE: printer << ' ' << "v2";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
// TYPE: printer << getV2();
// TYPE: printer.printStrippedAttrOrType(getV2());
// TYPE: printer << "v1";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
// TYPE: printer << getV1();
// TYPE: printer.printStrippedAttrOrType(getV1());
// TYPE: printer << ",";
// TYPE: printer << ' ' << "v3";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
// TYPE: printer << getV3();
// TYPE: printer.printStrippedAttrOrType(getV3());
// TYPE: }
def TypeC : TestType<"TestE"> {

View File

@ -256,16 +256,50 @@ test.format_optional_else else
// Format a custom attribute
//===----------------------------------------------------------------------===//
// CHECK: test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]>
test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]>
// CHECK: test.format_compound_attr <1, !test.smpla, [5, 6]>
test.format_compound_attr <1, !test.smpla, [5, 6]>
// CHECK: module attributes {test.nested = #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>} {
//-----
// CHECK: module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
}
//-----
// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`.
// CHECK: module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
module attributes {test.nested = #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>} {
}
// CHECK: test.format_nested_attr #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>
// CHECK: test.format_nested_attr <nested = <1, !test.smpla, [5, 6]>>
test.format_nested_attr #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>
//-----
// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`.
// CHECK: test.format_nested_attr <nested = <1, !test.smpla, [5, 6]>>
test.format_nested_attr #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>
//-----
// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>}
module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>}
{
}
//-----
// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>}
module attributes {test.someAttr = #test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>}
{
}
//-----
// CHECK: test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>
test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>
//===----------------------------------------------------------------------===//
// Format custom directives

View File

@ -13,6 +13,22 @@ func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () {
return
}
// CHECK: @compoundNested(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>)
func @compoundNested(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>) -> () {
return
}
// Same as above, but we're parsing the complete spec for the inner type
// CHECK: @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>)
func @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer<i !test.cmpnd_inner<42 <1, !test.smpla, [5, 6]>>>) -> () {
// Verify that the type prefix is elided and optional
// CHECK: format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
// CHECK: format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
test.format_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer<i !test.cmpnd_inner<42 <1, !test.smpla, [5, 6]>>>
test.format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
return
}
// CHECK: @testInt(%arg0: !test.int<signed, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<none, 1>)
func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<n, 1>) {
return

View File

@ -163,7 +163,8 @@ static const char *const defaultParameterParser =
"::mlir::FieldParser<$0>::parse($_parser)";
/// Default printer for attribute or type parameters.
static const char *const defaultParameterPrinter = "$_printer << $_self";
static const char *const defaultParameterPrinter =
"$_printer.printStrippedAttrOrType($_self)";
/// Print an error when failing to parse an element.
///

View File

@ -496,13 +496,25 @@ static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
/// {0}: The name of the attribute.
/// {1}: The type for the attribute.
const char *const attrParserCode = R"(
if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}",
result.attributes)) {{
return ::mlir::failure();
}
)";
/// The code snippet used to generate a parser call for an attribute.
///
/// {0}: The name of the attribute.
/// {1}: The type for the attribute.
const char *const genericAttrParserCode = R"(
if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes))
return ::mlir::failure();
)";
const char *const optionalAttrParserCode = R"(
{
::mlir::OptionalParseResult parseResult =
parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes);
if (parseResult.hasValue() && failed(*parseResult))
return ::mlir::failure();
}
@ -635,8 +647,12 @@ const char *const optionalTypeParserCode = R"(
}
)";
const char *const typeParserCode = R"(
if (parser.parseType({0}RawTypes[0]))
return ::mlir::failure();
{
{0} type;
if (parser.parseCustomTypeWithFallback(type))
return ::mlir::failure();
{1}RawTypes[0] = type;
}
)";
/// The code snippet used to generate a parser call for a functional type.
@ -1269,12 +1285,19 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
std::string attrTypeStr;
if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
llvm::raw_string_ostream os(attrTypeStr);
os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
os << tgfmt(*typeBuilder, &attrTypeCtx);
} else {
attrTypeStr = "Type{}";
}
if (var->attr.isOptional()) {
body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
} else {
if (var->attr.getStorageType() == "::mlir::Attribute")
body << formatv(genericAttrParserCode, var->name, attrTypeStr);
else
body << formatv(attrParserCode, var->name, attrTypeStr);
}
body << formatv(var->attr.isOptional() ? optionalAttrParserCode
: attrParserCode,
var->name, attrTypeStr);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name;
@ -1334,14 +1357,23 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
else if (lengthKind == ArgumentLengthKind::Variadic)
} else if (lengthKind == ArgumentLengthKind::Variadic) {
body << llvm::formatv(variadicTypeParserCode, listName);
else if (lengthKind == ArgumentLengthKind::Optional)
} else if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(optionalTypeParserCode, listName);
else
body << formatv(typeParserCode, listName);
} else {
TypeSwitch<Element *>(dir->getOperand())
.Case<OperandVariable, ResultVariable>([&](auto operand) {
body << formatv(typeParserCode,
operand->getVar()->constraint.getCPPClassName(),
listName);
})
.Default([&](auto operand) {
body << formatv(typeParserCode, "Type", listName);
});
}
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
ArgumentLengthKind ignored;
body << formatv(functionalTypeParserCode,
@ -1761,7 +1793,8 @@ static void genVariadicRegionPrinter(const Twine &regionListName,
/// Generate the C++ for an operand to a (*-)type directive.
static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
MethodBody &body) {
MethodBody &body,
bool useArrayRef = true) {
if (isa<OperandsDirective>(arg))
return body << "getOperation()->getOperandTypes()";
if (isa<ResultsDirective>(arg))
@ -1778,8 +1811,10 @@ static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
"({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
"::llvm::ArrayRef<::mlir::Type>())",
op.getGetterName(var->name));
return body << "::llvm::ArrayRef<::mlir::Type>("
<< op.getGetterName(var->name) << "().getType())";
if (useArrayRef)
return body << "::llvm::ArrayRef<::mlir::Type>("
<< op.getGetterName(var->name) << "().getType())";
return body << op.getGetterName(var->name) << "().getType()";
}
/// Generate the printer for an enum attribute.
@ -1978,9 +2013,15 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
if (attr->getTypeBuilder())
body << " _odsPrinter.printAttributeWithoutType("
<< op.getGetterName(var->name) << "Attr());\n";
else
else if (var->attr.isOptional())
body << "_odsPrinter.printAttribute(" << op.getGetterName(var->name)
<< "Attr());\n";
else if (var->attr.getStorageType() == "::mlir::Attribute")
body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name)
<< "Attr());\n";
else
body << "_odsPrinter.printStrippedAttrOrType("
<< op.getGetterName(var->name) << "Attr());\n";
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
if (operand->getVar()->isVariadicOfVariadic()) {
body << " ::llvm::interleaveComma("
@ -2033,8 +2074,29 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
return;
}
}
const NamedTypeConstraint *var = nullptr;
{
if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand()))
var = operand->getVar();
else if (auto *operand = dyn_cast<ResultVariable>(dir->getOperand()))
var = operand->getVar();
}
if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
!var->isOptional()) {
std::string cppClass = var->constraint.getCPPClassName();
body << " {\n"
<< " auto type = " << op.getGetterName(var->name)
<< "().getType();\n"
<< " if (auto validType = type.dyn_cast<" << cppClass << ">())\n"
<< " _odsPrinter.printStrippedAttrOrType(validType);\n"
<< " else\n"
<< " _odsPrinter << type;\n"
<< " }\n";
return;
}
body << " _odsPrinter << ";
genTypeOperandPrinter(dir->getOperand(), op, body) << ";\n";
genTypeOperandPrinter(dir->getOperand(), op, body, /*useArrayRef=*/false)
<< ";\n";
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
body << " _odsPrinter.printFunctionalType(";
genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";