[ODS] Extra Concrete Declarations and Definitions under Traits

Support extra concrete class declarations and definitions under NativeTrait that get injected into the class that specifies the trait. Extra declarations and definitions can be passed in as template arguments for NativeOpTraitNativeAttrTrait and NativeTypeTrait.

Usage examples of this feature include:

- Creating a wrapper Trait for authoring inferReturnTypes with the OpAdaptor by specifying necessary Op specific declarations and definitions directly in the trait
- Refactoring the InferTensorType trait

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D154731
This commit is contained in:
Amanda Tang 2023-07-12 08:14:16 -07:00 committed by Jacques Pienaar
parent bb3e8e90f1
commit 47b0a9b931
14 changed files with 200 additions and 72 deletions

View File

@ -100,6 +100,22 @@ Note: It is generally good practice to define the implementation of the
`foldTrait` hook out-of-line as a free function when possible to avoid
instantiating the implementation for every concrete operation type.
### Extra Declarations and Definitions
A trait may require additional declarations and definitions directly on
the Operation, Attribute or Type instances which specify that trait.
The `extraConcreteClassDeclaration` and `extraConcreteClassDefinition`
fields under the `NativeTrait` class are mechanisms designed for injecting
code directly into generated C++ Operation, Attribute or Type classes.
Code within the `extraConcreteClassDeclaration` field will be formatted and copied
into the generated C++ Operation, Attribute or Type class. Code within
`extraConcreteClassDefinition` will be added to the generated source file inside
the classs C++ namespace. The substitution `$cppClass` is replaced by the C++ class
name.
The intention is to group trait specific logic together and reduce
redundant extra declarations and definitions on the instances themselves.
### Parametric Traits
The above demonstrates the definition of a simple self-contained trait. It is

View File

