[mlir] Extend the promise interface mechanism

This patch pairs a promised interface with the object (Op/Attr/Type/Dialect) requesting the promise, ie:
```
declarePromisedInterface<MyAttr, MyInterface>();
```
Allowing to make fine grained promises. It also adds a mechanism to query if `Op/Attr/Type` has an specific promise returning true if the promise is there or if an implementation has been added. Finally it adds a couple of `Attr|TypeConstraints` that can be used in ODS to query if the promise or an implementation is there.

This patch tries to solve 2 issues:
1. Different entities cannot use the same promise.
```
declarePromisedInterface<MyInterface>();
// Resolves a promise.
MyAttr1::attachInterface<MyInterface>(ctx);
// Doesn't resolves a promise, as the previous attachment removed the promise.
MyAttr2::attachInterface<MyInterface>(ctx);
```
2. Is not possible to query if a promise has been declared.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D158464
This commit is contained in:
Fabian Mora 2023-09-05 09:17:54 -04:00
parent 5857fe0647
commit d0e6fd99aa
15 changed files with 168 additions and 26 deletions

View File

@ -87,6 +87,15 @@ public:
friend ::llvm::hash_code hash_value(Attribute arg);
/// Returns true if `InterfaceT` has been promised by the dialect or
/// implemented.
template <typename InterfaceT>
bool hasPromiseOrImplementsInterface() {
return dialect_extension_detail::hasPromisedInterface(
getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
mlir::isa<InterfaceT>(*this);
}
/// Returns true if the type was registered with a particular trait.
template <template <typename T> class Trait>
bool hasTrait() {
@ -289,7 +298,7 @@ private:
// Check that the current interface isn't an unresolved promise for the
// given attribute.
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
attr.getDialect(), ConcreteType::getInterfaceID(),
attr.getDialect(), attr.getTypeID(), ConcreteType::getInterfaceID(),
llvm::getTypeName<ConcreteType>());
#endif

View File

@ -160,7 +160,7 @@ public:
/// nullptr.
DialectInterface *getRegisteredInterface(TypeID interfaceID) {
#ifndef NDEBUG
handleUseOfUndefinedPromisedInterface(interfaceID);
handleUseOfUndefinedPromisedInterface(getTypeID(), interfaceID);
#endif
auto it = registeredInterfaces.find(interfaceID);
@ -169,7 +169,8 @@ public:
template <typename InterfaceT>
InterfaceT *getRegisteredInterface() {
#ifndef NDEBUG
handleUseOfUndefinedPromisedInterface(InterfaceT::getInterfaceID(),
handleUseOfUndefinedPromisedInterface(getTypeID(),
InterfaceT::getInterfaceID(),
llvm::getTypeName<InterfaceT>());
#endif
@ -209,18 +210,21 @@ public:
/// registration. The promised interface type can be an interface of any type
/// not just a dialect interface, i.e. it may also be an
/// AttributeInterface/OpInterface/TypeInterface/etc.
template <typename InterfaceT>
template <typename ConcreteT, typename InterfaceT>
void declarePromisedInterface() {
unresolvedPromisedInterfaces.insert(InterfaceT::getInterfaceID());
unresolvedPromisedInterfaces.insert(
{TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
}
/// Checks if the given interface, which is attempting to be used, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// emits a fatal error. `interfaceName` is an optional string that contains a
/// more user readable name for the interface (such as the class name).
void handleUseOfUndefinedPromisedInterface(TypeID interfaceID,
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
TypeID interfaceID,
StringRef interfaceName = "") {
if (unresolvedPromisedInterfaces.count(interfaceID)) {
if (unresolvedPromisedInterfaces.count(
{interfaceRequestorID, interfaceID})) {
llvm::report_fatal_error(
"checking for an interface (`" + interfaceName +
"`) that was promised by dialect '" + getNamespace() +
@ -229,11 +233,27 @@ public:
"registered.");
}
}
/// Checks if the given interface, which is attempting to be attached to a
/// construct owned by this dialect, is a promised interface of this dialect
/// that has yet to be implemented. If so, it resolves the interface promise.
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceID) {
unresolvedPromisedInterfaces.erase(interfaceID);
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
TypeID interfaceID) {
unresolvedPromisedInterfaces.erase({interfaceRequestorID, interfaceID});
}
/// Checks if a promise has been made for the interface/requestor pair.
bool hasPromisedInterface(TypeID interfaceRequestorID,
TypeID interfaceID) const {
return unresolvedPromisedInterfaces.count(
{interfaceRequestorID, interfaceID});
}
/// Checks if a promise has been made for the interface/requestor pair.
template <typename ConcreteT, typename InterfaceT>
bool hasPromisedInterface() const {
return hasPromisedInterface(TypeID::get<ConcreteT>(),
InterfaceT::getInterfaceID());
}
protected:
@ -332,7 +352,7 @@ private:
/// A set of interfaces that the dialect (or its constructs, i.e.
/// Attributes/Operations/Types/etc.) has promised to implement, but has yet
/// to provide an implementation for.
DenseSet<TypeID> unresolvedPromisedInterfaces;
DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces;
friend class DialectRegistry;
friend void registerDialect();

View File

@ -102,15 +102,29 @@ namespace dialect_extension_detail {
/// Checks if the given interface, which is attempting to be used, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// emits a fatal error.
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceID,
void handleUseOfUndefinedPromisedInterface(Dialect &dialect,
TypeID interfaceRequestorID,
TypeID interfaceID,
StringRef interfaceName);
/// Checks if the given interface, which is attempting to be attached, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// the promised interface is marked as resolved.
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect,
TypeID interfaceRequestorID,
TypeID interfaceID);
/// Checks if a promise has been made for the interface/requestor pair.
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
TypeID interfaceID);
/// Checks if a promise has been made for the interface/requestor pair.
template <typename ConcreteT, typename InterfaceT>
bool hasPromisedInterface(Dialect &dialect) {
return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(),
InterfaceT::getInterfaceID());
}
} // namespace dialect_extension_detail
//===----------------------------------------------------------------------===//

View File

@ -450,6 +450,30 @@ class Results<dag rets> {
dag results = rets;
}
//===----------------------------------------------------------------------===//
// Common promised interface constraints
//===----------------------------------------------------------------------===//
// This constrait represents a promise or an implementation of an attr interface.
class PromisedAttrInterface<AttrInterface interface> : AttrConstraint<
CPred<"$_self.hasPromiseOrImplementsInterface<" #
!if(!empty(interface.cppNamespace),
"",
interface.cppNamespace # "::") # interface.cppInterfaceName #">()">,
"promising or implementing the `" # interface.cppInterfaceName # "` attr interface">;
// This predicate checks if the type promises or implementats a type interface.
class HasPromiseOrImplementsTypeInterface<TypeInterface interface> :
CPred<"$_self.hasPromiseOrImplementsInterface<" #
!if(!empty(interface.cppNamespace),
"",
interface.cppNamespace # "::") # interface.cppInterfaceName #">()">;
// This constrait represents a promise or an implementation of a type interface.
class PromisedTypeInterface<TypeInterface interface> : TypeConstraint<
HasPromiseOrImplementsTypeInterface<interface>,
"promising or implementing the `" # interface.cppInterfaceName # "` type interface">;
//===----------------------------------------------------------------------===//
// Common op type constraints
//===----------------------------------------------------------------------===//

View File

@ -2075,7 +2075,7 @@ protected:
// given operation.
if (Dialect *dialect = name.getDialect()) {
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
*dialect, ConcreteType::getInterfaceID(),
*dialect, name.getTypeID(), ConcreteType::getInterfaceID(),
llvm::getTypeName<ConcreteType>());
}
#endif

View File

@ -698,6 +698,13 @@ public:
/// If folding was unsuccessful, this function returns "failure".
LogicalResult fold(SmallVectorImpl<OpFoldResult> &results);
/// Returns true if `InterfaceT` has been promised by the dialect or
/// implemented.
template <typename InterfaceT>
bool hasPromiseOrImplementsInterface() const {
return name.hasPromiseOrImplementsInterface<InterfaceT>();
}
/// Returns true if the operation was registered with a particular trait, e.g.
/// hasTrait<OperandsAreSignlessIntegerLike>().
template <template <typename T> class Trait>

View File

@ -351,12 +351,21 @@ public:
void attachInterface() {
// Handle the case where the models resolve a promised interface.
(dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
*getDialect(), Models::Interface::getInterfaceID()),
*getDialect(), getTypeID(), Models::Interface::getInterfaceID()),
...);
getImpl()->getInterfaceMap().insertModels<Models...>();
}
/// Returns true if `InterfaceT` has been promised by the dialect or
/// implemented.
template <typename InterfaceT>
bool hasPromiseOrImplementsInterface() const {
return dialect_extension_detail::hasPromisedInterface(
getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
hasInterface<InterfaceT>();
}
/// Returns true if this operation has the given interface registered to it.
template <typename T>
bool hasInterface() const {

View File

@ -163,7 +163,8 @@ public:
// Handle the case where the models resolve a promised interface.
(dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
abstract->getDialect(), IfaceModels::Interface::getInterfaceID()),
abstract->getDialect(), abstract->getTypeID(),
IfaceModels::Interface::getInterfaceID()),
...);
(checkInterfaceTarget<IfaceModels>(), ...);

View File

@ -180,6 +180,15 @@ public:
return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
}
/// Returns true if `InterfaceT` has been promised by the dialect or
/// implemented.
template <typename InterfaceT>
bool hasPromiseOrImplementsInterface() {
return dialect_extension_detail::hasPromisedInterface(
getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
mlir::isa<InterfaceT>(*this);
}
/// Returns true if the type was registered with a particular trait.
template <template <typename T> class Trait>
bool hasTrait() {
@ -274,7 +283,7 @@ private:
// Check that the current interface isn't an unresolved promise for the
// given type.
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
type.getDialect(), ConcreteType::getInterfaceID(),
type.getDialect(), type.getTypeID(), ConcreteType::getInterfaceID(),
llvm::getTypeName<ConcreteType>());
#endif

View File

@ -40,7 +40,7 @@ void FuncDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
>();
declarePromisedInterface<DialectInlinerInterface>();
declarePromisedInterface<FuncDialect, DialectInlinerInterface>();
}
/// Materialize a single constant operation from a given attribute value with

View File

@ -994,8 +994,8 @@ void NVVMDialect::initialize() {
// Support unknown operations because not all NVVM operations are
// registered.
allowUnknownOperations();
declarePromisedInterface<ConvertToLLVMPatternInterface>();
declarePromisedInterface<gpu::TargetAttrInterface>();
declarePromisedInterface<NVVMDialect, ConvertToLLVMPatternInterface>();
declarePromisedInterface<NVVMTargetAttr, gpu::TargetAttrInterface>();
}
LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,

View File

@ -247,7 +247,7 @@ void ROCDLDialect::initialize() {
// Support unknown operations because not all ROCDL operations are registered.
allowUnknownOperations();
declarePromisedInterface<gpu::TargetAttrInterface>();
declarePromisedInterface<ROCDLTargetAttr, gpu::TargetAttrInterface>();
}
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,

View File

@ -97,7 +97,7 @@ bool Dialect::isValidNamespace(StringRef str) {
/// Register a set of dialect interfaces with this dialect instance.
void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
// Handle the case where the models resolve a promised interface.
handleAdditionOfUndefinedPromisedInterface(interface->getID());
handleAdditionOfUndefinedPromisedInterface(getTypeID(), interface->getID());
auto it = registeredInterfaces.try_emplace(interface->getID(),
std::move(interface));
@ -125,8 +125,8 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
for (auto *dialect : ctx->getLoadedDialects()) {
#ifndef NDEBUG
dialect->handleUseOfUndefinedPromisedInterface(interfaceKind,
interfaceName);
dialect->handleUseOfUndefinedPromisedInterface(
dialect->getTypeID(), interfaceKind, interfaceName);
#endif
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
interfaces.insert(interface);
@ -151,13 +151,22 @@ DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
DialectExtensionBase::~DialectExtensionBase() = default;
void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
Dialect &dialect, TypeID interfaceID, StringRef interfaceName) {
dialect.handleUseOfUndefinedPromisedInterface(interfaceID, interfaceName);
Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
StringRef interfaceName) {
dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
interfaceID, interfaceName);
}
void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
Dialect &dialect, TypeID interfaceID) {
dialect.handleAdditionOfUndefinedPromisedInterface(interfaceID);
Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
interfaceID);
}
bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
TypeID interfaceRequestorID,
TypeID interfaceID) {
return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
}
//===----------------------------------------------------------------------===//

