[mlir] Dialect type/attr bytecode read/write generator.

Tool to help generate dialect bytecode Attribute & Type reader/writing.
Show usage by flipping builtin dialect.

It helps reduce boilerplate when writing dialect bytecode attribute and
type readers/writers. It is not an attempt at a generic spec mechanism
but rather practically focussing on boilerplate reduction while also
considering that it need not be the only in memory format and make it
relatively easy to change.

There should be some cleanup in follow up as we expand to more dialects.

Differential Revision: https://reviews.llvm.org/D144820
This commit is contained in:
Jacques Pienaar 2023-04-24 11:53:58 -07:00
parent 8e091b1220
commit 0911558005
9 changed files with 1324 additions and 1132 deletions

View File

@ -345,6 +345,37 @@ public:
}
};
/// Helper for resource handle reading that returns LogicalResult.
template <typename T, typename... Ts>
static LogicalResult readResourceHandle(DialectBytecodeReader &reader,
FailureOr<T> &value, Ts &&...params) {
FailureOr<T> handle = reader.readResourceHandle<T>();
if (failed(handle))
return failure();
if (auto *result = dyn_cast<T>(&*handle)) {
value = std::move(*result);
return success();
}
return failure();
}
/// Helper method that injects context only if needed, this helps unify some of
/// the attribute construction methods.
template <typename T, typename... Ts>
auto get(MLIRContext *context, Ts &&...params) {
// Prefer a direct `get` method if one exists.
if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) {
(void)context;
return T::get(std::forward<Ts>(params)...);
} else if constexpr (llvm::is_detected<detail::has_get_method, T,
MLIRContext *, Ts...>::value) {
return T::get(context, std::forward<Ts>(params)...);
} else {
// Otherwise, pass to the base get.
return T::Base::get(context, std::forward<Ts>(params)...);
}
}
} // namespace mlir
#endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H

View File

