mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-18 19:16:43 +00:00
[mlir] Dialect type/attr bytecode read/write generator.
Tool to help generate dialect bytecode Attribute & Type reader/writing. Show usage by flipping builtin dialect. It helps reduce boilerplate when writing dialect bytecode attribute and type readers/writers. It is not an attempt at a generic spec mechanism but rather practically focussing on boilerplate reduction while also considering that it need not be the only in memory format and make it relatively easy to change. There should be some cleanup in follow up as we expand to more dialects. Differential Revision: https://reviews.llvm.org/D144820
This commit is contained in:
parent
8e091b1220
commit
0911558005
@ -345,6 +345,37 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper for resource handle reading that returns LogicalResult.
|
||||
template <typename T, typename... Ts>
|
||||
static LogicalResult readResourceHandle(DialectBytecodeReader &reader,
|
||||
FailureOr<T> &value, Ts &&...params) {
|
||||
FailureOr<T> handle = reader.readResourceHandle<T>();
|
||||
if (failed(handle))
|
||||
return failure();
|
||||
if (auto *result = dyn_cast<T>(&*handle)) {
|
||||
value = std::move(*result);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
/// Helper method that injects context only if needed, this helps unify some of
|
||||
/// the attribute construction methods.
|
||||
template <typename T, typename... Ts>
|
||||
auto get(MLIRContext *context, Ts &&...params) {
|
||||
// Prefer a direct `get` method if one exists.
|
||||
if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) {
|
||||
(void)context;
|
||||
return T::get(std::forward<Ts>(params)...);
|
||||
} else if constexpr (llvm::is_detected<detail::has_get_method, T,
|
||||
MLIRContext *, Ts...>::value) {
|
||||
return T::get(context, std::forward<Ts>(params)...);
|
||||
} else {
|
||||
// Otherwise, pass to the base get.
|
||||
return T::Base::get(context, std::forward<Ts>(params)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
|
||||
|
566
mlir/include/mlir/IR/BuiltinDialectBytecode.td
Normal file
566
mlir/include/mlir/IR/BuiltinDialectBytecode.td
Normal file
@ -0,0 +1,566 @@
|
||||
//===-- BuiltinBytecode.td - Builtin bytecode defs ---------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the Builtin bytecode reader/writer definition file.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef BUILTIN_BYTECODE
|
||||
#define BUILTIN_BYTECODE
|
||||
|
||||
include "mlir/IR/BytecodeBase.td"
|
||||
|
||||
def LocationAttr : AttributeKind;
|
||||
def ShapedType: WithType<"ShapedType", Type>;
|
||||
|
||||
def Location : CompositeBytecode {
|
||||
dag members = (attr
|
||||
WithGetter<"(LocationAttr)$_attrType", WithType<"LocationAttr", LocationAttr>>:$value
|
||||
);
|
||||
let cBuilder = "Location($_args)";
|
||||
}
|
||||
|
||||
def String :
|
||||
WithParser <"succeeded($_reader.readString($_var))",
|
||||
WithBuilder<"$_args",
|
||||
WithPrinter<"$_writer.writeOwnedString($_getter)",
|
||||
WithGetter <"$_attrType",
|
||||
WithType <"StringRef">>>>>;
|
||||
|
||||
// enum AttributeCode {
|
||||
// /// ArrayAttr {
|
||||
// /// elements: Attribute[]
|
||||
// /// }
|
||||
// ///
|
||||
// kArrayAttr = 0,
|
||||
//
|
||||
def ArrayAttr : DialectAttribute<(attr
|
||||
Array<Attribute>:$value
|
||||
)>;
|
||||
|
||||
let cType = "StringAttr" in {
|
||||
// /// StringAttr {
|
||||
// /// value: string
|
||||
// /// }
|
||||
// kStringAttr = 2,
|
||||
def StringAttr : DialectAttribute<(attr
|
||||
String:$value
|
||||
)> {
|
||||
let printerPredicate = "$_val.getType().isa<NoneType>()";
|
||||
}
|
||||
|
||||
// /// StringAttrWithType {
|
||||
// /// value: string,
|
||||
// /// type: Type
|
||||
// /// }
|
||||
// /// A variant of StringAttr with a type.
|
||||
// kStringAttrWithType = 3,
|
||||
def StringAttrWithType : DialectAttribute<(attr
|
||||
String:$value,
|
||||
Type:$type
|
||||
)> { let printerPredicate = "!$_val.getType().isa<NoneType>()"; }
|
||||
}
|
||||
|
||||
// /// DictionaryAttr {
|
||||
// /// attrs: <StringAttr, Attribute>[]
|
||||
// /// }
|
||||
// kDictionaryAttr = 1,
|
||||
def NamedAttribute : CompositeBytecode {
|
||||
dag members = (attr
|
||||
StringAttr:$name,
|
||||
Attribute:$value
|
||||
);
|
||||
let cBuilder = "NamedAttribute($_args)";
|
||||
}
|
||||
def DictionaryAttr : DialectAttribute<(attr
|
||||
Array<NamedAttribute>:$value
|
||||
)>;
|
||||
|
||||
// /// FlatSymbolRefAttr {
|
||||
// /// rootReference: StringAttr
|
||||
// /// }
|
||||
// /// A variant of SymbolRefAttr with no leaf references.
|
||||
// kFlatSymbolRefAttr = 4,
|
||||
def FlatSymbolRefAttr: DialectAttribute<(attr
|
||||
StringAttr:$rootReference
|
||||
)>;
|
||||
|
||||
// /// SymbolRefAttr {
|
||||
// /// rootReference: StringAttr,
|
||||
// /// leafReferences: FlatSymbolRefAttr[]
|
||||
// /// }
|
||||
// kSymbolRefAttr = 5,
|
||||
def SymbolRefAttr: DialectAttribute<(attr
|
||||
StringAttr:$rootReference,
|
||||
Array<FlatSymbolRefAttr>:$nestedReferences
|
||||
)>;
|
||||
|
||||
// /// TypeAttr {
|
||||
// /// value: Type
|
||||
// /// }
|
||||
// kTypeAttr = 6,
|
||||
def TypeAttr: DialectAttribute<(attr
|
||||
Type:$value
|
||||
)>;
|
||||
|
||||
// /// UnitAttr {
|
||||
// /// }
|
||||
// kUnitAttr = 7,
|
||||
def UnitAttr: DialectAttribute<(attr)>;
|
||||
|
||||
// /// IntegerAttr {
|
||||
// /// type: Type
|
||||
// /// value: APInt,
|
||||
// /// }
|
||||
// kIntegerAttr = 8,
|
||||
def IntegerAttr: DialectAttribute<(attr
|
||||
Type:$type,
|
||||
KnownWidthAPInt<"type">:$value
|
||||
)> {
|
||||
let cBuilder = "get<$_resultType>(context, type, *value)";
|
||||
}
|
||||
|
||||
//
|
||||
// /// FloatAttr {
|
||||
// /// type: FloatType
|
||||
// /// value: APFloat
|
||||
// /// }
|
||||
// kFloatAttr = 9,
|
||||
defvar FloatType = Type;
|
||||
def FloatAttr : DialectAttribute<(attr
|
||||
FloatType:$type,
|
||||
KnownSemanticsAPFloat<"type">:$value
|
||||
)> {
|
||||
let cBuilder = "get<$_resultType>(context, type, *value)";
|
||||
}
|
||||
|
||||
// /// CallSiteLoc {
|
||||
// /// callee: LocationAttr,
|
||||
// /// caller: LocationAttr
|
||||
// /// }
|
||||
// kCallSiteLoc = 10,
|
||||
def CallSiteLoc : DialectAttribute<(attr
|
||||
LocationAttr:$callee,
|
||||
LocationAttr:$caller
|
||||
)>;
|
||||
|
||||
// /// FileLineColLoc {
|
||||
// /// filename: StringAttr,
|
||||
// /// line: varint,
|
||||
// /// column: varint
|
||||
// /// }
|
||||
// kFileLineColLoc = 11,
|
||||
def FileLineColLoc : DialectAttribute<(attr
|
||||
StringAttr:$filename,
|
||||
VarInt:$line,
|
||||
VarInt:$column
|
||||
)>;
|
||||
|
||||
let cType = "FusedLoc",
|
||||
cBuilder = "cast<FusedLoc>(get<FusedLoc>(context, $_args))" in {
|
||||
// /// FusedLoc {
|
||||
// /// locations: Location[]
|
||||
// /// }
|
||||
// kFusedLoc = 12,
|
||||
def FusedLoc : DialectAttribute<(attr
|
||||
Array<Location>:$locations
|
||||
)> {
|
||||
let printerPredicate = "!$_val.getMetadata()";
|
||||
}
|
||||
|
||||
// /// FusedLocWithMetadata {
|
||||
// /// locations: LocationAttr[],
|
||||
// /// metadata: Attribute
|
||||
// /// }
|
||||
// /// A variant of FusedLoc with metadata.
|
||||
// kFusedLocWithMetadata = 13,
|
||||
def FusedLocWithMetadata : DialectAttribute<(attr
|
||||
Array<Location>:$locations,
|
||||
Attribute:$metadata
|
||||
)> {
|
||||
let printerPredicate = "$_val.getMetadata()";
|
||||
}
|
||||
}
|
||||
|
||||
// /// NameLoc {
|
||||
// /// name: StringAttr,
|
||||
// /// childLoc: LocationAttr
|
||||
// /// }
|
||||
// kNameLoc = 14,
|
||||
def NameLoc : DialectAttribute<(attr
|
||||
StringAttr:$name,
|
||||
LocationAttr:$childLoc
|
||||
)>;
|
||||
|
||||
// /// UnknownLoc {
|
||||
// /// }
|
||||
// kUnknownLoc = 15,
|
||||
def UnknownLoc : DialectAttribute<(attr)>;
|
||||
|
||||
// /// DenseResourceElementsAttr {
|
||||
// /// type: ShapedType,
|
||||
// /// handle: ResourceHandle
|
||||
// /// }
|
||||
// kDenseResourceElementsAttr = 16,
|
||||
def DenseResourceElementsAttr : DialectAttribute<(attr
|
||||
ShapedType:$type,
|
||||
ResourceHandle<"DenseResourceElementsHandle">:$rawHandle
|
||||
)> {
|
||||
// Note: order of serialization does not match order of builder.
|
||||
let cBuilder = "get<$_resultType>(context, type, *rawHandle)";
|
||||
}
|
||||
|
||||
let cType = "RankedTensorType" in {
|
||||
// /// RankedTensorType {
|
||||
// /// shape: svarint[],
|
||||
// /// elementType: Type,
|
||||
// /// }
|
||||
// ///
|
||||
// kRankedTensorType = 13,
|
||||
def RankedTensorType : DialectType<(type
|
||||
Array<SignedVarInt>:$shape,
|
||||
Type:$elementType
|
||||
)> {
|
||||
let printerPredicate = "!$_val.getEncoding()";
|
||||
}
|
||||
|
||||
// /// RankedTensorTypeWithEncoding {
|
||||
// /// encoding: Attribute,
|
||||
// /// shape: svarint[],
|
||||
// /// elementType: Type
|
||||
// /// }
|
||||
// /// Variant of RankedTensorType with an encoding.
|
||||
// kRankedTensorTypeWithEncoding = 14,
|
||||
def RankedTensorTypeWithEncoding : DialectType<(type
|
||||
Attribute:$encoding,
|
||||
Array<SignedVarInt>:$shape,
|
||||
Type:$elementType
|
||||
)> {
|
||||
let printerPredicate = "$_val.getEncoding()";
|
||||
// Note: order of serialization does not match order of builder.
|
||||
let cBuilder = "get<$_resultType>(context, shape, elementType, encoding)";
|
||||
}
|
||||
}
|
||||
|
||||
// /// DenseArrayAttr {
|
||||
// /// elementType: Type,
|
||||
// /// size: varint
|
||||
// /// data: blob
|
||||
// /// }
|
||||
// kDenseArrayAttr = 17,
|
||||
def DenseArrayAttr : DialectAttribute<(attr
|
||||
Type:$elementType,
|
||||
VarInt:$size,
|
||||
Blob:$rawData
|
||||
)>;
|
||||
|
||||
// /// DenseIntOrFPElementsAttr {
|
||||
// /// type: ShapedType,
|
||||
// /// data: blob
|
||||
// /// }
|
||||
// kDenseIntOrFPElementsAttr = 18,
|
||||
def DenseElementsAttr : WithType<"DenseIntElementsAttr", Attribute>;
|
||||
def DenseIntOrFPElementsAttr : DialectAttribute<(attr
|
||||
ShapedType:$type,
|
||||
Blob:$rawData
|
||||
)> {
|
||||
let cBuilder = "cast<$_resultType>($_resultType::getFromRawBuffer($_args))";
|
||||
}
|
||||
|
||||
// /// DenseStringElementsAttr {
|
||||
// /// type: ShapedType,
|
||||
// /// isSplat: varint,
|
||||
// /// data: string[]
|
||||
// /// }
|
||||
// kDenseStringElementsAttr = 19,
|
||||
def DenseStringElementsAttr : DialectAttribute<(attr
|
||||
ShapedType:$type,
|
||||
WithGetter<"$_attrType.isSplat()", VarInt>:$_isSplat,
|
||||
WithBuilder<"$_args",
|
||||
WithType<"SmallVector<StringRef>",
|
||||
WithParser <"succeeded(readPotentiallySplatString($_reader, type, _isSplat, $_var))",
|
||||
WithPrinter<"writePotentiallySplatString($_writer, $_name)">>>>:$rawStringData
|
||||
)>;
|
||||
|
||||
// /// SparseElementsAttr {
|
||||
// /// type: ShapedType,
|
||||
// /// indices: DenseIntElementsAttr,
|
||||
// /// values: DenseElementsAttr
|
||||
// /// }
|
||||
// kSparseElementsAttr = 20,
|
||||
def DenseIntElementsAttr : WithType<"DenseIntElementsAttr", Attribute>;
|
||||
def SparseElementsAttr : DialectAttribute<(attr
|
||||
ShapedType:$type,
|
||||
DenseIntElementsAttr:$indices,
|
||||
DenseElementsAttr:$values
|
||||
)>;
|
||||
|
||||
// Types
|
||||
// -----
|
||||
|
||||
// enum TypeCode {
|
||||
// /// IntegerType {
|
||||
// /// widthAndSignedness: varint // (width << 2) | (signedness)
|
||||
// /// }
|
||||
// ///
|
||||
// kIntegerType = 0,
|
||||
def IntegerType : DialectType<(type
|
||||
// Yes not pretty,
|
||||
WithParser<"succeeded($_reader.readVarInt($_var))",
|
||||
WithBuilder<"$_args",
|
||||
WithPrinter<"$_writer.writeVarInt($_name.getWidth() << 2 | $_name.getSignedness())",
|
||||
WithType <"uint64_t">>>>:$_widthAndSignedness,
|
||||
// Split up parsed varint for create method.
|
||||
LocalVar<"uint64_t", "_widthAndSignedness >> 2">:$width,
|
||||
LocalVar<"IntegerType::SignednessSemantics",
|
||||
"static_cast<IntegerType::SignednessSemantics>(_widthAndSignedness & 0x3)">:$signedness
|
||||
)>;
|
||||
|
||||
//
|
||||
// /// IndexType {
|
||||
// /// }
|
||||
// ///
|
||||
// kIndexType = 1,
|
||||
def IndexType : DialectType<(type)>;
|
||||
|
||||
// /// FunctionType {
|
||||
// /// inputs: Type[],
|
||||
// /// results: Type[]
|
||||
// /// }
|
||||
// ///
|
||||
// kFunctionType = 2,
|
||||
def FunctionType : DialectType<(type
|
||||
Array<Type>:$inputs,
|
||||
Array<Type>:$results
|
||||
)>;
|
||||
|
||||
// /// BFloat16Type {
|
||||
// /// }
|
||||
// ///
|
||||
// kBFloat16Type = 3,
|
||||
def BFloat16Type : DialectType<(type)>;
|
||||
|
||||
// /// Float16Type {
|
||||
// /// }
|
||||
// ///
|
||||
// kFloat16Type = 4,
|
||||
def Float16Type : DialectType<(type)>;
|
||||
|
||||
// /// Float32Type {
|
||||
// /// }
|
||||
// ///
|
||||
// kFloat32Type = 5,
|
||||
def Float32Type : DialectType<(type)>;
|
||||
|
||||
// /// Float64Type {
|
||||
// /// }
|
||||
// ///
|
||||
// kFloat64Type = 6,
|
||||
def Float64Type : DialectType<(type)>;
|
||||
|
||||
// /// Float80Type {
|
||||
// /// }
|
||||
// ///
|
||||
// kFloat80Type = 7,
|
||||
def Float80Type : DialectType<(type)>;
|
||||
|
||||
// /// Float128Type {
|
||||
// /// }
|
||||
// ///
|
||||
// kFloat128Type = 8,
|
||||
def Float128Type : DialectType<(type)>;
|
||||
|
||||
// /// ComplexType {
|
||||
// /// elementType: Type
|
||||
// /// }
|
||||
// ///
|
||||
// kComplexType = 9,
|
||||
def ComplexType : DialectType<(type
|
||||
Type:$elementType
|
||||
)>;
|
||||
|
||||
def MemRefLayout: WithType<"MemRefLayoutAttrInterface", Attribute>;
|
||||
|
||||
let cType = "MemRefType" in {
|
||||
// /// MemRefType {
|
||||
// /// shape: svarint[],
|
||||
// /// elementType: Type,
|
||||
// /// layout: Attribute
|
||||
// /// }
|
||||
// ///
|
||||
// kMemRefType = 10,
|
||||
def MemRefType : DialectType<(type
|
||||
Array<SignedVarInt>:$shape,
|
||||
Type:$elementType,
|
||||
MemRefLayout:$layout
|
||||
)> {
|
||||
let printerPredicate = "!$_val.getMemorySpace()";
|
||||
}
|
||||
|
||||
// /// MemRefTypeWithMemSpace {
|
||||
// /// memorySpace: Attribute,
|
||||
// /// shape: svarint[],
|
||||
// /// elementType: Type,
|
||||
// /// layout: Attribute
|
||||
// /// }
|
||||
// /// Variant of MemRefType with non-default memory space.
|
||||
// kMemRefTypeWithMemSpace = 11,
|
||||
def MemRefTypeWithMemSpace : DialectType<(type
|
||||
Attribute:$memorySpace,
|
||||
Array<SignedVarInt>:$shape,
|
||||
Type:$elementType,
|
||||
MemRefLayout:$layout
|
||||
)> {
|
||||
let printerPredicate = "!!$_val.getMemorySpace()";
|
||||
// Note: order of serialization does not match order of builder.
|
||||
let cBuilder = "get<$_resultType>(context, shape, elementType, layout, memorySpace)";
|
||||
}
|
||||
}
|
||||
|
||||
// /// NoneType {
|
||||
// /// }
|
||||
// ///
|
||||
// kNoneType = 12,
|
||||
def NoneType : DialectType<(type)>;
|
||||
|
||||
// /// TupleType {
|
||||
// /// elementTypes: Type[]
|
||||
// /// }
|
||||
// kTupleType = 15,
|
||||
def TupleType : DialectType<(type
|
||||
Array<Type>:$types
|
||||
)>;
|
||||
|
||||
let cType = "UnrankedMemRefType" in {
|
||||
// /// UnrankedMemRefType {
|
||||
// /// elementType: Type
|
||||
// /// }
|
||||
// ///
|
||||
// kUnrankedMemRefType = 16,
|
||||
def UnrankedMemRefType : DialectType<(type
|
||||
Type:$elementType
|
||||
)> {
|
||||
let printerPredicate = "!$_val.getMemorySpace()";
|
||||
let cBuilder = "get<$_resultType>(context, elementType, Attribute())";
|
||||
}
|
||||
|
||||
// /// UnrankedMemRefTypeWithMemSpace {
|
||||
// /// memorySpace: Attribute,
|
||||
// /// elementType: Type
|
||||
// /// }
|
||||
// /// Variant of UnrankedMemRefType with non-default memory space.
|
||||
// kUnrankedMemRefTypeWithMemSpace = 17,
|
||||
def UnrankedMemRefTypeWithMemSpace : DialectType<(type
|
||||
Attribute:$memorySpace,
|
||||
Type:$elementType
|
||||
)> {
|
||||
let printerPredicate = "$_val.getMemorySpace()";
|
||||
// Note: order of serialization does not match order of builder.
|
||||
let cBuilder = "get<$_resultType>(context, elementType, memorySpace)";
|
||||
}
|
||||
}
|
||||
|
||||
// /// UnrankedTensorType {
|
||||
// /// elementType: Type
|
||||
// /// }
|
||||
// ///
|
||||
// kUnrankedTensorType = 18,
|
||||
def UnrankedTensorType : DialectType<(type
|
||||
Type:$elementType
|
||||
)>;
|
||||
|
||||
let cType = "VectorType" in {
|
||||
// /// VectorType {
|
||||
// /// shape: svarint[],
|
||||
// /// elementType: Type
|
||||
// /// }
|
||||
// ///
|
||||
// kVectorType = 19,
|
||||
def VectorType : DialectType<(type
|
||||
Array<SignedVarInt>:$shape,
|
||||
Type:$elementType
|
||||
)> {
|
||||
let printerPredicate = "!$_val.getNumScalableDims()";
|
||||
}
|
||||
|
||||
// /// VectorTypeWithScalableDims {
|
||||
// /// numScalableDims: varint,
|
||||
// /// shape: svarint[],
|
||||
// /// elementType: Type
|
||||
// /// }
|
||||
// /// Variant of VectorType with scalable dimensions.
|
||||
// kVectorTypeWithScalableDims = 20,
|
||||
def VectorTypeWithScalableDims : DialectType<(type
|
||||
VarInt:$numScalableDims,
|
||||
Array<SignedVarInt>:$shape,
|
||||
Type:$elementType
|
||||
)> {
|
||||
let printerPredicate = "$_val.getNumScalableDims()";
|
||||
// Note: order of serialization does not match order of builder.
|
||||
let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims)";
|
||||
}
|
||||
}
|
||||
|
||||
/// This enum contains marker codes used to indicate which attribute is
|
||||
/// currently being decoded, and how it should be decoded. The order of these
|
||||
/// codes should generally be unchanged, as any changes will inevitably break
|
||||
/// compatibility with older bytecode.
|
||||
|
||||
def BuiltinDialectAttributes : DialectAttributes<"Builtin"> {
|
||||
let elems = [
|
||||
ArrayAttr,
|
||||
DictionaryAttr,
|
||||
StringAttr,
|
||||
StringAttrWithType,
|
||||
FlatSymbolRefAttr,
|
||||
SymbolRefAttr,
|
||||
TypeAttr,
|
||||
UnitAttr,
|
||||
IntegerAttr,
|
||||
FloatAttr,
|
||||
CallSiteLoc,
|
||||
FileLineColLoc,
|
||||
FusedLoc,
|
||||
FusedLocWithMetadata,
|
||||
NameLoc,
|
||||
UnknownLoc,
|
||||
DenseResourceElementsAttr,
|
||||
DenseArrayAttr,
|
||||
DenseIntOrFPElementsAttr,
|
||||
DenseStringElementsAttr,
|
||||
SparseElementsAttr
|
||||
];
|
||||
}
|
||||
|
||||
def BuiltinDialectTypes : DialectTypes<"Builtin"> {
|
||||
let elems = [
|
||||
IntegerType,
|
||||
IndexType,
|
||||
FunctionType,
|
||||
BFloat16Type,
|
||||
Float16Type,
|
||||
Float32Type,
|
||||
Float64Type,
|
||||
Float80Type,
|
||||
Float128Type,
|
||||
ComplexType,
|
||||
MemRefType,
|
||||
MemRefTypeWithMemSpace,
|
||||
NoneType,
|
||||
RankedTensorType,
|
||||
RankedTensorTypeWithEncoding,
|
||||
TupleType,
|
||||
UnrankedMemRefType,
|
||||
UnrankedMemRefTypeWithMemSpace,
|
||||
UnrankedTensorType,
|
||||
VectorType,
|
||||
VectorTypeWithScalableDims
|
||||
];
|
||||
}
|
||||
|
||||
#endif // BUILTIN_BYTECODE
|
159
mlir/include/mlir/IR/BytecodeBase.td
Normal file
159
mlir/include/mlir/IR/BytecodeBase.td
Normal file
@ -0,0 +1,159 @@
|
||||
//===-- BytecodeBase.td - Base bytecode R/W defs -----------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the base bytecode reader/writer definition file.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef BYTECODE_BASE
|
||||
#define BYTECODE_BASE
|
||||
|
||||
class Bytecode<string parse="", string build="", string print="", string get="", string t=""> {
|
||||
// Template for parsing.
|
||||
// $_reader == dialect bytecode reader
|
||||
// $_resultType == result type of parsed instance
|
||||
// $_var == variable being parsed
|
||||
// If parser is not specified, then the parse of members is used.
|
||||
string cParser = parse;
|
||||
|
||||
// Template for building from parsed.
|
||||
// $_resultType == result type of parsed instance
|
||||
// $_args == args/members comma separated
|
||||
string cBuilder = build;
|
||||
|
||||
// Template for printing.
|
||||
// $_writer == dialect bytecode writer
|
||||
// $_name == parent attribute/type name
|
||||
// $_getter == getter
|
||||
string cPrinter = print;
|
||||
|
||||
// Template for getter from in memory form.
|
||||
// $_attrType == attribute/type
|
||||
// $_member == member instance
|
||||
// $_getMember == get + UpperCamelFromSnake($_member)
|
||||
string cGetter = get;
|
||||
|
||||
// Type built.
|
||||
// Note: if cType is empty, then name of def is used.
|
||||
string cType = t;
|
||||
|
||||
// Predicate guarding parse method as an Attribute/Type could have multiple
|
||||
// parse methods, specify predicates to be orthogonal and cover entire
|
||||
// "print space" to avoid order dependence.
|
||||
// If empty then method is unconditional.
|
||||
// $_val == predicate function to apply on value dyn_casted to cType.
|
||||
string printerPredicate = "";
|
||||
}
|
||||
|
||||
class WithParser<string p="", Bytecode t=Bytecode<>> :
|
||||
Bytecode<p, t.cBuilder, t.cPrinter, t.cGetter, t.cType>;
|
||||
class WithBuilder<string b="", Bytecode t=Bytecode<>> :
|
||||
Bytecode<t.cParser, b, t.cPrinter, t.cGetter, t.cType>;
|
||||
class WithPrinter<string p="", Bytecode t=Bytecode<>> :
|
||||
Bytecode<t.cParser, t.cBuilder, p, t.cGetter, t.cType>;
|
||||
class WithType<string ty="", Bytecode t=Bytecode<>> :
|
||||
Bytecode<t.cParser, t.cBuilder, t.cPrinter, t.cGetter, ty>;
|
||||
class WithGetter<string g="", Bytecode t=Bytecode<>> :
|
||||
Bytecode<t.cParser, t.cBuilder, t.cPrinter, g, t.cType>;
|
||||
|
||||
class CompositeBytecode<string t = ""> : WithType<t>;
|
||||
|
||||
class AttributeKind :
|
||||
WithParser <"succeeded($_reader.readAttribute($_var))",
|
||||
WithBuilder<"$_args",
|
||||
WithPrinter<"$_writer.writeAttribute($_getter)">>>;
|
||||
def Attribute : AttributeKind;
|
||||
class TypeKind :
|
||||
WithParser <"succeeded($_reader.readType($_var))",
|
||||
WithBuilder<"$_args",
|
||||
WithPrinter<"$_writer.writeType($_getter)">>>;
|
||||
def Type : TypeKind;
|
||||
def VarInt :
|
||||
WithParser <"succeeded($_reader.readVarInt($_var))",
|
||||
WithBuilder<"$_args",
|
||||
WithPrinter<"$_writer.writeVarInt($_getter)",
|
||||
WithType <"uint64_t">>>>;
|
||||
def SignedVarInt :
|
||||
WithParser <"succeeded($_reader.readSignedVarInt($_var))",
|
||||
WithBuilder<"$_args",
|
||||
WithPrinter<"$_writer.writeSignedVarInt($_getter)",
|
||||
WithGetter<"$_attrType",
|
||||
WithType <"int64_t">>>>>;
|
||||
def Blob :
|
||||
WithParser <"succeeded($_reader.readBlob($_var))",
|
||||
WithBuilder<"$_args",
|
||||
WithPrinter<"$_writer.writeOwnedBlob($_getter)",
|
||||
WithType <"ArrayRef<char>">>>>;
|
||||
|
||||
class KnownWidthAPInt<string s> :
|
||||
WithParser <"succeeded(readAPIntWithKnownWidth($_reader, " # s # ", $_var))",
|
||||
WithBuilder<"$_args",
|
||||
WithPrinter<"$_writer.writeAPIntWithKnownWidth($_getter)",
|
||||
WithType <"FailureOr<APInt>">>>>;
|
||||
class KnownSemanticsAPFloat<string s> :
|
||||
WithParser <"succeeded(readAPFloatWithKnownSemantics($_reader, " # s # ", $_var))",
|
||||
WithBuilder<"$_args",
|
||||
WithPrinter<"$_writer.writeAPFloatWithKnownSemantics($_getter)",
|
||||
WithType <"FailureOr<APFloat>">>>>;
|
||||
class ResourceHandle<string s> :
|
||||
WithParser <"succeeded(readResourceHandle<" # s # ">($_reader, $_var))",
|
||||
WithBuilder<"$_args",
|
||||
WithPrinter<"$_writer.writeResourceHandle($_getter)",
|
||||
WithType <"FailureOr<" # s # ">">>>>;
|
||||
|
||||
// Helper to define variable that is defined later but not parsed nor printed.
|
||||
class LocalVar<string t, string d> :
|
||||
WithParser <"(($_var = " # d # "), true)",
|
||||
WithBuilder<"$_args",
|
||||
WithPrinter<"",
|
||||
WithType <t>>>>;
|
||||
|
||||
// Array instances.
|
||||
class Array<Bytecode t> {
|
||||
Bytecode elemT = t;
|
||||
|
||||
string cBuilder = "$_args";
|
||||
}
|
||||
|
||||
// Define dialect attribute or type.
|
||||
class DialectAttrOrType<dag d> {
|
||||
// Any members starting with underscore is not fed to create function but
|
||||
// treated as purely local variable.
|
||||
dag members = d;
|
||||
|
||||
// When needing to specify a custom return type.
|
||||
string cType = "";
|
||||
|
||||
// Any post-processing that needs to be done.
|
||||
code postProcess = "";
|
||||
}
|
||||
|
||||
class DialectAttribute<dag d> : DialectAttrOrType<d>, AttributeKind {
|
||||
let cParser = "succeeded($_reader.readAttribute<$_resultType>($_var))";
|
||||
let cBuilder = "get<$_resultType>(context, $_args)";
|
||||
}
|
||||
class DialectType<dag d> : DialectAttrOrType<d>, TypeKind {
|
||||
let cParser = "succeeded($_reader.readType<$_resultType>($_var))";
|
||||
let cBuilder = "get<$_resultType>(context, $_args)";
|
||||
}
|
||||
|
||||
class DialectAttributes<string d> {
|
||||
string dialect = d;
|
||||
list<DialectAttrOrType> elems;
|
||||
}
|
||||
|
||||
class DialectTypes<string d> {
|
||||
string dialect = d;
|
||||
list<DialectAttrOrType> elems;
|
||||
}
|
||||
|
||||
def attr;
|
||||
def type;
|
||||
|
||||
#endif // BYTECODE_BASE
|
||||
|
@ -17,6 +17,10 @@ mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(BuiltinDialect.cpp.inc -gen-dialect-defs)
|
||||
add_public_tablegen_target(MLIRBuiltinDialectIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS BuiltinDialectBytecode.td)
|
||||
mlir_tablegen(BuiltinDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Builtin")
|
||||
add_public_tablegen_target(MLIRBuiltinDialectBytecodeIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS BuiltinLocationAttributes.td)
|
||||
mlir_tablegen(BuiltinLocationAttributes.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(BuiltinLocationAttributes.cpp.inc -gen-attrdef-defs)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -44,6 +44,7 @@ add_mlir_library(MLIRIR
|
||||
DEPENDS
|
||||
MLIRBuiltinAttributesIncGen
|
||||
MLIRBuiltinAttributeInterfacesIncGen
|
||||
MLIRBuiltinDialectBytecodeIncGen
|
||||
MLIRBuiltinDialectIncGen
|
||||
MLIRBuiltinLocationAttributesIncGen
|
||||
MLIRBuiltinOpsIncGen
|
||||
|
467
mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
Normal file
467
mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
Normal file
@ -0,0 +1,467 @@
|
||||
//===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/IndentedOstream.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include <regex>
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
|
||||
static llvm::cl::opt<std::string>
|
||||
selectedBcDialect("bytecode-dialect",
|
||||
llvm::cl::desc("The dialect to gen for"),
|
||||
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
|
||||
|
||||
namespace {
|
||||
|
||||
/// Helper class to generate C++ bytecode parser helpers.
|
||||
class Generator {
|
||||
public:
|
||||
Generator(raw_ostream &output) : output(output) {}
|
||||
|
||||
/// Returns whether successfully emitted attribute/type parsers.
|
||||
void emitParse(StringRef kind, Record &x);
|
||||
|
||||
/// Returns whether successfully emitted attribute/type printers.
|
||||
void emitPrint(StringRef kind, StringRef type,
|
||||
ArrayRef<std::pair<int64_t, Record *>> vec);
|
||||
|
||||
/// Emits parse dispatch table.
|
||||
void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec);
|
||||
|
||||
/// Emits print dispatch table.
|
||||
void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
|
||||
|
||||
private:
|
||||
/// Emits parse calls to construct given kind.
|
||||
void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
|
||||
ArrayRef<Init *> args, ArrayRef<std::string> argNames,
|
||||
StringRef failure, mlir::raw_indented_ostream &ios);
|
||||
|
||||
/// Emits print instructions.
|
||||
void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent,
|
||||
StringRef name, mlir::raw_indented_ostream &ios);
|
||||
|
||||
raw_ostream &output;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Helper to replace set of from strings to target in `s`.
|
||||
/// Assumed: non-overlapping replacements.
|
||||
static std::string format(StringRef templ,
|
||||
std::map<std::string, std::string> &&map) {
|
||||
std::string s = templ.str();
|
||||
for (const auto &[from, to] : map)
|
||||
// All replacements start with $, don't treat as anchor.
|
||||
s = std::regex_replace(s, std::regex("\\" + from), to);
|
||||
return s;
|
||||
}
|
||||
|
||||
/// Return string with first character capitalized.
|
||||
static std::string capitalize(StringRef str) {
|
||||
return ((Twine)toUpper(str[0]) + str.drop_front()).str();
|
||||
}
|
||||
|
||||
/// Return the C++ type for the given record.
|
||||
static std::string getCType(Record *def) {
|
||||
std::string format = "{0}";
|
||||
if (def->isSubClassOf("Array")) {
|
||||
def = def->getValueAsDef("elemT");
|
||||
format = "SmallVector<{0}>";
|
||||
}
|
||||
|
||||
StringRef cType = def->getValueAsString("cType");
|
||||
if (cType.empty()) {
|
||||
if (def->isAnonymous())
|
||||
PrintFatalError(def->getLoc(), "Unable to determine cType");
|
||||
|
||||
return formatv(format.c_str(), def->getName().str());
|
||||
}
|
||||
return formatv(format.c_str(), cType.str());
|
||||
}
|
||||
|
||||
void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
|
||||
mlir::raw_indented_ostream os(output);
|
||||
char const *head =
|
||||
R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
|
||||
os << formatv(head, capitalize(kind));
|
||||
auto funScope = os.scope(" {\n", "}\n\n");
|
||||
|
||||
os << "uint64_t kind;\n";
|
||||
os << "if (failed(reader.readVarInt(kind)))\n"
|
||||
<< " return " << capitalize(kind) << "();\n";
|
||||
os << "switch (kind) ";
|
||||
{
|
||||
auto switchScope = os.scope("{\n", "}\n");
|
||||
for (const auto &it : llvm::enumerate(vec)) {
|
||||
os << formatv("case {1}:\n return read{0}(context, reader);\n",
|
||||
it.value()->getName(), it.index());
|
||||
}
|
||||
os << "default:\n"
|
||||
<< " reader.emitError() << \"unknown attribute code: \" "
|
||||
<< "<< kind;\n"
|
||||
<< " return " << capitalize(kind) << "();\n";
|
||||
}
|
||||
os << "return " << capitalize(kind) << "();\n";
|
||||
}
|
||||
|
||||
void Generator::emitParse(StringRef kind, Record &x) {
|
||||
char const *head =
|
||||
R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
|
||||
mlir::raw_indented_ostream os(output);
|
||||
std::string returnType = getCType(&x);
|
||||
os << formatv(head, returnType, x.getName());
|
||||
DagInit *members = x.getValueAsDag("members");
|
||||
SmallVector<std::string> argNames =
|
||||
llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
|
||||
return init->getAsUnquotedString();
|
||||
}));
|
||||
StringRef builder = x.getValueAsString("cBuilder");
|
||||
emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
|
||||
returnType + "()", os);
|
||||
os << "\n\n";
|
||||
}
|
||||
|
||||
void printParseConditional(mlir::raw_indented_ostream &ios,
|
||||
ArrayRef<Init *> args,
|
||||
ArrayRef<std::string> argNames) {
|
||||
ios << "if ";
|
||||
auto parenScope = ios.scope("(", ") {");
|
||||
ios.indent();
|
||||
|
||||
auto listHelperName = [](StringRef name) {
|
||||
return formatv("read{0}", capitalize(name));
|
||||
};
|
||||
|
||||
auto parsedArgs =
|
||||
llvm::to_vector(make_filter_range(args, [](Init *const attr) {
|
||||
Record *def = cast<DefInit>(attr)->getDef();
|
||||
if (def->isSubClassOf("Array"))
|
||||
return true;
|
||||
return !def->getValueAsString("cParser").empty();
|
||||
}));
|
||||
|
||||
interleave(
|
||||
zip(parsedArgs, argNames),
|
||||
[&](std::tuple<llvm::Init *&, const std::string &> it) {
|
||||
Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
|
||||
std::string parser;
|
||||
if (auto optParser = attr->getValueAsOptionalString("cParser")) {
|
||||
parser = *optParser;
|
||||
} else if (attr->isSubClassOf("Array")) {
|
||||
Record *def = attr->getValueAsDef("elemT");
|
||||
bool composite = def->isSubClassOf("CompositeBytecode");
|
||||
if (!composite && def->isSubClassOf("AttributeKind"))
|
||||
parser = "succeeded($_reader.readAttributes($_var))";
|
||||
else if (!composite && def->isSubClassOf("TypeKind"))
|
||||
parser = "succeeded($_reader.readTypes($_var))";
|
||||
else
|
||||
parser = ("succeeded($_reader.readList($_var, " +
|
||||
listHelperName(std::get<1>(it)) + "))")
|
||||
.str();
|
||||
} else {
|
||||
PrintFatalError(attr->getLoc(), "No parser specified");
|
||||
}
|
||||
std::string type = getCType(attr);
|
||||
ios << format(parser, {{"$_reader", "reader"},
|
||||
{"$_resultType", type},
|
||||
{"$_var", std::get<1>(it)}});
|
||||
},
|
||||
[&]() { ios << " &&\n"; });
|
||||
}
|
||||
|
||||
void Generator::emitParseHelper(StringRef kind, StringRef returnType,
|
||||
StringRef builder, ArrayRef<Init *> args,
|
||||
ArrayRef<std::string> argNames,
|
||||
StringRef failure,
|
||||
mlir::raw_indented_ostream &ios) {
|
||||
auto funScope = ios.scope("{\n", "}");
|
||||
|
||||
if (args.empty()) {
|
||||
ios << formatv("return get<{0}>(context);\n", returnType);
|
||||
return;
|
||||
}
|
||||
|
||||
// Print decls.
|
||||
std::string lastCType = "";
|
||||
for (auto [arg, name] : zip(args, argNames)) {
|
||||
DefInit *first = dyn_cast<DefInit>(arg);
|
||||
if (!first)
|
||||
PrintFatalError("Unexpected type for " + name);
|
||||
Record *def = first->getDef();
|
||||
|
||||
// Create variable decls, if there are a block of same type then create
|
||||
// comma separated list of them.
|
||||
std::string cType = getCType(def);
|
||||
if (lastCType == cType) {
|
||||
ios << ", ";
|
||||
} else {
|
||||
if (!lastCType.empty())
|
||||
ios << ";\n";
|
||||
ios << cType << " ";
|
||||
}
|
||||
ios << name;
|
||||
lastCType = cType;
|
||||
}
|
||||
ios << ";\n";
|
||||
|
||||
// Returns the name of the helper used in list parsing. E.g., the name of the
|
||||
// lambda passed to array parsing.
|
||||
auto listHelperName = [](StringRef name) {
|
||||
return formatv("read{0}", capitalize(name));
|
||||
};
|
||||
|
||||
// Emit list helper functions.
|
||||
for (auto [arg, name] : zip(args, argNames)) {
|
||||
Record *attr = cast<DefInit>(arg)->getDef();
|
||||
if (!attr->isSubClassOf("Array"))
|
||||
continue;
|
||||
|
||||
// TODO: Dedupe readers.
|
||||
Record *def = attr->getValueAsDef("elemT");
|
||||
if (!def->isSubClassOf("CompositeBytecode") &&
|
||||
(def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind")))
|
||||
continue;
|
||||
|
||||
std::string returnType = getCType(def);
|
||||
ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
|
||||
<< returnType << "> ";
|
||||
SmallVector<Init *> args;
|
||||
SmallVector<std::string> argNames;
|
||||
if (def->isSubClassOf("CompositeBytecode")) {
|
||||
DagInit *members = def->getValueAsDag("members");
|
||||
args = llvm::to_vector(members->getArgs());
|
||||
argNames = llvm::to_vector(
|
||||
map_range(members->getArgNames(), [](StringInit *init) {
|
||||
return init->getAsUnquotedString();
|
||||
}));
|
||||
} else {
|
||||
args = {def->getDefInit()};
|
||||
argNames = {"temp"};
|
||||
}
|
||||
StringRef builder = def->getValueAsString("cBuilder");
|
||||
emitParseHelper(kind, returnType, builder, args, argNames, "failure()",
|
||||
ios);
|
||||
ios << ";\n";
|
||||
}
|
||||
|
||||
// Print parse conditional.
|
||||
printParseConditional(ios, args, argNames);
|
||||
|
||||
// Compute args to pass to create method.
|
||||
auto passedArgs = llvm::to_vector(make_filter_range(
|
||||
argNames, [](StringRef str) { return !str.starts_with("_"); }));
|
||||
std::string argStr;
|
||||
raw_string_ostream argStream(argStr);
|
||||
interleaveComma(passedArgs, argStream,
|
||||
[&](const std::string &str) { argStream << str; });
|
||||
// Return the invoked constructor.
|
||||
ios << "\nreturn "
|
||||
<< format(builder, {{"$_resultType", returnType.str()},
|
||||
{"$_args", argStream.str()}})
|
||||
<< ";\n";
|
||||
ios.unindent();
|
||||
|
||||
// TODO: Emit error in debug.
|
||||
// This assumes the result types in error case can always be empty
|
||||
// constructed.
|
||||
ios << "}\nreturn " << failure << ";\n";
|
||||
}
|
||||
|
||||
void Generator::emitPrint(StringRef kind, StringRef type,
|
||||
ArrayRef<std::pair<int64_t, Record *>> vec) {
|
||||
char const *head =
|
||||
R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
|
||||
mlir::raw_indented_ostream os(output);
|
||||
os << formatv(head, type, kind);
|
||||
auto funScope = os.scope("{\n", "}\n\n");
|
||||
|
||||
// Check that predicates specified if multiple bytecode instances.
|
||||
for (llvm::Record *rec : make_second_range(vec)) {
|
||||
StringRef pred = rec->getValueAsString("printerPredicate");
|
||||
if (vec.size() > 1 && pred.empty()) {
|
||||
for (auto [index, rec] : vec) {
|
||||
(void)index;
|
||||
StringRef pred = rec->getValueAsString("printerPredicate");
|
||||
if (vec.size() > 1 && pred.empty())
|
||||
PrintError(rec->getLoc(),
|
||||
"Requires parsing predicate given common cType");
|
||||
}
|
||||
PrintFatalError("Unspecified for shared cType " + type);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto [index, rec] : vec) {
|
||||
StringRef pred = rec->getValueAsString("printerPredicate");
|
||||
if (!pred.empty()) {
|
||||
os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n";
|
||||
os.indent();
|
||||
}
|
||||
|
||||
os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
|
||||
<< ");\n";
|
||||
|
||||
auto *members = rec->getValueAsDag("members");
|
||||
for (auto [arg, name] :
|
||||
llvm::zip(members->getArgs(), members->getArgNames())) {
|
||||
DefInit *def = dyn_cast<DefInit>(arg);
|
||||
assert(def);
|
||||
Record *memberRec = def->getDef();
|
||||
emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
|
||||
}
|
||||
|
||||
if (!pred.empty()) {
|
||||
os.unindent();
|
||||
os << "}\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
|
||||
StringRef parent, StringRef name,
|
||||
mlir::raw_indented_ostream &ios) {
|
||||
std::string getter;
|
||||
if (auto cGetter = memberRec->getValueAsOptionalString("cGetter");
|
||||
cGetter && !cGetter->empty()) {
|
||||
getter = format(
|
||||
*cGetter,
|
||||
{{"$_attrType", parent.str()},
|
||||
{"$_member", name.str()},
|
||||
{"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}});
|
||||
} else {
|
||||
getter =
|
||||
formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true))
|
||||
.str();
|
||||
}
|
||||
|
||||
if (memberRec->isSubClassOf("Array")) {
|
||||
Record *def = memberRec->getValueAsDef("elemT");
|
||||
if (!def->isSubClassOf("CompositeBytecode")) {
|
||||
if (def->isSubClassOf("AttributeKind")) {
|
||||
ios << "writer.writeAttributes(" << getter << ");\n";
|
||||
return;
|
||||
}
|
||||
if (def->isSubClassOf("TypeKind")) {
|
||||
ios << "writer.writeTypes(" << getter << ");\n";
|
||||
return;
|
||||
}
|
||||
}
|
||||
std::string returnType = getCType(def);
|
||||
ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
|
||||
<< kind << ") ";
|
||||
auto lambdaScope = ios.scope("{\n", "});\n");
|
||||
return emitPrintHelper(def, kind, kind, kind, ios);
|
||||
}
|
||||
if (memberRec->isSubClassOf("CompositeBytecode")) {
|
||||
auto *members = memberRec->getValueAsDag("members");
|
||||
for (auto [arg, argName] :
|
||||
zip(members->getArgs(), members->getArgNames())) {
|
||||
DefInit *def = dyn_cast<DefInit>(arg);
|
||||
assert(def);
|
||||
emitPrintHelper(def->getDef(), kind, parent,
|
||||
argName->getAsUnquotedString(), ios);
|
||||
}
|
||||
}
|
||||
|
||||
if (std::string printer = memberRec->getValueAsString("cPrinter").str();
|
||||
!printer.empty())
|
||||
ios << format(printer, {{"$_writer", "writer"},
|
||||
{"$_name", kind.str()},
|
||||
{"$_getter", getter}})
|
||||
<< ";\n";
|
||||
}
|
||||
|
||||
void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
|
||||
mlir::raw_indented_ostream os(output);
|
||||
char const *head = R"(static LogicalResult write{0}({0} {1},
|
||||
DialectBytecodeWriter &writer))";
|
||||
os << formatv(head, capitalize(kind), kind);
|
||||
auto funScope = os.scope(" {\n", "}\n\n");
|
||||
|
||||
os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind
|
||||
<< ")";
|
||||
auto switchScope = os.scope("", "");
|
||||
for (StringRef type : vec) {
|
||||
os << "\n.Case([&](" << type << " t)";
|
||||
auto caseScope = os.scope(" {\n", "})");
|
||||
os << "return write(t, writer), success();\n";
|
||||
}
|
||||
os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n";
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Container of Attribute or Type for Dialect.
|
||||
struct AttrOrType {
|
||||
std::vector<Record *> attr, type;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
|
||||
MapVector<StringRef, AttrOrType> dialectAttrOrType;
|
||||
for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) {
|
||||
if (!selectedBcDialect.empty() &&
|
||||
it->getValueAsString("dialect") != selectedBcDialect)
|
||||
continue;
|
||||
dialectAttrOrType[it->getValueAsString("dialect")].attr =
|
||||
it->getValueAsListOfDefs("elems");
|
||||
}
|
||||
for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) {
|
||||
if (!selectedBcDialect.empty() &&
|
||||
it->getValueAsString("dialect") != selectedBcDialect)
|
||||
continue;
|
||||
dialectAttrOrType[it->getValueAsString("dialect")].type =
|
||||
it->getValueAsListOfDefs("elems");
|
||||
}
|
||||
|
||||
if (dialectAttrOrType.size() != 1)
|
||||
PrintFatalError("Single dialect per invocation required (either only "
|
||||
"one in input file or specified via dialect option)");
|
||||
|
||||
auto it = dialectAttrOrType.front();
|
||||
Generator gen(os);
|
||||
|
||||
SmallVector<std::vector<Record *> *, 2> vecs;
|
||||
SmallVector<std::string, 2> kinds;
|
||||
vecs.push_back(&it.second.attr);
|
||||
kinds.push_back("attribute");
|
||||
vecs.push_back(&it.second.type);
|
||||
kinds.push_back("type");
|
||||
for (auto [vec, kind] : zip(vecs, kinds)) {
|
||||
// Handle Attribute/Type emission.
|
||||
std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType;
|
||||
for (auto kt : llvm::enumerate(*vec))
|
||||
perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
|
||||
for (const auto &jt : perType) {
|
||||
for (auto kt : jt.second)
|
||||
gen.emitParse(kind, *std::get<1>(kt));
|
||||
gen.emitPrint(kind, jt.first, jt.second);
|
||||
}
|
||||
gen.emitParseDispatch(kind, *vec);
|
||||
|
||||
SmallVector<std::string> types;
|
||||
for (const auto &it : perType) {
|
||||
types.push_back(it.first);
|
||||
}
|
||||
gen.emitPrintDispatch(kind, types);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static mlir::GenRegistration
|
||||
genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitBCRW(records, os);
|
||||
});
|
@ -9,6 +9,7 @@ add_tablegen(mlir-tblgen MLIR
|
||||
EXPORT MLIR
|
||||
AttrOrTypeDefGen.cpp
|
||||
AttrOrTypeFormatGen.cpp
|
||||
BytecodeDialectGen.cpp
|
||||
DialectGen.cpp
|
||||
DirectiveCommonGen.cpp
|
||||
EnumsGen.cpp
|
||||
|
@ -91,6 +91,7 @@ td_library(
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":BytecodeTdFiles",
|
||||
":CallInterfacesTdFiles",
|
||||
":CastInterfacesTdFiles",
|
||||
":DataLayoutInterfacesTdFiles",
|
||||
@ -117,6 +118,20 @@ gentbl_cc_library(
|
||||
deps = [":BuiltinDialectTdFiles"],
|
||||
)
|
||||
|
||||
gentbl_cc_library(
|
||||
name = "BuiltinDialectBytecodeGen",
|
||||
strip_include_prefix = "include",
|
||||
tbl_outs = [
|
||||
(
|
||||
["-gen-bytecode", "-bytecode-dialect=Builtin"],
|
||||
"include/mlir/IR/BuiltinDialectBytecode.cpp.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":mlir-tblgen",
|
||||
td_file = "include/mlir/IR/BuiltinDialectBytecode.td",
|
||||
deps = [":BuiltinDialectTdFiles"],
|
||||
)
|
||||
|
||||
gentbl_cc_library(
|
||||
name = "BuiltinAttributesIncGen",
|
||||
strip_include_prefix = "include",
|
||||
@ -277,6 +292,7 @@ cc_library(
|
||||
deps = [
|
||||
":BuiltinAttributeInterfacesIncGen",
|
||||
":BuiltinAttributesIncGen",
|
||||
":BuiltinDialectBytecodeGen",
|
||||
":BuiltinDialectIncGen",
|
||||
":BuiltinLocationAttributesIncGen",
|
||||
":BuiltinOpsIncGen",
|
||||
@ -929,6 +945,12 @@ td_library(
|
||||
],
|
||||
)
|
||||
|
||||
td_library(
|
||||
name = "BytecodeTdFiles",
|
||||
srcs = ["include/mlir/IR/BytecodeBase.td"],
|
||||
includes = ["include"],
|
||||
)
|
||||
|
||||
td_library(
|
||||
name = "CallInterfacesTdFiles",
|
||||
srcs = ["include/mlir/Interfaces/CallInterfaces.td"],
|
||||
|
Loading…
x
Reference in New Issue
Block a user