diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 4d975af9cfe7..e01e21c287fa 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -492,7 +492,7 @@ MLIR provides a first class set of polyhedral operations and analyses within the Each SSA value in MLIR has a type defined by the type system below. There are a number of primitive types (like integers) and also aggregate types for tensors -and memory buffers. MLIR standard types do not include complex numbers, tuples, +and memory buffers. MLIR standard types do not include complex numbers, structures, arrays, or dictionaries. MLIR has an open type system (there is no fixed list of types), and types may @@ -556,7 +556,7 @@ Builtin types consist of only the types needed for the validity of the IR. Syntax: ``` {.ebnf} -// MLIR doesn't have a tuple type but functions can return multiple values. +// MLIR functions can return multiple values. function-result-type ::= type-list-parens | non-function-type @@ -860,6 +860,34 @@ access pattern analysis, and for performance optimizations like vectorization, copy elision and in-place updates. If an affine map composition is not specified for the memref, the identity affine map is assumed. +#### Tuple Type {#tuple-type} + +Syntax: + +``` {.ebnf} +tuple-type ::= `tuple` `<` (type ( `,` type)*)? `>` +``` + +The value of `tuple` type represents a fixed-size collection of elements, where +each element may be of a different type. + +**Rationale:** Though this type is first class in the type system, MLIR provides +no standard operations for operating on `tuple` types +[rationale](Rationale.md#tuple-type). + +Examples: + +```mlir {.mlir} +// Empty tuple. +tuple<> + +// Single element +tuple + +// Many elements. +tuple, i5> +``` + ## Attributes {#attributes} Syntax: diff --git a/mlir/g3doc/Rationale.md b/mlir/g3doc/Rationale.md index 3e7d068140b4..907cbd9cd7d8 100644 --- a/mlir/g3doc/Rationale.md +++ b/mlir/g3doc/Rationale.md @@ -487,6 +487,20 @@ to think of these types as existing within the namespace of the dialect. If a dialect wishes to assign a canonical name to a type, it can be done via [type aliases](LangRef.md#type-aliases). +### Tuple types {#tuple-type} + +The MLIR type system provides first class support for defining +[tuple types](LangRef.md#tuple-type). This is due to the fact that `Tuple` +represents a universal concept that is likely to, and already has, present +itself in many different dialects. Though this type is first class in the type +system, it merely serves to provide a common mechanism in which to represent +this concept in MLIR. As such, MLIR provides no standard operations for +interfacing with `tuple` types. It is up to dialect authors to provide +operations, e.g. extract_tuple_element, to interpret and manipulate them. When +possible, operations should prefer to use multiple results instead. These +provide a myriad of benefits, such as alleviating any need for tuple-extract +operations that merely get in the way of analysis and transformation. + ### Assembly forms MLIR decides to support both generic and custom assembly forms under the diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 1768314e30d5..78eb7a54d6a0 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -37,6 +37,7 @@ class MemRefType; class VectorType; class RankedTensorType; class UnrankedTensorType; +class TupleType; class BoolAttr; class IntegerAttr; class FloatAttr; @@ -87,6 +88,7 @@ public: VectorType getVectorType(ArrayRef shape, Type elementType); RankedTensorType getTensorType(ArrayRef shape, Type elementType); UnrankedTensorType getTensorType(Type elementType); + TupleType getTupleType(ArrayRef elementTypes); /// Get or construct an instance of the type 'ty' with provided arguments. template Ty getType(Args... args) { diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index e11806c5d00f..93b4f4cf4a1d 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -42,6 +42,7 @@ struct TensorTypeStorage; struct RankedTensorTypeStorage; struct UnrankedTensorTypeStorage; struct MemRefTypeStorage; +struct TupleTypeStorage; } // namespace detail @@ -64,6 +65,7 @@ enum Kind { RankedTensor, UnrankedTensor, MemRef, + Tuple, }; } // namespace StandardTypes @@ -421,6 +423,43 @@ private: unsigned memorySpace, Optional location); }; +/// Tuple types represent a collection of other types. Note: This type merely +/// provides a common mechanism for representing tuples in MLIR. It is up to +/// dialect authors to provides operations for manipulating them, e.g. +/// extract_tuple_element. When possible, users should prefer multi-result +/// operations in the place of tuples. +class TupleType + : public Type::TypeBase { +public: + using Base::Base; + + /// Get or create a new TupleType with the provided element types. Assumes the + /// arguments define a well-formed type. + static TupleType get(ArrayRef elementTypes, MLIRContext *context); + + /// Get or create an empty tuple type. + static TupleType get(MLIRContext *context) { return get({}, context); } + + /// Return the elements types for this tuple. + ArrayRef getTypes() const; + + /// Return the number of held types. + unsigned size() const; + + /// Iterate over the held elements. + using iterator = ArrayRef::iterator; + iterator begin() const { return getTypes().begin(); } + iterator end() const { return getTypes().end(); } + + /// Return the element type at index 'index'. + Type getType(unsigned index) const { + assert(index < size() && "invalid index for tuple type"); + return getTypes()[index]; + } + + static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; } +}; + } // end namespace mlir #endif // MLIR_IR_STANDARDTYPES_H diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h index f6d89f7fd4bf..5698841afd4d 100644 --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -134,6 +134,11 @@ public: // Allocate an instance of the provided type. template T *allocate() { return allocator.Allocate(); } + /// Allocate 'size' bytes of 'alignment' aligned memory. + void *allocate(size_t size, size_t alignment) { + return allocator.Allocate(size, alignment); + } + private: /// The raw allocator for type storage objects. llvm::BumpPtrAllocator allocator; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 4454f69b7fc8..b9ca89dfb730 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -805,6 +805,13 @@ void ModulePrinter::printType(Type type) { os << '>'; return; } + case StandardTypes::Tuple: { + auto tuple = type.cast(); + os << "tuple<"; + interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); }); + os << '>'; + return; + } } } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 56d0ad059fa2..6f1936b951f8 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -96,6 +96,10 @@ UnrankedTensorType Builder::getTensorType(Type elementType) { return UnrankedTensorType::get(elementType); } +TupleType Builder::getTupleType(ArrayRef elementTypes) { + return TupleType::get(elementTypes, context); +} + //===----------------------------------------------------------------------===// // Attributes. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 140dfa6b3eb6..7ea757229bbb 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -105,7 +105,8 @@ namespace { struct BuiltinDialect : public Dialect { BuiltinDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) { addTypes(); + VectorType, RankedTensorType, UnrankedTensorType, MemRefType, + TupleType>(); } }; diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 0d46aa59e055..b9da3b922855 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -366,3 +366,21 @@ unsigned MemRefType::getMemorySpace() const { unsigned MemRefType::getNumDynamicDims() const { return llvm::count_if(getShape(), [](int64_t i) { return i < 0; }); } + +/// TupleType + +/// Get or create a new TupleType with the provided element types. Assumes the +/// arguments define a well-formed type. +TupleType TupleType::get(ArrayRef elementTypes, MLIRContext *context) { + return Base::get(context, StandardTypes::Tuple, elementTypes); +} + +/// Return the elements types for this tuple. +ArrayRef TupleType::getTypes() const { + return static_cast(type)->getTypes(); +} + +/// Return the number of element types. +unsigned TupleType::size() const { + return static_cast(type)->size(); +} diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h index 91762df53d6f..c55f79563343 100644 --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -26,6 +26,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" +#include "llvm/Support/TrailingObjects.h" namespace mlir { @@ -255,6 +256,39 @@ struct MemRefTypeStorage : public TypeStorage { const unsigned memorySpace; }; +/// A type representing a collection of other types. +struct TupleTypeStorage final + : public TypeStorage, + public llvm::TrailingObjects { + using KeyTy = ArrayRef; + + TupleTypeStorage(unsigned numTypes) : TypeStorage(numTypes) {} + + /// Construction. + static TupleTypeStorage *construct(TypeStorageAllocator &allocator, + const ArrayRef &key) { + // Allocate a new storage instance. + auto byteSize = TupleTypeStorage::totalSizeToAlloc(key.size()); + auto rawMem = allocator.allocate(byteSize, alignof(TupleTypeStorage)); + auto result = ::new (rawMem) TupleTypeStorage(key.size()); + + // Copy in the element types into the trailing storage. + std::uninitialized_copy(key.begin(), key.end(), + result->getTrailingObjects()); + return result; + } + + bool operator==(const KeyTy &key) const { return key == getTypes(); } + + /// Return the number of held types. + unsigned size() const { return getSubclassData(); } + + /// Return the held types. + ArrayRef getTypes() const { + return {getTrailingObjects(), size()}; + } +}; + } // namespace detail } // namespace mlir #endif // TYPEDETAIL_H_ diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index d01727395a95..4905beebf908 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -185,6 +185,7 @@ public: bool allowDynamic); Type parseExtendedType(); Type parseTensorType(); + Type parseTupleType(); Type parseMemRefType(); Type parseFunctionType(); Type parseNonFunctionType(); @@ -319,6 +320,7 @@ ParseResult Parser::parseCommaSeparatedListUntil( /// | vector-type /// | tensor-type /// | memref-type +/// | tuple-type /// /// index-type ::= `index` /// float-type ::= `f16` | `bf16` | `f32` | `f64` @@ -331,6 +333,8 @@ Type Parser::parseNonFunctionType() { return parseMemRefType(); case Token::kw_tensor: return parseTensorType(); + case Token::kw_tuple: + return parseTupleType(); case Token::kw_vector: return parseVectorType(); // integer-type @@ -567,6 +571,30 @@ Type Parser::parseTensorType() { return RankedTensorType::getChecked(dimensions, elementType, typeLocation); } +/// Parse a tuple type. +/// +/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` +/// +Type Parser::parseTupleType() { + consumeToken(Token::kw_tuple); + + // Parse the '<'. + if (parseToken(Token::less, "expected '<' in tuple type")) + return nullptr; + + // Check for an empty tuple by directly parsing '>'. + if (consumeIf(Token::greater)) + return TupleType::get(getContext()); + + // Parse the element types and the '>'. + SmallVector types; + if (parseTypeListNoParens(types) || + parseToken(Token::greater, "expected '>' in tuple type")) + return nullptr; + + return TupleType::get(types, getContext()); +} + /// Parse a memref type. /// /// memref-type ::= `memref` `<` dimension-list-ranked element-type diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index ec00f98b3f5c..f58fa9cef41b 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -111,6 +111,7 @@ TOK_KEYWORD(step) TOK_KEYWORD(tensor) TOK_KEYWORD(to) TOK_KEYWORD(true) +TOK_KEYWORD(tuple) TOK_KEYWORD(type) TOK_KEYWORD(sparse) TOK_KEYWORD(vector) diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index cebcdf6cc9e6..8e8bcf44aec6 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -933,3 +933,13 @@ func @$invalid_function_name() // expected-error @+1 {{function arguments may only have dialect attributes}} func @invalid_func_arg_attr(i1 {non_dialect_attr: 10}) + +// ----- + +// expected-error @+1 {{expected '<' in tuple type}} +func @invalid_tuple_missing_less(tuple i32>) + +// ----- + +// expected-error @+1 {{expected '>' in tuple type}} +func @invalid_tuple_missing_greater(tuple) +func @empty_tuple(tuple<>) + +// CHECK-LABEL: func @tuple_single_element(tuple) +func @tuple_single_element(tuple) + +// CHECK-LABEL: func @tuple_multi_element(tuple) +func @tuple_multi_element(tuple) + +// CHECK-LABEL: func @tuple_nested(tuple>>) +func @tuple_nested(tuple>>)