@ -0,0 +1,566 @@
//===-- BuiltinBytecode.td - Builtin bytecode defs ---------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the Builtin bytecode reader/writer definition file.
//
//===----------------------------------------------------------------------===//
#ifndef BUILTIN_BYTECODE
#define BUILTIN_BYTECODE
include "mlir/IR/BytecodeBase.td"
def LocationAttr : AttributeKind;
def ShapedType: WithType<"ShapedType", Type>;
def Location : CompositeBytecode {
dag members = (attr
WithGetter<"(LocationAttr)$_attrType", WithType<"LocationAttr", LocationAttr>>:$value
);
let cBuilder = "Location($_args)";
}
def String :
WithParser <"succeeded($_reader.readString($_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeOwnedString($_getter)",
WithGetter <"$_attrType",
WithType <"StringRef">>>>>;
// enum AttributeCode {
// /// ArrayAttr {
// /// elements: Attribute[]
// /// }
// ///
// kArrayAttr = 0,
//
def ArrayAttr : DialectAttribute<(attr
Array<Attribute>:$value
)>;
let cType = "StringAttr" in {
// /// StringAttr {
// /// value: string
// /// }
// kStringAttr = 2,
def StringAttr : DialectAttribute<(attr
String:$value
)> {
let printerPredicate = "$_val.getType().isa<NoneType>()";
}
// /// StringAttrWithType {
// /// value: string,
// /// type: Type
// /// }
// /// A variant of StringAttr with a type.
// kStringAttrWithType = 3,
def StringAttrWithType : DialectAttribute<(attr
String:$value,
Type:$type
)> { let printerPredicate = "!$_val.getType().isa<NoneType>()"; }
}
// /// DictionaryAttr {
// /// attrs: <StringAttr, Attribute>[]
// /// }
// kDictionaryAttr = 1,
def NamedAttribute : CompositeBytecode {
dag members = (attr
StringAttr:$name,
Attribute:$value
);
let cBuilder = "NamedAttribute($_args)";
}
def DictionaryAttr : DialectAttribute<(attr
Array<NamedAttribute>:$value
)>;
// /// FlatSymbolRefAttr {
// /// rootReference: StringAttr
// /// }
// /// A variant of SymbolRefAttr with no leaf references.
// kFlatSymbolRefAttr = 4,
def FlatSymbolRefAttr: DialectAttribute<(attr
StringAttr:$rootReference
)>;
// /// SymbolRefAttr {
// /// rootReference: StringAttr,
// /// leafReferences: FlatSymbolRefAttr[]
// /// }
// kSymbolRefAttr = 5,
def SymbolRefAttr: DialectAttribute<(attr
StringAttr:$rootReference,
Array<FlatSymbolRefAttr>:$nestedReferences
)>;
// /// TypeAttr {
// /// value: Type
// /// }
// kTypeAttr = 6,
def TypeAttr: DialectAttribute<(attr
Type:$value
)>;
// /// UnitAttr {
// /// }
// kUnitAttr = 7,
def UnitAttr: DialectAttribute<(attr)>;
// /// IntegerAttr {
// /// type: Type
// /// value: APInt,
// /// }
// kIntegerAttr = 8,
def IntegerAttr: DialectAttribute<(attr
Type:$type,
KnownWidthAPInt<"type">:$value
)> {
let cBuilder = "get<$_resultType>(context, type, *value)";
}
//
// /// FloatAttr {
// /// type: FloatType
// /// value: APFloat
// /// }
// kFloatAttr = 9,
defvar FloatType = Type;
def FloatAttr : DialectAttribute<(attr
FloatType:$type,
KnownSemanticsAPFloat<"type">:$value
)> {
let cBuilder = "get<$_resultType>(context, type, *value)";
}
// /// CallSiteLoc {
// /// callee: LocationAttr,
// /// caller: LocationAttr
// /// }
// kCallSiteLoc = 10,
def CallSiteLoc : DialectAttribute<(attr
LocationAttr:$callee,
LocationAttr:$caller
)>;
// /// FileLineColLoc {
// /// filename: StringAttr,
// /// line: varint,
// /// column: varint
// /// }
// kFileLineColLoc = 11,
def FileLineColLoc : DialectAttribute<(attr
StringAttr:$filename,
VarInt:$line,
VarInt:$column
)>;
let cType = "FusedLoc",
cBuilder = "cast<FusedLoc>(get<FusedLoc>(context, $_args))" in {
// /// FusedLoc {
// /// locations: Location[]
// /// }
// kFusedLoc = 12,
def FusedLoc : DialectAttribute<(attr
Array<Location>:$locations
)> {
let printerPredicate = "!$_val.getMetadata()";
}
// /// FusedLocWithMetadata {
// /// locations: LocationAttr[],
// /// metadata: Attribute
// /// }
// /// A variant of FusedLoc with metadata.
// kFusedLocWithMetadata = 13,
def FusedLocWithMetadata : DialectAttribute<(attr
Array<Location>:$locations,
Attribute:$metadata
)> {
let printerPredicate = "$_val.getMetadata()";
}
}
// /// NameLoc {
// /// name: StringAttr,
// /// childLoc: LocationAttr
// /// }
// kNameLoc = 14,
def NameLoc : DialectAttribute<(attr
StringAttr:$name,
LocationAttr:$childLoc
)>;
// /// UnknownLoc {
// /// }
// kUnknownLoc = 15,
def UnknownLoc : DialectAttribute<(attr)>;
// /// DenseResourceElementsAttr {
// /// type: ShapedType,
// /// handle: ResourceHandle
// /// }
// kDenseResourceElementsAttr = 16,
def DenseResourceElementsAttr : DialectAttribute<(attr
ShapedType:$type,
ResourceHandle<"DenseResourceElementsHandle">:$rawHandle
)> {
// Note: order of serialization does not match order of builder.
let cBuilder = "get<$_resultType>(context, type, *rawHandle)";
}
let cType = "RankedTensorType" in {
// /// RankedTensorType {
// /// shape: svarint[],
// /// elementType: Type,
// /// }
// ///
// kRankedTensorType = 13,
def RankedTensorType : DialectType<(type
Array<SignedVarInt>:$shape,
Type:$elementType
)> {
let printerPredicate = "!$_val.getEncoding()";
}
// /// RankedTensorTypeWithEncoding {
// /// encoding: Attribute,
// /// shape: svarint[],
// /// elementType: Type
// /// }
// /// Variant of RankedTensorType with an encoding.
// kRankedTensorTypeWithEncoding = 14,
def RankedTensorTypeWithEncoding : DialectType<(type
Attribute:$encoding,
Array<SignedVarInt>:$shape,
Type:$elementType
)> {
let printerPredicate = "$_val.getEncoding()";
// Note: order of serialization does not match order of builder.
let cBuilder = "get<$_resultType>(context, shape, elementType, encoding)";
}
}
// /// DenseArrayAttr {
// /// elementType: Type,
// /// size: varint
// /// data: blob
// /// }
// kDenseArrayAttr = 17,
def DenseArrayAttr : DialectAttribute<(attr
Type:$elementType,
VarInt:$size,
Blob:$rawData
)>;
// /// DenseIntOrFPElementsAttr {
// /// type: ShapedType,
// /// data: blob
// /// }
// kDenseIntOrFPElementsAttr = 18,
def DenseElementsAttr : WithType<"DenseIntElementsAttr", Attribute>;
def DenseIntOrFPElementsAttr : DialectAttribute<(attr
ShapedType:$type,
Blob:$rawData
)> {
let cBuilder = "cast<$_resultType>($_resultType::getFromRawBuffer($_args))";
}
// /// DenseStringElementsAttr {
// /// type: ShapedType,
// /// isSplat: varint,
// /// data: string[]
// /// }
// kDenseStringElementsAttr = 19,
def DenseStringElementsAttr : DialectAttribute<(attr
ShapedType:$type,
WithGetter<"$_attrType.isSplat()", VarInt>:$_isSplat,
WithBuilder<"$_args",
WithType<"SmallVector<StringRef>",
WithParser <"succeeded(readPotentiallySplatString($_reader, type, _isSplat, $_var))",
WithPrinter<"writePotentiallySplatString($_writer, $_name)">>>>:$rawStringData
)>;
// /// SparseElementsAttr {
// /// type: ShapedType,
// /// indices: DenseIntElementsAttr,
// /// values: DenseElementsAttr
// /// }
// kSparseElementsAttr = 20,
def DenseIntElementsAttr : WithType<"DenseIntElementsAttr", Attribute>;
def SparseElementsAttr : DialectAttribute<(attr
ShapedType:$type,
DenseIntElementsAttr:$indices,
DenseElementsAttr:$values
)>;
// Types
// -----
// enum TypeCode {
// /// IntegerType {
// /// widthAndSignedness: varint // (width << 2) | (signedness)
// /// }
// ///
// kIntegerType = 0,
def IntegerType : DialectType<(type
// Yes not pretty,
WithParser<"succeeded($_reader.readVarInt($_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeVarInt($_name.getWidth() << 2 | $_name.getSignedness())",
WithType <"uint64_t">>>>:$_widthAndSignedness,
// Split up parsed varint for create method.
LocalVar<"uint64_t", "_widthAndSignedness >> 2">:$width,
LocalVar<"IntegerType::SignednessSemantics",
"static_cast<IntegerType::SignednessSemantics>(_widthAndSignedness & 0x3)">:$signedness
)>;
//
// /// IndexType {
// /// }
// ///
// kIndexType = 1,
def IndexType : DialectType<(type)>;
// /// FunctionType {
// /// inputs: Type[],
// /// results: Type[]
// /// }
// ///
// kFunctionType = 2,
def FunctionType : DialectType<(type
Array<Type>:$inputs,
Array<Type>:$results
)>;
// /// BFloat16Type {
// /// }
// ///
// kBFloat16Type = 3,
def BFloat16Type : DialectType<(type)>;
// /// Float16Type {
// /// }
// ///
// kFloat16Type = 4,
def Float16Type : DialectType<(type)>;
// /// Float32Type {
// /// }
// ///
// kFloat32Type = 5,
def Float32Type : DialectType<(type)>;
// /// Float64Type {
// /// }
// ///
// kFloat64Type = 6,
def Float64Type : DialectType<(type)>;
// /// Float80Type {
// /// }
// ///
// kFloat80Type = 7,
def Float80Type : DialectType<(type)>;
// /// Float128Type {
// /// }
// ///
// kFloat128Type = 8,
def Float128Type : DialectType<(type)>;
// /// ComplexType {
// /// elementType: Type
// /// }
// ///
// kComplexType = 9,
def ComplexType : DialectType<(type
Type:$elementType
)>;
def MemRefLayout: WithType<"MemRefLayoutAttrInterface", Attribute>;
let cType = "MemRefType" in {
// /// MemRefType {
// /// shape: svarint[],
// /// elementType: Type,
// /// layout: Attribute
// /// }
// ///
// kMemRefType = 10,
def MemRefType : DialectType<(type
Array<SignedVarInt>:$shape,
Type:$elementType,
MemRefLayout:$layout
)> {
let printerPredicate = "!$_val.getMemorySpace()";
}
// /// MemRefTypeWithMemSpace {
// /// memorySpace: Attribute,
// /// shape: svarint[],
// /// elementType: Type,
// /// layout: Attribute
// /// }
// /// Variant of MemRefType with non-default memory space.
// kMemRefTypeWithMemSpace = 11,
def MemRefTypeWithMemSpace : DialectType<(type
Attribute:$memorySpace,
Array<SignedVarInt>:$shape,
Type:$elementType,
MemRefLayout:$layout
)> {
let printerPredicate = "!!$_val.getMemorySpace()";
// Note: order of serialization does not match order of builder.
let cBuilder = "get<$_resultType>(context, shape, elementType, layout, memorySpace)";
}
}
// /// NoneType {
// /// }
// ///
// kNoneType = 12,
def NoneType : DialectType<(type)>;
// /// TupleType {
// /// elementTypes: Type[]
// /// }
// kTupleType = 15,
def TupleType : DialectType<(type
Array<Type>:$types
)>;
let cType = "UnrankedMemRefType" in {
// /// UnrankedMemRefType {
// /// elementType: Type
// /// }
// ///
// kUnrankedMemRefType = 16,
def UnrankedMemRefType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.getMemorySpace()";
let cBuilder = "get<$_resultType>(context, elementType, Attribute())";
}
// /// UnrankedMemRefTypeWithMemSpace {
// /// memorySpace: Attribute,
// /// elementType: Type
// /// }
// /// Variant of UnrankedMemRefType with non-default memory space.
// kUnrankedMemRefTypeWithMemSpace = 17,
def UnrankedMemRefTypeWithMemSpace : DialectType<(type
Attribute:$memorySpace,
Type:$elementType
)> {
let printerPredicate = "$_val.getMemorySpace()";
// Note: order of serialization does not match order of builder.
let cBuilder = "get<$_resultType>(context, elementType, memorySpace)";
}
}
// /// UnrankedTensorType {
// /// elementType: Type
// /// }
// ///
// kUnrankedTensorType = 18,
def UnrankedTensorType : DialectType<(type
Type:$elementType
)>;
let cType = "VectorType" in {
// /// VectorType {
// /// shape: svarint[],
// /// elementType: Type
// /// }
// ///
// kVectorType = 19,
def VectorType : DialectType<(type
Array<SignedVarInt>:$shape,
Type:$elementType
)> {
let printerPredicate = "!$_val.getNumScalableDims()";
}
// /// VectorTypeWithScalableDims {
// /// numScalableDims: varint,
// /// shape: svarint[],
// /// elementType: Type
// /// }
// /// Variant of VectorType with scalable dimensions.
// kVectorTypeWithScalableDims = 20,
def VectorTypeWithScalableDims : DialectType<(type
VarInt:$numScalableDims,
Array<SignedVarInt>:$shape,
Type:$elementType
)> {
let printerPredicate = "$_val.getNumScalableDims()";
// Note: order of serialization does not match order of builder.
let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims)";
}
}
/// This enum contains marker codes used to indicate which attribute is
/// currently being decoded, and how it should be decoded. The order of these
/// codes should generally be unchanged, as any changes will inevitably break
/// compatibility with older bytecode.
def BuiltinDialectAttributes : DialectAttributes<"Builtin"> {
let elems = [
ArrayAttr,
DictionaryAttr,
StringAttr,
StringAttrWithType,
FlatSymbolRefAttr,
SymbolRefAttr,
TypeAttr,
UnitAttr,
IntegerAttr,
FloatAttr,
CallSiteLoc,
FileLineColLoc,
FusedLoc,
FusedLocWithMetadata,
NameLoc,
UnknownLoc,
DenseResourceElementsAttr,
DenseArrayAttr,
DenseIntOrFPElementsAttr,
DenseStringElementsAttr,
SparseElementsAttr
];
}
def BuiltinDialectTypes : DialectTypes<"Builtin"> {
let elems = [
IntegerType,
IndexType,
FunctionType,
BFloat16Type,
Float16Type,
Float32Type,
Float64Type,
Float80Type,
Float128Type,
ComplexType,
MemRefType,
MemRefTypeWithMemSpace,
NoneType,
RankedTensorType,
RankedTensorTypeWithEncoding,
TupleType,
UnrankedMemRefType,
UnrankedMemRefTypeWithMemSpace,
UnrankedTensorType,
VectorType,
VectorTypeWithScalableDims
];
}
#endif // BUILTIN_BYTECODE

