mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-28 01:16:06 +00:00
Add support for a standard TupleType. Though this is a standard type, it merely provides a common mechanism for representing tuples in MLIR. It is up to dialect authors to provides operations for manipulating them, e.g. extract_tuple_element.
TupleType has the following form: tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` Example: // Empty tuple. tuple<> // Single element. tuple<i32> // Multi element. tuple<i32, tuple<f32>, i16> PiperOrigin-RevId: 239226021
This commit is contained in:
parent
57270a9a99
commit
30e68230bd
@ -492,7 +492,7 @@ MLIR provides a first class set of polyhedral operations and analyses within the
|
||||
|
||||
Each SSA value in MLIR has a type defined by the type system below. There are a
|
||||
number of primitive types (like integers) and also aggregate types for tensors
|
||||
and memory buffers. MLIR standard types do not include complex numbers, tuples,
|
||||
and memory buffers. MLIR standard types do not include complex numbers,
|
||||
structures, arrays, or dictionaries.
|
||||
|
||||
MLIR has an open type system (there is no fixed list of types), and types may
|
||||
@ -556,7 +556,7 @@ Builtin types consist of only the types needed for the validity of the IR.
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
// MLIR doesn't have a tuple type but functions can return multiple values.
|
||||
// MLIR functions can return multiple values.
|
||||
function-result-type ::= type-list-parens
|
||||
| non-function-type
|
||||
|
||||
@ -860,6 +860,34 @@ access pattern analysis, and for performance optimizations like vectorization,
|
||||
copy elision and in-place updates. If an affine map composition is not specified
|
||||
for the memref, the identity affine map is assumed.
|
||||
|
||||
#### Tuple Type {#tuple-type}
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
tuple-type ::= `tuple` `<` (type ( `,` type)*)? `>`
|
||||
```
|
||||
|
||||
The value of `tuple` type represents a fixed-size collection of elements, where
|
||||
each element may be of a different type.
|
||||
|
||||
**Rationale:** Though this type is first class in the type system, MLIR provides
|
||||
no standard operations for operating on `tuple` types
|
||||
[rationale](Rationale.md#tuple-type).
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir {.mlir}
|
||||
// Empty tuple.
|
||||
tuple<>
|
||||
|
||||
// Single element
|
||||
tuple<f32>
|
||||
|
||||
// Many elements.
|
||||
tuple<i32, f32, tensor<i1>, i5>
|
||||
```
|
||||
|
||||
## Attributes {#attributes}
|
||||
|
||||
Syntax:
|
||||
|
@ -487,6 +487,20 @@ to think of these types as existing within the namespace of the dialect. If a
|
||||
dialect wishes to assign a canonical name to a type, it can be done via
|
||||
[type aliases](LangRef.md#type-aliases).
|
||||
|
||||
### Tuple types {#tuple-type}
|
||||
|
||||
The MLIR type system provides first class support for defining
|
||||
[tuple types](LangRef.md#tuple-type). This is due to the fact that `Tuple`
|
||||
represents a universal concept that is likely to, and already has, present
|
||||
itself in many different dialects. Though this type is first class in the type
|
||||
system, it merely serves to provide a common mechanism in which to represent
|
||||
this concept in MLIR. As such, MLIR provides no standard operations for
|
||||
interfacing with `tuple` types. It is up to dialect authors to provide
|
||||
operations, e.g. extract_tuple_element, to interpret and manipulate them. When
|
||||
possible, operations should prefer to use multiple results instead. These
|
||||
provide a myriad of benefits, such as alleviating any need for tuple-extract
|
||||
operations that merely get in the way of analysis and transformation.
|
||||
|
||||
### Assembly forms
|
||||
|
||||
MLIR decides to support both generic and custom assembly forms under the
|
||||
|
@ -37,6 +37,7 @@ class MemRefType;
|
||||
class VectorType;
|
||||
class RankedTensorType;
|
||||
class UnrankedTensorType;
|
||||
class TupleType;
|
||||
class BoolAttr;
|
||||
class IntegerAttr;
|
||||
class FloatAttr;
|
||||
@ -87,6 +88,7 @@ public:
|
||||
VectorType getVectorType(ArrayRef<int64_t> shape, Type elementType);
|
||||
RankedTensorType getTensorType(ArrayRef<int64_t> shape, Type elementType);
|
||||
UnrankedTensorType getTensorType(Type elementType);
|
||||
TupleType getTupleType(ArrayRef<Type> elementTypes);
|
||||
|
||||
/// Get or construct an instance of the type 'ty' with provided arguments.
|
||||
template <typename Ty, typename... Args> Ty getType(Args... args) {
|
||||
|
@ -42,6 +42,7 @@ struct TensorTypeStorage;
|
||||
struct RankedTensorTypeStorage;
|
||||
struct UnrankedTensorTypeStorage;
|
||||
struct MemRefTypeStorage;
|
||||
struct TupleTypeStorage;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@ -64,6 +65,7 @@ enum Kind {
|
||||
RankedTensor,
|
||||
UnrankedTensor,
|
||||
MemRef,
|
||||
Tuple,
|
||||
};
|
||||
|
||||
} // namespace StandardTypes
|
||||
@ -421,6 +423,43 @@ private:
|
||||
unsigned memorySpace, Optional<Location> location);
|
||||
};
|
||||
|
||||
/// Tuple types represent a collection of other types. Note: This type merely
|
||||
/// provides a common mechanism for representing tuples in MLIR. It is up to
|
||||
/// dialect authors to provides operations for manipulating them, e.g.
|
||||
/// extract_tuple_element. When possible, users should prefer multi-result
|
||||
/// operations in the place of tuples.
|
||||
class TupleType
|
||||
: public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Get or create a new TupleType with the provided element types. Assumes the
|
||||
/// arguments define a well-formed type.
|
||||
static TupleType get(ArrayRef<Type> elementTypes, MLIRContext *context);
|
||||
|
||||
/// Get or create an empty tuple type.
|
||||
static TupleType get(MLIRContext *context) { return get({}, context); }
|
||||
|
||||
/// Return the elements types for this tuple.
|
||||
ArrayRef<Type> getTypes() const;
|
||||
|
||||
/// Return the number of held types.
|
||||
unsigned size() const;
|
||||
|
||||
/// Iterate over the held elements.
|
||||
using iterator = ArrayRef<Type>::iterator;
|
||||
iterator begin() const { return getTypes().begin(); }
|
||||
iterator end() const { return getTypes().end(); }
|
||||
|
||||
/// Return the element type at index 'index'.
|
||||
Type getType(unsigned index) const {
|
||||
assert(index < size() && "invalid index for tuple type");
|
||||
return getTypes()[index];
|
||||
}
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; }
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_STANDARDTYPES_H
|
||||
|
@ -134,6 +134,11 @@ public:
|
||||
// Allocate an instance of the provided type.
|
||||
template <typename T> T *allocate() { return allocator.Allocate<T>(); }
|
||||
|
||||
/// Allocate 'size' bytes of 'alignment' aligned memory.
|
||||
void *allocate(size_t size, size_t alignment) {
|
||||
return allocator.Allocate(size, alignment);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The raw allocator for type storage objects.
|
||||
llvm::BumpPtrAllocator allocator;
|
||||
|
@ -805,6 +805,13 @@ void ModulePrinter::printType(Type type) {
|
||||
os << '>';
|
||||
return;
|
||||
}
|
||||
case StandardTypes::Tuple: {
|
||||
auto tuple = type.cast<TupleType>();
|
||||
os << "tuple<";
|
||||
interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); });
|
||||
os << '>';
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -96,6 +96,10 @@ UnrankedTensorType Builder::getTensorType(Type elementType) {
|
||||
return UnrankedTensorType::get(elementType);
|
||||
}
|
||||
|
||||
TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) {
|
||||
return TupleType::get(elementTypes, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attributes.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -105,7 +105,8 @@ namespace {
|
||||
struct BuiltinDialect : public Dialect {
|
||||
BuiltinDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) {
|
||||
addTypes<FunctionType, UnknownType, FloatType, IndexType, IntegerType,
|
||||
VectorType, RankedTensorType, UnrankedTensorType, MemRefType>();
|
||||
VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
|
||||
TupleType>();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -366,3 +366,21 @@ unsigned MemRefType::getMemorySpace() const {
|
||||
unsigned MemRefType::getNumDynamicDims() const {
|
||||
return llvm::count_if(getShape(), [](int64_t i) { return i < 0; });
|
||||
}
|
||||
|
||||
/// TupleType
|
||||
|
||||
/// Get or create a new TupleType with the provided element types. Assumes the
|
||||
/// arguments define a well-formed type.
|
||||
TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) {
|
||||
return Base::get(context, StandardTypes::Tuple, elementTypes);
|
||||
}
|
||||
|
||||
/// Return the elements types for this tuple.
|
||||
ArrayRef<Type> TupleType::getTypes() const {
|
||||
return static_cast<ImplType *>(type)->getTypes();
|
||||
}
|
||||
|
||||
/// Return the number of element types.
|
||||
unsigned TupleType::size() const {
|
||||
return static_cast<ImplType *>(type)->size();
|
||||
}
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "llvm/Support/TrailingObjects.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@ -255,6 +256,39 @@ struct MemRefTypeStorage : public TypeStorage {
|
||||
const unsigned memorySpace;
|
||||
};
|
||||
|
||||
/// A type representing a collection of other types.
|
||||
struct TupleTypeStorage final
|
||||
: public TypeStorage,
|
||||
public llvm::TrailingObjects<TupleTypeStorage, Type> {
|
||||
using KeyTy = ArrayRef<Type>;
|
||||
|
||||
TupleTypeStorage(unsigned numTypes) : TypeStorage(numTypes) {}
|
||||
|
||||
/// Construction.
|
||||
static TupleTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const ArrayRef<Type> &key) {
|
||||
// Allocate a new storage instance.
|
||||
auto byteSize = TupleTypeStorage::totalSizeToAlloc<Type>(key.size());
|
||||
auto rawMem = allocator.allocate(byteSize, alignof(TupleTypeStorage));
|
||||
auto result = ::new (rawMem) TupleTypeStorage(key.size());
|
||||
|
||||
// Copy in the element types into the trailing storage.
|
||||
std::uninitialized_copy(key.begin(), key.end(),
|
||||
result->getTrailingObjects<Type>());
|
||||
return result;
|
||||
}
|
||||
|
||||
bool operator==(const KeyTy &key) const { return key == getTypes(); }
|
||||
|
||||
/// Return the number of held types.
|
||||
unsigned size() const { return getSubclassData(); }
|
||||
|
||||
/// Return the held types.
|
||||
ArrayRef<Type> getTypes() const {
|
||||
return {getTrailingObjects<Type>(), size()};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace mlir
|
||||
#endif // TYPEDETAIL_H_
|
||||
|
@ -185,6 +185,7 @@ public:
|
||||
bool allowDynamic);
|
||||
Type parseExtendedType();
|
||||
Type parseTensorType();
|
||||
Type parseTupleType();
|
||||
Type parseMemRefType();
|
||||
Type parseFunctionType();
|
||||
Type parseNonFunctionType();
|
||||
@ -319,6 +320,7 @@ ParseResult Parser::parseCommaSeparatedListUntil(
|
||||
/// | vector-type
|
||||
/// | tensor-type
|
||||
/// | memref-type
|
||||
/// | tuple-type
|
||||
///
|
||||
/// index-type ::= `index`
|
||||
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
|
||||
@ -331,6 +333,8 @@ Type Parser::parseNonFunctionType() {
|
||||
return parseMemRefType();
|
||||
case Token::kw_tensor:
|
||||
return parseTensorType();
|
||||
case Token::kw_tuple:
|
||||
return parseTupleType();
|
||||
case Token::kw_vector:
|
||||
return parseVectorType();
|
||||
// integer-type
|
||||
@ -567,6 +571,30 @@ Type Parser::parseTensorType() {
|
||||
return RankedTensorType::getChecked(dimensions, elementType, typeLocation);
|
||||
}
|
||||
|
||||
/// Parse a tuple type.
|
||||
///
|
||||
/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
|
||||
///
|
||||
Type Parser::parseTupleType() {
|
||||
consumeToken(Token::kw_tuple);
|
||||
|
||||
// Parse the '<'.
|
||||
if (parseToken(Token::less, "expected '<' in tuple type"))
|
||||
return nullptr;
|
||||
|
||||
// Check for an empty tuple by directly parsing '>'.
|
||||
if (consumeIf(Token::greater))
|
||||
return TupleType::get(getContext());
|
||||
|
||||
// Parse the element types and the '>'.
|
||||
SmallVector<Type, 4> types;
|
||||
if (parseTypeListNoParens(types) ||
|
||||
parseToken(Token::greater, "expected '>' in tuple type"))
|
||||
return nullptr;
|
||||
|
||||
return TupleType::get(types, getContext());
|
||||
}
|
||||
|
||||
/// Parse a memref type.
|
||||
///
|
||||
/// memref-type ::= `memref` `<` dimension-list-ranked element-type
|
||||
|
@ -111,6 +111,7 @@ TOK_KEYWORD(step)
|
||||
TOK_KEYWORD(tensor)
|
||||
TOK_KEYWORD(to)
|
||||
TOK_KEYWORD(true)
|
||||
TOK_KEYWORD(tuple)
|
||||
TOK_KEYWORD(type)
|
||||
TOK_KEYWORD(sparse)
|
||||
TOK_KEYWORD(vector)
|
||||
|
@ -933,3 +933,13 @@ func @$invalid_function_name()
|
||||
|
||||
// expected-error @+1 {{function arguments may only have dialect attributes}}
|
||||
func @invalid_func_arg_attr(i1 {non_dialect_attr: 10})
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected '<' in tuple type}}
|
||||
func @invalid_tuple_missing_less(tuple i32>)
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected '>' in tuple type}}
|
||||
func @invalid_tuple_missing_greater(tuple<i32)
|
||||
|
@ -825,3 +825,15 @@ func @external_func_arg_attrs(i32, i1 {dialect.attr: 10}, i32)
|
||||
func @func_arg_attrs(%arg0: i1 {dialect.attr: 10}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @empty_tuple(tuple<>)
|
||||
func @empty_tuple(tuple<>)
|
||||
|
||||
// CHECK-LABEL: func @tuple_single_element(tuple<i32>)
|
||||
func @tuple_single_element(tuple<i32>)
|
||||
|
||||
// CHECK-LABEL: func @tuple_multi_element(tuple<i32, i16, f32>)
|
||||
func @tuple_multi_element(tuple<i32, i16, f32>)
|
||||
|
||||
// CHECK-LABEL: func @tuple_nested(tuple<tuple<tuple<i32>>>)
|
||||
func @tuple_nested(tuple<tuple<tuple<i32>>>)
|
||||
|
Loading…
x
Reference in New Issue
Block a user