mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-24 04:26:07 +00:00
[MLIR] Add a second map for registered OperationName in MLIRContext (NFC) (#87170)
This speeds up registered op creation by 10-11% by allowing lookup by TypeID instead of StringRef. This can break your build/tests at runtime with an error that you're creating an unregistered operation that you have registered. If so you are likely using a class inheriting from the "real" operation. See for example in this patch the case of: class ConstantIndexOp : public arith::ConstantOp { If one is using `builder.create<ConstantIndexOp>()` they actually create an `arith.constant` operation, but the builder will fetch the TypeID for the `ConstantIndexOp` class which does not correspond to any registered operation. To fix it the `ConstantIndexOp` class got this addition: static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
This commit is contained in:
parent
75f7d53f0b
commit
82c6eeed08
@ -53,6 +53,7 @@ namespace arith {
|
||||
class ConstantIntOp : public arith::ConstantOp {
|
||||
public:
|
||||
using arith::ConstantOp::ConstantOp;
|
||||
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
|
||||
|
||||
/// Build a constant int op that produces an integer of the specified width.
|
||||
static void build(OpBuilder &builder, OperationState &result, int64_t value,
|
||||
@ -74,6 +75,7 @@ public:
|
||||
class ConstantFloatOp : public arith::ConstantOp {
|
||||
public:
|
||||
using arith::ConstantOp::ConstantOp;
|
||||
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
|
||||
|
||||
/// Build a constant float op that produces a float of the specified type.
|
||||
static void build(OpBuilder &builder, OperationState &result,
|
||||
@ -90,7 +92,7 @@ public:
|
||||
class ConstantIndexOp : public arith::ConstantOp {
|
||||
public:
|
||||
using arith::ConstantOp::ConstantOp;
|
||||
|
||||
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
|
||||
/// Build a constant int op that produces an index.
|
||||
static void build(OpBuilder &builder, OperationState &result, int64_t value);
|
||||
|
||||
|
@ -252,21 +252,21 @@ private:
|
||||
|
||||
template <typename OpTy>
|
||||
void TransformDialect::addOperationIfNotRegistered() {
|
||||
StringRef name = OpTy::getOperationName();
|
||||
std::optional<RegisteredOperationName> opName =
|
||||
RegisteredOperationName::lookup(name, getContext());
|
||||
RegisteredOperationName::lookup(TypeID::get<OpTy>(), getContext());
|
||||
if (!opName) {
|
||||
addOperations<OpTy>();
|
||||
#ifndef NDEBUG
|
||||
StringRef name = OpTy::getOperationName();
|
||||
detail::checkImplementsTransformOpInterface(name, getContext());
|
||||
#endif // NDEBUG
|
||||
return;
|
||||
}
|
||||
|
||||
if (opName->getTypeID() == TypeID::get<OpTy>())
|
||||
if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>()))
|
||||
return;
|
||||
|
||||
reportDuplicateOpRegistration(name);
|
||||
reportDuplicateOpRegistration(OpTy::getOperationName());
|
||||
}
|
||||
|
||||
template <typename Type>
|
||||
|
@ -490,7 +490,7 @@ private:
|
||||
template <typename OpT>
|
||||
RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
|
||||
std::optional<RegisteredOperationName> opName =
|
||||
RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
|
||||
RegisteredOperationName::lookup(TypeID::get<OpT>(), ctx);
|
||||
if (LLVM_UNLIKELY(!opName)) {
|
||||
llvm::report_fatal_error(
|
||||
"Building op `" + OpT::getOperationName() +
|
||||
|
@ -1729,8 +1729,7 @@ public:
|
||||
template <typename... Models>
|
||||
static void attachInterface(MLIRContext &context) {
|
||||
std::optional<RegisteredOperationName> info =
|
||||
RegisteredOperationName::lookup(ConcreteType::getOperationName(),
|
||||
&context);
|
||||
RegisteredOperationName::lookup(TypeID::get<ConcreteType>(), &context);
|
||||
if (!info)
|
||||
llvm::report_fatal_error(
|
||||
"Attempting to attach an interface to an unregistered operation " +
|
||||
|
@ -676,6 +676,11 @@ public:
|
||||
static std::optional<RegisteredOperationName> lookup(StringRef name,
|
||||
MLIRContext *ctx);
|
||||
|
||||
/// Lookup the registered operation information for the given operation.
|
||||
/// Returns std::nullopt if the operation isn't registered.
|
||||
static std::optional<RegisteredOperationName> lookup(TypeID typeID,
|
||||
MLIRContext *ctx);
|
||||
|
||||
/// Register a new operation in a Dialect object.
|
||||
/// This constructor is used by Dialect objects when they register the list
|
||||
/// of operations they contain.
|
||||
|
@ -183,7 +183,8 @@ public:
|
||||
llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;
|
||||
|
||||
/// A vector of operation info specifically for registered operations.
|
||||
llvm::StringMap<RegisteredOperationName> registeredOperations;
|
||||
llvm::DenseMap<TypeID, RegisteredOperationName> registeredOperations;
|
||||
llvm::StringMap<RegisteredOperationName> registeredOperationsByName;
|
||||
|
||||
/// This is a sorted container of registered operations for a deterministic
|
||||
/// and efficient `getRegisteredOperations` implementation.
|
||||
@ -780,8 +781,8 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
|
||||
// Check the registered info map first. In the overwhelmingly common case,
|
||||
// the entry will be in here and it also removes the need to acquire any
|
||||
// locks.
|
||||
auto registeredIt = ctxImpl.registeredOperations.find(name);
|
||||
if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
|
||||
auto registeredIt = ctxImpl.registeredOperationsByName.find(name);
|
||||
if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperationsByName.end())) {
|
||||
impl = registeredIt->second.impl;
|
||||
return;
|
||||
}
|
||||
@ -908,11 +909,20 @@ OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) {
|
||||
// RegisteredOperationName
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::optional<RegisteredOperationName>
|
||||
RegisteredOperationName::lookup(TypeID typeID, MLIRContext *ctx) {
|
||||
auto &impl = ctx->getImpl();
|
||||
auto it = impl.registeredOperations.find(typeID);
|
||||
if (it != impl.registeredOperations.end())
|
||||
return it->second;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<RegisteredOperationName>
|
||||
RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
|
||||
auto &impl = ctx->getImpl();
|
||||
auto it = impl.registeredOperations.find(name);
|
||||
if (it != impl.registeredOperations.end())
|
||||
auto it = impl.registeredOperationsByName.find(name);
|
||||
if (it != impl.registeredOperationsByName.end())
|
||||
return it->getValue();
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -945,11 +955,16 @@ void RegisteredOperationName::insert(
|
||||
|
||||
// Update the registered info for this operation.
|
||||
auto emplaced = ctxImpl.registeredOperations.try_emplace(
|
||||
name, RegisteredOperationName(impl));
|
||||
impl->getTypeID(), RegisteredOperationName(impl));
|
||||
assert(emplaced.second && "operation name registration must be successful");
|
||||
auto emplacedByName = ctxImpl.registeredOperationsByName.try_emplace(
|
||||
name, RegisteredOperationName(impl));
|
||||
(void)emplacedByName;
|
||||
assert(emplacedByName.second &&
|
||||
"operation name registration must be successful");
|
||||
|
||||
// Add emplaced operation name to the sorted operations container.
|
||||
RegisteredOperationName &value = emplaced.first->getValue();
|
||||
RegisteredOperationName &value = emplaced.first->second;
|
||||
ctxImpl.sortedRegisteredOperations.insert(
|
||||
llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
|
||||
[](auto &lhs, auto &rhs) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user