View File

@ -0,0 +1,159 @@
//===-- BytecodeBase.td - Base bytecode R/W defs -----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the base bytecode reader/writer definition file.
//
//===----------------------------------------------------------------------===//
#ifndef BYTECODE_BASE
#define BYTECODE_BASE
class Bytecode<string parse="", string build="", string print="", string get="", string t=""> {
// Template for parsing.
// $_reader == dialect bytecode reader
// $_resultType == result type of parsed instance
// $_var == variable being parsed
// If parser is not specified, then the parse of members is used.
string cParser = parse;
// Template for building from parsed.
// $_resultType == result type of parsed instance
// $_args == args/members comma separated
string cBuilder = build;
// Template for printing.
// $_writer == dialect bytecode writer
// $_name == parent attribute/type name
// $_getter == getter
string cPrinter = print;
// Template for getter from in memory form.
// $_attrType == attribute/type
// $_member == member instance
// $_getMember == get + UpperCamelFromSnake($_member)
string cGetter = get;
// Type built.
// Note: if cType is empty, then name of def is used.
string cType = t;
// Predicate guarding parse method as an Attribute/Type could have multiple
// parse methods, specify predicates to be orthogonal and cover entire
// "print space" to avoid order dependence.
// If empty then method is unconditional.
// $_val == predicate function to apply on value dyn_casted to cType.
string printerPredicate = "";
}
class WithParser<string p="", Bytecode t=Bytecode<>> :
Bytecode<p, t.cBuilder, t.cPrinter, t.cGetter, t.cType>;
class WithBuilder<string b="", Bytecode t=Bytecode<>> :
Bytecode<t.cParser, b, t.cPrinter, t.cGetter, t.cType>;
class WithPrinter<string p="", Bytecode t=Bytecode<>> :
Bytecode<t.cParser, t.cBuilder, p, t.cGetter, t.cType>;
class WithType<string ty="", Bytecode t=Bytecode<>> :
Bytecode<t.cParser, t.cBuilder, t.cPrinter, t.cGetter, ty>;
class WithGetter<string g="", Bytecode t=Bytecode<>> :
Bytecode<t.cParser, t.cBuilder, t.cPrinter, g, t.cType>;
class CompositeBytecode<string t = ""> : WithType<t>;
class AttributeKind :
WithParser <"succeeded($_reader.readAttribute($_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeAttribute($_getter)">>>;
def Attribute : AttributeKind;
class TypeKind :
WithParser <"succeeded($_reader.readType($_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeType($_getter)">>>;
def Type : TypeKind;
def VarInt :
WithParser <"succeeded($_reader.readVarInt($_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeVarInt($_getter)",
WithType <"uint64_t">>>>;
def SignedVarInt :
WithParser <"succeeded($_reader.readSignedVarInt($_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeSignedVarInt($_getter)",
WithGetter<"$_attrType",
WithType <"int64_t">>>>>;
def Blob :
WithParser <"succeeded($_reader.readBlob($_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeOwnedBlob($_getter)",
WithType <"ArrayRef<char>">>>>;
class KnownWidthAPInt<string s> :
WithParser <"succeeded(readAPIntWithKnownWidth($_reader, " # s # ", $_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeAPIntWithKnownWidth($_getter)",
WithType <"FailureOr<APInt>">>>>;
class KnownSemanticsAPFloat<string s> :
WithParser <"succeeded(readAPFloatWithKnownSemantics($_reader, " # s # ", $_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeAPFloatWithKnownSemantics($_getter)",
WithType <"FailureOr<APFloat>">>>>;
class ResourceHandle<string s> :
WithParser <"succeeded(readResourceHandle<" # s # ">($_reader, $_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeResourceHandle($_getter)",
WithType <"FailureOr<" # s # ">">>>>;
// Helper to define variable that is defined later but not parsed nor printed.
class LocalVar<string t, string d> :
WithParser <"(($_var = " # d # "), true)",
WithBuilder<"$_args",
WithPrinter<"",
WithType <t>>>>;
// Array instances.
class Array<Bytecode t> {
Bytecode elemT = t;
string cBuilder = "$_args";
}
// Define dialect attribute or type.
class DialectAttrOrType<dag d> {
// Any members starting with underscore is not fed to create function but
// treated as purely local variable.
dag members = d;
// When needing to specify a custom return type.
string cType = "";
// Any post-processing that needs to be done.
code postProcess = "";
}
class DialectAttribute<dag d> : DialectAttrOrType<d>, AttributeKind {
let cParser = "succeeded($_reader.readAttribute<$_resultType>($_var))";
let cBuilder = "get<$_resultType>(context, $_args)";
}
class DialectType<dag d> : DialectAttrOrType<d>, TypeKind {
let cParser = "succeeded($_reader.readType<$_resultType>($_var))";
let cBuilder = "get<$_resultType>(context, $_args)";
}
class DialectAttributes<string d> {
string dialect = d;
list<DialectAttrOrType> elems;
}
class DialectTypes<string d> {
string dialect = d;
list<DialectAttrOrType> elems;
}
def attr;
def type;
#endif // BYTECODE_BASE

