mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-24 04:26:07 +00:00
[mlir] Extend the promise interface mechanism
This patch pairs a promised interface with the object (Op/Attr/Type/Dialect) requesting the promise, ie: ``` declarePromisedInterface<MyAttr, MyInterface>(); ``` Allowing to make fine grained promises. It also adds a mechanism to query if `Op/Attr/Type` has an specific promise returning true if the promise is there or if an implementation has been added. Finally it adds a couple of `Attr|TypeConstraints` that can be used in ODS to query if the promise or an implementation is there. This patch tries to solve 2 issues: 1. Different entities cannot use the same promise. ``` declarePromisedInterface<MyInterface>(); // Resolves a promise. MyAttr1::attachInterface<MyInterface>(ctx); // Doesn't resolves a promise, as the previous attachment removed the promise. MyAttr2::attachInterface<MyInterface>(ctx); ``` 2. Is not possible to query if a promise has been declared. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D158464
This commit is contained in:
parent
5857fe0647
commit
d0e6fd99aa
@ -87,6 +87,15 @@ public:
|
||||
|
||||
friend ::llvm::hash_code hash_value(Attribute arg);
|
||||
|
||||
/// Returns true if `InterfaceT` has been promised by the dialect or
|
||||
/// implemented.
|
||||
template <typename InterfaceT>
|
||||
bool hasPromiseOrImplementsInterface() {
|
||||
return dialect_extension_detail::hasPromisedInterface(
|
||||
getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
|
||||
mlir::isa<InterfaceT>(*this);
|
||||
}
|
||||
|
||||
/// Returns true if the type was registered with a particular trait.
|
||||
template <template <typename T> class Trait>
|
||||
bool hasTrait() {
|
||||
@ -289,7 +298,7 @@ private:
|
||||
// Check that the current interface isn't an unresolved promise for the
|
||||
// given attribute.
|
||||
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
|
||||
attr.getDialect(), ConcreteType::getInterfaceID(),
|
||||
attr.getDialect(), attr.getTypeID(), ConcreteType::getInterfaceID(),
|
||||
llvm::getTypeName<ConcreteType>());
|
||||
#endif
|
||||
|
||||
|
@ -160,7 +160,7 @@ public:
|
||||
/// nullptr.
|
||||
DialectInterface *getRegisteredInterface(TypeID interfaceID) {
|
||||
#ifndef NDEBUG
|
||||
handleUseOfUndefinedPromisedInterface(interfaceID);
|
||||
handleUseOfUndefinedPromisedInterface(getTypeID(), interfaceID);
|
||||
#endif
|
||||
|
||||
auto it = registeredInterfaces.find(interfaceID);
|
||||
@ -169,7 +169,8 @@ public:
|
||||
template <typename InterfaceT>
|
||||
InterfaceT *getRegisteredInterface() {
|
||||
#ifndef NDEBUG
|
||||
handleUseOfUndefinedPromisedInterface(InterfaceT::getInterfaceID(),
|
||||
handleUseOfUndefinedPromisedInterface(getTypeID(),
|
||||
InterfaceT::getInterfaceID(),
|
||||
llvm::getTypeName<InterfaceT>());
|
||||
#endif
|
||||
|
||||
@ -209,18 +210,21 @@ public:
|
||||
/// registration. The promised interface type can be an interface of any type
|
||||
/// not just a dialect interface, i.e. it may also be an
|
||||
/// AttributeInterface/OpInterface/TypeInterface/etc.
|
||||
template <typename InterfaceT>
|
||||
template <typename ConcreteT, typename InterfaceT>
|
||||
void declarePromisedInterface() {
|
||||
unresolvedPromisedInterfaces.insert(InterfaceT::getInterfaceID());
|
||||
unresolvedPromisedInterfaces.insert(
|
||||
{TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
|
||||
}
|
||||
|
||||
/// Checks if the given interface, which is attempting to be used, is a
|
||||
/// promised interface of this dialect that has yet to be implemented. If so,
|
||||
/// emits a fatal error. `interfaceName` is an optional string that contains a
|
||||
/// more user readable name for the interface (such as the class name).
|
||||
void handleUseOfUndefinedPromisedInterface(TypeID interfaceID,
|
||||
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
|
||||
TypeID interfaceID,
|
||||
StringRef interfaceName = "") {
|
||||
if (unresolvedPromisedInterfaces.count(interfaceID)) {
|
||||
if (unresolvedPromisedInterfaces.count(
|
||||
{interfaceRequestorID, interfaceID})) {
|
||||
llvm::report_fatal_error(
|
||||
"checking for an interface (`" + interfaceName +
|
||||
"`) that was promised by dialect '" + getNamespace() +
|
||||
@ -229,11 +233,27 @@ public:
|
||||
"registered.");
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the given interface, which is attempting to be attached to a
|
||||
/// construct owned by this dialect, is a promised interface of this dialect
|
||||
/// that has yet to be implemented. If so, it resolves the interface promise.
|
||||
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceID) {
|
||||
unresolvedPromisedInterfaces.erase(interfaceID);
|
||||
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
|
||||
TypeID interfaceID) {
|
||||
unresolvedPromisedInterfaces.erase({interfaceRequestorID, interfaceID});
|
||||
}
|
||||
|
||||
/// Checks if a promise has been made for the interface/requestor pair.
|
||||
bool hasPromisedInterface(TypeID interfaceRequestorID,
|
||||
TypeID interfaceID) const {
|
||||
return unresolvedPromisedInterfaces.count(
|
||||
{interfaceRequestorID, interfaceID});
|
||||
}
|
||||
|
||||
/// Checks if a promise has been made for the interface/requestor pair.
|
||||
template <typename ConcreteT, typename InterfaceT>
|
||||
bool hasPromisedInterface() const {
|
||||
return hasPromisedInterface(TypeID::get<ConcreteT>(),
|
||||
InterfaceT::getInterfaceID());
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -332,7 +352,7 @@ private:
|
||||
/// A set of interfaces that the dialect (or its constructs, i.e.
|
||||
/// Attributes/Operations/Types/etc.) has promised to implement, but has yet
|
||||
/// to provide an implementation for.
|
||||
DenseSet<TypeID> unresolvedPromisedInterfaces;
|
||||
DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces;
|
||||
|
||||
friend class DialectRegistry;
|
||||
friend void registerDialect();
|
||||
|
@ -102,15 +102,29 @@ namespace dialect_extension_detail {
|
||||
/// Checks if the given interface, which is attempting to be used, is a
|
||||
/// promised interface of this dialect that has yet to be implemented. If so,
|
||||
/// emits a fatal error.
|
||||
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceID,
|
||||
void handleUseOfUndefinedPromisedInterface(Dialect &dialect,
|
||||
TypeID interfaceRequestorID,
|
||||
TypeID interfaceID,
|
||||
StringRef interfaceName);
|
||||
|
||||
/// Checks if the given interface, which is attempting to be attached, is a
|
||||
/// promised interface of this dialect that has yet to be implemented. If so,
|
||||
/// the promised interface is marked as resolved.
|
||||
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect,
|
||||
TypeID interfaceRequestorID,
|
||||
TypeID interfaceID);
|
||||
|
||||
/// Checks if a promise has been made for the interface/requestor pair.
|
||||
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
|
||||
TypeID interfaceID);
|
||||
|
||||
/// Checks if a promise has been made for the interface/requestor pair.
|
||||
template <typename ConcreteT, typename InterfaceT>
|
||||
bool hasPromisedInterface(Dialect &dialect) {
|
||||
return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(),
|
||||
InterfaceT::getInterfaceID());
|
||||
}
|
||||
|
||||
} // namespace dialect_extension_detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -450,6 +450,30 @@ class Results<dag rets> {
|
||||
dag results = rets;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common promised interface constraints
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// This constrait represents a promise or an implementation of an attr interface.
|
||||
class PromisedAttrInterface<AttrInterface interface> : AttrConstraint<
|
||||
CPred<"$_self.hasPromiseOrImplementsInterface<" #
|
||||
!if(!empty(interface.cppNamespace),
|
||||
"",
|
||||
interface.cppNamespace # "::") # interface.cppInterfaceName #">()">,
|
||||
"promising or implementing the `" # interface.cppInterfaceName # "` attr interface">;
|
||||
|
||||
// This predicate checks if the type promises or implementats a type interface.
|
||||
class HasPromiseOrImplementsTypeInterface<TypeInterface interface> :
|
||||
CPred<"$_self.hasPromiseOrImplementsInterface<" #
|
||||
!if(!empty(interface.cppNamespace),
|
||||
"",
|
||||
interface.cppNamespace # "::") # interface.cppInterfaceName #">()">;
|
||||
|
||||
// This constrait represents a promise or an implementation of a type interface.
|
||||
class PromisedTypeInterface<TypeInterface interface> : TypeConstraint<
|
||||
HasPromiseOrImplementsTypeInterface<interface>,
|
||||
"promising or implementing the `" # interface.cppInterfaceName # "` type interface">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common op type constraints
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2075,7 +2075,7 @@ protected:
|
||||
// given operation.
|
||||
if (Dialect *dialect = name.getDialect()) {
|
||||
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
|
||||
*dialect, ConcreteType::getInterfaceID(),
|
||||
*dialect, name.getTypeID(), ConcreteType::getInterfaceID(),
|
||||
llvm::getTypeName<ConcreteType>());
|
||||
}
|
||||
#endif
|
||||
|
@ -698,6 +698,13 @@ public:
|
||||
/// If folding was unsuccessful, this function returns "failure".
|
||||
LogicalResult fold(SmallVectorImpl<OpFoldResult> &results);
|
||||
|
||||
/// Returns true if `InterfaceT` has been promised by the dialect or
|
||||
/// implemented.
|
||||
template <typename InterfaceT>
|
||||
bool hasPromiseOrImplementsInterface() const {
|
||||
return name.hasPromiseOrImplementsInterface<InterfaceT>();
|
||||
}
|
||||
|
||||
/// Returns true if the operation was registered with a particular trait, e.g.
|
||||
/// hasTrait<OperandsAreSignlessIntegerLike>().
|
||||
template <template <typename T> class Trait>
|
||||
|
@ -351,12 +351,21 @@ public:
|
||||
void attachInterface() {
|
||||
// Handle the case where the models resolve a promised interface.
|
||||
(dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
|
||||
*getDialect(), Models::Interface::getInterfaceID()),
|
||||
*getDialect(), getTypeID(), Models::Interface::getInterfaceID()),
|
||||
...);
|
||||
|
||||
getImpl()->getInterfaceMap().insertModels<Models...>();
|
||||
}
|
||||
|
||||
/// Returns true if `InterfaceT` has been promised by the dialect or
|
||||
/// implemented.
|
||||
template <typename InterfaceT>
|
||||
bool hasPromiseOrImplementsInterface() const {
|
||||
return dialect_extension_detail::hasPromisedInterface(
|
||||
getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
|
||||
hasInterface<InterfaceT>();
|
||||
}
|
||||
|
||||
/// Returns true if this operation has the given interface registered to it.
|
||||
template <typename T>
|
||||
bool hasInterface() const {
|
||||
|
@ -163,7 +163,8 @@ public:
|
||||
|
||||
// Handle the case where the models resolve a promised interface.
|
||||
(dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
|
||||
abstract->getDialect(), IfaceModels::Interface::getInterfaceID()),
|
||||
abstract->getDialect(), abstract->getTypeID(),
|
||||
IfaceModels::Interface::getInterfaceID()),
|
||||
...);
|
||||
|
||||
(checkInterfaceTarget<IfaceModels>(), ...);
|
||||
|
@ -180,6 +180,15 @@ public:
|
||||
return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
|
||||
}
|
||||
|
||||
/// Returns true if `InterfaceT` has been promised by the dialect or
|
||||
/// implemented.
|
||||
template <typename InterfaceT>
|
||||
bool hasPromiseOrImplementsInterface() {
|
||||
return dialect_extension_detail::hasPromisedInterface(
|
||||
getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
|
||||
mlir::isa<InterfaceT>(*this);
|
||||
}
|
||||
|
||||
/// Returns true if the type was registered with a particular trait.
|
||||
template <template <typename T> class Trait>
|
||||
bool hasTrait() {
|
||||
@ -274,7 +283,7 @@ private:
|
||||
// Check that the current interface isn't an unresolved promise for the
|
||||
// given type.
|
||||
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
|
||||
type.getDialect(), ConcreteType::getInterfaceID(),
|
||||
type.getDialect(), type.getTypeID(), ConcreteType::getInterfaceID(),
|
||||
llvm::getTypeName<ConcreteType>());
|
||||
#endif
|
||||
|
||||
|
@ -40,7 +40,7 @@ void FuncDialect::initialize() {
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
|
||||
>();
|
||||
declarePromisedInterface<DialectInlinerInterface>();
|
||||
declarePromisedInterface<FuncDialect, DialectInlinerInterface>();
|
||||
}
|
||||
|
||||
/// Materialize a single constant operation from a given attribute value with
|
||||
|
@ -994,8 +994,8 @@ void NVVMDialect::initialize() {
|
||||
// Support unknown operations because not all NVVM operations are
|
||||
// registered.
|
||||
allowUnknownOperations();
|
||||
declarePromisedInterface<ConvertToLLVMPatternInterface>();
|
||||
declarePromisedInterface<gpu::TargetAttrInterface>();
|
||||
declarePromisedInterface<NVVMDialect, ConvertToLLVMPatternInterface>();
|
||||
declarePromisedInterface<NVVMTargetAttr, gpu::TargetAttrInterface>();
|
||||
}
|
||||
|
||||
LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
|
||||
|
@ -247,7 +247,7 @@ void ROCDLDialect::initialize() {
|
||||
|
||||
// Support unknown operations because not all ROCDL operations are registered.
|
||||
allowUnknownOperations();
|
||||
declarePromisedInterface<gpu::TargetAttrInterface>();
|
||||
declarePromisedInterface<ROCDLTargetAttr, gpu::TargetAttrInterface>();
|
||||
}
|
||||
|
||||
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
|
||||
|
@ -97,7 +97,7 @@ bool Dialect::isValidNamespace(StringRef str) {
|
||||
/// Register a set of dialect interfaces with this dialect instance.
|
||||
void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
|
||||
// Handle the case where the models resolve a promised interface.
|
||||
handleAdditionOfUndefinedPromisedInterface(interface->getID());
|
||||
handleAdditionOfUndefinedPromisedInterface(getTypeID(), interface->getID());
|
||||
|
||||
auto it = registeredInterfaces.try_emplace(interface->getID(),
|
||||
std::move(interface));
|
||||
@ -125,8 +125,8 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
|
||||
MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
|
||||
for (auto *dialect : ctx->getLoadedDialects()) {
|
||||
#ifndef NDEBUG
|
||||
dialect->handleUseOfUndefinedPromisedInterface(interfaceKind,
|
||||
interfaceName);
|
||||
dialect->handleUseOfUndefinedPromisedInterface(
|
||||
dialect->getTypeID(), interfaceKind, interfaceName);
|
||||
#endif
|
||||
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
|
||||
interfaces.insert(interface);
|
||||
@ -151,13 +151,22 @@ DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
|
||||
DialectExtensionBase::~DialectExtensionBase() = default;
|
||||
|
||||
void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
|
||||
Dialect &dialect, TypeID interfaceID, StringRef interfaceName) {
|
||||
dialect.handleUseOfUndefinedPromisedInterface(interfaceID, interfaceName);
|
||||
Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
|
||||
StringRef interfaceName) {
|
||||
dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
|
||||
interfaceID, interfaceName);
|
||||
}
|
||||
|
||||
void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
|
||||
Dialect &dialect, TypeID interfaceID) {
|
||||
dialect.handleAdditionOfUndefinedPromisedInterface(interfaceID);
|
||||
Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
|
||||
dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
|
||||
interfaceID);
|
||||
}
|
||||
|
||||
bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
|
||||
TypeID interfaceRequestorID,
|
||||
TypeID interfaceID) {
|
||||
return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -368,6 +368,20 @@ def DenseArrayNonNegativeOp : TEST_Op<"confined_non_negative_attr"> {
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Promised Interfaces Constraints
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PromisedInterfacesOp : TEST_Op<"promised_interfaces"> {
|
||||
let arguments = (ins
|
||||
ConfinedAttr<AnyAttr,
|
||||
[PromisedAttrInterface<TestExternalAttrInterface>]>:$promisedAttr,
|
||||
ConfinedType<AnyType,
|
||||
[HasPromiseOrImplementsTypeInterface<TestExternalTypeInterface>]
|
||||
>:$promisedType
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Enum Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -417,4 +417,30 @@ TEST(InterfaceAttachment, OperationDelayedContextAppend) {
|
||||
EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
|
||||
}
|
||||
|
||||
TEST(InterfaceAttachmentTest, PromisedInterfaces) {
|
||||
// Attribute interfaces use the exact same mechanism as types, so just check
|
||||
// that the promise mechanism works for attributes.
|
||||
MLIRContext context;
|
||||
auto testDialect = context.getOrLoadDialect<test::TestDialect>();
|
||||
auto attr = test::SimpleAAttr::get(&context);
|
||||
|
||||
// `SimpleAAttr` doesn't implement nor promises the
|
||||
// `TestExternalAttrInterface` interface.
|
||||
EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
|
||||
EXPECT_FALSE(
|
||||
attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
|
||||
|
||||
// Add a promise `TestExternalAttrInterface`.
|
||||
testDialect->declarePromisedInterface<test::SimpleAAttr,
|
||||
TestExternalAttrInterface>();
|
||||
EXPECT_TRUE(
|
||||
attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
|
||||
|
||||
// Attach the interface.
|
||||
test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context);
|
||||
EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
|
||||
EXPECT_TRUE(
|
||||
attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user