From 09115580056fc57d5cdaa1de20631a838a4ea1c4 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 24 Apr 2023 11:53:58 -0700 Subject: [PATCH] [mlir] Dialect type/attr bytecode read/write generator. Tool to help generate dialect bytecode Attribute & Type reader/writing. Show usage by flipping builtin dialect. It helps reduce boilerplate when writing dialect bytecode attribute and type readers/writers. It is not an attempt at a generic spec mechanism but rather practically focussing on boilerplate reduction while also considering that it need not be the only in memory format and make it relatively easy to change. There should be some cleanup in follow up as we expand to more dialects. Differential Revision: https://reviews.llvm.org/D144820 --- .../mlir/Bytecode/BytecodeImplementation.h | 31 + .../include/mlir/IR/BuiltinDialectBytecode.td | 566 ++++++++ mlir/include/mlir/IR/BytecodeBase.td | 159 +++ mlir/include/mlir/IR/CMakeLists.txt | 4 + mlir/lib/IR/BuiltinDialectBytecode.cpp | 1205 +---------------- mlir/lib/IR/CMakeLists.txt | 1 + mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp | 467 +++++++ mlir/tools/mlir-tblgen/CMakeLists.txt | 1 + .../llvm-project-overlay/mlir/BUILD.bazel | 22 + 9 files changed, 1324 insertions(+), 1132 deletions(-) create mode 100644 mlir/include/mlir/IR/BuiltinDialectBytecode.td create mode 100644 mlir/include/mlir/IR/BytecodeBase.td create mode 100644 mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h index ea9bcad735b3..6e7b9ff26342 100644 --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -345,6 +345,37 @@ public: } }; +/// Helper for resource handle reading that returns LogicalResult. +template +static LogicalResult readResourceHandle(DialectBytecodeReader &reader, + FailureOr &value, Ts &&...params) { + FailureOr handle = reader.readResourceHandle(); + if (failed(handle)) + return failure(); + if (auto *result = dyn_cast(&*handle)) { + value = std::move(*result); + return success(); + } + return failure(); +} + +/// Helper method that injects context only if needed, this helps unify some of +/// the attribute construction methods. +template +auto get(MLIRContext *context, Ts &&...params) { + // Prefer a direct `get` method if one exists. + if constexpr (llvm::is_detected::value) { + (void)context; + return T::get(std::forward(params)...); + } else if constexpr (llvm::is_detected::value) { + return T::get(context, std::forward(params)...); + } else { + // Otherwise, pass to the base get. + return T::Base::get(context, std::forward(params)...); + } +} + } // namespace mlir #endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td new file mode 100644 index 000000000000..b59f96c9fa9f --- /dev/null +++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td @@ -0,0 +1,566 @@ +//===-- BuiltinBytecode.td - Builtin bytecode defs ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the Builtin bytecode reader/writer definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef BUILTIN_BYTECODE +#define BUILTIN_BYTECODE + +include "mlir/IR/BytecodeBase.td" + +def LocationAttr : AttributeKind; +def ShapedType: WithType<"ShapedType", Type>; + +def Location : CompositeBytecode { + dag members = (attr + WithGetter<"(LocationAttr)$_attrType", WithType<"LocationAttr", LocationAttr>>:$value + ); + let cBuilder = "Location($_args)"; +} + +def String : + WithParser <"succeeded($_reader.readString($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeOwnedString($_getter)", + WithGetter <"$_attrType", + WithType <"StringRef">>>>>; + +// enum AttributeCode { +// /// ArrayAttr { +// /// elements: Attribute[] +// /// } +// /// +// kArrayAttr = 0, +// +def ArrayAttr : DialectAttribute<(attr + Array:$value +)>; + +let cType = "StringAttr" in { +// /// StringAttr { +// /// value: string +// /// } +// kStringAttr = 2, +def StringAttr : DialectAttribute<(attr + String:$value +)> { + let printerPredicate = "$_val.getType().isa()"; +} + +// /// StringAttrWithType { +// /// value: string, +// /// type: Type +// /// } +// /// A variant of StringAttr with a type. +// kStringAttrWithType = 3, +def StringAttrWithType : DialectAttribute<(attr + String:$value, + Type:$type +)> { let printerPredicate = "!$_val.getType().isa()"; } +} + +// /// DictionaryAttr { +// /// attrs: [] +// /// } +// kDictionaryAttr = 1, +def NamedAttribute : CompositeBytecode { + dag members = (attr + StringAttr:$name, + Attribute:$value + ); + let cBuilder = "NamedAttribute($_args)"; +} +def DictionaryAttr : DialectAttribute<(attr + Array:$value +)>; + +// /// FlatSymbolRefAttr { +// /// rootReference: StringAttr +// /// } +// /// A variant of SymbolRefAttr with no leaf references. +// kFlatSymbolRefAttr = 4, +def FlatSymbolRefAttr: DialectAttribute<(attr + StringAttr:$rootReference +)>; + +// /// SymbolRefAttr { +// /// rootReference: StringAttr, +// /// leafReferences: FlatSymbolRefAttr[] +// /// } +// kSymbolRefAttr = 5, +def SymbolRefAttr: DialectAttribute<(attr + StringAttr:$rootReference, + Array:$nestedReferences +)>; + +// /// TypeAttr { +// /// value: Type +// /// } +// kTypeAttr = 6, +def TypeAttr: DialectAttribute<(attr + Type:$value +)>; + +// /// UnitAttr { +// /// } +// kUnitAttr = 7, +def UnitAttr: DialectAttribute<(attr)>; + +// /// IntegerAttr { +// /// type: Type +// /// value: APInt, +// /// } +// kIntegerAttr = 8, +def IntegerAttr: DialectAttribute<(attr + Type:$type, + KnownWidthAPInt<"type">:$value +)> { + let cBuilder = "get<$_resultType>(context, type, *value)"; +} + +// +// /// FloatAttr { +// /// type: FloatType +// /// value: APFloat +// /// } +// kFloatAttr = 9, +defvar FloatType = Type; +def FloatAttr : DialectAttribute<(attr + FloatType:$type, + KnownSemanticsAPFloat<"type">:$value +)> { + let cBuilder = "get<$_resultType>(context, type, *value)"; +} + +// /// CallSiteLoc { +// /// callee: LocationAttr, +// /// caller: LocationAttr +// /// } +// kCallSiteLoc = 10, +def CallSiteLoc : DialectAttribute<(attr + LocationAttr:$callee, + LocationAttr:$caller +)>; + +// /// FileLineColLoc { +// /// filename: StringAttr, +// /// line: varint, +// /// column: varint +// /// } +// kFileLineColLoc = 11, +def FileLineColLoc : DialectAttribute<(attr + StringAttr:$filename, + VarInt:$line, + VarInt:$column +)>; + +let cType = "FusedLoc", + cBuilder = "cast(get(context, $_args))" in { +// /// FusedLoc { +// /// locations: Location[] +// /// } +// kFusedLoc = 12, +def FusedLoc : DialectAttribute<(attr + Array:$locations +)> { + let printerPredicate = "!$_val.getMetadata()"; +} + +// /// FusedLocWithMetadata { +// /// locations: LocationAttr[], +// /// metadata: Attribute +// /// } +// /// A variant of FusedLoc with metadata. +// kFusedLocWithMetadata = 13, +def FusedLocWithMetadata : DialectAttribute<(attr + Array:$locations, + Attribute:$metadata +)> { + let printerPredicate = "$_val.getMetadata()"; +} +} + +// /// NameLoc { +// /// name: StringAttr, +// /// childLoc: LocationAttr +// /// } +// kNameLoc = 14, +def NameLoc : DialectAttribute<(attr + StringAttr:$name, + LocationAttr:$childLoc +)>; + +// /// UnknownLoc { +// /// } +// kUnknownLoc = 15, +def UnknownLoc : DialectAttribute<(attr)>; + +// /// DenseResourceElementsAttr { +// /// type: ShapedType, +// /// handle: ResourceHandle +// /// } +// kDenseResourceElementsAttr = 16, +def DenseResourceElementsAttr : DialectAttribute<(attr + ShapedType:$type, + ResourceHandle<"DenseResourceElementsHandle">:$rawHandle +)> { + // Note: order of serialization does not match order of builder. + let cBuilder = "get<$_resultType>(context, type, *rawHandle)"; +} + +let cType = "RankedTensorType" in { +// /// RankedTensorType { +// /// shape: svarint[], +// /// elementType: Type, +// /// } +// /// +// kRankedTensorType = 13, +def RankedTensorType : DialectType<(type + Array:$shape, + Type:$elementType +)> { + let printerPredicate = "!$_val.getEncoding()"; +} + +// /// RankedTensorTypeWithEncoding { +// /// encoding: Attribute, +// /// shape: svarint[], +// /// elementType: Type +// /// } +// /// Variant of RankedTensorType with an encoding. +// kRankedTensorTypeWithEncoding = 14, +def RankedTensorTypeWithEncoding : DialectType<(type + Attribute:$encoding, + Array:$shape, + Type:$elementType +)> { + let printerPredicate = "$_val.getEncoding()"; + // Note: order of serialization does not match order of builder. + let cBuilder = "get<$_resultType>(context, shape, elementType, encoding)"; +} +} + +// /// DenseArrayAttr { +// /// elementType: Type, +// /// size: varint +// /// data: blob +// /// } +// kDenseArrayAttr = 17, +def DenseArrayAttr : DialectAttribute<(attr + Type:$elementType, + VarInt:$size, + Blob:$rawData +)>; + +// /// DenseIntOrFPElementsAttr { +// /// type: ShapedType, +// /// data: blob +// /// } +// kDenseIntOrFPElementsAttr = 18, +def DenseElementsAttr : WithType<"DenseIntElementsAttr", Attribute>; +def DenseIntOrFPElementsAttr : DialectAttribute<(attr + ShapedType:$type, + Blob:$rawData +)> { + let cBuilder = "cast<$_resultType>($_resultType::getFromRawBuffer($_args))"; +} + +// /// DenseStringElementsAttr { +// /// type: ShapedType, +// /// isSplat: varint, +// /// data: string[] +// /// } +// kDenseStringElementsAttr = 19, +def DenseStringElementsAttr : DialectAttribute<(attr + ShapedType:$type, + WithGetter<"$_attrType.isSplat()", VarInt>:$_isSplat, + WithBuilder<"$_args", + WithType<"SmallVector", + WithParser <"succeeded(readPotentiallySplatString($_reader, type, _isSplat, $_var))", + WithPrinter<"writePotentiallySplatString($_writer, $_name)">>>>:$rawStringData +)>; + +// /// SparseElementsAttr { +// /// type: ShapedType, +// /// indices: DenseIntElementsAttr, +// /// values: DenseElementsAttr +// /// } +// kSparseElementsAttr = 20, +def DenseIntElementsAttr : WithType<"DenseIntElementsAttr", Attribute>; +def SparseElementsAttr : DialectAttribute<(attr + ShapedType:$type, + DenseIntElementsAttr:$indices, + DenseElementsAttr:$values +)>; + +// Types +// ----- + +// enum TypeCode { +// /// IntegerType { +// /// widthAndSignedness: varint // (width << 2) | (signedness) +// /// } +// /// +// kIntegerType = 0, +def IntegerType : DialectType<(type + // Yes not pretty, + WithParser<"succeeded($_reader.readVarInt($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeVarInt($_name.getWidth() << 2 | $_name.getSignedness())", + WithType <"uint64_t">>>>:$_widthAndSignedness, + // Split up parsed varint for create method. + LocalVar<"uint64_t", "_widthAndSignedness >> 2">:$width, + LocalVar<"IntegerType::SignednessSemantics", + "static_cast(_widthAndSignedness & 0x3)">:$signedness +)>; + +// +// /// IndexType { +// /// } +// /// +// kIndexType = 1, +def IndexType : DialectType<(type)>; + +// /// FunctionType { +// /// inputs: Type[], +// /// results: Type[] +// /// } +// /// +// kFunctionType = 2, +def FunctionType : DialectType<(type + Array:$inputs, + Array:$results +)>; + +// /// BFloat16Type { +// /// } +// /// +// kBFloat16Type = 3, +def BFloat16Type : DialectType<(type)>; + +// /// Float16Type { +// /// } +// /// +// kFloat16Type = 4, +def Float16Type : DialectType<(type)>; + +// /// Float32Type { +// /// } +// /// +// kFloat32Type = 5, +def Float32Type : DialectType<(type)>; + +// /// Float64Type { +// /// } +// /// +// kFloat64Type = 6, +def Float64Type : DialectType<(type)>; + +// /// Float80Type { +// /// } +// /// +// kFloat80Type = 7, +def Float80Type : DialectType<(type)>; + +// /// Float128Type { +// /// } +// /// +// kFloat128Type = 8, +def Float128Type : DialectType<(type)>; + +// /// ComplexType { +// /// elementType: Type +// /// } +// /// +// kComplexType = 9, +def ComplexType : DialectType<(type + Type:$elementType +)>; + +def MemRefLayout: WithType<"MemRefLayoutAttrInterface", Attribute>; + +let cType = "MemRefType" in { +// /// MemRefType { +// /// shape: svarint[], +// /// elementType: Type, +// /// layout: Attribute +// /// } +// /// +// kMemRefType = 10, +def MemRefType : DialectType<(type + Array:$shape, + Type:$elementType, + MemRefLayout:$layout +)> { + let printerPredicate = "!$_val.getMemorySpace()"; +} + +// /// MemRefTypeWithMemSpace { +// /// memorySpace: Attribute, +// /// shape: svarint[], +// /// elementType: Type, +// /// layout: Attribute +// /// } +// /// Variant of MemRefType with non-default memory space. +// kMemRefTypeWithMemSpace = 11, +def MemRefTypeWithMemSpace : DialectType<(type + Attribute:$memorySpace, + Array:$shape, + Type:$elementType, + MemRefLayout:$layout +)> { + let printerPredicate = "!!$_val.getMemorySpace()"; + // Note: order of serialization does not match order of builder. + let cBuilder = "get<$_resultType>(context, shape, elementType, layout, memorySpace)"; +} +} + +// /// NoneType { +// /// } +// /// +// kNoneType = 12, +def NoneType : DialectType<(type)>; + +// /// TupleType { +// /// elementTypes: Type[] +// /// } +// kTupleType = 15, +def TupleType : DialectType<(type + Array:$types +)>; + +let cType = "UnrankedMemRefType" in { +// /// UnrankedMemRefType { +// /// elementType: Type +// /// } +// /// +// kUnrankedMemRefType = 16, +def UnrankedMemRefType : DialectType<(type + Type:$elementType +)> { + let printerPredicate = "!$_val.getMemorySpace()"; + let cBuilder = "get<$_resultType>(context, elementType, Attribute())"; +} + +// /// UnrankedMemRefTypeWithMemSpace { +// /// memorySpace: Attribute, +// /// elementType: Type +// /// } +// /// Variant of UnrankedMemRefType with non-default memory space. +// kUnrankedMemRefTypeWithMemSpace = 17, +def UnrankedMemRefTypeWithMemSpace : DialectType<(type + Attribute:$memorySpace, + Type:$elementType +)> { + let printerPredicate = "$_val.getMemorySpace()"; + // Note: order of serialization does not match order of builder. + let cBuilder = "get<$_resultType>(context, elementType, memorySpace)"; +} +} + +// /// UnrankedTensorType { +// /// elementType: Type +// /// } +// /// +// kUnrankedTensorType = 18, +def UnrankedTensorType : DialectType<(type + Type:$elementType +)>; + +let cType = "VectorType" in { +// /// VectorType { +// /// shape: svarint[], +// /// elementType: Type +// /// } +// /// +// kVectorType = 19, +def VectorType : DialectType<(type + Array:$shape, + Type:$elementType +)> { + let printerPredicate = "!$_val.getNumScalableDims()"; +} + +// /// VectorTypeWithScalableDims { +// /// numScalableDims: varint, +// /// shape: svarint[], +// /// elementType: Type +// /// } +// /// Variant of VectorType with scalable dimensions. +// kVectorTypeWithScalableDims = 20, +def VectorTypeWithScalableDims : DialectType<(type + VarInt:$numScalableDims, + Array:$shape, + Type:$elementType +)> { + let printerPredicate = "$_val.getNumScalableDims()"; + // Note: order of serialization does not match order of builder. + let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims)"; +} +} + +/// This enum contains marker codes used to indicate which attribute is +/// currently being decoded, and how it should be decoded. The order of these +/// codes should generally be unchanged, as any changes will inevitably break +/// compatibility with older bytecode. + +def BuiltinDialectAttributes : DialectAttributes<"Builtin"> { + let elems = [ + ArrayAttr, + DictionaryAttr, + StringAttr, + StringAttrWithType, + FlatSymbolRefAttr, + SymbolRefAttr, + TypeAttr, + UnitAttr, + IntegerAttr, + FloatAttr, + CallSiteLoc, + FileLineColLoc, + FusedLoc, + FusedLocWithMetadata, + NameLoc, + UnknownLoc, + DenseResourceElementsAttr, + DenseArrayAttr, + DenseIntOrFPElementsAttr, + DenseStringElementsAttr, + SparseElementsAttr + ]; +} + +def BuiltinDialectTypes : DialectTypes<"Builtin"> { + let elems = [ + IntegerType, + IndexType, + FunctionType, + BFloat16Type, + Float16Type, + Float32Type, + Float64Type, + Float80Type, + Float128Type, + ComplexType, + MemRefType, + MemRefTypeWithMemSpace, + NoneType, + RankedTensorType, + RankedTensorTypeWithEncoding, + TupleType, + UnrankedMemRefType, + UnrankedMemRefTypeWithMemSpace, + UnrankedTensorType, + VectorType, + VectorTypeWithScalableDims + ]; +} + +#endif // BUILTIN_BYTECODE diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td new file mode 100644 index 000000000000..8cadf978b347 --- /dev/null +++ b/mlir/include/mlir/IR/BytecodeBase.td @@ -0,0 +1,159 @@ +//===-- BytecodeBase.td - Base bytecode R/W defs -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the base bytecode reader/writer definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTECODE_BASE +#define BYTECODE_BASE + +class Bytecode { + // Template for parsing. + // $_reader == dialect bytecode reader + // $_resultType == result type of parsed instance + // $_var == variable being parsed + // If parser is not specified, then the parse of members is used. + string cParser = parse; + + // Template for building from parsed. + // $_resultType == result type of parsed instance + // $_args == args/members comma separated + string cBuilder = build; + + // Template for printing. + // $_writer == dialect bytecode writer + // $_name == parent attribute/type name + // $_getter == getter + string cPrinter = print; + + // Template for getter from in memory form. + // $_attrType == attribute/type + // $_member == member instance + // $_getMember == get + UpperCamelFromSnake($_member) + string cGetter = get; + + // Type built. + // Note: if cType is empty, then name of def is used. + string cType = t; + + // Predicate guarding parse method as an Attribute/Type could have multiple + // parse methods, specify predicates to be orthogonal and cover entire + // "print space" to avoid order dependence. + // If empty then method is unconditional. + // $_val == predicate function to apply on value dyn_casted to cType. + string printerPredicate = ""; +} + +class WithParser> : + Bytecode; +class WithBuilder> : + Bytecode; +class WithPrinter> : + Bytecode; +class WithType> : + Bytecode; +class WithGetter> : + Bytecode; + +class CompositeBytecode : WithType; + +class AttributeKind : + WithParser <"succeeded($_reader.readAttribute($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeAttribute($_getter)">>>; +def Attribute : AttributeKind; +class TypeKind : + WithParser <"succeeded($_reader.readType($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeType($_getter)">>>; +def Type : TypeKind; +def VarInt : + WithParser <"succeeded($_reader.readVarInt($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeVarInt($_getter)", + WithType <"uint64_t">>>>; +def SignedVarInt : + WithParser <"succeeded($_reader.readSignedVarInt($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeSignedVarInt($_getter)", + WithGetter<"$_attrType", + WithType <"int64_t">>>>>; +def Blob : + WithParser <"succeeded($_reader.readBlob($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeOwnedBlob($_getter)", + WithType <"ArrayRef">>>>; + +class KnownWidthAPInt : + WithParser <"succeeded(readAPIntWithKnownWidth($_reader, " # s # ", $_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeAPIntWithKnownWidth($_getter)", + WithType <"FailureOr">>>>; +class KnownSemanticsAPFloat : + WithParser <"succeeded(readAPFloatWithKnownSemantics($_reader, " # s # ", $_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeAPFloatWithKnownSemantics($_getter)", + WithType <"FailureOr">>>>; +class ResourceHandle : + WithParser <"succeeded(readResourceHandle<" # s # ">($_reader, $_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeResourceHandle($_getter)", + WithType <"FailureOr<" # s # ">">>>>; + +// Helper to define variable that is defined later but not parsed nor printed. +class LocalVar : + WithParser <"(($_var = " # d # "), true)", + WithBuilder<"$_args", + WithPrinter<"", + WithType >>>; + +// Array instances. +class Array { + Bytecode elemT = t; + + string cBuilder = "$_args"; +} + +// Define dialect attribute or type. +class DialectAttrOrType { + // Any members starting with underscore is not fed to create function but + // treated as purely local variable. + dag members = d; + + // When needing to specify a custom return type. + string cType = ""; + + // Any post-processing that needs to be done. + code postProcess = ""; +} + +class DialectAttribute : DialectAttrOrType, AttributeKind { + let cParser = "succeeded($_reader.readAttribute<$_resultType>($_var))"; + let cBuilder = "get<$_resultType>(context, $_args)"; +} +class DialectType : DialectAttrOrType, TypeKind { + let cParser = "succeeded($_reader.readType<$_resultType>($_var))"; + let cBuilder = "get<$_resultType>(context, $_args)"; +} + +class DialectAttributes { + string dialect = d; + list elems; +} + +class DialectTypes { + string dialect = d; + list elems; +} + +def attr; +def type; + +#endif // BYTECODE_BASE + diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt index 78d41d6dc4ab..404f130022f1 100644 --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -17,6 +17,10 @@ mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls) mlir_tablegen(BuiltinDialect.cpp.inc -gen-dialect-defs) add_public_tablegen_target(MLIRBuiltinDialectIncGen) +set(LLVM_TARGET_DEFINITIONS BuiltinDialectBytecode.td) +mlir_tablegen(BuiltinDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Builtin") +add_public_tablegen_target(MLIRBuiltinDialectBytecodeIncGen) + set(LLVM_TARGET_DEFINITIONS BuiltinLocationAttributes.td) mlir_tablegen(BuiltinLocationAttributes.h.inc -gen-attrdef-decls) mlir_tablegen(BuiltinLocationAttributes.cpp.inc -gen-attrdef-defs) diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp index 22a563dd7b2a..40af5f3b1744 100644 --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -8,6 +8,7 @@ #include "BuiltinDialectBytecode.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" @@ -16,285 +17,71 @@ using namespace mlir; -//===----------------------------------------------------------------------===// -// Encoding -//===----------------------------------------------------------------------===// - -namespace { -namespace builtin_encoding { -/// This enum contains marker codes used to indicate which attribute is -/// currently being decoded, and how it should be decoded. The order of these -/// codes should generally be unchanged, as any changes will inevitably break -/// compatibility with older bytecode. -enum AttributeCode { - /// ArrayAttr { - /// elements: Attribute[] - /// } - /// - kArrayAttr = 0, - - /// DictionaryAttr { - /// attrs: [] - /// } - kDictionaryAttr = 1, - - /// StringAttr { - /// value: string - /// } - kStringAttr = 2, - - /// StringAttrWithType { - /// value: string, - /// type: Type - /// } - /// A variant of StringAttr with a type. - kStringAttrWithType = 3, - - /// FlatSymbolRefAttr { - /// rootReference: StringAttr - /// } - /// A variant of SymbolRefAttr with no leaf references. - kFlatSymbolRefAttr = 4, - - /// SymbolRefAttr { - /// rootReference: StringAttr, - /// leafReferences: FlatSymbolRefAttr[] - /// } - kSymbolRefAttr = 5, - - /// TypeAttr { - /// value: Type - /// } - kTypeAttr = 6, - - /// UnitAttr { - /// } - kUnitAttr = 7, - - /// IntegerAttr { - /// type: Type - /// value: APInt, - /// } - kIntegerAttr = 8, - - /// FloatAttr { - /// type: FloatType - /// value: APFloat - /// } - kFloatAttr = 9, - - /// CallSiteLoc { - /// callee: LocationAttr, - /// caller: LocationAttr - /// } - kCallSiteLoc = 10, - - /// FileLineColLoc { - /// file: StringAttr, - /// line: varint, - /// column: varint - /// } - kFileLineColLoc = 11, - - /// FusedLoc { - /// locations: LocationAttr[] - /// } - kFusedLoc = 12, - - /// FusedLocWithMetadata { - /// locations: LocationAttr[], - /// metadata: Attribute - /// } - /// A variant of FusedLoc with metadata. - kFusedLocWithMetadata = 13, - - /// NameLoc { - /// name: StringAttr, - /// childLoc: LocationAttr - /// } - kNameLoc = 14, - - /// UnknownLoc { - /// } - kUnknownLoc = 15, - - /// DenseResourceElementsAttr { - /// type: Type, - /// handle: ResourceHandle - /// } - kDenseResourceElementsAttr = 16, - - /// DenseArrayAttr { - /// type: RankedTensorType, - /// data: blob - /// } - kDenseArrayAttr = 17, - - /// DenseIntOrFPElementsAttr { - /// type: ShapedType, - /// data: blob - /// } - kDenseIntOrFPElementsAttr = 18, - - /// DenseStringElementsAttr { - /// type: ShapedType, - /// isSplat: varint, - /// data: string[] - /// } - kDenseStringElementsAttr = 19, - - /// SparseElementsAttr { - /// type: ShapedType, - /// indices: DenseIntElementsAttr, - /// values: DenseElementsAttr - /// } - kSparseElementsAttr = 20, -}; - -/// This enum contains marker codes used to indicate which type is currently -/// being decoded, and how it should be decoded. The order of these codes should -/// generally be unchanged, as any changes will inevitably break compatibility -/// with older bytecode. -enum TypeCode { - /// IntegerType { - /// widthAndSignedness: varint // (width << 2) | (signedness) - /// } - /// - kIntegerType = 0, - - /// IndexType { - /// } - /// - kIndexType = 1, - - /// FunctionType { - /// inputs: Type[], - /// results: Type[] - /// } - /// - kFunctionType = 2, - - /// BFloat16Type { - /// } - /// - kBFloat16Type = 3, - - /// Float16Type { - /// } - /// - kFloat16Type = 4, - - /// Float32Type { - /// } - /// - kFloat32Type = 5, - - /// Float64Type { - /// } - /// - kFloat64Type = 6, - - /// Float80Type { - /// } - /// - kFloat80Type = 7, - - /// Float128Type { - /// } - /// - kFloat128Type = 8, - - /// ComplexType { - /// elementType: Type - /// } - /// - kComplexType = 9, - - /// MemRefType { - /// shape: svarint[], - /// elementType: Type, - /// layout: Attribute - /// } - /// - kMemRefType = 10, - - /// MemRefTypeWithMemSpace { - /// memorySpace: Attribute, - /// shape: svarint[], - /// elementType: Type, - /// layout: Attribute - /// } - /// Variant of MemRefType with non-default memory space. - kMemRefTypeWithMemSpace = 11, - - /// NoneType { - /// } - /// - kNoneType = 12, - - /// RankedTensorType { - /// shape: svarint[], - /// elementType: Type, - /// } - /// - kRankedTensorType = 13, - - /// RankedTensorTypeWithEncoding { - /// encoding: Attribute, - /// shape: svarint[], - /// elementType: Type - /// } - /// Variant of RankedTensorType with an encoding. - kRankedTensorTypeWithEncoding = 14, - - /// TupleType { - /// elementTypes: Type[] - /// } - kTupleType = 15, - - /// UnrankedMemRefType { - /// shape: svarint[] - /// } - /// - kUnrankedMemRefType = 16, - - /// UnrankedMemRefTypeWithMemSpace { - /// memorySpace: Attribute, - /// shape: svarint[] - /// } - /// Variant of UnrankedMemRefType with non-default memory space. - kUnrankedMemRefTypeWithMemSpace = 17, - - /// UnrankedTensorType { - /// elementType: Type - /// } - /// - kUnrankedTensorType = 18, - - /// VectorType { - /// shape: svarint[], - /// elementType: Type - /// } - /// - kVectorType = 19, - - /// VectorTypeWithScalableDims { - /// numScalableDims: varint, - /// shape: svarint[], - /// elementType: Type - /// } - /// Variant of VectorType with scalable dimensions. - kVectorTypeWithScalableDims = 20, -}; - -} // namespace builtin_encoding -} // namespace - //===----------------------------------------------------------------------===// // BuiltinDialectBytecodeInterface //===----------------------------------------------------------------------===// namespace { + +//===----------------------------------------------------------------------===// +// Utility functions + +// TODO: Move these to separate file. + +// Returns the bitwidth if known, else return 0. +static unsigned getIntegerBitWidth(DialectBytecodeReader &reader, Type type) { + if (auto intType = dyn_cast(type)) { + return intType.getWidth(); + } else if (type.isa()) { + return IndexType::kInternalStorageBitWidth; + } + reader.emitError() + << "expected integer or index type for IntegerAttr, but got: " << type; + return 0; +} + +static LogicalResult readAPIntWithKnownWidth(DialectBytecodeReader &reader, + Type type, FailureOr &val) { + unsigned bitWidth = getIntegerBitWidth(reader, type); + if (bitWidth == 0) + return failure(); + val = reader.readAPIntWithKnownWidth(bitWidth); + return val; +} + +static LogicalResult +readAPFloatWithKnownSemantics(DialectBytecodeReader &reader, Type type, + FailureOr &val) { + auto ftype = dyn_cast(type); + if (!ftype) + return failure(); + val = reader.readAPFloatWithKnownSemantics(ftype.getFloatSemantics()); + return success(); +} + +LogicalResult +readPotentiallySplatString(DialectBytecodeReader &reader, ShapedType type, + bool isSplat, + SmallVectorImpl &rawStringData) { + rawStringData.resize(isSplat ? 1 : type.getNumElements()); + for (StringRef &value : rawStringData) + if (failed(reader.readString(value))) + return failure(); + return success(); +} + +void writePotentiallySplatString(DialectBytecodeWriter &writer, + DenseStringElementsAttr attr) { + bool isSplat = attr.isSplat(); + if (isSplat) + return writer.writeOwnedString(attr.getRawStringData().front()); + + for (StringRef str : attr.getRawStringData()) + writer.writeOwnedString(str); +} + +#include "mlir/IR/BuiltinDialectBytecode.cpp.inc" + /// This class implements the bytecode interface for the builtin dialect. struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface { BuiltinDialectBytecodeInterface(Dialect *dialect) @@ -303,875 +90,29 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface { //===--------------------------------------------------------------------===// // Attributes - Attribute readAttribute(DialectBytecodeReader &reader) const override; - ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const; - DenseArrayAttr readDenseArrayAttr(DialectBytecodeReader &reader) const; - DenseElementsAttr - readDenseIntOrFPElementsAttr(DialectBytecodeReader &reader) const; - DenseStringElementsAttr - readDenseStringElementsAttr(DialectBytecodeReader &reader) const; - DenseResourceElementsAttr - readDenseResourceElementsAttr(DialectBytecodeReader &reader) const; - DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const; - FloatAttr readFloatAttr(DialectBytecodeReader &reader) const; - IntegerAttr readIntegerAttr(DialectBytecodeReader &reader) const; - SparseElementsAttr - readSparseElementsAttr(DialectBytecodeReader &reader) const; - StringAttr readStringAttr(DialectBytecodeReader &reader, bool hasType) const; - SymbolRefAttr readSymbolRefAttr(DialectBytecodeReader &reader, - bool hasNestedRefs) const; - TypeAttr readTypeAttr(DialectBytecodeReader &reader) const; - - LocationAttr readCallSiteLoc(DialectBytecodeReader &reader) const; - LocationAttr readFileLineColLoc(DialectBytecodeReader &reader) const; - LocationAttr readFusedLoc(DialectBytecodeReader &reader, - bool hasMetadata) const; - LocationAttr readNameLoc(DialectBytecodeReader &reader) const; + Attribute readAttribute(DialectBytecodeReader &reader) const override { + return ::readAttribute(getContext(), reader); + } LogicalResult writeAttribute(Attribute attr, - DialectBytecodeWriter &writer) const override; - void write(ArrayAttr attr, DialectBytecodeWriter &writer) const; - void write(DenseArrayAttr attr, DialectBytecodeWriter &writer) const; - void write(DenseIntOrFPElementsAttr attr, - DialectBytecodeWriter &writer) const; - void write(DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const; - void write(DenseResourceElementsAttr attr, - DialectBytecodeWriter &writer) const; - void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const; - void write(IntegerAttr attr, DialectBytecodeWriter &writer) const; - void write(FloatAttr attr, DialectBytecodeWriter &writer) const; - void write(SparseElementsAttr attr, DialectBytecodeWriter &writer) const; - void write(StringAttr attr, DialectBytecodeWriter &writer) const; - void write(SymbolRefAttr attr, DialectBytecodeWriter &writer) const; - void write(TypeAttr attr, DialectBytecodeWriter &writer) const; - - void write(CallSiteLoc attr, DialectBytecodeWriter &writer) const; - void write(FileLineColLoc attr, DialectBytecodeWriter &writer) const; - void write(FusedLoc attr, DialectBytecodeWriter &writer) const; - void write(NameLoc attr, DialectBytecodeWriter &writer) const; - LogicalResult write(OpaqueLoc attr, DialectBytecodeWriter &writer) const; + DialectBytecodeWriter &writer) const override { + return ::writeAttribute(attr, writer); + } //===--------------------------------------------------------------------===// // Types - Type readType(DialectBytecodeReader &reader) const override; - ComplexType readComplexType(DialectBytecodeReader &reader) const; - IntegerType readIntegerType(DialectBytecodeReader &reader) const; - FunctionType readFunctionType(DialectBytecodeReader &reader) const; - MemRefType readMemRefType(DialectBytecodeReader &reader, - bool hasMemSpace) const; - RankedTensorType readRankedTensorType(DialectBytecodeReader &reader, - bool hasEncoding) const; - TupleType readTupleType(DialectBytecodeReader &reader) const; - UnrankedMemRefType readUnrankedMemRefType(DialectBytecodeReader &reader, - bool hasMemSpace) const; - UnrankedTensorType - readUnrankedTensorType(DialectBytecodeReader &reader) const; - VectorType readVectorType(DialectBytecodeReader &reader, - bool hasScalableDims) const; + Type readType(DialectBytecodeReader &reader) const override { + return ::readType(getContext(), reader); + } LogicalResult writeType(Type type, - DialectBytecodeWriter &writer) const override; - void write(ComplexType type, DialectBytecodeWriter &writer) const; - void write(IntegerType type, DialectBytecodeWriter &writer) const; - void write(FunctionType type, DialectBytecodeWriter &writer) const; - void write(MemRefType type, DialectBytecodeWriter &writer) const; - void write(RankedTensorType type, DialectBytecodeWriter &writer) const; - void write(TupleType type, DialectBytecodeWriter &writer) const; - void write(UnrankedMemRefType type, DialectBytecodeWriter &writer) const; - void write(UnrankedTensorType type, DialectBytecodeWriter &writer) const; - void write(VectorType type, DialectBytecodeWriter &writer) const; + DialectBytecodeWriter &writer) const override { + return ::writeType(type, writer); + } }; } // namespace void builtin_dialect_detail::addBytecodeInterface(BuiltinDialect *dialect) { dialect->addInterfaces(); } - -//===----------------------------------------------------------------------===// -// Attributes -//===----------------------------------------------------------------------===// - -Attribute BuiltinDialectBytecodeInterface::readAttribute( - DialectBytecodeReader &reader) const { - uint64_t code; - if (failed(reader.readVarInt(code))) - return Attribute(); - switch (code) { - case builtin_encoding::kArrayAttr: - return readArrayAttr(reader); - case builtin_encoding::kDictionaryAttr: - return readDictionaryAttr(reader); - case builtin_encoding::kStringAttr: - return readStringAttr(reader, /*hasType=*/false); - case builtin_encoding::kStringAttrWithType: - return readStringAttr(reader, /*hasType=*/true); - case builtin_encoding::kFlatSymbolRefAttr: - return readSymbolRefAttr(reader, /*hasNestedRefs=*/false); - case builtin_encoding::kSymbolRefAttr: - return readSymbolRefAttr(reader, /*hasNestedRefs=*/true); - case builtin_encoding::kTypeAttr: - return readTypeAttr(reader); - case builtin_encoding::kUnitAttr: - return UnitAttr::get(getContext()); - case builtin_encoding::kIntegerAttr: - return readIntegerAttr(reader); - case builtin_encoding::kFloatAttr: - return readFloatAttr(reader); - case builtin_encoding::kCallSiteLoc: - return readCallSiteLoc(reader); - case builtin_encoding::kFileLineColLoc: - return readFileLineColLoc(reader); - case builtin_encoding::kFusedLoc: - return readFusedLoc(reader, /*hasMetadata=*/false); - case builtin_encoding::kFusedLocWithMetadata: - return readFusedLoc(reader, /*hasMetadata=*/true); - case builtin_encoding::kNameLoc: - return readNameLoc(reader); - case builtin_encoding::kUnknownLoc: - return UnknownLoc::get(getContext()); - case builtin_encoding::kDenseResourceElementsAttr: - return readDenseResourceElementsAttr(reader); - case builtin_encoding::kDenseArrayAttr: - return readDenseArrayAttr(reader); - case builtin_encoding::kDenseIntOrFPElementsAttr: - return readDenseIntOrFPElementsAttr(reader); - case builtin_encoding::kDenseStringElementsAttr: - return readDenseStringElementsAttr(reader); - case builtin_encoding::kSparseElementsAttr: - return readSparseElementsAttr(reader); - default: - reader.emitError() << "unknown builtin attribute code: " << code; - return Attribute(); - } -} - -LogicalResult BuiltinDialectBytecodeInterface::writeAttribute( - Attribute attr, DialectBytecodeWriter &writer) const { - return TypeSwitch(attr) - .Case([&](auto attr) { - write(attr, writer); - return success(); - }) - .Case([&](auto attr) { - write(attr, writer); - return success(); - }) - .Case([&](OpaqueLoc attr) { return write(attr, writer); }) - .Case([&](UnitAttr) { - writer.writeVarInt(builtin_encoding::kUnitAttr); - return success(); - }) - .Case([&](UnknownLoc) { - writer.writeVarInt(builtin_encoding::kUnknownLoc); - return success(); - }) - .Default([&](Attribute) { return failure(); }); -} - -//===----------------------------------------------------------------------===// -// ArrayAttr - -ArrayAttr BuiltinDialectBytecodeInterface::readArrayAttr( - DialectBytecodeReader &reader) const { - SmallVector elements; - if (failed(reader.readAttributes(elements))) - return ArrayAttr(); - return ArrayAttr::get(getContext(), elements); -} - -void BuiltinDialectBytecodeInterface::write( - ArrayAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kArrayAttr); - writer.writeAttributes(attr.getValue()); -} - -//===----------------------------------------------------------------------===// -// DenseArrayAttr - -DenseArrayAttr BuiltinDialectBytecodeInterface::readDenseArrayAttr( - DialectBytecodeReader &reader) const { - Type elementType; - uint64_t size; - ArrayRef blob; - if (failed(reader.readType(elementType)) || failed(reader.readVarInt(size)) || - failed(reader.readBlob(blob))) - return DenseArrayAttr(); - return DenseArrayAttr::get(elementType, size, blob); -} - -void BuiltinDialectBytecodeInterface::write( - DenseArrayAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kDenseArrayAttr); - writer.writeType(attr.getElementType()); - writer.writeVarInt(attr.getSize()); - writer.writeOwnedBlob(attr.getRawData()); -} - -//===----------------------------------------------------------------------===// -// DenseIntOrFPElementsAttr - -DenseElementsAttr BuiltinDialectBytecodeInterface::readDenseIntOrFPElementsAttr( - DialectBytecodeReader &reader) const { - ShapedType type; - ArrayRef blob; - if (failed(reader.readType(type)) || failed(reader.readBlob(blob))) - return DenseIntOrFPElementsAttr(); - return DenseIntOrFPElementsAttr::getFromRawBuffer(type, blob); -} - -void BuiltinDialectBytecodeInterface::write( - DenseIntOrFPElementsAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kDenseIntOrFPElementsAttr); - writer.writeType(attr.getType()); - writer.writeOwnedBlob(attr.getRawData()); -} - -//===----------------------------------------------------------------------===// -// DenseStringElementsAttr - -DenseStringElementsAttr -BuiltinDialectBytecodeInterface::readDenseStringElementsAttr( - DialectBytecodeReader &reader) const { - ShapedType type; - uint64_t isSplat; - if (failed(reader.readType(type)) || failed(reader.readVarInt(isSplat))) - return DenseStringElementsAttr(); - - SmallVector values(isSplat ? 1 : type.getNumElements()); - for (StringRef &value : values) - if (failed(reader.readString(value))) - return DenseStringElementsAttr(); - return DenseStringElementsAttr::get(type, values); -} - -void BuiltinDialectBytecodeInterface::write( - DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kDenseStringElementsAttr); - writer.writeType(attr.getType()); - - bool isSplat = attr.isSplat(); - writer.writeVarInt(isSplat); - - // If the attribute is a splat, only write out the single value. - if (isSplat) - return writer.writeOwnedString(attr.getRawStringData().front()); - - for (StringRef str : attr.getRawStringData()) - writer.writeOwnedString(str); -} - -//===----------------------------------------------------------------------===// -// DenseResourceElementsAttr - -DenseResourceElementsAttr -BuiltinDialectBytecodeInterface::readDenseResourceElementsAttr( - DialectBytecodeReader &reader) const { - ShapedType type; - if (failed(reader.readType(type))) - return DenseResourceElementsAttr(); - - FailureOr handle = - reader.readResourceHandle(); - if (failed(handle)) - return DenseResourceElementsAttr(); - - return DenseResourceElementsAttr::get(type, *handle); -} - -void BuiltinDialectBytecodeInterface::write( - DenseResourceElementsAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kDenseResourceElementsAttr); - writer.writeType(attr.getType()); - writer.writeResourceHandle(attr.getRawHandle()); -} - -//===----------------------------------------------------------------------===// -// DictionaryAttr - -DictionaryAttr BuiltinDialectBytecodeInterface::readDictionaryAttr( - DialectBytecodeReader &reader) const { - auto readNamedAttr = [&]() -> FailureOr { - StringAttr name; - Attribute value; - if (failed(reader.readAttribute(name)) || - failed(reader.readAttribute(value))) - return failure(); - return NamedAttribute(name, value); - }; - SmallVector attrs; - if (failed(reader.readList(attrs, readNamedAttr))) - return DictionaryAttr(); - return DictionaryAttr::get(getContext(), attrs); -} - -void BuiltinDialectBytecodeInterface::write( - DictionaryAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kDictionaryAttr); - writer.writeList(attr.getValue(), [&](NamedAttribute attr) { - writer.writeAttribute(attr.getName()); - writer.writeAttribute(attr.getValue()); - }); -} - -//===----------------------------------------------------------------------===// -// FloatAttr - -FloatAttr BuiltinDialectBytecodeInterface::readFloatAttr( - DialectBytecodeReader &reader) const { - FloatType type; - if (failed(reader.readType(type))) - return FloatAttr(); - FailureOr value = - reader.readAPFloatWithKnownSemantics(type.getFloatSemantics()); - if (failed(value)) - return FloatAttr(); - return FloatAttr::get(type, *value); -} - -void BuiltinDialectBytecodeInterface::write( - FloatAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kFloatAttr); - writer.writeType(attr.getType()); - writer.writeAPFloatWithKnownSemantics(attr.getValue()); -} - -//===----------------------------------------------------------------------===// -// IntegerAttr - -IntegerAttr BuiltinDialectBytecodeInterface::readIntegerAttr( - DialectBytecodeReader &reader) const { - Type type; - if (failed(reader.readType(type))) - return IntegerAttr(); - - // Extract the value storage width from the type. - unsigned bitWidth; - if (auto intType = type.dyn_cast()) { - bitWidth = intType.getWidth(); - } else if (type.isa()) { - bitWidth = IndexType::kInternalStorageBitWidth; - } else { - reader.emitError() - << "expected integer or index type for IntegerAttr, but got: " << type; - return IntegerAttr(); - } - - FailureOr value = reader.readAPIntWithKnownWidth(bitWidth); - if (failed(value)) - return IntegerAttr(); - return IntegerAttr::get(type, *value); -} - -void BuiltinDialectBytecodeInterface::write( - IntegerAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kIntegerAttr); - writer.writeType(attr.getType()); - writer.writeAPIntWithKnownWidth(attr.getValue()); -} - -//===----------------------------------------------------------------------===// -// SparseElementsAttr - -SparseElementsAttr BuiltinDialectBytecodeInterface::readSparseElementsAttr( - DialectBytecodeReader &reader) const { - ShapedType type; - DenseIntElementsAttr indices; - DenseElementsAttr values; - if (failed(reader.readType(type)) || failed(reader.readAttribute(indices)) || - failed(reader.readAttribute(values))) - return SparseElementsAttr(); - return SparseElementsAttr::get(type, indices, values); -} - -void BuiltinDialectBytecodeInterface::write( - SparseElementsAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kSparseElementsAttr); - writer.writeType(attr.getType()); - writer.writeAttribute(attr.getIndices()); - writer.writeAttribute(attr.getValues()); -} - -//===----------------------------------------------------------------------===// -// StringAttr - -StringAttr -BuiltinDialectBytecodeInterface::readStringAttr(DialectBytecodeReader &reader, - bool hasType) const { - StringRef string; - if (failed(reader.readString(string))) - return StringAttr(); - - // Read the type if present. - Type type; - if (!hasType) - type = NoneType::get(getContext()); - else if (failed(reader.readType(type))) - return StringAttr(); - return StringAttr::get(string, type); -} - -void BuiltinDialectBytecodeInterface::write( - StringAttr attr, DialectBytecodeWriter &writer) const { - // We only encode the type if it isn't NoneType, which is significantly less - // common. - Type type = attr.getType(); - if (!type.isa()) { - writer.writeVarInt(builtin_encoding::kStringAttrWithType); - writer.writeOwnedString(attr.getValue()); - writer.writeType(type); - return; - } - writer.writeVarInt(builtin_encoding::kStringAttr); - writer.writeOwnedString(attr.getValue()); -} - -//===----------------------------------------------------------------------===// -// SymbolRefAttr - -SymbolRefAttr BuiltinDialectBytecodeInterface::readSymbolRefAttr( - DialectBytecodeReader &reader, bool hasNestedRefs) const { - StringAttr rootReference; - if (failed(reader.readAttribute(rootReference))) - return SymbolRefAttr(); - SmallVector nestedReferences; - if (hasNestedRefs && failed(reader.readAttributes(nestedReferences))) - return SymbolRefAttr(); - return SymbolRefAttr::get(rootReference, nestedReferences); -} - -void BuiltinDialectBytecodeInterface::write( - SymbolRefAttr attr, DialectBytecodeWriter &writer) const { - ArrayRef nestedRefs = attr.getNestedReferences(); - writer.writeVarInt(nestedRefs.empty() ? builtin_encoding::kFlatSymbolRefAttr - : builtin_encoding::kSymbolRefAttr); - - writer.writeAttribute(attr.getRootReference()); - if (!nestedRefs.empty()) - writer.writeAttributes(nestedRefs); -} - -//===----------------------------------------------------------------------===// -// TypeAttr - -TypeAttr BuiltinDialectBytecodeInterface::readTypeAttr( - DialectBytecodeReader &reader) const { - Type type; - if (failed(reader.readType(type))) - return TypeAttr(); - return TypeAttr::get(type); -} - -void BuiltinDialectBytecodeInterface::write( - TypeAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kTypeAttr); - writer.writeType(attr.getValue()); -} - -//===----------------------------------------------------------------------===// -// CallSiteLoc - -LocationAttr BuiltinDialectBytecodeInterface::readCallSiteLoc( - DialectBytecodeReader &reader) const { - LocationAttr callee, caller; - if (failed(reader.readAttribute(callee)) || - failed(reader.readAttribute(caller))) - return LocationAttr(); - return CallSiteLoc::get(callee, caller); -} - -void BuiltinDialectBytecodeInterface::write( - CallSiteLoc attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kCallSiteLoc); - writer.writeAttribute(attr.getCallee()); - writer.writeAttribute(attr.getCaller()); -} - -//===----------------------------------------------------------------------===// -// FileLineColLoc - -LocationAttr BuiltinDialectBytecodeInterface::readFileLineColLoc( - DialectBytecodeReader &reader) const { - StringAttr filename; - uint64_t line, column; - if (failed(reader.readAttribute(filename)) || - failed(reader.readVarInt(line)) || failed(reader.readVarInt(column))) - return LocationAttr(); - return FileLineColLoc::get(filename, line, column); -} - -void BuiltinDialectBytecodeInterface::write( - FileLineColLoc attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kFileLineColLoc); - writer.writeAttribute(attr.getFilename()); - writer.writeVarInt(attr.getLine()); - writer.writeVarInt(attr.getColumn()); -} - -//===----------------------------------------------------------------------===// -// FusedLoc - -LocationAttr -BuiltinDialectBytecodeInterface::readFusedLoc(DialectBytecodeReader &reader, - bool hasMetadata) const { - // Parse the child locations. - auto readLoc = [&]() -> FailureOr { - LocationAttr locAttr; - if (failed(reader.readAttribute(locAttr))) - return failure(); - return Location(locAttr); - }; - SmallVector locations; - if (failed(reader.readList(locations, readLoc))) - return LocationAttr(); - - // Parse the metadata if present. - Attribute metadata; - if (hasMetadata && failed(reader.readAttribute(metadata))) - return LocationAttr(); - - return FusedLoc::get(locations, metadata, getContext()); -} - -void BuiltinDialectBytecodeInterface::write( - FusedLoc attr, DialectBytecodeWriter &writer) const { - if (Attribute metadata = attr.getMetadata()) { - writer.writeVarInt(builtin_encoding::kFusedLocWithMetadata); - writer.writeAttributes(attr.getLocations()); - writer.writeAttribute(metadata); - } else { - writer.writeVarInt(builtin_encoding::kFusedLoc); - writer.writeAttributes(attr.getLocations()); - } -} - -//===----------------------------------------------------------------------===// -// NameLoc - -LocationAttr BuiltinDialectBytecodeInterface::readNameLoc( - DialectBytecodeReader &reader) const { - StringAttr name; - LocationAttr childLoc; - if (failed(reader.readAttribute(name)) || - failed(reader.readAttribute(childLoc))) - return LocationAttr(); - return NameLoc::get(name, childLoc); -} - -void BuiltinDialectBytecodeInterface::write( - NameLoc attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kNameLoc); - writer.writeAttribute(attr.getName()); - writer.writeAttribute(attr.getChildLoc()); -} - -//===----------------------------------------------------------------------===// -// OpaqueLoc - -LogicalResult -BuiltinDialectBytecodeInterface::write(OpaqueLoc attr, - DialectBytecodeWriter &writer) const { - // We can't encode an OpaqueLoc directly given that it is in-memory only, so - // encode the fallback instead. - return writeAttribute(attr.getFallbackLocation(), writer); -} - -//===----------------------------------------------------------------------===// -// Types -//===----------------------------------------------------------------------===// - -Type BuiltinDialectBytecodeInterface::readType( - DialectBytecodeReader &reader) const { - uint64_t code; - if (failed(reader.readVarInt(code))) - return Type(); - switch (code) { - case builtin_encoding::kIntegerType: - return readIntegerType(reader); - case builtin_encoding::kIndexType: - return IndexType::get(getContext()); - case builtin_encoding::kFunctionType: - return readFunctionType(reader); - case builtin_encoding::kBFloat16Type: - return BFloat16Type::get(getContext()); - case builtin_encoding::kFloat16Type: - return Float16Type::get(getContext()); - case builtin_encoding::kFloat32Type: - return Float32Type::get(getContext()); - case builtin_encoding::kFloat64Type: - return Float64Type::get(getContext()); - case builtin_encoding::kFloat80Type: - return Float80Type::get(getContext()); - case builtin_encoding::kFloat128Type: - return Float128Type::get(getContext()); - case builtin_encoding::kComplexType: - return readComplexType(reader); - case builtin_encoding::kMemRefType: - return readMemRefType(reader, /*hasMemSpace=*/false); - case builtin_encoding::kMemRefTypeWithMemSpace: - return readMemRefType(reader, /*hasMemSpace=*/true); - case builtin_encoding::kNoneType: - return NoneType::get(getContext()); - case builtin_encoding::kRankedTensorType: - return readRankedTensorType(reader, /*hasEncoding=*/false); - case builtin_encoding::kRankedTensorTypeWithEncoding: - return readRankedTensorType(reader, /*hasEncoding=*/true); - case builtin_encoding::kTupleType: - return readTupleType(reader); - case builtin_encoding::kUnrankedMemRefType: - return readUnrankedMemRefType(reader, /*hasMemSpace=*/false); - case builtin_encoding::kUnrankedMemRefTypeWithMemSpace: - return readUnrankedMemRefType(reader, /*hasMemSpace=*/true); - case builtin_encoding::kUnrankedTensorType: - return readUnrankedTensorType(reader); - case builtin_encoding::kVectorType: - return readVectorType(reader, /*hasScalableDims=*/false); - case builtin_encoding::kVectorTypeWithScalableDims: - return readVectorType(reader, /*hasScalableDims=*/true); - - default: - reader.emitError() << "unknown builtin type code: " << code; - return Type(); - } -} - -LogicalResult BuiltinDialectBytecodeInterface::writeType( - Type type, DialectBytecodeWriter &writer) const { - return TypeSwitch(type) - .Case([&](auto type) { - write(type, writer); - return success(); - }) - .Case([&](IndexType) { - return writer.writeVarInt(builtin_encoding::kIndexType), success(); - }) - .Case([&](BFloat16Type) { - return writer.writeVarInt(builtin_encoding::kBFloat16Type), success(); - }) - .Case([&](Float16Type) { - return writer.writeVarInt(builtin_encoding::kFloat16Type), success(); - }) - .Case([&](Float32Type) { - return writer.writeVarInt(builtin_encoding::kFloat32Type), success(); - }) - .Case([&](Float64Type) { - return writer.writeVarInt(builtin_encoding::kFloat64Type), success(); - }) - .Case([&](Float80Type) { - return writer.writeVarInt(builtin_encoding::kFloat80Type), success(); - }) - .Case([&](Float128Type) { - return writer.writeVarInt(builtin_encoding::kFloat128Type), success(); - }) - .Case([&](NoneType) { - return writer.writeVarInt(builtin_encoding::kNoneType), success(); - }) - .Default([&](Type) { return failure(); }); -} - -//===----------------------------------------------------------------------===// -// ComplexType - -ComplexType BuiltinDialectBytecodeInterface::readComplexType( - DialectBytecodeReader &reader) const { - Type elementType; - if (failed(reader.readType(elementType))) - return ComplexType(); - return ComplexType::get(elementType); -} - -void BuiltinDialectBytecodeInterface::write( - ComplexType type, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kComplexType); - writer.writeType(type.getElementType()); -} - -//===----------------------------------------------------------------------===// -// IntegerType - -IntegerType BuiltinDialectBytecodeInterface::readIntegerType( - DialectBytecodeReader &reader) const { - uint64_t encoding; - if (failed(reader.readVarInt(encoding))) - return IntegerType(); - return IntegerType::get( - getContext(), encoding >> 2, - static_cast(encoding & 0x3)); -} - -void BuiltinDialectBytecodeInterface::write( - IntegerType type, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kIntegerType); - writer.writeVarInt((type.getWidth() << 2) | type.getSignedness()); -} - -//===----------------------------------------------------------------------===// -// FunctionType - -FunctionType BuiltinDialectBytecodeInterface::readFunctionType( - DialectBytecodeReader &reader) const { - SmallVector inputs, results; - if (failed(reader.readTypes(inputs)) || failed(reader.readTypes(results))) - return FunctionType(); - return FunctionType::get(getContext(), inputs, results); -} - -void BuiltinDialectBytecodeInterface::write( - FunctionType type, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kFunctionType); - writer.writeTypes(type.getInputs()); - writer.writeTypes(type.getResults()); -} - -//===----------------------------------------------------------------------===// -// MemRefType - -MemRefType -BuiltinDialectBytecodeInterface::readMemRefType(DialectBytecodeReader &reader, - bool hasMemSpace) const { - Attribute memorySpace; - if (hasMemSpace && failed(reader.readAttribute(memorySpace))) - return MemRefType(); - SmallVector shape; - Type elementType; - MemRefLayoutAttrInterface layout; - if (failed(reader.readSignedVarInts(shape)) || - failed(reader.readType(elementType)) || - failed(reader.readAttribute(layout))) - return MemRefType(); - return MemRefType::get(shape, elementType, layout, memorySpace); -} - -void BuiltinDialectBytecodeInterface::write( - MemRefType type, DialectBytecodeWriter &writer) const { - if (Attribute memSpace = type.getMemorySpace()) { - writer.writeVarInt(builtin_encoding::kMemRefTypeWithMemSpace); - writer.writeAttribute(memSpace); - } else { - writer.writeVarInt(builtin_encoding::kMemRefType); - } - writer.writeSignedVarInts(type.getShape()); - writer.writeType(type.getElementType()); - writer.writeAttribute(type.getLayout()); -} - -//===----------------------------------------------------------------------===// -// RankedTensorType - -RankedTensorType BuiltinDialectBytecodeInterface::readRankedTensorType( - DialectBytecodeReader &reader, bool hasEncoding) const { - Attribute encoding; - if (hasEncoding && failed(reader.readAttribute(encoding))) - return RankedTensorType(); - SmallVector shape; - Type elementType; - if (failed(reader.readSignedVarInts(shape)) || - failed(reader.readType(elementType))) - return RankedTensorType(); - return RankedTensorType::get(shape, elementType, encoding); -} - -void BuiltinDialectBytecodeInterface::write( - RankedTensorType type, DialectBytecodeWriter &writer) const { - if (Attribute encoding = type.getEncoding()) { - writer.writeVarInt(builtin_encoding::kRankedTensorTypeWithEncoding); - writer.writeAttribute(encoding); - } else { - writer.writeVarInt(builtin_encoding::kRankedTensorType); - } - writer.writeSignedVarInts(type.getShape()); - writer.writeType(type.getElementType()); -} - -//===----------------------------------------------------------------------===// -// TupleType - -TupleType BuiltinDialectBytecodeInterface::readTupleType( - DialectBytecodeReader &reader) const { - SmallVector elements; - if (failed(reader.readTypes(elements))) - return TupleType(); - return TupleType::get(getContext(), elements); -} - -void BuiltinDialectBytecodeInterface::write( - TupleType type, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kTupleType); - writer.writeTypes(type.getTypes()); -} - -//===----------------------------------------------------------------------===// -// UnrankedMemRefType - -UnrankedMemRefType BuiltinDialectBytecodeInterface::readUnrankedMemRefType( - DialectBytecodeReader &reader, bool hasMemSpace) const { - Attribute memorySpace; - if (hasMemSpace && failed(reader.readAttribute(memorySpace))) - return UnrankedMemRefType(); - Type elementType; - if (failed(reader.readType(elementType))) - return UnrankedMemRefType(); - return UnrankedMemRefType::get(elementType, memorySpace); -} - -void BuiltinDialectBytecodeInterface::write( - UnrankedMemRefType type, DialectBytecodeWriter &writer) const { - if (Attribute memSpace = type.getMemorySpace()) { - writer.writeVarInt(builtin_encoding::kUnrankedMemRefTypeWithMemSpace); - writer.writeAttribute(memSpace); - } else { - writer.writeVarInt(builtin_encoding::kUnrankedMemRefType); - } - writer.writeType(type.getElementType()); -} - -//===----------------------------------------------------------------------===// -// UnrankedTensorType - -UnrankedTensorType BuiltinDialectBytecodeInterface::readUnrankedTensorType( - DialectBytecodeReader &reader) const { - Type elementType; - if (failed(reader.readType(elementType))) - return UnrankedTensorType(); - return UnrankedTensorType::get(elementType); -} - -void BuiltinDialectBytecodeInterface::write( - UnrankedTensorType type, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kUnrankedTensorType); - writer.writeType(type.getElementType()); -} - -//===----------------------------------------------------------------------===// -// VectorType - -VectorType -BuiltinDialectBytecodeInterface::readVectorType(DialectBytecodeReader &reader, - bool hasScalableDims) const { - uint64_t numScalableDims = 0; - if (hasScalableDims && failed(reader.readVarInt(numScalableDims))) - return VectorType(); - SmallVector shape; - Type elementType; - if (failed(reader.readSignedVarInts(shape)) || - failed(reader.readType(elementType))) - return VectorType(); - return VectorType::get(shape, elementType, numScalableDims); -} - -void BuiltinDialectBytecodeInterface::write( - VectorType type, DialectBytecodeWriter &writer) const { - if (unsigned numScalableDims = type.getNumScalableDims()) { - writer.writeVarInt(builtin_encoding::kVectorTypeWithScalableDims); - writer.writeVarInt(numScalableDims); - } else { - writer.writeVarInt(builtin_encoding::kVectorType); - } - writer.writeSignedVarInts(type.getShape()); - writer.writeType(type.getElementType()); -} diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt index 4377ebe16055..b729282e627d 100644 --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -44,6 +44,7 @@ add_mlir_library(MLIRIR DEPENDS MLIRBuiltinAttributesIncGen MLIRBuiltinAttributeInterfacesIncGen + MLIRBuiltinDialectBytecodeIncGen MLIRBuiltinDialectIncGen MLIRBuiltinLocationAttributesIncGen MLIRBuiltinOpsIncGen diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp new file mode 100644 index 000000000000..f13bdd49413b --- /dev/null +++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp @@ -0,0 +1,467 @@ +//===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/IndentedOstream.h" +#include "mlir/TableGen/GenInfo.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include + +using namespace llvm; + +static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode"); +static llvm::cl::opt + selectedBcDialect("bytecode-dialect", + llvm::cl::desc("The dialect to gen for"), + llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); + +namespace { + +/// Helper class to generate C++ bytecode parser helpers. +class Generator { +public: + Generator(raw_ostream &output) : output(output) {} + + /// Returns whether successfully emitted attribute/type parsers. + void emitParse(StringRef kind, Record &x); + + /// Returns whether successfully emitted attribute/type printers. + void emitPrint(StringRef kind, StringRef type, + ArrayRef> vec); + + /// Emits parse dispatch table. + void emitParseDispatch(StringRef kind, ArrayRef vec); + + /// Emits print dispatch table. + void emitPrintDispatch(StringRef kind, ArrayRef vec); + +private: + /// Emits parse calls to construct given kind. + void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder, + ArrayRef args, ArrayRef argNames, + StringRef failure, mlir::raw_indented_ostream &ios); + + /// Emits print instructions. + void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent, + StringRef name, mlir::raw_indented_ostream &ios); + + raw_ostream &output; +}; +} // namespace + +/// Helper to replace set of from strings to target in `s`. +/// Assumed: non-overlapping replacements. +static std::string format(StringRef templ, + std::map &&map) { + std::string s = templ.str(); + for (const auto &[from, to] : map) + // All replacements start with $, don't treat as anchor. + s = std::regex_replace(s, std::regex("\\" + from), to); + return s; +} + +/// Return string with first character capitalized. +static std::string capitalize(StringRef str) { + return ((Twine)toUpper(str[0]) + str.drop_front()).str(); +} + +/// Return the C++ type for the given record. +static std::string getCType(Record *def) { + std::string format = "{0}"; + if (def->isSubClassOf("Array")) { + def = def->getValueAsDef("elemT"); + format = "SmallVector<{0}>"; + } + + StringRef cType = def->getValueAsString("cType"); + if (cType.empty()) { + if (def->isAnonymous()) + PrintFatalError(def->getLoc(), "Unable to determine cType"); + + return formatv(format.c_str(), def->getName().str()); + } + return formatv(format.c_str(), cType.str()); +} + +void Generator::emitParseDispatch(StringRef kind, ArrayRef vec) { + mlir::raw_indented_ostream os(output); + char const *head = + R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))"; + os << formatv(head, capitalize(kind)); + auto funScope = os.scope(" {\n", "}\n\n"); + + os << "uint64_t kind;\n"; + os << "if (failed(reader.readVarInt(kind)))\n" + << " return " << capitalize(kind) << "();\n"; + os << "switch (kind) "; + { + auto switchScope = os.scope("{\n", "}\n"); + for (const auto &it : llvm::enumerate(vec)) { + os << formatv("case {1}:\n return read{0}(context, reader);\n", + it.value()->getName(), it.index()); + } + os << "default:\n" + << " reader.emitError() << \"unknown attribute code: \" " + << "<< kind;\n" + << " return " << capitalize(kind) << "();\n"; + } + os << "return " << capitalize(kind) << "();\n"; +} + +void Generator::emitParse(StringRef kind, Record &x) { + char const *head = + R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )"; + mlir::raw_indented_ostream os(output); + std::string returnType = getCType(&x); + os << formatv(head, returnType, x.getName()); + DagInit *members = x.getValueAsDag("members"); + SmallVector argNames = + llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) { + return init->getAsUnquotedString(); + })); + StringRef builder = x.getValueAsString("cBuilder"); + emitParseHelper(kind, returnType, builder, members->getArgs(), argNames, + returnType + "()", os); + os << "\n\n"; +} + +void printParseConditional(mlir::raw_indented_ostream &ios, + ArrayRef args, + ArrayRef argNames) { + ios << "if "; + auto parenScope = ios.scope("(", ") {"); + ios.indent(); + + auto listHelperName = [](StringRef name) { + return formatv("read{0}", capitalize(name)); + }; + + auto parsedArgs = + llvm::to_vector(make_filter_range(args, [](Init *const attr) { + Record *def = cast(attr)->getDef(); + if (def->isSubClassOf("Array")) + return true; + return !def->getValueAsString("cParser").empty(); + })); + + interleave( + zip(parsedArgs, argNames), + [&](std::tuple it) { + Record *attr = cast(std::get<0>(it))->getDef(); + std::string parser; + if (auto optParser = attr->getValueAsOptionalString("cParser")) { + parser = *optParser; + } else if (attr->isSubClassOf("Array")) { + Record *def = attr->getValueAsDef("elemT"); + bool composite = def->isSubClassOf("CompositeBytecode"); + if (!composite && def->isSubClassOf("AttributeKind")) + parser = "succeeded($_reader.readAttributes($_var))"; + else if (!composite && def->isSubClassOf("TypeKind")) + parser = "succeeded($_reader.readTypes($_var))"; + else + parser = ("succeeded($_reader.readList($_var, " + + listHelperName(std::get<1>(it)) + "))") + .str(); + } else { + PrintFatalError(attr->getLoc(), "No parser specified"); + } + std::string type = getCType(attr); + ios << format(parser, {{"$_reader", "reader"}, + {"$_resultType", type}, + {"$_var", std::get<1>(it)}}); + }, + [&]() { ios << " &&\n"; }); +} + +void Generator::emitParseHelper(StringRef kind, StringRef returnType, + StringRef builder, ArrayRef args, + ArrayRef argNames, + StringRef failure, + mlir::raw_indented_ostream &ios) { + auto funScope = ios.scope("{\n", "}"); + + if (args.empty()) { + ios << formatv("return get<{0}>(context);\n", returnType); + return; + } + + // Print decls. + std::string lastCType = ""; + for (auto [arg, name] : zip(args, argNames)) { + DefInit *first = dyn_cast(arg); + if (!first) + PrintFatalError("Unexpected type for " + name); + Record *def = first->getDef(); + + // Create variable decls, if there are a block of same type then create + // comma separated list of them. + std::string cType = getCType(def); + if (lastCType == cType) { + ios << ", "; + } else { + if (!lastCType.empty()) + ios << ";\n"; + ios << cType << " "; + } + ios << name; + lastCType = cType; + } + ios << ";\n"; + + // Returns the name of the helper used in list parsing. E.g., the name of the + // lambda passed to array parsing. + auto listHelperName = [](StringRef name) { + return formatv("read{0}", capitalize(name)); + }; + + // Emit list helper functions. + for (auto [arg, name] : zip(args, argNames)) { + Record *attr = cast(arg)->getDef(); + if (!attr->isSubClassOf("Array")) + continue; + + // TODO: Dedupe readers. + Record *def = attr->getValueAsDef("elemT"); + if (!def->isSubClassOf("CompositeBytecode") && + (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind"))) + continue; + + std::string returnType = getCType(def); + ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<" + << returnType << "> "; + SmallVector args; + SmallVector argNames; + if (def->isSubClassOf("CompositeBytecode")) { + DagInit *members = def->getValueAsDag("members"); + args = llvm::to_vector(members->getArgs()); + argNames = llvm::to_vector( + map_range(members->getArgNames(), [](StringInit *init) { + return init->getAsUnquotedString(); + })); + } else { + args = {def->getDefInit()}; + argNames = {"temp"}; + } + StringRef builder = def->getValueAsString("cBuilder"); + emitParseHelper(kind, returnType, builder, args, argNames, "failure()", + ios); + ios << ";\n"; + } + + // Print parse conditional. + printParseConditional(ios, args, argNames); + + // Compute args to pass to create method. + auto passedArgs = llvm::to_vector(make_filter_range( + argNames, [](StringRef str) { return !str.starts_with("_"); })); + std::string argStr; + raw_string_ostream argStream(argStr); + interleaveComma(passedArgs, argStream, + [&](const std::string &str) { argStream << str; }); + // Return the invoked constructor. + ios << "\nreturn " + << format(builder, {{"$_resultType", returnType.str()}, + {"$_args", argStream.str()}}) + << ";\n"; + ios.unindent(); + + // TODO: Emit error in debug. + // This assumes the result types in error case can always be empty + // constructed. + ios << "}\nreturn " << failure << ";\n"; +} + +void Generator::emitPrint(StringRef kind, StringRef type, + ArrayRef> vec) { + char const *head = + R"(static void write({0} {1}, DialectBytecodeWriter &writer) )"; + mlir::raw_indented_ostream os(output); + os << formatv(head, type, kind); + auto funScope = os.scope("{\n", "}\n\n"); + + // Check that predicates specified if multiple bytecode instances. + for (llvm::Record *rec : make_second_range(vec)) { + StringRef pred = rec->getValueAsString("printerPredicate"); + if (vec.size() > 1 && pred.empty()) { + for (auto [index, rec] : vec) { + (void)index; + StringRef pred = rec->getValueAsString("printerPredicate"); + if (vec.size() > 1 && pred.empty()) + PrintError(rec->getLoc(), + "Requires parsing predicate given common cType"); + } + PrintFatalError("Unspecified for shared cType " + type); + } + } + + for (auto [index, rec] : vec) { + StringRef pred = rec->getValueAsString("printerPredicate"); + if (!pred.empty()) { + os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n"; + os.indent(); + } + + os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index + << ");\n"; + + auto *members = rec->getValueAsDag("members"); + for (auto [arg, name] : + llvm::zip(members->getArgs(), members->getArgNames())) { + DefInit *def = dyn_cast(arg); + assert(def); + Record *memberRec = def->getDef(); + emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os); + } + + if (!pred.empty()) { + os.unindent(); + os << "}\n"; + } + } +} + +void Generator::emitPrintHelper(Record *memberRec, StringRef kind, + StringRef parent, StringRef name, + mlir::raw_indented_ostream &ios) { + std::string getter; + if (auto cGetter = memberRec->getValueAsOptionalString("cGetter"); + cGetter && !cGetter->empty()) { + getter = format( + *cGetter, + {{"$_attrType", parent.str()}, + {"$_member", name.str()}, + {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}}); + } else { + getter = + formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true)) + .str(); + } + + if (memberRec->isSubClassOf("Array")) { + Record *def = memberRec->getValueAsDef("elemT"); + if (!def->isSubClassOf("CompositeBytecode")) { + if (def->isSubClassOf("AttributeKind")) { + ios << "writer.writeAttributes(" << getter << ");\n"; + return; + } + if (def->isSubClassOf("TypeKind")) { + ios << "writer.writeTypes(" << getter << ");\n"; + return; + } + } + std::string returnType = getCType(def); + ios << "writer.writeList(" << getter << ", [&](" << returnType << " " + << kind << ") "; + auto lambdaScope = ios.scope("{\n", "});\n"); + return emitPrintHelper(def, kind, kind, kind, ios); + } + if (memberRec->isSubClassOf("CompositeBytecode")) { + auto *members = memberRec->getValueAsDag("members"); + for (auto [arg, argName] : + zip(members->getArgs(), members->getArgNames())) { + DefInit *def = dyn_cast(arg); + assert(def); + emitPrintHelper(def->getDef(), kind, parent, + argName->getAsUnquotedString(), ios); + } + } + + if (std::string printer = memberRec->getValueAsString("cPrinter").str(); + !printer.empty()) + ios << format(printer, {{"$_writer", "writer"}, + {"$_name", kind.str()}, + {"$_getter", getter}}) + << ";\n"; +} + +void Generator::emitPrintDispatch(StringRef kind, ArrayRef vec) { + mlir::raw_indented_ostream os(output); + char const *head = R"(static LogicalResult write{0}({0} {1}, + DialectBytecodeWriter &writer))"; + os << formatv(head, capitalize(kind), kind); + auto funScope = os.scope(" {\n", "}\n\n"); + + os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind + << ")"; + auto switchScope = os.scope("", ""); + for (StringRef type : vec) { + os << "\n.Case([&](" << type << " t)"; + auto caseScope = os.scope(" {\n", "})"); + os << "return write(t, writer), success();\n"; + } + os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n"; +} + +namespace { +/// Container of Attribute or Type for Dialect. +struct AttrOrType { + std::vector attr, type; +}; +} // namespace + +static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) { + MapVector dialectAttrOrType; + for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) { + if (!selectedBcDialect.empty() && + it->getValueAsString("dialect") != selectedBcDialect) + continue; + dialectAttrOrType[it->getValueAsString("dialect")].attr = + it->getValueAsListOfDefs("elems"); + } + for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) { + if (!selectedBcDialect.empty() && + it->getValueAsString("dialect") != selectedBcDialect) + continue; + dialectAttrOrType[it->getValueAsString("dialect")].type = + it->getValueAsListOfDefs("elems"); + } + + if (dialectAttrOrType.size() != 1) + PrintFatalError("Single dialect per invocation required (either only " + "one in input file or specified via dialect option)"); + + auto it = dialectAttrOrType.front(); + Generator gen(os); + + SmallVector *, 2> vecs; + SmallVector kinds; + vecs.push_back(&it.second.attr); + kinds.push_back("attribute"); + vecs.push_back(&it.second.type); + kinds.push_back("type"); + for (auto [vec, kind] : zip(vecs, kinds)) { + // Handle Attribute/Type emission. + std::map>> perType; + for (auto kt : llvm::enumerate(*vec)) + perType[getCType(kt.value())].emplace_back(kt.index(), kt.value()); + for (const auto &jt : perType) { + for (auto kt : jt.second) + gen.emitParse(kind, *std::get<1>(kt)); + gen.emitPrint(kind, jt.first, jt.second); + } + gen.emitParseDispatch(kind, *vec); + + SmallVector types; + for (const auto &it : perType) { + types.push_back(it.first); + } + gen.emitPrintDispatch(kind, types); + } + + return false; +} + +static mlir::GenRegistration + genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers", + [](const RecordKeeper &records, raw_ostream &os) { + return emitBCRW(records, os); + }); diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt index fe217450d6fb..0835b6d27c71 100644 --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -9,6 +9,7 @@ add_tablegen(mlir-tblgen MLIR EXPORT MLIR AttrOrTypeDefGen.cpp AttrOrTypeFormatGen.cpp + BytecodeDialectGen.cpp DialectGen.cpp DirectiveCommonGen.cpp EnumsGen.cpp diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 881a12426baf..fc189556b864 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -91,6 +91,7 @@ td_library( ], includes = ["include"], deps = [ + ":BytecodeTdFiles", ":CallInterfacesTdFiles", ":CastInterfacesTdFiles", ":DataLayoutInterfacesTdFiles", @@ -117,6 +118,20 @@ gentbl_cc_library( deps = [":BuiltinDialectTdFiles"], ) +gentbl_cc_library( + name = "BuiltinDialectBytecodeGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-bytecode", "-bytecode-dialect=Builtin"], + "include/mlir/IR/BuiltinDialectBytecode.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/IR/BuiltinDialectBytecode.td", + deps = [":BuiltinDialectTdFiles"], +) + gentbl_cc_library( name = "BuiltinAttributesIncGen", strip_include_prefix = "include", @@ -277,6 +292,7 @@ cc_library( deps = [ ":BuiltinAttributeInterfacesIncGen", ":BuiltinAttributesIncGen", + ":BuiltinDialectBytecodeGen", ":BuiltinDialectIncGen", ":BuiltinLocationAttributesIncGen", ":BuiltinOpsIncGen", @@ -929,6 +945,12 @@ td_library( ], ) +td_library( + name = "BytecodeTdFiles", + srcs = ["include/mlir/IR/BytecodeBase.td"], + includes = ["include"], +) + td_library( name = "CallInterfacesTdFiles", srcs = ["include/mlir/Interfaces/CallInterfaces.td"],