View File

@ -17,6 +17,10 @@ mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls)
mlir_tablegen(BuiltinDialect.cpp.inc -gen-dialect-defs)
add_public_tablegen_target(MLIRBuiltinDialectIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinDialectBytecode.td)
mlir_tablegen(BuiltinDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Builtin")
add_public_tablegen_target(MLIRBuiltinDialectBytecodeIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinLocationAttributes.td)
mlir_tablegen(BuiltinLocationAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(BuiltinLocationAttributes.cpp.inc -gen-attrdef-defs)

File diff suppressed because it is too large Load Diff

View File

@ -44,6 +44,7 @@ add_mlir_library(MLIRIR
DEPENDS
MLIRBuiltinAttributesIncGen
MLIRBuiltinAttributeInterfacesIncGen
MLIRBuiltinDialectBytecodeIncGen
MLIRBuiltinDialectIncGen
MLIRBuiltinLocationAttributesIncGen
MLIRBuiltinOpsIncGen

View File

@ -0,0 +1,467 @@
//===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include <regex>
using namespace llvm;
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
static llvm::cl::opt<std::string>
selectedBcDialect("bytecode-dialect",
llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
namespace {
/// Helper class to generate C++ bytecode parser helpers.
class Generator {
public:
Generator(raw_ostream &output) : output(output) {}
/// Returns whether successfully emitted attribute/type parsers.
void emitParse(StringRef kind, Record &x);
/// Returns whether successfully emitted attribute/type printers.
void emitPrint(StringRef kind, StringRef type,
ArrayRef<std::pair<int64_t, Record *>> vec);
/// Emits parse dispatch table.
void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec);
/// Emits print dispatch table.
void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
private:
/// Emits parse calls to construct given kind.
void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
ArrayRef<Init *> args, ArrayRef<std::string> argNames,
StringRef failure, mlir::raw_indented_ostream &ios);
/// Emits print instructions.
void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent,
StringRef name, mlir::raw_indented_ostream &ios);
raw_ostream &output;
};
} // namespace
/// Helper to replace set of from strings to target in `s`.
/// Assumed: non-overlapping replacements.
static std::string format(StringRef templ,
std::map<std::string, std::string> &&map) {
std::string s = templ.str();
for (const auto &[from, to] : map)
// All replacements start with $, don't treat as anchor.
s = std::regex_replace(s, std::regex("\\" + from), to);
return s;
}
/// Return string with first character capitalized.
static std::string capitalize(StringRef str) {
return ((Twine)toUpper(str[0]) + str.drop_front()).str();
}
/// Return the C++ type for the given record.
static std::string getCType(Record *def) {
std::string format = "{0}";
if (def->isSubClassOf("Array")) {
def = def->getValueAsDef("elemT");
format = "SmallVector<{0}>";
}
StringRef cType = def->getValueAsString("cType");
if (cType.empty()) {
if (def->isAnonymous())
PrintFatalError(def->getLoc(), "Unable to determine cType");
return formatv(format.c_str(), def->getName().str());
}
return formatv(format.c_str(), cType.str());
}
void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
mlir::raw_indented_ostream os(output);
char const *head =
R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
os << formatv(head, capitalize(kind));
auto funScope = os.scope(" {\n", "}\n\n");
os << "uint64_t kind;\n";
os << "if (failed(reader.readVarInt(kind)))\n"
<< " return " << capitalize(kind) << "();\n";
os << "switch (kind) ";
{
auto switchScope = os.scope("{\n", "}\n");
for (const auto &it : llvm::enumerate(vec)) {
os << formatv("case {1}:\n return read{0}(context, reader);\n",
it.value()->getName(), it.index());
}
os << "default:\n"
<< " reader.emitError() << \"unknown attribute code: \" "
<< "<< kind;\n"
<< " return " << capitalize(kind) << "();\n";
}
os << "return " << capitalize(kind) << "();\n";
}
void Generator::emitParse(StringRef kind, Record &x) {
char const *head =
R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
mlir::raw_indented_ostream os(output);
std::string returnType = getCType(&x);
os << formatv(head, returnType, x.getName());
DagInit *members = x.getValueAsDag("members");
SmallVector<std::string> argNames =
llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
return init->getAsUnquotedString();
}));
StringRef builder = x.getValueAsString("cBuilder");
emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
returnType + "()", os);
os << "\n\n";
}
void printParseConditional(mlir::raw_indented_ostream &ios,
ArrayRef<Init *> args,
ArrayRef<std::string> argNames) {
ios << "if ";
auto parenScope = ios.scope("(", ") {");
ios.indent();
auto listHelperName = [](StringRef name) {
return formatv("read{0}", capitalize(name));
};
auto parsedArgs =
llvm::to_vector(make_filter_range(args, [](Init *const attr) {
Record *def = cast<DefInit>(attr)->getDef();
if (def->isSubClassOf("Array"))
return true;
return !def->getValueAsString("cParser").empty();
}));
interleave(
zip(parsedArgs, argNames),
[&](std::tuple<llvm::Init *&, const std::string &> it) {
Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
std::string parser;
if (auto optParser = attr->getValueAsOptionalString("cParser")) {
parser = *optParser;
} else if (attr->isSubClassOf("Array")) {
Record *def = attr->getValueAsDef("elemT");
bool composite = def->isSubClassOf("CompositeBytecode");
if (!composite && def->isSubClassOf("AttributeKind"))
parser = "succeeded($_reader.readAttributes($_var))";
else if (!composite && def->isSubClassOf("TypeKind"))
parser = "succeeded($_reader.readTypes($_var))";
else
parser = ("succeeded($_reader.readList($_var, " +
listHelperName(std::get<1>(it)) + "))")
.str();
} else {
PrintFatalError(attr->getLoc(), "No parser specified");
}
std::string type = getCType(attr);
ios << format(parser, {{"$_reader", "reader"},
{"$_resultType", type},
{"$_var", std::get<1>(it)}});
},
[&]() { ios << " &&\n"; });
}
void Generator::emitParseHelper(StringRef kind, StringRef returnType,
StringRef builder, ArrayRef<Init *> args,
ArrayRef<std::string> argNames,
StringRef failure,
mlir::raw_indented_ostream &ios) {
auto funScope = ios.scope("{\n", "}");
if (args.empty()) {
ios << formatv("return get<{0}>(context);\n", returnType);
return;
}
// Print decls.
std::string lastCType = "";
for (auto [arg, name] : zip(args, argNames)) {
DefInit *first = dyn_cast<DefInit>(arg);
if (!first)
PrintFatalError("Unexpected type for " + name);
Record *def = first->getDef();
// Create variable decls, if there are a block of same type then create
// comma separated list of them.
std::string cType = getCType(def);
if (lastCType == cType) {
ios << ", ";
} else {
if (!lastCType.empty())
ios << ";\n";
ios << cType << " ";
}
ios << name;
lastCType = cType;
}
ios << ";\n";
// Returns the name of the helper used in list parsing. E.g., the name of the
// lambda passed to array parsing.
auto listHelperName = [](StringRef name) {
return formatv("read{0}", capitalize(name));
};
// Emit list helper functions.
for (auto [arg, name] : zip(args, argNames)) {
Record *attr = cast<DefInit>(arg)->getDef();
if (!attr->isSubClassOf("Array"))
continue;
// TODO: Dedupe readers.
Record *def = attr->getValueAsDef("elemT");
if (!def->isSubClassOf("CompositeBytecode") &&
(def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind")))
continue;
std::string returnType = getCType(def);
ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
<< returnType << "> ";
SmallVector<Init *> args;
SmallVector<std::string> argNames;
if (def->isSubClassOf("CompositeBytecode")) {
DagInit *members = def->getValueAsDag("members");
args = llvm::to_vector(members->getArgs());
argNames = llvm::to_vector(
map_range(members->getArgNames(), [](StringInit *init) {
return init->getAsUnquotedString();
}));
} else {
args = {def->getDefInit()};
argNames = {"temp"};
}
StringRef builder = def->getValueAsString("cBuilder");
emitParseHelper(kind, returnType, builder, args, argNames, "failure()",
ios);
ios << ";\n";
}
// Print parse conditional.
printParseConditional(ios, args, argNames);
// Compute args to pass to create method.
auto passedArgs = llvm::to_vector(make_filter_range(
argNames, [](StringRef str) { return !str.starts_with("_"); }));
std::string argStr;
raw_string_ostream argStream(argStr);
interleaveComma(passedArgs, argStream,
[&](const std::string &str) { argStream << str; });
// Return the invoked constructor.
ios << "\nreturn "
<< format(builder, {{"$_resultType", returnType.str()},
{"$_args", argStream.str()}})
<< ";\n";
ios.unindent();
// TODO: Emit error in debug.
// This assumes the result types in error case can always be empty
// constructed.
ios << "}\nreturn " << failure << ";\n";
}
void Generator::emitPrint(StringRef kind, StringRef type,
ArrayRef<std::pair<int64_t, Record *>> vec) {
char const *head =
R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
mlir::raw_indented_ostream os(output);
os << formatv(head, type, kind);
auto funScope = os.scope("{\n", "}\n\n");
// Check that predicates specified if multiple bytecode instances.
for (llvm::Record *rec : make_second_range(vec)) {
StringRef pred = rec->getValueAsString("printerPredicate");
if (vec.size() > 1 && pred.empty()) {
for (auto [index, rec] : vec) {
(void)index;
StringRef pred = rec->getValueAsString("printerPredicate");
if (vec.size() > 1 && pred.empty())
PrintError(rec->getLoc(),
"Requires parsing predicate given common cType");
}
PrintFatalError("Unspecified for shared cType " + type);
}
}
for (auto [index, rec] : vec) {
StringRef pred = rec->getValueAsString("printerPredicate");
if (!pred.empty()) {
os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n";
os.indent();
}
os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
<< ");\n";
auto *members = rec->getValueAsDag("members");
for (auto [arg, name] :
llvm::zip(members->getArgs(), members->getArgNames())) {
DefInit *def = dyn_cast<DefInit>(arg);
assert(def);
Record *memberRec = def->getDef();
emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
}
if (!pred.empty()) {
os.unindent();
os << "}\n";
}
}
}
void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
StringRef parent, StringRef name,
mlir::raw_indented_ostream &ios) {
std::string getter;
if (auto cGetter = memberRec->getValueAsOptionalString("cGetter");
cGetter && !cGetter->empty()) {
getter = format(
*cGetter,
{{"$_attrType", parent.str()},
{"$_member", name.str()},
{"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}});
} else {
getter =
formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true))
.str();
}
if (memberRec->isSubClassOf("Array")) {
Record *def = memberRec->getValueAsDef("elemT");
if (!def->isSubClassOf("CompositeBytecode")) {
if (def->isSubClassOf("AttributeKind")) {
ios << "writer.writeAttributes(" << getter << ");\n";
return;
}
if (def->isSubClassOf("TypeKind")) {
ios << "writer.writeTypes(" << getter << ");\n";
return;
}
}
std::string returnType = getCType(def);
ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
<< kind << ") ";
auto lambdaScope = ios.scope("{\n", "});\n");
return emitPrintHelper(def, kind, kind, kind, ios);
}
if (memberRec->isSubClassOf("CompositeBytecode")) {
auto *members = memberRec->getValueAsDag("members");
for (auto [arg, argName] :
zip(members->getArgs(), members->getArgNames())) {
DefInit *def = dyn_cast<DefInit>(arg);
assert(def);
emitPrintHelper(def->getDef(), kind, parent,
argName->getAsUnquotedString(), ios);
}
}
if (std::string printer = memberRec->getValueAsString("cPrinter").str();
!printer.empty())
ios << format(printer, {{"$_writer", "writer"},
{"$_name", kind.str()},
{"$_getter", getter}})
<< ";\n";
}
void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
mlir::raw_indented_ostream os(output);
char const *head = R"(static LogicalResult write{0}({0} {1},
DialectBytecodeWriter &writer))";
os << formatv(head, capitalize(kind), kind);
auto funScope = os.scope(" {\n", "}\n\n");
os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind
<< ")";
auto switchScope = os.scope("", "");
for (StringRef type : vec) {
os << "\n.Case([&](" << type << " t)";
auto caseScope = os.scope(" {\n", "})");
os << "return write(t, writer), success();\n";
}
os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n";
}
namespace {
/// Container of Attribute or Type for Dialect.
struct AttrOrType {
std::vector<Record *> attr, type;
};
} // namespace
static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
MapVector<StringRef, AttrOrType> dialectAttrOrType;
for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) {
if (!selectedBcDialect.empty() &&
it->getValueAsString("dialect") != selectedBcDialect)
continue;
dialectAttrOrType[it->getValueAsString("dialect")].attr =
it->getValueAsListOfDefs("elems");
}
for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) {
if (!selectedBcDialect.empty() &&
it->getValueAsString("dialect") != selectedBcDialect)
continue;
dialectAttrOrType[it->getValueAsString("dialect")].type =
it->getValueAsListOfDefs("elems");
}
if (dialectAttrOrType.size() != 1)
PrintFatalError("Single dialect per invocation required (either only "
"one in input file or specified via dialect option)");
auto it = dialectAttrOrType.front();
Generator gen(os);
SmallVector<std::vector<Record *> *, 2> vecs;
SmallVector<std::string, 2> kinds;
vecs.push_back(&it.second.attr);
kinds.push_back("attribute");
vecs.push_back(&it.second.type);
kinds.push_back("type");
for (auto [vec, kind] : zip(vecs, kinds)) {
// Handle Attribute/Type emission.
std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType;
for (auto kt : llvm::enumerate(*vec))
perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
for (const auto &jt : perType) {
for (auto kt : jt.second)
gen.emitParse(kind, *std::get<1>(kt));
gen.emitPrint(kind, jt.first, jt.second);
}
gen.emitParseDispatch(kind, *vec);
SmallVector<std::string> types;
for (const auto &it : perType) {
types.push_back(it.first);
}
gen.emitPrintDispatch(kind, types);
}
return false;
}
static mlir::GenRegistration
genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
[](const RecordKeeper &records, raw_ostream &os) {
return emitBCRW(records, os);
});