@ -903,7 +903,7 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
Pure,
SameVariadicResultSize,
ViewLikeOpInterface,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
InferTypeOpInterfaceAdaptor]> {
let summary = "Extracts a buffer base with offset and strides";
let description = [{
Extracts a base buffer, offset and strides. This op allows additional layers

View File

@ -21,7 +21,14 @@ include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// These classes are used to define attribute specific traits.
class NativeAttrTrait<string name> : NativeTrait<name, "Attribute">;
// Specify attribute specific declarations and definitions in `extraAttrDeclaration`
// and `extraAttrDefinition` template arguments.
class NativeAttrTrait<string name,
code extraAttrDeclaration = [{}],
code extraAttrDefinition = [{}]>
: NativeTrait<name, "Attribute", extraAttrDeclaration, extraAttrDefinition>;
class ParamNativeAttrTrait<string prop, string params>
: ParamNativeTrait<prop, params, "Attribute">;
class GenInternalAttrTrait<string prop> : GenInternalTrait<prop, "Attribute">;
@ -32,7 +39,14 @@ class PredAttrTrait<string descr, Pred pred> : PredTrait<descr, pred>;
//===----------------------------------------------------------------------===//
// These classes are used to define type specific traits.
class NativeTypeTrait<string name> : NativeTrait<name, "Type">;
// Specify type specific declarations and definitions in `extraTypeDeclaration`
// and `extraTypeDefinition` template arguments.
class NativeTypeTrait<string name,
code extraTypeDeclaration = [{}],
code extraTypeDefinition = [{}]>
: NativeTrait<name, "Type", extraTypeDeclaration, extraTypeDefinition>;
class ParamNativeTypeTrait<string prop, string params>
: ParamNativeTrait<prop, params, "Type">;
class GenInternalTypeTrait<string prop> : GenInternalTrait<prop, "Type">;

View File

@ -1958,9 +1958,16 @@ class TraitList<list<Trait> props> : Trait {
// NativeTrait corresponds to the MLIR C++ trait mechanism. The purpose to wrap
// around C++ symbol string with this class is to make traits specified for
// entities in TableGen less alien and more integrated.
class NativeTrait<string name, string entityType> : Trait {
// `extraConcreteClassDeclaration` and `extraConcreteClassDefinition` code
// get injected into the entities in which the NativeTrait is specified for.
class NativeTrait<string name, string entityType,
code extraClassDeclaration = [{}],
code extraClassDefinition = [{}]> : Trait {
string trait = name;
string cppNamespace = "::mlir::" # entityType # "Trait";
code extraConcreteClassDeclaration = extraClassDeclaration;
code extraConcreteClassDefinition = extraClassDefinition;
}
// ParamNativeTrait corresponds to the template-parameterized traits in the C++
@ -1993,8 +2000,13 @@ class PredTrait<string descr, Pred pred> : Trait {
class StructuralOpTrait;
// These classes are used to define operation specific traits.
class NativeOpTrait<string name, list<Trait> traits = []>
: NativeTrait<name, "Op"> {
// Specify op specific declarations and definitions in `extraOpDeclaration`
// and `extraOpDefinition` template arguments.
class NativeOpTrait<string name, list<Trait> traits = [],
code extraOpDeclaration = [{}],
code extraOpDefinition = [{}]>
: NativeTrait<name, "Op", extraOpDeclaration, extraOpDefinition> {
// Specify the list of traits that need to be verified before the verification
// of this NativeOpTrait.
list<Trait> dependentTraits = traits;

View File

@ -237,19 +237,9 @@ private:
namespace detail {
// Helper function to infer return tensor returns types given element and
// shape inference function.
//
// TODO: Consider generating typedefs for trait member functions if this usage
// becomes more common.
LogicalResult inferReturnTensorTypes(
function_ref<
LogicalResult(MLIRContext *, std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes);
LogicalResult
inferReturnTensorTypes(ArrayRef<ShapedTypeComponents> retComponents,
SmallVectorImpl<Type> &inferredReturnTypes);
/// Verifies that the inferred result types match the actual result types for
/// the op. Precondition: op implements InferTypeOpInterface.
@ -268,6 +258,10 @@ class InferTensorType;
namespace mlir {
namespace OpTrait {
template <typename ConcreteType>
class InferTypeOpInterfaceAdaptor
: public TraitBase<ConcreteType, InferTypeOpInterfaceAdaptor> {};
/// Tensor type inference trait that constructs a tensor from the inferred
/// shape and elemental types.
/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
@ -276,24 +270,7 @@ namespace OpTrait {
/// trait is currently only used where the interfaces are, so keep it
/// restricted for now).
template <typename ConcreteType>
class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
public:
static LogicalResult
inferReturnTypes(MLIRContext *context, std::optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
static_assert(
ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
"requires InferShapedTypeOpInterface to ensure succesful invocation");
static_assert(
ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
"requires InferTypeOpInterface to ensure succesful invocation");
return ::mlir::detail::inferReturnTensorTypes(
ConcreteType::inferReturnTypeComponents, context, location, operands,
attributes, properties, regions, inferredReturnTypes);
}
};
class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {};
} // namespace OpTrait
} // namespace mlir

View File

@ -184,18 +184,69 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
];
}
// Convenient trait to define a wrapper to inferReturnTypes that passes in the
// Op Adaptor directly
def InferTypeOpInterfaceAdaptor : TraitList<
[
// Op implements infer type op interface.
DeclareOpInterfaceMethods<InferTypeOpInterface>,
NativeOpTrait<
/*name=*/"InferTypeOpInterfaceAdaptor",
/*traits=*/[],
/*extraOpDeclaration=*/[{
static LogicalResult
inferReturnTypesAdaptor(MLIRContext *context,
std::optional<Location> location,
Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes);
}],
/*extraOpDefinition=*/[{
LogicalResult
$cppClass::inferReturnTypes(MLIRContext *context,
std::optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
$cppClass::Adaptor adaptor(operands, attributes, properties, regions);
return $cppClass::inferReturnTypesAdaptor(context,
location, adaptor, inferredReturnTypes);
}
}]
>
]>;
// Convenience class grouping together type and shaped type op interfaces for
// ops that have tensor return types.
class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
[
// Op implements infer type op interface.
InferTypeOpInterface,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
// The op will have methods implementing the ShapedType type inference
// interface.
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
// The op produces tensors and will use the ShapedType type infer interface
// along with knowledge that it is producing Tensors to infer the type.
NativeOpTrait<"InferTensorType">
NativeOpTrait<
/*name=*/"InferTensorType",
/*traits=*/[],
/*extraOpDeclaration=*/[{}],
/*extraOpDefinition=*/[{
LogicalResult
$cppClass::inferReturnTypes(MLIRContext *context,
std::optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
SmallVector<ShapedTypeComponents, 2> retComponents;
if (failed($cppClass::inferReturnTypeComponents(context, location,
operands, attributes, properties, regions,
retComponents)))
return failure();
return ::mlir::detail::inferReturnTensorTypes(retComponents,
inferredReturnTypes);
}
}]
>
]>;
def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>;

View File

@ -68,6 +68,14 @@ public:
// Returns if this is a structural op trait.
bool isStructuralOpTrait() const;
// Returns extra class declaration code to be added to the concrete instance
// when the trait is specified
StringRef getExtraConcreteClassDeclaration() const;
// Returns extra class definition code to be added to the concrete instance
// when the trait is specified
StringRef getExtraConcreteClassDefinition() const;
static bool classof(const Trait *t) { return t->getKind() == Kind::Native; }
};

View File

@ -1355,14 +1355,11 @@ void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
/// The number and type of the results are inferred from the
/// shape of the source.
LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
LogicalResult ExtractStridedMetadataOp::inferReturnTypesAdaptor(
MLIRContext *context, std::optional<Location> location,
ExtractStridedMetadataOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes,
properties);
auto sourceType =
llvm::dyn_cast<MemRefType>(extractAdaptor.getSource().getType());
auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
if (!sourceType)
return failure();

View File

@ -217,19 +217,8 @@ ShapeAdaptor ValueShapeRange::getShape(int index) const {
}
LogicalResult mlir::detail::inferReturnTensorTypes(
function_ref<
LogicalResult(MLIRContext *, std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
ArrayRef<ShapedTypeComponents> retComponents,
SmallVectorImpl<Type> &inferredReturnTypes) {
SmallVector<ShapedTypeComponents, 2> retComponents;
if (failed(componentTypeFn(context, location, operands, attributes,
properties, regions, retComponents)))
return failure();
for (const auto &shapeAndType : retComponents) {
Type elementTy = shapeAndType.getElementType();
assert(elementTy && "element type required to construct tensor");

View File

@ -54,6 +54,14 @@ bool NativeTrait::isStructuralOpTrait() const {
return def->isSubClassOf("StructuralOpTrait");
}
StringRef NativeTrait::getExtraConcreteClassDeclaration() const {
return def->getValueAsString("extraConcreteClassDeclaration");
}
StringRef NativeTrait::getExtraConcreteClassDefinition() const {
return def->getValueAsString("extraConcreteClassDefinition");
}
//===----------------------------------------------------------------------===//
// InternalTrait
//===----------------------------------------------------------------------===//

View File

@ -214,14 +214,42 @@ void DefGen::createParentWithTraits() {
defCls.addParent(std::move(defParent));
}
/// Include declarations specified on NativeTrait
static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
SmallVector<StringRef> extraDeclarations;
// Include extra class declarations from NativeTrait
for (const auto &trait : def.getTraits()) {
if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration();
if (value.empty())
continue;
extraDeclarations.push_back(value);
}
}
if (std::optional<StringRef> extraDecl = def.getExtraDecls()) {
extraDeclarations.push_back(*extraDecl);
}
return llvm::join(extraDeclarations, "\n");
}
/// Extra class definitions have a `$cppClass` substitution that is to be
/// replaced by the C++ class name.
static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
if (std::optional<StringRef> extraDef = def.getExtraDefs()) {
FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName());
return tgfmt(*extraDef, &ctx).str();
SmallVector<StringRef> extraDefinitions;
// Include extra class definitions from NativeTrait
for (const auto &trait : def.getTraits()) {
if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition();
if (value.empty())
continue;
extraDefinitions.push_back(value);
}
}
return "";
if (std::optional<StringRef> extraDef = def.getExtraDefs()) {
extraDefinitions.push_back(*extraDef);
}
FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName());
return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
}
void DefGen::emitTopLevelDeclarations() {
@ -230,9 +258,9 @@ void DefGen::emitTopLevelDeclarations() {
defCls.declare<UsingDeclaration>("Base::Base");
// Emit the extra declarations first in case there's a definition in there.
std::optional<StringRef> extraDecl = def.getExtraDecls();
std::string extraDecl = formatExtraDeclarations(def);
std::string extraDef = formatExtraDefinitions(def);
defCls.declare<ExtraClassDeclaration>(extraDecl ? *extraDecl : "",
defCls.declare<ExtraClassDeclaration>(std::move(extraDecl),
std::move(extraDef));
}

View File

@ -15,9 +15,10 @@ using namespace mlir::tblgen;
// OpClass definitions
//===----------------------------------------------------------------------===//
OpClass::OpClass(StringRef name, StringRef extraClassDeclaration,
OpClass::OpClass(StringRef name, std::string extraClassDeclaration,
std::string extraClassDefinition)
: Class(name.str()), extraClassDeclaration(extraClassDeclaration),
: Class(name.str()),
extraClassDeclaration(std::move(extraClassDeclaration)),
extraClassDefinition(std::move(extraClassDefinition)),
parent(addParent("::mlir::Op")) {
parent.addTemplateParam(getClassName().str());
@ -37,6 +38,5 @@ OpClass::OpClass(StringRef name, StringRef extraClassDeclaration,
void OpClass::finalize() {
Class::finalize();
declare<VisibilityDeclaration>(Visibility::Public);
declare<ExtraClassDeclaration>(extraClassDeclaration.str(),
extraClassDefinition);
declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
}

View File

@ -25,7 +25,7 @@ public:
/// - inheritance of `print`
/// - a type alias for the associated adaptor class
///
OpClass(StringRef name, StringRef extraClassDeclaration,
OpClass(StringRef name, std::string extraClassDeclaration,
std::string extraClassDefinition);
/// Add an op trait.
@ -39,7 +39,7 @@ public:
private:
/// Hand-written extra class declarations.
StringRef extraClassDeclaration;
std::string extraClassDeclaration;
/// Hand-written extra class definitions.
std::string extraClassDefinition;
/// The parent class, which also contains the traits to be inherited.

View File

@ -853,17 +853,45 @@ while (true) {{
emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
}
/// Include declarations specified on NativeTrait
static std::string formatExtraDeclarations(const Operator &op) {
SmallVector<StringRef> extraDeclarations;
// Include extra class declarations from NativeTrait
for (const auto &trait : op.getTraits()) {
if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
StringRef value = opTrait->getExtraConcreteClassDeclaration();
if (value.empty())
continue;
extraDeclarations.push_back(value);
}
}
extraDeclarations.push_back(op.getExtraClassDeclaration());
return llvm::join(extraDeclarations, "\n");
}
/// Op extra class definitions have a `$cppClass` substitution that is to be
/// replaced by the C++ class name.
/// Include declarations specified on NativeTrait
static std::string formatExtraDefinitions(const Operator &op) {
SmallVector<StringRef> extraDefinitions;
// Include extra class definitions from NativeTrait
for (const auto &trait : op.getTraits()) {
if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
StringRef value = opTrait->getExtraConcreteClassDefinition();
if (value.empty())
continue;
extraDefinitions.push_back(value);
}
}
extraDefinitions.push_back(op.getExtraClassDefinition());
FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName());
return tgfmt(op.getExtraClassDefinition(), &ctx).str();
return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
}
OpEmitter::OpEmitter(const Operator &op,
const StaticVerifierFunctionEmitter &staticVerifierEmitter)
: def(op.getDef()), op(op),
opClass(op.getCppClassName(), op.getExtraClassDeclaration(),
opClass(op.getCppClassName(), formatExtraDeclarations(op),
formatExtraDefinitions(op)),
staticVerifierEmitter(staticVerifierEmitter),
emitHelper(op, /*emitForOp=*/true) {