View File

@ -368,6 +368,20 @@ def DenseArrayNonNegativeOp : TEST_Op<"confined_non_negative_attr"> {
);
}
//===----------------------------------------------------------------------===//
// Test Promised Interfaces Constraints
//===----------------------------------------------------------------------===//
def PromisedInterfacesOp : TEST_Op<"promised_interfaces"> {
let arguments = (ins
ConfinedAttr<AnyAttr,
[PromisedAttrInterface<TestExternalAttrInterface>]>:$promisedAttr,
ConfinedType<AnyType,
[HasPromiseOrImplementsTypeInterface<TestExternalTypeInterface>]
>:$promisedType
);
}
//===----------------------------------------------------------------------===//
// Test Enum Attributes
//===----------------------------------------------------------------------===//

View File

@ -417,4 +417,30 @@ TEST(InterfaceAttachment, OperationDelayedContextAppend) {
EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
}
TEST(InterfaceAttachmentTest, PromisedInterfaces) {
// Attribute interfaces use the exact same mechanism as types, so just check
// that the promise mechanism works for attributes.
MLIRContext context;
auto testDialect = context.getOrLoadDialect<test::TestDialect>();
auto attr = test::SimpleAAttr::get(&context);
// `SimpleAAttr` doesn't implement nor promises the
// `TestExternalAttrInterface` interface.
EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
EXPECT_FALSE(
attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
// Add a promise `TestExternalAttrInterface`.
testDialect->declarePromisedInterface<test::SimpleAAttr,
TestExternalAttrInterface>();
EXPECT_TRUE(
attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
// Attach the interface.
test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context);
EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
EXPECT_TRUE(
attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
}
} // namespace