View File

@ -9,6 +9,7 @@ add_tablegen(mlir-tblgen MLIR
EXPORT MLIR
AttrOrTypeDefGen.cpp
AttrOrTypeFormatGen.cpp
BytecodeDialectGen.cpp
DialectGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp

View File

@ -91,6 +91,7 @@ td_library(
],
includes = ["include"],
deps = [
":BytecodeTdFiles",
":CallInterfacesTdFiles",
":CastInterfacesTdFiles",
":DataLayoutInterfacesTdFiles",
@ -117,6 +118,20 @@ gentbl_cc_library(
deps = [":BuiltinDialectTdFiles"],
)
gentbl_cc_library(
name = "BuiltinDialectBytecodeGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-bytecode", "-bytecode-dialect=Builtin"],
"include/mlir/IR/BuiltinDialectBytecode.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/IR/BuiltinDialectBytecode.td",
deps = [":BuiltinDialectTdFiles"],
)
gentbl_cc_library(
name = "BuiltinAttributesIncGen",
strip_include_prefix = "include",
@ -277,6 +292,7 @@ cc_library(
deps = [
":BuiltinAttributeInterfacesIncGen",
":BuiltinAttributesIncGen",
":BuiltinDialectBytecodeGen",
":BuiltinDialectIncGen",
":BuiltinLocationAttributesIncGen",
":BuiltinOpsIncGen",
@ -929,6 +945,12 @@ td_library(
],
)
td_library(
name = "BytecodeTdFiles",
srcs = ["include/mlir/IR/BytecodeBase.td"],
includes = ["include"],
)
td_library(
name = "CallInterfacesTdFiles",
srcs = ["include/mlir/Interfaces/CallInterfaces.td"],