mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-28 03:46:06 +00:00
Change the printing/parsing behavior for Attributes used in declarative assembly format
The new form of printing attribute in the declarative assembly is eliding the `#dialect.mnemonic` prefix to only keep the `<....>` part. Differential Revision: https://reviews.llvm.org/D113873
This commit is contained in:
parent
63cd1842a7
commit
ee0908703d
@ -31,6 +31,7 @@ def ArmSVE_Dialect : Dialect {
|
||||
vector operations, including a scalable vector type and intrinsics for
|
||||
some Arm SVE instructions.
|
||||
}];
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -66,20 +67,6 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
|
||||
"Type":$elementType
|
||||
);
|
||||
|
||||
let printer = [{
|
||||
$_printer << "<";
|
||||
for (int64_t dim : getShape())
|
||||
$_printer << dim << 'x';
|
||||
$_printer << getElementType() << '>';
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
VectorType vector;
|
||||
if ($_parser.parseType(vector))
|
||||
return Type();
|
||||
return get($_ctxt, vector.getShape(), vector.getElementType());
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
bool hasStaticShape() const {
|
||||
return llvm::none_of(getShape(), ShapedType::isDynamic);
|
||||
|
@ -64,19 +64,19 @@ struct FieldParser<
|
||||
AttributeT>> {
|
||||
static FailureOr<AttributeT> parse(AsmParser &parser) {
|
||||
AttributeT value;
|
||||
if (parser.parseAttribute(value))
|
||||
if (parser.parseCustomAttributeWithFallback(value))
|
||||
return failure();
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
/// Parse a type.
|
||||
/// Parse an attribute.
|
||||
template <typename TypeT>
|
||||
struct FieldParser<
|
||||
TypeT, std::enable_if_t<std::is_base_of<Type, TypeT>::value, TypeT>> {
|
||||
static FailureOr<TypeT> parse(AsmParser &parser) {
|
||||
TypeT value;
|
||||
if (parser.parseType(value))
|
||||
if (parser.parseCustomTypeWithFallback(value))
|
||||
return failure();
|
||||
return value;
|
||||
}
|
||||
|
@ -2984,6 +2984,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
|
||||
string baseCppClass = "::mlir::Type">
|
||||
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
|
||||
AttrOrTypeDef<"Type", name, traits, baseCppClass> {
|
||||
// Make it possible to use such type as parameters for other types.
|
||||
string cppType = dialect.cppNamespace # "::" # cppClassName;
|
||||
|
||||
// A constant builder provided when the type has no parameters.
|
||||
let builderCall = !if(!empty(parameters),
|
||||
"$_builder.getType<" # dialect.cppNamespace #
|
||||
|
@ -50,6 +50,36 @@ public:
|
||||
virtual void printType(Type type);
|
||||
virtual void printAttribute(Attribute attr);
|
||||
|
||||
/// Trait to check if `AttrType` provides a `print` method.
|
||||
template <typename AttrOrType>
|
||||
using has_print_method =
|
||||
decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
|
||||
template <typename AttrOrType>
|
||||
using detect_has_print_method =
|
||||
llvm::is_detected<has_print_method, AttrOrType>;
|
||||
|
||||
/// Print the provided attribute in the context of an operation custom
|
||||
/// printer/parser: this will invoke directly the print method on the
|
||||
/// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
|
||||
template <typename AttrOrType,
|
||||
std::enable_if_t<detect_has_print_method<AttrOrType>::value>
|
||||
*sfinae = nullptr>
|
||||
void printStrippedAttrOrType(AttrOrType attrOrType) {
|
||||
if (succeeded(printAlias(attrOrType)))
|
||||
return;
|
||||
attrOrType.print(*this);
|
||||
}
|
||||
|
||||
/// SFINAE for printing the provided attribute in the context of an operation
|
||||
/// custom printer in the case where the attribute does not define a print
|
||||
/// method.
|
||||
template <typename AttrOrType,
|
||||
std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
|
||||
*sfinae = nullptr>
|
||||
void printStrippedAttrOrType(AttrOrType attrOrType) {
|
||||
*this << attrOrType;
|
||||
}
|
||||
|
||||
/// Print the given attribute without its type. The corresponding parser must
|
||||
/// provide a valid type for the attribute.
|
||||
virtual void printAttributeWithoutType(Attribute attr);
|
||||
@ -102,6 +132,14 @@ private:
|
||||
AsmPrinter(const AsmPrinter &) = delete;
|
||||
void operator=(const AsmPrinter &) = delete;
|
||||
|
||||
/// Print the alias for the given attribute, return failure if no alias could
|
||||
/// be printed.
|
||||
virtual LogicalResult printAlias(Attribute attr);
|
||||
|
||||
/// Print the alias for the given type, return failure if no alias could
|
||||
/// be printed.
|
||||
virtual LogicalResult printAlias(Type type);
|
||||
|
||||
/// The internal implementation of the printer.
|
||||
Impl *impl;
|
||||
};
|
||||
@ -608,6 +646,13 @@ public:
|
||||
/// Parse an arbitrary attribute of a given type and return it in result.
|
||||
virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
|
||||
|
||||
/// Parse a custom attribute with the provided callback, unless the next
|
||||
/// token is `#`, in which case the generic parser is invoked.
|
||||
virtual ParseResult parseCustomAttributeWithFallback(
|
||||
Attribute &result, Type type,
|
||||
function_ref<ParseResult(Attribute &result, Type type)>
|
||||
parseAttribute) = 0;
|
||||
|
||||
/// Parse an attribute of a specific kind and type.
|
||||
template <typename AttrType>
|
||||
ParseResult parseAttribute(AttrType &result, Type type = {}) {
|
||||
@ -639,9 +684,9 @@ public:
|
||||
return parseAttribute(result, Type(), attrName, attrs);
|
||||
}
|
||||
|
||||
/// Parse an arbitrary attribute of a given type and return it in result. This
|
||||
/// also adds the attribute to the specified attribute list with the specified
|
||||
/// name.
|
||||
/// Parse an arbitrary attribute of a given type and populate it in `result`.
|
||||
/// This also adds the attribute to the specified attribute list with the
|
||||
/// specified name.
|
||||
template <typename AttrType>
|
||||
ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
|
||||
NamedAttrList &attrs) {
|
||||
@ -661,6 +706,82 @@ public:
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Trait to check if `AttrType` provides a `parse` method.
|
||||
template <typename AttrType>
|
||||
using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
|
||||
std::declval<Type>()));
|
||||
template <typename AttrType>
|
||||
using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
|
||||
|
||||
/// Parse a custom attribute of a given type unless the next token is `#`, in
|
||||
/// which case the generic parser is invoked. The parsed attribute is
|
||||
/// populated in `result` and also added to the specified attribute list with
|
||||
/// the specified name.
|
||||
template <typename AttrType>
|
||||
std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
|
||||
parseCustomAttributeWithFallback(AttrType &result, Type type,
|
||||
StringRef attrName, NamedAttrList &attrs) {
|
||||
llvm::SMLoc loc = getCurrentLocation();
|
||||
|
||||
// Parse any kind of attribute.
|
||||
Attribute attr;
|
||||
if (parseCustomAttributeWithFallback(
|
||||
attr, type, [&](Attribute &result, Type type) -> ParseResult {
|
||||
result = AttrType::parse(*this, type);
|
||||
if (!result)
|
||||
return failure();
|
||||
return success();
|
||||
}))
|
||||
return failure();
|
||||
|
||||
// Check for the right kind of attribute.
|
||||
result = attr.dyn_cast<AttrType>();
|
||||
if (!result)
|
||||
return emitError(loc, "invalid kind of attribute specified");
|
||||
|
||||
attrs.append(attrName, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// SFINAE parsing method for Attribute that don't implement a parse method.
|
||||
template <typename AttrType>
|
||||
std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
|
||||
parseCustomAttributeWithFallback(AttrType &result, Type type,
|
||||
StringRef attrName, NamedAttrList &attrs) {
|
||||
return parseAttribute(result, type, attrName, attrs);
|
||||
}
|
||||
|
||||
/// Parse a custom attribute of a given type unless the next token is `#`, in
|
||||
/// which case the generic parser is invoked. The parsed attribute is
|
||||
/// populated in `result`.
|
||||
template <typename AttrType>
|
||||
std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
|
||||
parseCustomAttributeWithFallback(AttrType &result) {
|
||||
llvm::SMLoc loc = getCurrentLocation();
|
||||
|
||||
// Parse any kind of attribute.
|
||||
Attribute attr;
|
||||
if (parseCustomAttributeWithFallback(
|
||||
attr, {}, [&](Attribute &result, Type type) -> ParseResult {
|
||||
result = AttrType::parse(*this, type);
|
||||
return success(!!result);
|
||||
}))
|
||||
return failure();
|
||||
|
||||
// Check for the right kind of attribute.
|
||||
result = attr.dyn_cast<AttrType>();
|
||||
if (!result)
|
||||
return emitError(loc, "invalid kind of attribute specified");
|
||||
return success();
|
||||
}
|
||||
|
||||
/// SFINAE parsing method for Attribute that don't implement a parse method.
|
||||
template <typename AttrType>
|
||||
std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
|
||||
parseCustomAttributeWithFallback(AttrType &result) {
|
||||
return parseAttribute(result);
|
||||
}
|
||||
|
||||
/// Parse an arbitrary optional attribute of a given type and return it in
|
||||
/// result.
|
||||
virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
|
||||
@ -740,6 +861,11 @@ public:
|
||||
/// Parse a type.
|
||||
virtual ParseResult parseType(Type &result) = 0;
|
||||
|
||||
/// Parse a custom type with the provided callback, unless the next
|
||||
/// token is `#`, in which case the generic parser is invoked.
|
||||
virtual ParseResult parseCustomTypeWithFallback(
|
||||
Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
|
||||
|
||||
/// Parse an optional type.
|
||||
virtual OptionalParseResult parseOptionalType(Type &result) = 0;
|
||||
|
||||
@ -753,7 +879,7 @@ public:
|
||||
if (parseType(type))
|
||||
return failure();
|
||||
|
||||
// Check for the right kind of attribute.
|
||||
// Check for the right kind of type.
|
||||
result = type.dyn_cast<TypeT>();
|
||||
if (!result)
|
||||
return emitError(loc, "invalid kind of type specified");
|
||||
@ -761,6 +887,44 @@ public:
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Trait to check if `TypeT` provides a `parse` method.
|
||||
template <typename TypeT>
|
||||
using type_has_parse_method =
|
||||
decltype(TypeT::parse(std::declval<AsmParser &>()));
|
||||
template <typename TypeT>
|
||||
using detect_type_has_parse_method =
|
||||
llvm::is_detected<type_has_parse_method, TypeT>;
|
||||
|
||||
/// Parse a custom Type of a given type unless the next token is `#`, in
|
||||
/// which case the generic parser is invoked. The parsed Type is
|
||||
/// populated in `result`.
|
||||
template <typename TypeT>
|
||||
std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
|
||||
parseCustomTypeWithFallback(TypeT &result) {
|
||||
llvm::SMLoc loc = getCurrentLocation();
|
||||
|
||||
// Parse any kind of Type.
|
||||
Type type;
|
||||
if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
|
||||
result = TypeT::parse(*this);
|
||||
return success(!!result);
|
||||
}))
|
||||
return failure();
|
||||
|
||||
// Check for the right kind of Type.
|
||||
result = type.dyn_cast<TypeT>();
|
||||
if (!result)
|
||||
return emitError(loc, "invalid kind of Type specified");
|
||||
return success();
|
||||
}
|
||||
|
||||
/// SFINAE parsing method for Type that don't implement a parse method.
|
||||
template <typename TypeT>
|
||||
std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
|
||||
parseCustomTypeWithFallback(TypeT &result) {
|
||||
return parseType(result);
|
||||
}
|
||||
|
||||
/// Parse a type list.
|
||||
ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
|
||||
do {
|
||||
@ -792,7 +956,7 @@ public:
|
||||
if (parseColonType(type))
|
||||
return failure();
|
||||
|
||||
// Check for the right kind of attribute.
|
||||
// Check for the right kind of type.
|
||||
result = type.dyn_cast<TypeType>();
|
||||
if (!result)
|
||||
return emitError(loc, "invalid kind of type specified");
|
||||
|
@ -53,21 +53,21 @@ void ArmSVEDialect::initialize() {
|
||||
// ScalableVectorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Type ArmSVEDialect::parseType(DialectAsmParser &parser) const {
|
||||
llvm::SMLoc typeLoc = parser.getCurrentLocation();
|
||||
{
|
||||
Type genType;
|
||||
auto parseResult = generatedTypeParser(parser, "vector", genType);
|
||||
if (parseResult.hasValue())
|
||||
return genType;
|
||||
}
|
||||
parser.emitError(typeLoc, "unknown type in ArmSVE dialect");
|
||||
return Type();
|
||||
void ScalableVectorType::print(AsmPrinter &printer) const {
|
||||
printer << "<";
|
||||
for (int64_t dim : getShape())
|
||||
printer << dim << 'x';
|
||||
printer << getElementType() << '>';
|
||||
}
|
||||
|
||||
void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
if (failed(generatedTypePrinter(type, os)))
|
||||
llvm_unreachable("unexpected 'arm_sve' type kind");
|
||||
Type ScalableVectorType::parse(AsmParser &parser) {
|
||||
SmallVector<int64_t> dims;
|
||||
Type eltType;
|
||||
if (parser.parseLess() ||
|
||||
parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
|
||||
parser.parseType(eltType) || parser.parseGreater())
|
||||
return {};
|
||||
return ScalableVectorType::get(eltType.getContext(), dims, eltType);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -170,7 +170,7 @@ static constexpr const CombiningKind combiningKindsList[] = {
|
||||
};
|
||||
|
||||
void CombiningKindAttr::print(AsmPrinter &printer) const {
|
||||
printer << "kind<";
|
||||
printer << "<";
|
||||
auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
|
||||
return bitEnumContains(this->getKind(), kind);
|
||||
});
|
||||
@ -215,10 +215,12 @@ Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
|
||||
|
||||
void VectorDialect::printAttribute(Attribute attr,
|
||||
DialectAsmPrinter &os) const {
|
||||
if (auto ck = attr.dyn_cast<CombiningKindAttr>())
|
||||
if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
|
||||
os << "kind";
|
||||
ck.print(os);
|
||||
else
|
||||
llvm_unreachable("Unknown attribute type");
|
||||
return;
|
||||
}
|
||||
llvm_unreachable("Unknown attribute type");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1188,7 +1188,7 @@ private:
|
||||
/// Ex:
|
||||
/// ```
|
||||
/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
|
||||
/// %1 = vector.multi_reduction #vector.kind<add>, %0 [1]
|
||||
/// %1 = vector.multi_reduction add, %0 [1]
|
||||
/// : vector<8x32x16xf32> to vector<8x16xf32>
|
||||
/// ```
|
||||
/// Gets converted to:
|
||||
@ -1198,7 +1198,7 @@ private:
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"],
|
||||
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
|
||||
/// kind = add} %0, %arg1, %cst_f0
|
||||
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
/// ```
|
||||
struct MultiReduceToContract
|
||||
@ -1247,7 +1247,7 @@ struct MultiReduceToContract
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"],
|
||||
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
|
||||
/// kind = add} %0, %arg1, %cst_f0
|
||||
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
/// ```
|
||||
/// Gets converted to:
|
||||
@ -1257,7 +1257,7 @@ struct MultiReduceToContract
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"],
|
||||
/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
|
||||
/// kind = add} %arg0, %arg1, %cst_f0
|
||||
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
/// ```
|
||||
struct CombineContractTranspose
|
||||
@ -1304,7 +1304,7 @@ struct CombineContractTranspose
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"],
|
||||
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
|
||||
/// kind = add} %0, %arg1, %cst_f0
|
||||
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
/// ```
|
||||
/// Gets converted to:
|
||||
@ -1314,7 +1314,7 @@ struct CombineContractTranspose
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"],
|
||||
/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
|
||||
/// kind = add} %arg0, %arg1, %cst_f0
|
||||
/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
/// ```
|
||||
struct CombineContractBroadcast
|
||||
|
@ -474,6 +474,14 @@ private:
|
||||
void printAttributeWithoutType(Attribute attr) override {
|
||||
printAttribute(attr);
|
||||
}
|
||||
LogicalResult printAlias(Attribute attr) override {
|
||||
initializer.visit(attr);
|
||||
return success();
|
||||
}
|
||||
LogicalResult printAlias(Type type) override {
|
||||
initializer.visit(type);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Print the given set of attributes with names not included within
|
||||
/// 'elidedAttrs'.
|
||||
@ -1252,8 +1260,16 @@ public:
|
||||
void printAttribute(Attribute attr,
|
||||
AttrTypeElision typeElision = AttrTypeElision::Never);
|
||||
|
||||
/// Print the alias for the given attribute, return failure if no alias could
|
||||
/// be printed.
|
||||
LogicalResult printAlias(Attribute attr);
|
||||
|
||||
void printType(Type type);
|
||||
|
||||
/// Print the alias for the given type, return failure if no alias could
|
||||
/// be printed.
|
||||
LogicalResult printAlias(Type type);
|
||||
|
||||
/// Print the given location to the stream. If `allowAlias` is true, this
|
||||
/// allows for the internal location to use an attribute alias.
|
||||
void printLocation(LocationAttr loc, bool allowAlias = false);
|
||||
@ -1594,6 +1610,14 @@ static void printElidedElementsAttr(raw_ostream &os) {
|
||||
os << R"(opaque<"_", "0xDEADBEEF">)";
|
||||
}
|
||||
|
||||
LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
|
||||
return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
|
||||
}
|
||||
|
||||
LogicalResult AsmPrinter::Impl::printAlias(Type type) {
|
||||
return success(state && succeeded(state->getAliasState().getAlias(type, os)));
|
||||
}
|
||||
|
||||
void AsmPrinter::Impl::printAttribute(Attribute attr,
|
||||
AttrTypeElision typeElision) {
|
||||
if (!attr) {
|
||||
@ -1602,7 +1626,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
|
||||
}
|
||||
|
||||
// Try to print an alias for this attribute.
|
||||
if (state && succeeded(state->getAliasState().getAlias(attr, os)))
|
||||
if (succeeded(printAlias(attr)))
|
||||
return;
|
||||
|
||||
if (!isa<BuiltinDialect>(attr.getDialect()))
|
||||
@ -2104,6 +2128,16 @@ void AsmPrinter::printAttribute(Attribute attr) {
|
||||
impl->printAttribute(attr);
|
||||
}
|
||||
|
||||
LogicalResult AsmPrinter::printAlias(Attribute attr) {
|
||||
assert(impl && "expected AsmPrinter::printAlias to be overriden");
|
||||
return impl->printAlias(attr);
|
||||
}
|
||||
|
||||
LogicalResult AsmPrinter::printAlias(Type type) {
|
||||
assert(impl && "expected AsmPrinter::printAlias to be overriden");
|
||||
return impl->printAlias(type);
|
||||
}
|
||||
|
||||
void AsmPrinter::printAttributeWithoutType(Attribute attr) {
|
||||
assert(impl &&
|
||||
"expected AsmPrinter::printAttributeWithoutType to be overriden");
|
||||
|
@ -374,6 +374,7 @@ BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BoolAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool BoolAttr::getValue() const {
|
||||
auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/TensorEncoding.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
@ -633,7 +634,7 @@ bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
|
||||
return true;
|
||||
|
||||
// Allow custom dialect attributes.
|
||||
if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect()))
|
||||
if (!isa<BuiltinDialect>(memorySpace.getDialect()))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
|
@ -343,6 +343,29 @@ public:
|
||||
return success(static_cast<bool>(result));
|
||||
}
|
||||
|
||||
/// Parse a custom attribute with the provided callback, unless the next
|
||||
/// token is `#`, in which case the generic parser is invoked.
|
||||
ParseResult parseCustomAttributeWithFallback(
|
||||
Attribute &result, Type type,
|
||||
function_ref<ParseResult(Attribute &result, Type type)> parseAttribute)
|
||||
override {
|
||||
if (parser.getToken().isNot(Token::hash_identifier))
|
||||
return parseAttribute(result, type);
|
||||
result = parser.parseAttribute(type);
|
||||
return success(static_cast<bool>(result));
|
||||
}
|
||||
|
||||
/// Parse a custom attribute with the provided callback, unless the next
|
||||
/// token is `#`, in which case the generic parser is invoked.
|
||||
ParseResult parseCustomTypeWithFallback(
|
||||
Type &result,
|
||||
function_ref<ParseResult(Type &result)> parseType) override {
|
||||
if (parser.getToken().isNot(Token::exclamation_identifier))
|
||||
return parseType(result);
|
||||
result = parser.parseType();
|
||||
return success(static_cast<bool>(result));
|
||||
}
|
||||
|
||||
OptionalParseResult parseOptionalAttribute(Attribute &result,
|
||||
Type type) override {
|
||||
return parser.parseOptionalAttribute(result, type);
|
||||
|
@ -3,7 +3,7 @@
|
||||
func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
|
||||
// CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
|
||||
// CHECK: arm_sve.sdot {{.*}}: <16xi8> to <4xi32
|
||||
%0 = arm_sve.sdot %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
@ -12,7 +12,7 @@ func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
|
||||
func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
|
||||
// CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
|
||||
// CHECK: arm_sve.smmla {{.*}}: <16xi8> to <4xi3
|
||||
%0 = arm_sve.smmla %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
@ -21,7 +21,7 @@ func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
|
||||
func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
|
||||
// CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
|
||||
// CHECK: arm_sve.udot {{.*}}: <16xi8> to <4xi32
|
||||
%0 = arm_sve.udot %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
@ -30,7 +30,7 @@ func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
|
||||
func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
|
||||
// CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
|
||||
// CHECK: arm_sve.ummla {{.*}}: <16xi8> to <4xi3
|
||||
%0 = arm_sve.ummla %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
|
@ -16,7 +16,7 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
|
||||
%0 = arith.addf %arg0, %arg0 : f32
|
||||
// CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value<f32>
|
||||
%1 = async.runtime.create: !async.value<f32>
|
||||
// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value<f32>
|
||||
// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : <f32>
|
||||
async.runtime.store %0, %1: !async.value<f32>
|
||||
// CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value<f32>
|
||||
async.runtime.set_available %1: !async.value<f32>
|
||||
@ -32,9 +32,9 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
|
||||
// CHECK: cond_br %[[IS_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_OK:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_OK]]:
|
||||
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : !async.value<f32>
|
||||
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[VAL_STORAGE]] : <f32>
|
||||
// CHECK: %[[RETURNED:.*]] = arith.mulf %[[ARG]], %[[LOADED]] : f32
|
||||
// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : !async.value<f32>
|
||||
// CHECK: async.runtime.store %[[RETURNED]], %[[RETURNED_STORAGE]] : <f32>
|
||||
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
@ -84,8 +84,8 @@ func @simple_caller() -> f32 {
|
||||
// CHECK: cond_br %[[IS_VALUE_ERROR]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_VALUE_OK]]:
|
||||
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : !async.value<f32>
|
||||
// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : !async.value<f32>
|
||||
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 : <f32>
|
||||
// CHECK: async.runtime.store %[[LOADED]], %[[RETURNED_STORAGE]] : <f32>
|
||||
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
@ -133,7 +133,7 @@ func @double_caller() -> f32 {
|
||||
// CHECK: cond_br %[[IS_VALUE_ERROR_1]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_1:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_VALUE_OK_1]]:
|
||||
// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : !async.value<f32>
|
||||
// CHECK: %[[LOADED_1:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_1]]#1 : <f32>
|
||||
// CHECK: %[[RETURNED_TO_CALLER_2:.*]]:2 = call @simple_callee(%[[LOADED_1]]) : (f32) -> (!async.token, !async.value<f32>)
|
||||
// CHECK: %[[SAVED_2:.*]] = async.coro.save %[[HDL]]
|
||||
// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_2]]#0, %[[HDL]]
|
||||
@ -150,8 +150,8 @@ func @double_caller() -> f32 {
|
||||
// CHECK: cond_br %[[IS_VALUE_ERROR_2]], ^[[BRANCH_ERROR:.*]], ^[[BRANCH_VALUE_OK_2:.*]]
|
||||
|
||||
// CHECK: ^[[BRANCH_VALUE_OK_2]]:
|
||||
// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : !async.value<f32>
|
||||
// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : !async.value<f32>
|
||||
// CHECK: %[[LOADED_2:.*]] = async.runtime.load %[[RETURNED_TO_CALLER_2]]#1 : <f32>
|
||||
// CHECK: async.runtime.store %[[LOADED_2]], %[[RETURNED_STORAGE]] : <f32>
|
||||
// CHECK: async.runtime.set_available %[[RETURNED_STORAGE]]
|
||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||
// CHECK: br ^[[CLEANUP]]
|
||||
|
@ -245,7 +245,7 @@ func @execute_and_return_f32() -> f32 {
|
||||
}
|
||||
|
||||
// CHECK: async.runtime.await %[[RET]]#1 : !async.value<f32>
|
||||
// CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : !async.value<f32>
|
||||
// CHECK: %[[VALUE:.*]] = async.runtime.load %[[RET]]#1 : <f32>
|
||||
%0 = async.await %result : !async.value<f32>
|
||||
|
||||
// CHECK: return %[[VALUE]]
|
||||
@ -323,7 +323,7 @@ func @async_value_operands() {
|
||||
|
||||
// // Load from the async.value argument after error checking.
|
||||
// CHECK: ^[[CONTINUATION:.*]]:
|
||||
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : !async.value<f32
|
||||
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : <f32
|
||||
// CHECK: arith.addf %[[LOADED]], %[[LOADED]] : f32
|
||||
// CHECK: async.runtime.set_available %[[TOKEN]]
|
||||
|
||||
|
@ -129,16 +129,16 @@ func @resume(%arg0: !async.coro.handle) {
|
||||
|
||||
// CHECK-LABEL: @store
|
||||
func @store(%arg0: f32, %arg1: !async.value<f32>) {
|
||||
// CHECK: async.runtime.store %arg0, %arg1 : !async.value<f32>
|
||||
async.runtime.store %arg0, %arg1 : !async.value<f32>
|
||||
// CHECK: async.runtime.store %arg0, %arg1 : <f32>
|
||||
async.runtime.store %arg0, %arg1 : <f32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @load
|
||||
func @load(%arg0: !async.value<f32>) -> f32 {
|
||||
// CHECK: %0 = async.runtime.load %arg0 : !async.value<f32>
|
||||
// CHECK: %0 = async.runtime.load %arg0 : <f32>
|
||||
// CHECK: return %0 : f32
|
||||
%0 = async.runtime.load %arg0 : !async.value<f32>
|
||||
%0 = async.runtime.load %arg0 : <f32>
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
|
||||
|
||||
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32>
|
||||
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [0] : vector<1584xf32> to f32
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [0] : vector<1584xf32> to f32
|
||||
// CHECK: arith.addf %{{.*}}, %{{.*}} : f32
|
||||
linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
|
||||
outs(%C: memref<f32>)
|
||||
@ -19,7 +19,7 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32
|
||||
func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
|
||||
|
||||
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
|
||||
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
|
||||
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32>
|
||||
linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
|
||||
outs(%C: memref<1584xf32>)
|
||||
@ -31,7 +31,7 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: me
|
||||
// CHECK-LABEL: contraction_matmul
|
||||
func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
|
||||
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
|
||||
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
|
||||
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
|
||||
linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
|
||||
outs(%C: memref<1584x1584xf32>)
|
||||
@ -43,7 +43,7 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %
|
||||
// CHECK-LABEL: contraction_batch_matmul
|
||||
func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
|
||||
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
|
||||
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
|
||||
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
|
||||
linalg.batch_matmul
|
||||
ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
|
||||
@ -71,7 +71,7 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
|
||||
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
// CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
|
||||
linalg.generic #matmul_trait
|
||||
@ -105,7 +105,7 @@ func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
|
||||
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
// CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
|
||||
linalg.generic #matmul_transpose_out_trait
|
||||
@ -139,7 +139,7 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32>
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
|
||||
// CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
|
||||
// CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32>
|
||||
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
|
||||
@ -160,7 +160,7 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
|
||||
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
|
||||
%C: memref<8x32xf32>) {
|
||||
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
|
||||
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32>
|
||||
linalg.matmul
|
||||
ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
|
||||
@ -523,7 +523,7 @@ func @matmul_tensors(
|
||||
// linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
|
||||
// convert it to a 2D contract.
|
||||
// CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
|
||||
// CHECK: %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32>
|
||||
// CHECK: %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
|
||||
%0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
|
||||
@ -744,7 +744,7 @@ func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
|
||||
// CHECK: math.exp {{.*}} : vector<4x16x8xf32>
|
||||
// CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
|
||||
// CHECK: addf {{.*}} : vector<4x16xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
|
||||
// CHECK: return {{.*}} : tensor<4x16xf32>
|
||||
@ -779,7 +779,7 @@ func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: ten
|
||||
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
|
||||
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
|
||||
// CHECK: addf {{.*}} : vector<2x3x4x5xf32>
|
||||
// CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
|
||||
// CHECK: vector.multi_reduction <add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
|
||||
// CHECK: addf {{.*}} : vector<2x5xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
|
||||
// CHECK: return {{.*}} : tensor<5x2xf32>
|
||||
@ -808,7 +808,7 @@ func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
|
||||
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
// CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind<maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: %[[R:.+]] = vector.multi_reduction <maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
%ident = arith.constant -3.40282e+38 : f32
|
||||
@ -833,7 +833,7 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
|
||||
// CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind<minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: %[[R:.+]] = vector.multi_reduction <minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: arith.minf %[[R]], %[[CMAXF]] : vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
%maxf32 = arith.constant 3.40282e+38 : f32
|
||||
@ -857,7 +857,7 @@ func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
|
||||
// CHECK: vector.multi_reduction #vector.kind<mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: mulf {{.*}} : vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
%ident = arith.constant 1.0 : f32
|
||||
@ -881,7 +881,7 @@ func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
|
||||
// CHECK: vector.multi_reduction #vector.kind<or>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.multi_reduction <or>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
%ident = arith.constant false
|
||||
%init = linalg.init_tensor [4] : tensor<4xi1>
|
||||
@ -904,7 +904,7 @@ func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
|
||||
// CHECK: vector.multi_reduction #vector.kind<and>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.multi_reduction <and>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
%ident = arith.constant true
|
||||
%init = linalg.init_tensor [4] : tensor<4xi1>
|
||||
@ -927,7 +927,7 @@ func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
|
||||
// CHECK: vector.multi_reduction #vector.kind<xor>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.multi_reduction <xor>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
%ident = arith.constant false
|
||||
%init = linalg.init_tensor [4] : tensor<4xi1>
|
||||
@ -979,7 +979,7 @@ func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) ->
|
||||
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32>
|
||||
// CHECK: subf {{.*}} : vector<4x4xf32>
|
||||
// CHECK: math.exp {{.*}} : vector<4x4xf32>
|
||||
// CHECK: vector.multi_reduction #vector.kind<add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: vector.multi_reduction <add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: addf {{.*}} : vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
@ -1019,7 +1019,7 @@ func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
|
||||
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
|
||||
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
|
||||
// CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
|
||||
// CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[r]] [0]
|
||||
// CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]] [0]
|
||||
// CHECK-SAME: : vector<32xf32> to f32
|
||||
// CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32
|
||||
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<f32>
|
||||
|
@ -1027,7 +1027,7 @@ func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v
|
||||
// CHECK-LABEL: func @vector_multi_reduction_single_parallel(
|
||||
// CHECK-SAME: %[[v:.*]]: vector<2xf32>
|
||||
func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
|
||||
%0 = vector.multi_reduction <mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
|
||||
|
||||
// CHECK: return %[[v]] : vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
|
@ -3,7 +3,7 @@
|
||||
// -----
|
||||
|
||||
func @broadcast_to_scalar(%arg0: f32) -> f32 {
|
||||
// expected-error@+1 {{'vector.broadcast' op result #0 must be vector of any type values, but got 'f32'}}
|
||||
// expected-error@+1 {{custom op 'vector.broadcast' invalid kind of type specified}}
|
||||
%0 = vector.broadcast %arg0 : f32 to f32
|
||||
}
|
||||
|
||||
@ -1022,7 +1022,7 @@ func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
|
||||
// -----
|
||||
|
||||
func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
|
||||
// expected-error@+1 {{must be vector of any type values}}
|
||||
// expected-error@+1 {{'vector.bitcast' invalid kind of type specified}}
|
||||
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32
|
||||
}
|
||||
|
||||
|
@ -685,9 +685,9 @@ func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
|
||||
|
||||
// CHECK-LABEL: @multi_reduction
|
||||
func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 {
|
||||
%1 = vector.multi_reduction #vector.kind<add>, %0 [1, 3] :
|
||||
%1 = vector.multi_reduction <add>, %0 [1, 3] :
|
||||
vector<4x8x16x32xf32> to vector<4x16xf32>
|
||||
%2 = vector.multi_reduction #vector.kind<add>, %1 [0, 1] :
|
||||
%2 = vector.multi_reduction <add>, %1 [0, 1] :
|
||||
vector<4x16xf32> to f32
|
||||
return %2 : f32
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
|
||||
|
||||
func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
%0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_multi_reduction
|
||||
@ -18,7 +18,7 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
// CHECK: return %[[RESULT_VEC]]
|
||||
|
||||
func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
|
||||
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
|
||||
%0 = vector.multi_reduction <mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
|
||||
return %0 : f32
|
||||
}
|
||||
// CHECK-LABEL: func @vector_multi_reduction_to_scalar
|
||||
@ -30,7 +30,7 @@ func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
|
||||
// CHECK: return %[[RES]]
|
||||
|
||||
func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
|
||||
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
|
||||
%0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
|
||||
return %0 : vector<2x3xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_reduction_inner
|
||||
@ -66,7 +66,7 @@ func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
|
||||
|
||||
|
||||
func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
|
||||
%0 = vector.multi_reduction <add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
|
||||
return %0 : vector<2x5xf32>
|
||||
}
|
||||
|
||||
@ -78,7 +78,7 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
|
||||
%0 = vector.multi_reduction <mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
|
||||
return %0 : vector<2x4xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_multi_reduction_ordering
|
||||
|
@ -1,7 +1,7 @@
|
||||
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s
|
||||
|
||||
func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
%0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
@ -18,7 +18,7 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||
|
||||
func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
%0 = vector.multi_reduction <minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
@ -35,7 +35,7 @@ func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||
|
||||
func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
%0 = vector.multi_reduction <maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
@ -52,7 +52,7 @@ func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||
|
||||
func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||
%0 = vector.multi_reduction #vector.kind<and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||
%0 = vector.multi_reduction <and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||
return %0 : vector<2xi32>
|
||||
}
|
||||
|
||||
@ -69,7 +69,7 @@ func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
|
||||
|
||||
func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||
%0 = vector.multi_reduction #vector.kind<or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||
%0 = vector.multi_reduction <or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||
return %0 : vector<2xi32>
|
||||
}
|
||||
|
||||
@ -86,7 +86,7 @@ func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
|
||||
|
||||
func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||
%0 = vector.multi_reduction #vector.kind<xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||
%0 = vector.multi_reduction <xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||
return %0 : vector<2xi32>
|
||||
}
|
||||
|
||||
@ -104,7 +104,7 @@ func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||
|
||||
|
||||
func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
|
||||
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
|
||||
%0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
|
||||
return %0 : vector<2x3xi32>
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
func @multidimreduction_contract(
|
||||
%arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> {
|
||||
%0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
|
||||
%1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
|
||||
%1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
|
||||
return %1 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
@ -30,7 +30,7 @@ func @multidimreduction_contract(
|
||||
func @multidimreduction_contract_int(
|
||||
%arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> {
|
||||
%0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32>
|
||||
%1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
|
||||
%1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
|
||||
return %1 : vector<8x16xi32>
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
#define TEST_ATTRDEFS
|
||||
|
||||
// To get the test dialect definition.
|
||||
include "TestOps.td"
|
||||
include "TestDialect.td"
|
||||
include "mlir/IR/BuiltinAttributeInterfaces.td"
|
||||
include "mlir/IR/SubElementInterfaces.td"
|
||||
|
||||
@ -121,6 +121,29 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
|
||||
);
|
||||
}
|
||||
|
||||
// A more complex parameterized attribute with multiple level of nesting.
|
||||
def CompoundNestedInner : Test_Attr<"CompoundNestedInner"> {
|
||||
let mnemonic = "cmpnd_nested_inner";
|
||||
// List of type parameters.
|
||||
let parameters = (
|
||||
ins
|
||||
"int":$some_int,
|
||||
CompoundAttrA:$cmpdA
|
||||
);
|
||||
let assemblyFormat = "`<` $some_int $cmpdA `>`";
|
||||
}
|
||||
|
||||
def CompoundNestedOuter : Test_Attr<"CompoundNestedOuter"> {
|
||||
let mnemonic = "cmpnd_nested_outer";
|
||||
|
||||
// List of type parameters.
|
||||
let parameters = (
|
||||
ins
|
||||
CompoundNestedInner:$inner
|
||||
);
|
||||
let assemblyFormat = "`<` `i` $inner `>`";
|
||||
}
|
||||
|
||||
def TestParamOne : AttrParameter<"int64_t", ""> {}
|
||||
|
||||
def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> {
|
||||
|
46
mlir/test/lib/Dialect/Test/TestDialect.td
Normal file
46
mlir/test/lib/Dialect/Test/TestDialect.td
Normal file
@ -0,0 +1,46 @@
|
||||
//===-- TestDialect.td - Test dialect definition -----------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TEST_DIALECT
|
||||
#define TEST_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Test_Dialect : Dialect {
|
||||
let name = "test";
|
||||
let cppNamespace = "::test";
|
||||
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
|
||||
let hasCanonicalizer = 1;
|
||||
let hasConstantMaterializer = 1;
|
||||
let hasOperationAttrVerify = 1;
|
||||
let hasRegionArgAttrVerify = 1;
|
||||
let hasRegionResultAttrVerify = 1;
|
||||
let hasOperationInterfaceFallback = 1;
|
||||
let hasNonDefaultDestructor = 1;
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
let dependentDialects = ["::mlir::DLTIDialect"];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
void registerAttributes();
|
||||
void registerTypes();
|
||||
|
||||
// Provides a custom printing/parsing for some operations.
|
||||
::llvm::Optional<ParseOpHook>
|
||||
getParseOperationHook(::llvm::StringRef opName) const override;
|
||||
::llvm::unique_function<void(::mlir::Operation *,
|
||||
::mlir::OpAsmPrinter &printer)>
|
||||
getOperationPrinter(::mlir::Operation *op) const override;
|
||||
|
||||
private:
|
||||
// Storage for a custom fallback interface.
|
||||
void *fallbackEffectOpInterfaces;
|
||||
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TEST_DIALECT
|
@ -9,6 +9,7 @@
|
||||
#ifndef TEST_OPS
|
||||
#define TEST_OPS
|
||||
|
||||
include "TestDialect.td"
|
||||
include "mlir/Dialect/DLTI/DLTIBase.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
@ -23,40 +24,11 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
|
||||
include "TestInterfaces.td"
|
||||
|
||||
def Test_Dialect : Dialect {
|
||||
let name = "test";
|
||||
let cppNamespace = "::test";
|
||||
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
|
||||
let hasCanonicalizer = 1;
|
||||
let hasConstantMaterializer = 1;
|
||||
let hasOperationAttrVerify = 1;
|
||||
let hasRegionArgAttrVerify = 1;
|
||||
let hasRegionResultAttrVerify = 1;
|
||||
let hasOperationInterfaceFallback = 1;
|
||||
let hasNonDefaultDestructor = 1;
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
let dependentDialects = ["::mlir::DLTIDialect"];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
void registerAttributes();
|
||||
void registerTypes();
|
||||
|
||||
// Provides a custom printing/parsing for some operations.
|
||||
::llvm::Optional<ParseOpHook>
|
||||
getParseOperationHook(::llvm::StringRef opName) const override;
|
||||
::llvm::unique_function<void(::mlir::Operation *,
|
||||
::mlir::OpAsmPrinter &printer)>
|
||||
getOperationPrinter(::mlir::Operation *op) const override;
|
||||
|
||||
private:
|
||||
// Storage for a custom fallback interface.
|
||||
void *fallbackEffectOpInterfaces;
|
||||
|
||||
}];
|
||||
}
|
||||
|
||||
// Include the attribute definitions.
|
||||
include "TestAttrDefs.td"
|
||||
// Include the type definitions.
|
||||
include "TestTypeDefs.td"
|
||||
|
||||
|
||||
class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
@ -1933,6 +1905,16 @@ def FormatNestedAttr : TEST_Op<"format_nested_attr"> {
|
||||
let assemblyFormat = "$nested attr-dict-with-keyword";
|
||||
}
|
||||
|
||||
def FormatNestedCompoundAttr : TEST_Op<"format_cpmd_nested_attr"> {
|
||||
let arguments = (ins CompoundNestedOuter:$nested);
|
||||
let assemblyFormat = "`nested` $nested attr-dict-with-keyword";
|
||||
}
|
||||
|
||||
def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> {
|
||||
let arguments = (ins CompoundNestedOuterType:$nested);
|
||||
let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Custom Directives
|
||||
|
||||
|
@ -14,8 +14,9 @@
|
||||
#define TEST_TYPEDEFS
|
||||
|
||||
// To get the test dialect def.
|
||||
include "TestOps.td"
|
||||
include "TestDialect.td"
|
||||
include "TestAttrDefs.td"
|
||||
include "TestInterfaces.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "mlir/Interfaces/DataLayoutInterfaces.td"
|
||||
|
||||
@ -49,6 +50,29 @@ def CompoundTypeA : Test_Type<"CompoundA"> {
|
||||
}];
|
||||
}
|
||||
|
||||
// A more complex and nested parameterized type.
|
||||
def CompoundNestedInnerType : Test_Type<"CompoundNestedInner"> {
|
||||
let mnemonic = "cmpnd_inner";
|
||||
// List of type parameters.
|
||||
let parameters = (
|
||||
ins
|
||||
"int":$some_int,
|
||||
CompoundTypeA:$cmpdA
|
||||
);
|
||||
let assemblyFormat = "`<` $some_int $cmpdA `>`";
|
||||
}
|
||||
|
||||
def CompoundNestedOuterType : Test_Type<"CompoundNestedOuter"> {
|
||||
let mnemonic = "cmpnd_nested_outer";
|
||||
|
||||
// List of type parameters.
|
||||
let parameters = (
|
||||
ins
|
||||
CompoundNestedInnerType:$inner
|
||||
);
|
||||
let assemblyFormat = "`<` `i` $inner `>`";
|
||||
}
|
||||
|
||||
// An example of how one could implement a standard integer.
|
||||
def IntegerType : Test_Type<"TestInteger"> {
|
||||
let mnemonic = "int";
|
||||
|
@ -25,6 +25,7 @@
|
||||
#include "mlir/Interfaces/DataLayoutInterfaces.h"
|
||||
|
||||
namespace test {
|
||||
class TestAttrWithFormatAttr;
|
||||
|
||||
/// FieldInfo represents a field in the StructType data type. It is used as a
|
||||
/// parameter in TestTypeDefs.td.
|
||||
@ -63,13 +64,13 @@ struct FieldParser<test::CustomParam> {
|
||||
return test::CustomParam{value.getValue()};
|
||||
}
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
|
||||
const test::CustomParam ¶m) {
|
||||
test::CustomParam param) {
|
||||
return printer << param.value;
|
||||
}
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#include "TestTypeInterfaces.h.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
|
@ -61,7 +61,7 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
|
||||
// ATTR: printer << ' ' << "hello";
|
||||
// ATTR: printer << ' ' << "=";
|
||||
// ATTR: printer << ' ';
|
||||
// ATTR: printer << getValue();
|
||||
// ATTR: printer.printStrippedAttrOrType(getValue());
|
||||
// ATTR: printer << ",";
|
||||
// ATTR: printer << ' ';
|
||||
// ATTR: ::printAttrParamA(printer, getComplex());
|
||||
@ -154,10 +154,10 @@ def AttrB : TestAttr<"TestB"> {
|
||||
|
||||
// ATTR: void TestFAttr::print(::mlir::AsmPrinter &printer) const {
|
||||
// ATTR: printer << ' ';
|
||||
// ATTR: printer << getV0();
|
||||
// ATTR: printer.printStrippedAttrOrType(getV0());
|
||||
// ATTR: printer << ",";
|
||||
// ATTR: printer << ' ';
|
||||
// ATTR: printer << getV1();
|
||||
// ATTR: printer.printStrippedAttrOrType(getV1());
|
||||
// ATTR: }
|
||||
|
||||
def AttrC : TestAttr<"TestF"> {
|
||||
@ -213,7 +213,7 @@ def AttrC : TestAttr<"TestF"> {
|
||||
// TYPE: printer << ' ' << "bob";
|
||||
// TYPE: printer << ' ' << "bar";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getValue();
|
||||
// TYPE: printer.printStrippedAttrOrType(getValue());
|
||||
// TYPE: printer << ' ' << "complex";
|
||||
// TYPE: printer << ' ' << "=";
|
||||
// TYPE: printer << ' ';
|
||||
@ -361,21 +361,21 @@ def TypeB : TestType<"TestD"> {
|
||||
// TYPE: printer << "v0";
|
||||
// TYPE: printer << ' ' << "=";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getV0();
|
||||
// TYPE: printer.printStrippedAttrOrType(getV0());
|
||||
// TYPE: printer << ",";
|
||||
// TYPE: printer << ' ' << "v2";
|
||||
// TYPE: printer << ' ' << "=";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getV2();
|
||||
// TYPE: printer.printStrippedAttrOrType(getV2());
|
||||
// TYPE: printer << "v1";
|
||||
// TYPE: printer << ' ' << "=";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getV1();
|
||||
// TYPE: printer.printStrippedAttrOrType(getV1());
|
||||
// TYPE: printer << ",";
|
||||
// TYPE: printer << ' ' << "v3";
|
||||
// TYPE: printer << ' ' << "=";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getV3();
|
||||
// TYPE: printer.printStrippedAttrOrType(getV3());
|
||||
// TYPE: }
|
||||
|
||||
def TypeC : TestType<"TestE"> {
|
||||
|
@ -256,16 +256,50 @@ test.format_optional_else else
|
||||
// Format a custom attribute
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]>
|
||||
test.format_compound_attr #test.cmpnd_a<1, !test.smpla, [5, 6]>
|
||||
// CHECK: test.format_compound_attr <1, !test.smpla, [5, 6]>
|
||||
test.format_compound_attr <1, !test.smpla, [5, 6]>
|
||||
|
||||
// CHECK: module attributes {test.nested = #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>} {
|
||||
//-----
|
||||
|
||||
|
||||
// CHECK: module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
|
||||
module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
|
||||
}
|
||||
|
||||
//-----
|
||||
|
||||
// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`.
|
||||
// CHECK: module attributes {test.nested = #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>} {
|
||||
module attributes {test.nested = #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>} {
|
||||
}
|
||||
|
||||
// CHECK: test.format_nested_attr #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>
|
||||
// CHECK: test.format_nested_attr <nested = <1, !test.smpla, [5, 6]>>
|
||||
test.format_nested_attr #test.cmpnd_nested<nested = <1, !test.smpla, [5, 6]>>
|
||||
|
||||
//-----
|
||||
|
||||
// Same as above, but fully spelling the inner attribute prefix `#test.cmpnd_a`.
|
||||
// CHECK: test.format_nested_attr <nested = <1, !test.smpla, [5, 6]>>
|
||||
test.format_nested_attr #test.cmpnd_nested<nested = #test.cmpnd_a<1, !test.smpla, [5, 6]>>
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>}
|
||||
module attributes {test.someAttr = #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>}
|
||||
{
|
||||
}
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK: module attributes {test.someAttr = #test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>}
|
||||
module attributes {test.someAttr = #test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>}
|
||||
{
|
||||
}
|
||||
|
||||
//-----
|
||||
|
||||
// CHECK: test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>
|
||||
test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Format custom directives
|
||||
|
@ -13,6 +13,22 @@ func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: @compoundNested(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>)
|
||||
func @compoundNested(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>) -> () {
|
||||
return
|
||||
}
|
||||
|
||||
// Same as above, but we're parsing the complete spec for the inner type
|
||||
// CHECK: @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer<i <42 <1, !test.smpla, [5, 6]>>>)
|
||||
func @compoundNestedExplicit(%arg0: !test.cmpnd_nested_outer<i !test.cmpnd_inner<42 <1, !test.smpla, [5, 6]>>>) -> () {
|
||||
// Verify that the type prefix is elided and optional
|
||||
// CHECK: format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
|
||||
// CHECK: format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
|
||||
test.format_cpmd_nested_type %arg0 nested !test.cmpnd_nested_outer<i !test.cmpnd_inner<42 <1, !test.smpla, [5, 6]>>>
|
||||
test.format_cpmd_nested_type %arg0 nested <i <42 <1, !test.smpla, [5, 6]>>>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: @testInt(%arg0: !test.int<signed, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<none, 1>)
|
||||
func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<n, 1>) {
|
||||
return
|
||||
|
@ -163,7 +163,8 @@ static const char *const defaultParameterParser =
|
||||
"::mlir::FieldParser<$0>::parse($_parser)";
|
||||
|
||||
/// Default printer for attribute or type parameters.
|
||||
static const char *const defaultParameterPrinter = "$_printer << $_self";
|
||||
static const char *const defaultParameterPrinter =
|
||||
"$_printer.printStrippedAttrOrType($_self)";
|
||||
|
||||
/// Print an error when failing to parse an element.
|
||||
///
|
||||
|
@ -496,13 +496,25 @@ static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
|
||||
/// {0}: The name of the attribute.
|
||||
/// {1}: The type for the attribute.
|
||||
const char *const attrParserCode = R"(
|
||||
if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
|
||||
if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}",
|
||||
result.attributes)) {{
|
||||
return ::mlir::failure();
|
||||
}
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a parser call for an attribute.
|
||||
///
|
||||
/// {0}: The name of the attribute.
|
||||
/// {1}: The type for the attribute.
|
||||
const char *const genericAttrParserCode = R"(
|
||||
if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes))
|
||||
return ::mlir::failure();
|
||||
)";
|
||||
|
||||
const char *const optionalAttrParserCode = R"(
|
||||
{
|
||||
::mlir::OptionalParseResult parseResult =
|
||||
parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
|
||||
parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes);
|
||||
if (parseResult.hasValue() && failed(*parseResult))
|
||||
return ::mlir::failure();
|
||||
}
|
||||
@ -635,8 +647,12 @@ const char *const optionalTypeParserCode = R"(
|
||||
}
|
||||
)";
|
||||
const char *const typeParserCode = R"(
|
||||
if (parser.parseType({0}RawTypes[0]))
|
||||
return ::mlir::failure();
|
||||
{
|
||||
{0} type;
|
||||
if (parser.parseCustomTypeWithFallback(type))
|
||||
return ::mlir::failure();
|
||||
{1}RawTypes[0] = type;
|
||||
}
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a parser call for a functional type.
|
||||
@ -1269,12 +1285,19 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
|
||||
std::string attrTypeStr;
|
||||
if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
|
||||
llvm::raw_string_ostream os(attrTypeStr);
|
||||
os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
|
||||
os << tgfmt(*typeBuilder, &attrTypeCtx);
|
||||
} else {
|
||||
attrTypeStr = "Type{}";
|
||||
}
|
||||
if (var->attr.isOptional()) {
|
||||
body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
|
||||
} else {
|
||||
if (var->attr.getStorageType() == "::mlir::Attribute")
|
||||
body << formatv(genericAttrParserCode, var->name, attrTypeStr);
|
||||
else
|
||||
body << formatv(attrParserCode, var->name, attrTypeStr);
|
||||
}
|
||||
|
||||
body << formatv(var->attr.isOptional() ? optionalAttrParserCode
|
||||
: attrParserCode,
|
||||
var->name, attrTypeStr);
|
||||
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
|
||||
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
|
||||
StringRef name = operand->getVar()->name;
|
||||
@ -1334,14 +1357,23 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
||||
ArgumentLengthKind lengthKind;
|
||||
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
|
||||
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
|
||||
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
|
||||
body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
|
||||
else if (lengthKind == ArgumentLengthKind::Variadic)
|
||||
} else if (lengthKind == ArgumentLengthKind::Variadic) {
|
||||
body << llvm::formatv(variadicTypeParserCode, listName);
|
||||
else if (lengthKind == ArgumentLengthKind::Optional)
|
||||
} else if (lengthKind == ArgumentLengthKind::Optional) {
|
||||
body << llvm::formatv(optionalTypeParserCode, listName);
|
||||
else
|
||||
body << formatv(typeParserCode, listName);
|
||||
} else {
|
||||
TypeSwitch<Element *>(dir->getOperand())
|
||||
.Case<OperandVariable, ResultVariable>([&](auto operand) {
|
||||
body << formatv(typeParserCode,
|
||||
operand->getVar()->constraint.getCPPClassName(),
|
||||
listName);
|
||||
})
|
||||
.Default([&](auto operand) {
|
||||
body << formatv(typeParserCode, "Type", listName);
|
||||
});
|
||||
}
|
||||
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
|
||||
ArgumentLengthKind ignored;
|
||||
body << formatv(functionalTypeParserCode,
|
||||
@ -1761,7 +1793,8 @@ static void genVariadicRegionPrinter(const Twine ®ionListName,
|
||||
|
||||
/// Generate the C++ for an operand to a (*-)type directive.
|
||||
static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
|
||||
MethodBody &body) {
|
||||
MethodBody &body,
|
||||
bool useArrayRef = true) {
|
||||
if (isa<OperandsDirective>(arg))
|
||||
return body << "getOperation()->getOperandTypes()";
|
||||
if (isa<ResultsDirective>(arg))
|
||||
@ -1778,8 +1811,10 @@ static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
|
||||
"({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
|
||||
"::llvm::ArrayRef<::mlir::Type>())",
|
||||
op.getGetterName(var->name));
|
||||
return body << "::llvm::ArrayRef<::mlir::Type>("
|
||||
<< op.getGetterName(var->name) << "().getType())";
|
||||
if (useArrayRef)
|
||||
return body << "::llvm::ArrayRef<::mlir::Type>("
|
||||
<< op.getGetterName(var->name) << "().getType())";
|
||||
return body << op.getGetterName(var->name) << "().getType()";
|
||||
}
|
||||
|
||||
/// Generate the printer for an enum attribute.
|
||||
@ -1978,9 +2013,15 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
|
||||
if (attr->getTypeBuilder())
|
||||
body << " _odsPrinter.printAttributeWithoutType("
|
||||
<< op.getGetterName(var->name) << "Attr());\n";
|
||||
else
|
||||
else if (var->attr.isOptional())
|
||||
body << "_odsPrinter.printAttribute(" << op.getGetterName(var->name)
|
||||
<< "Attr());\n";
|
||||
else if (var->attr.getStorageType() == "::mlir::Attribute")
|
||||
body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name)
|
||||
<< "Attr());\n";
|
||||
else
|
||||
body << "_odsPrinter.printStrippedAttrOrType("
|
||||
<< op.getGetterName(var->name) << "Attr());\n";
|
||||
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
|
||||
if (operand->getVar()->isVariadicOfVariadic()) {
|
||||
body << " ::llvm::interleaveComma("
|
||||
@ -2033,8 +2074,29 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
|
||||
return;
|
||||
}
|
||||
}
|
||||
const NamedTypeConstraint *var = nullptr;
|
||||
{
|
||||
if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand()))
|
||||
var = operand->getVar();
|
||||
else if (auto *operand = dyn_cast<ResultVariable>(dir->getOperand()))
|
||||
var = operand->getVar();
|
||||
}
|
||||
if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
|
||||
!var->isOptional()) {
|
||||
std::string cppClass = var->constraint.getCPPClassName();
|
||||
body << " {\n"
|
||||
<< " auto type = " << op.getGetterName(var->name)
|
||||
<< "().getType();\n"
|
||||
<< " if (auto validType = type.dyn_cast<" << cppClass << ">())\n"
|
||||
<< " _odsPrinter.printStrippedAttrOrType(validType);\n"
|
||||
<< " else\n"
|
||||
<< " _odsPrinter << type;\n"
|
||||
<< " }\n";
|
||||
return;
|
||||
}
|
||||
body << " _odsPrinter << ";
|
||||
genTypeOperandPrinter(dir->getOperand(), op, body) << ";\n";
|
||||
genTypeOperandPrinter(dir->getOperand(), op, body, /*useArrayRef=*/false)
|
||||
<< ";\n";
|
||||
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
|
||||
body << " _odsPrinter.printFunctionalType(";
|
||||
genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";
|
||||
|
Loading…
x
Reference in New Issue
Block a user