//===- MLIRContext.cpp - MLIR Type Classes --------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/MLIRContext.h" #include "AffineExprDetail.h" #include "AffineMapDetail.h" #include "AttributeDetail.h" #include "IntegerSetDetail.h" #include "TypeDetail.h" #include "mlir/IR/Action.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Mutex.h" #include "llvm/Support/RWMutex.h" #include "llvm/Support/ThreadPool.h" #include "llvm/Support/raw_ostream.h" #include #include #define DEBUG_TYPE "mlircontext" using namespace mlir; using namespace mlir::detail; //===----------------------------------------------------------------------===// // MLIRContext CommandLine Options //===----------------------------------------------------------------------===// namespace { /// This struct contains command line options that can be used to initialize /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need /// for global command line options. struct MLIRContextOptions { llvm::cl::opt disableThreading{ "mlir-disable-threading", llvm::cl::desc("Disable multi-threading within MLIR, overrides any " "further call to MLIRContext::enableMultiThreading()")}; llvm::cl::opt printOpOnDiagnostic{ "mlir-print-op-on-diagnostic", llvm::cl::desc("When a diagnostic is emitted on an operation, also print " "the operation as an attached note"), llvm::cl::init(true)}; llvm::cl::opt printStackTraceOnDiagnostic{ "mlir-print-stacktrace-on-diagnostic", llvm::cl::desc("When a diagnostic is emitted, also print the stack trace " "as an attached note")}; }; } // namespace static llvm::ManagedStatic clOptions; static bool isThreadingGloballyDisabled() { #if LLVM_ENABLE_THREADS != 0 return clOptions.isConstructed() && clOptions->disableThreading; #else return true; #endif } /// Register a set of useful command-line options that can be used to configure /// various flags within the MLIRContext. These flags are used when constructing /// an MLIR context for initialization. void mlir::registerMLIRContextCLOptions() { // Make sure that the options struct has been initialized. *clOptions; } //===----------------------------------------------------------------------===// // Locking Utilities //===----------------------------------------------------------------------===// namespace { /// Utility writer lock that takes a runtime flag that specifies if we really /// need to lock. struct ScopedWriterLock { ScopedWriterLock(llvm::sys::SmartRWMutex &mutexParam, bool shouldLock) : mutex(shouldLock ? &mutexParam : nullptr) { if (mutex) mutex->lock(); } ~ScopedWriterLock() { if (mutex) mutex->unlock(); } llvm::sys::SmartRWMutex *mutex; }; } // namespace //===----------------------------------------------------------------------===// // MLIRContextImpl //===----------------------------------------------------------------------===// namespace mlir { /// This is the implementation of the MLIRContext class, using the pImpl idiom. /// This class is completely private to this file, so everything is public. class MLIRContextImpl { public: //===--------------------------------------------------------------------===// // Debugging //===--------------------------------------------------------------------===// /// An action handler for handling actions that are dispatched through this /// context. std::function, const tracing::Action &)> actionHandler; //===--------------------------------------------------------------------===// // Diagnostics //===--------------------------------------------------------------------===// DiagnosticEngine diagEngine; //===--------------------------------------------------------------------===// // Options //===--------------------------------------------------------------------===// /// In most cases, creating operation in unregistered dialect is not desired /// and indicate a misconfiguration of the compiler. This option enables to /// detect such use cases bool allowUnregisteredDialects = false; /// Enable support for multi-threading within MLIR. bool threadingIsEnabled = true; /// Track if we are currently executing in a threaded execution environment /// (like the pass-manager): this is only a debugging feature to help reducing /// the chances of data races one some context APIs. #ifndef NDEBUG std::atomic multiThreadedExecutionContext{0}; #endif /// If the operation should be attached to diagnostics printed via the /// Operation::emit methods. bool printOpOnDiagnostic = true; /// If the current stack trace should be attached when emitting diagnostics. bool printStackTraceOnDiagnostic = false; //===--------------------------------------------------------------------===// // Other //===--------------------------------------------------------------------===// /// This points to the ThreadPool used when processing MLIR tasks in parallel. /// It can't be nullptr when multi-threading is enabled. Otherwise if /// multi-threading is disabled, and the threadpool wasn't externally provided /// using `setThreadPool`, this will be nullptr. llvm::ThreadPoolInterface *threadPool = nullptr; /// In case where the thread pool is owned by the context, this ensures /// destruction with the context. std::unique_ptr ownedThreadPool; /// An allocator used for AbstractAttribute and AbstractType objects. llvm::BumpPtrAllocator abstractDialectSymbolAllocator; /// This is a mapping from operation name to the operation info describing it. llvm::StringMap> operations; /// A vector of operation info specifically for registered operations. llvm::DenseMap registeredOperations; llvm::StringMap registeredOperationsByName; /// This is a sorted container of registered operations for a deterministic /// and efficient `getRegisteredOperations` implementation. SmallVector sortedRegisteredOperations; /// This is a list of dialects that are created referring to this context. /// The MLIRContext owns the objects. These need to be declared after the /// registered operations to ensure correct destruction order. DenseMap> loadedDialects; DialectRegistry dialectsRegistry; /// A mutex used when accessing operation information. llvm::sys::SmartRWMutex operationInfoMutex; //===--------------------------------------------------------------------===// // Affine uniquing //===--------------------------------------------------------------------===// // Affine expression, map and integer set uniquing. StorageUniquer affineUniquer; //===--------------------------------------------------------------------===// // Type uniquing //===--------------------------------------------------------------------===// DenseMap registeredTypes; StorageUniquer typeUniquer; /// This is a mapping from type name to the abstract type describing it. /// It is used by `AbstractType::lookup` to get an `AbstractType` from a name. /// As this map needs to be populated before `StringAttr` is loaded, we /// cannot use `StringAttr` as the key. The context does not take ownership /// of the key, so the `StringRef` must outlive the context. llvm::DenseMap nameToType; /// Cached Type Instances. Float4E2M1FNType f4E2M1FNTy; Float6E2M3FNType f6E2M3FNTy; Float6E3M2FNType f6E3M2FNTy; Float8E5M2Type f8E5M2Ty; Float8E4M3Type f8E4M3Ty; Float8E4M3FNType f8E4M3FNTy; Float8E5M2FNUZType f8E5M2FNUZTy; Float8E4M3FNUZType f8E4M3FNUZTy; Float8E4M3B11FNUZType f8E4M3B11FNUZTy; Float8E3M4Type f8E3M4Ty; Float8E8M0FNUType f8E8M0FNUTy; BFloat16Type bf16Ty; Float16Type f16Ty; FloatTF32Type tf32Ty; Float32Type f32Ty; Float64Type f64Ty; Float80Type f80Ty; Float128Type f128Ty; IndexType indexTy; IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; NoneType noneType; //===--------------------------------------------------------------------===// // Attribute uniquing //===--------------------------------------------------------------------===// DenseMap registeredAttributes; StorageUniquer attributeUniquer; /// This is a mapping from attribute name to the abstract attribute describing /// it. It is used by `AbstractType::lookup` to get an `AbstractType` from a /// name. /// As this map needs to be populated before `StringAttr` is loaded, we /// cannot use `StringAttr` as the key. The context does not take ownership /// of the key, so the `StringRef` must outlive the context. llvm::DenseMap nameToAttribute; /// Cached Attribute Instances. BoolAttr falseAttr, trueAttr; UnitAttr unitAttr; UnknownLoc unknownLocAttr; DictionaryAttr emptyDictionaryAttr; StringAttr emptyStringAttr; /// Map of string attributes that may reference a dialect, that are awaiting /// that dialect to be loaded. llvm::sys::SmartMutex dialectRefStrAttrMutex; DenseMap> dialectReferencingStrAttrs; /// A distinct attribute allocator that allocates every time since the /// address of the distinct attribute storage serves as unique identifier. The /// allocator is thread safe and frees the allocated storage after its /// destruction. DistinctAttributeAllocator distinctAttributeAllocator; public: MLIRContextImpl(bool threadingIsEnabled) : threadingIsEnabled(threadingIsEnabled) { if (threadingIsEnabled) { ownedThreadPool = std::make_unique(); threadPool = ownedThreadPool.get(); } } ~MLIRContextImpl() { for (auto typeMapping : registeredTypes) typeMapping.second->~AbstractType(); for (auto attrMapping : registeredAttributes) attrMapping.second->~AbstractAttribute(); } }; } // namespace mlir MLIRContext::MLIRContext(Threading setting) : MLIRContext(DialectRegistry(), setting) {} MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) : impl(new MLIRContextImpl(setting == Threading::ENABLED && !isThreadingGloballyDisabled())) { // Initialize values based on the command line flags if they were provided. if (clOptions.isConstructed()) { printOpOnDiagnostic(clOptions->printOpOnDiagnostic); printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); } // Pre-populate the registry. registry.appendTo(impl->dialectsRegistry); // Ensure the builtin dialect is always pre-loaded. getOrLoadDialect(); // Initialize several common attributes and types to avoid the need to lock // the context when accessing them. //// Types. /// Floating-point Types. impl->f4E2M1FNTy = TypeUniquer::get(this); impl->f6E2M3FNTy = TypeUniquer::get(this); impl->f6E3M2FNTy = TypeUniquer::get(this); impl->f8E5M2Ty = TypeUniquer::get(this); impl->f8E4M3Ty = TypeUniquer::get(this); impl->f8E4M3FNTy = TypeUniquer::get(this); impl->f8E5M2FNUZTy = TypeUniquer::get(this); impl->f8E4M3FNUZTy = TypeUniquer::get(this); impl->f8E4M3B11FNUZTy = TypeUniquer::get(this); impl->f8E3M4Ty = TypeUniquer::get(this); impl->f8E8M0FNUTy = TypeUniquer::get(this); impl->bf16Ty = TypeUniquer::get(this); impl->f16Ty = TypeUniquer::get(this); impl->tf32Ty = TypeUniquer::get(this); impl->f32Ty = TypeUniquer::get(this); impl->f64Ty = TypeUniquer::get(this); impl->f80Ty = TypeUniquer::get(this); impl->f128Ty = TypeUniquer::get(this); /// Index Type. impl->indexTy = TypeUniquer::get(this); /// Integer Types. impl->int1Ty = TypeUniquer::get(this, 1, IntegerType::Signless); impl->int8Ty = TypeUniquer::get(this, 8, IntegerType::Signless); impl->int16Ty = TypeUniquer::get(this, 16, IntegerType::Signless); impl->int32Ty = TypeUniquer::get(this, 32, IntegerType::Signless); impl->int64Ty = TypeUniquer::get(this, 64, IntegerType::Signless); impl->int128Ty = TypeUniquer::get(this, 128, IntegerType::Signless); /// None Type. impl->noneType = TypeUniquer::get(this); //// Attributes. //// Note: These must be registered after the types as they may generate one //// of the above types internally. /// Unknown Location Attribute. impl->unknownLocAttr = AttributeUniquer::get(this); /// Bool Attributes. impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false); impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true); /// Unit Attribute. impl->unitAttr = AttributeUniquer::get(this); /// The empty dictionary attribute. impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this); /// The empty string attribute. impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this); // Register the affine storage objects with the uniquer. impl->affineUniquer .registerParametricStorageType(); impl->affineUniquer .registerParametricStorageType(); impl->affineUniquer.registerParametricStorageType(); impl->affineUniquer.registerParametricStorageType(); impl->affineUniquer.registerParametricStorageType(); } MLIRContext::~MLIRContext() = default; /// Copy the specified array of elements into memory managed by the provided /// bump pointer allocator. This assumes the elements are all PODs. template static ArrayRef copyArrayRefInto(llvm::BumpPtrAllocator &allocator, ArrayRef elements) { auto result = allocator.Allocate(elements.size()); std::uninitialized_copy(elements.begin(), elements.end(), result); return ArrayRef(result, elements.size()); } //===----------------------------------------------------------------------===// // Action Handling //===----------------------------------------------------------------------===// void MLIRContext::registerActionHandler(HandlerTy handler) { getImpl().actionHandler = std::move(handler); } /// Dispatch the provided action to the handler if any, or just execute it. void MLIRContext::executeActionInternal(function_ref actionFn, const tracing::Action &action) { assert(getImpl().actionHandler); getImpl().actionHandler(actionFn, action); } bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; } //===----------------------------------------------------------------------===// // Diagnostic Handlers //===----------------------------------------------------------------------===// /// Returns the diagnostic engine for this context. DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } //===----------------------------------------------------------------------===// // Dialect and Operation Registration //===----------------------------------------------------------------------===// void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { if (registry.isSubsetOf(impl->dialectsRegistry)) return; assert(impl->multiThreadedExecutionContext == 0 && "appending to the MLIRContext dialect registry while in a " "multi-threaded execution context"); registry.appendTo(impl->dialectsRegistry); // For the already loaded dialects, apply any possible extensions immediately. registry.applyExtensions(this); } const DialectRegistry &MLIRContext::getDialectRegistry() { return impl->dialectsRegistry; } /// Return information about all registered IR dialects. std::vector MLIRContext::getLoadedDialects() { std::vector result; result.reserve(impl->loadedDialects.size()); for (auto &dialect : impl->loadedDialects) result.push_back(dialect.second.get()); llvm::array_pod_sort(result.begin(), result.end(), [](Dialect *const *lhs, Dialect *const *rhs) -> int { return (*lhs)->getNamespace() < (*rhs)->getNamespace(); }); return result; } std::vector MLIRContext::getAvailableDialects() { std::vector result; for (auto dialect : impl->dialectsRegistry.getDialectNames()) result.push_back(dialect); return result; } /// Get a registered IR dialect with the given namespace. If none is found, /// then return nullptr. Dialect *MLIRContext::getLoadedDialect(StringRef name) { // Dialects are sorted by name, so we can use binary search for lookup. auto it = impl->loadedDialects.find(name); return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr; } Dialect *MLIRContext::getOrLoadDialect(StringRef name) { Dialect *dialect = getLoadedDialect(name); if (dialect) return dialect; DialectAllocatorFunctionRef allocator = impl->dialectsRegistry.getDialectAllocator(name); return allocator ? allocator(this) : nullptr; } /// Get a dialect for the provided namespace and TypeID: abort the program if a /// dialect exist for this namespace with different TypeID. Returns a pointer to /// the dialect owned by the context. Dialect * MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, function_ref()> ctor) { auto &impl = getImpl(); // Get the correct insertion position sorted by namespace. auto dialectIt = impl.loadedDialects.try_emplace(dialectNamespace, nullptr); if (dialectIt.second) { LLVM_DEBUG(llvm::dbgs() << "Load new dialect in Context " << dialectNamespace << "\n"); #ifndef NDEBUG if (impl.multiThreadedExecutionContext != 0) llvm::report_fatal_error( "Loading a dialect (" + dialectNamespace + ") while in a multi-threaded execution context (maybe " "the PassManager): this can indicate a " "missing `dependentDialects` in a pass for example."); #endif // NDEBUG // loadedDialects entry is initialized to nullptr, indicating that the // dialect is currently being loaded. Re-lookup the address in // loadedDialects because the table might have been rehashed by recursive // dialect loading in ctor(). std::unique_ptr &dialectOwned = impl.loadedDialects[dialectNamespace] = ctor(); Dialect *dialect = dialectOwned.get(); assert(dialect && "dialect ctor failed"); // Refresh all the identifiers dialect field, this catches cases where a // dialect may be loaded after identifier prefixed with this dialect name // were already created. auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace); if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) { for (StringAttrStorage *storage : stringAttrsIt->second) storage->referencedDialect = dialect; impl.dialectReferencingStrAttrs.erase(stringAttrsIt); } // Apply any extensions to this newly loaded dialect. impl.dialectsRegistry.applyExtensions(dialect); return dialect; } #ifndef NDEBUG if (dialectIt.first->second == nullptr) llvm::report_fatal_error( "Loading (and getting) a dialect (" + dialectNamespace + ") while the same dialect is still loading: use loadDialect instead " "of getOrLoadDialect."); #endif // NDEBUG // Abort if dialect with namespace has already been registered. std::unique_ptr &dialect = dialectIt.first->second; if (dialect->getTypeID() != dialectID) llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + "' has already been registered"); return dialect.get(); } bool MLIRContext::isDialectLoading(StringRef dialectNamespace) { auto it = getImpl().loadedDialects.find(dialectNamespace); // nullptr indicates that the dialect is currently being loaded. return it != getImpl().loadedDialects.end() && it->second == nullptr; } DynamicDialect *MLIRContext::getOrLoadDynamicDialect( StringRef dialectNamespace, function_ref ctor) { auto &impl = getImpl(); // Get the correct insertion position sorted by namespace. auto dialectIt = impl.loadedDialects.find(dialectNamespace); if (dialectIt != impl.loadedDialects.end()) { if (auto *dynDialect = dyn_cast(dialectIt->second.get())) return dynDialect; llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + "' has already been registered"); } LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context " << dialectNamespace << "\n"); #ifndef NDEBUG if (impl.multiThreadedExecutionContext != 0) llvm::report_fatal_error( "Loading a dynamic dialect (" + dialectNamespace + ") while in a multi-threaded execution context (maybe " "the PassManager): this can indicate a " "missing `dependentDialects` in a pass for example."); #endif auto name = StringAttr::get(this, dialectNamespace); auto *dialect = new DynamicDialect(name, this); (void)getOrLoadDialect(name, dialect->getTypeID(), [dialect, ctor]() { ctor(dialect); return std::unique_ptr(dialect); }); // This is the same result as `getOrLoadDialect` (if it didn't failed), // since it has the same TypeID, and TypeIDs are unique. return dialect; } void MLIRContext::loadAllAvailableDialects() { for (StringRef name : getAvailableDialects()) getOrLoadDialect(name); } llvm::hash_code MLIRContext::getRegistryHash() { llvm::hash_code hash(0); // Factor in number of loaded dialects, attributes, operations, types. hash = llvm::hash_combine(hash, impl->loadedDialects.size()); hash = llvm::hash_combine(hash, impl->registeredAttributes.size()); hash = llvm::hash_combine(hash, impl->registeredOperations.size()); hash = llvm::hash_combine(hash, impl->registeredTypes.size()); return hash; } bool MLIRContext::allowsUnregisteredDialects() { return impl->allowUnregisteredDialects; } void MLIRContext::allowUnregisteredDialects(bool allowing) { assert(impl->multiThreadedExecutionContext == 0 && "changing MLIRContext `allow-unregistered-dialects` configuration " "while in a multi-threaded execution context"); impl->allowUnregisteredDialects = allowing; } /// Return true if multi-threading is enabled by the context. bool MLIRContext::isMultithreadingEnabled() { return impl->threadingIsEnabled && llvm::llvm_is_multithreaded(); } /// Set the flag specifying if multi-threading is disabled by the context. void MLIRContext::disableMultithreading(bool disable) { // This API can be overridden by the global debugging flag // --mlir-disable-threading if (isThreadingGloballyDisabled()) return; assert(impl->multiThreadedExecutionContext == 0 && "changing MLIRContext `disable-threading` configuration while " "in a multi-threaded execution context"); impl->threadingIsEnabled = !disable; // Update the threading mode for each of the uniquers. impl->affineUniquer.disableMultithreading(disable); impl->attributeUniquer.disableMultithreading(disable); impl->typeUniquer.disableMultithreading(disable); // Destroy thread pool (stop all threads) if it is no longer needed, or create // a new one if multithreading was re-enabled. if (disable) { // If the thread pool is owned, explicitly set it to nullptr to avoid // keeping a dangling pointer around. If the thread pool is externally // owned, we don't do anything. if (impl->ownedThreadPool) { assert(impl->threadPool); impl->threadPool = nullptr; impl->ownedThreadPool.reset(); } } else if (!impl->threadPool) { // The thread pool isn't externally provided. assert(!impl->ownedThreadPool); impl->ownedThreadPool = std::make_unique(); impl->threadPool = impl->ownedThreadPool.get(); } } void MLIRContext::setThreadPool(llvm::ThreadPoolInterface &pool) { assert(!isMultithreadingEnabled() && "expected multi-threading to be disabled when setting a ThreadPool"); impl->threadPool = &pool; impl->ownedThreadPool.reset(); enableMultithreading(); } unsigned MLIRContext::getNumThreads() { if (isMultithreadingEnabled()) { assert(impl->threadPool && "multi-threading is enabled but threadpool not set"); return impl->threadPool->getMaxConcurrency(); } // No multithreading or active thread pool. Return 1 thread. return 1; } llvm::ThreadPoolInterface &MLIRContext::getThreadPool() { assert(isMultithreadingEnabled() && "expected multi-threading to be enabled within the context"); assert(impl->threadPool && "multi-threading is enabled but threadpool not set"); return *impl->threadPool; } void MLIRContext::enterMultiThreadedExecution() { #ifndef NDEBUG ++impl->multiThreadedExecutionContext; #endif } void MLIRContext::exitMultiThreadedExecution() { #ifndef NDEBUG --impl->multiThreadedExecutionContext; #endif } /// Return true if we should attach the operation to diagnostics emitted via /// Operation::emit. bool MLIRContext::shouldPrintOpOnDiagnostic() { return impl->printOpOnDiagnostic; } /// Set the flag specifying if we should attach the operation to diagnostics /// emitted via Operation::emit. void MLIRContext::printOpOnDiagnostic(bool enable) { assert(impl->multiThreadedExecutionContext == 0 && "changing MLIRContext `print-op-on-diagnostic` configuration while in " "a multi-threaded execution context"); impl->printOpOnDiagnostic = enable; } /// Return true if we should attach the current stacktrace to diagnostics when /// emitted. bool MLIRContext::shouldPrintStackTraceOnDiagnostic() { return impl->printStackTraceOnDiagnostic; } /// Set the flag specifying if we should attach the current stacktrace when /// emitting diagnostics. void MLIRContext::printStackTraceOnDiagnostic(bool enable) { assert(impl->multiThreadedExecutionContext == 0 && "changing MLIRContext `print-stacktrace-on-diagnostic` configuration " "while in a multi-threaded execution context"); impl->printStackTraceOnDiagnostic = enable; } /// Return information about all registered operations. ArrayRef MLIRContext::getRegisteredOperations() { return impl->sortedRegisteredOperations; } /// Return information for registered operations by dialect. ArrayRef MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) { auto lowerBound = std::lower_bound(impl->sortedRegisteredOperations.begin(), impl->sortedRegisteredOperations.end(), dialectName, [](auto &lhs, auto &rhs) { return lhs.getDialect().getNamespace().compare(rhs); }); if (lowerBound == impl->sortedRegisteredOperations.end() || lowerBound->getDialect().getNamespace() != dialectName) return ArrayRef(); auto upperBound = std::upper_bound(lowerBound, impl->sortedRegisteredOperations.end(), dialectName, [](auto &lhs, auto &rhs) { return lhs.compare(rhs.getDialect().getNamespace()); }); size_t count = std::distance(lowerBound, upperBound); return ArrayRef(&*lowerBound, count); } bool MLIRContext::isOperationRegistered(StringRef name) { return RegisteredOperationName::lookup(name, this).has_value(); } void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { auto &impl = context->getImpl(); assert(impl.multiThreadedExecutionContext == 0 && "Registering a new type kind while in a multi-threaded execution " "context"); auto *newInfo = new (impl.abstractDialectSymbolAllocator.Allocate()) AbstractType(std::move(typeInfo)); if (!impl.registeredTypes.insert({typeID, newInfo}).second) llvm::report_fatal_error("Dialect Type already registered."); if (!impl.nameToType.insert({newInfo->getName(), newInfo}).second) llvm::report_fatal_error("Dialect Type with name " + newInfo->getName() + " is already registered."); } void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { auto &impl = context->getImpl(); assert(impl.multiThreadedExecutionContext == 0 && "Registering a new attribute kind while in a multi-threaded execution " "context"); auto *newInfo = new (impl.abstractDialectSymbolAllocator.Allocate()) AbstractAttribute(std::move(attrInfo)); if (!impl.registeredAttributes.insert({typeID, newInfo}).second) llvm::report_fatal_error("Dialect Attribute already registered."); if (!impl.nameToAttribute.insert({newInfo->getName(), newInfo}).second) llvm::report_fatal_error("Dialect Attribute with name " + newInfo->getName() + " is already registered."); } //===----------------------------------------------------------------------===// // AbstractAttribute //===----------------------------------------------------------------------===// /// Get the dialect that registered the attribute with the provided typeid. const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, MLIRContext *context) { const AbstractAttribute *abstract = lookupMutable(typeID, context); if (!abstract) llvm::report_fatal_error("Trying to create an Attribute that was not " "registered in this MLIRContext."); return *abstract; } AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID, MLIRContext *context) { auto &impl = context->getImpl(); return impl.registeredAttributes.lookup(typeID); } std::optional> AbstractAttribute::lookup(StringRef name, MLIRContext *context) { MLIRContextImpl &impl = context->getImpl(); const AbstractAttribute *type = impl.nameToAttribute.lookup(name); if (!type) return std::nullopt; return {*type}; } //===----------------------------------------------------------------------===// // OperationName //===----------------------------------------------------------------------===// OperationName::Impl::Impl(StringRef name, Dialect *dialect, TypeID typeID, detail::InterfaceMap interfaceMap) : Impl(StringAttr::get(dialect->getContext(), name), dialect, typeID, std::move(interfaceMap)) {} OperationName::OperationName(StringRef name, MLIRContext *context) { MLIRContextImpl &ctxImpl = context->getImpl(); // Check for an existing name in read-only mode. bool isMultithreadingEnabled = context->isMultithreadingEnabled(); if (isMultithreadingEnabled) { // Check the registered info map first. In the overwhelmingly common case, // the entry will be in here and it also removes the need to acquire any // locks. auto registeredIt = ctxImpl.registeredOperationsByName.find(name); if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperationsByName.end())) { impl = registeredIt->second.impl; return; } llvm::sys::SmartScopedReader contextLock(ctxImpl.operationInfoMutex); auto it = ctxImpl.operations.find(name); if (it != ctxImpl.operations.end()) { impl = it->second.get(); return; } } // Acquire a writer-lock so that we can safely create the new instance. ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled); auto it = ctxImpl.operations.insert({name, nullptr}); if (it.second) { auto nameAttr = StringAttr::get(context, name); it.first->second = std::make_unique( nameAttr, nameAttr.getReferencedDialect(), TypeID::get(), detail::InterfaceMap()); } impl = it.first->second.get(); } StringRef OperationName::getDialectNamespace() const { if (Dialect *dialect = getDialect()) return dialect->getNamespace(); return getStringRef().split('.').first; } LogicalResult OperationName::UnregisteredOpModel::foldHook(Operation *, ArrayRef, SmallVectorImpl &) { return failure(); } void OperationName::UnregisteredOpModel::getCanonicalizationPatterns( RewritePatternSet &, MLIRContext *) {} bool OperationName::UnregisteredOpModel::hasTrait(TypeID) { return false; } OperationName::ParseAssemblyFn OperationName::UnregisteredOpModel::getParseAssemblyFn() { llvm::report_fatal_error("getParseAssemblyFn hook called on unregistered op"); } void OperationName::UnregisteredOpModel::populateDefaultAttrs( const OperationName &, NamedAttrList &) {} void OperationName::UnregisteredOpModel::printAssembly( Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { p.printGenericOp(op); } LogicalResult OperationName::UnregisteredOpModel::verifyInvariants(Operation *) { return success(); } LogicalResult OperationName::UnregisteredOpModel::verifyRegionInvariants(Operation *) { return success(); } std::optional OperationName::UnregisteredOpModel::getInherentAttr(Operation *op, StringRef name) { auto dict = dyn_cast_or_null(getPropertiesAsAttr(op)); if (!dict) return std::nullopt; if (Attribute attr = dict.get(name)) return attr; return std::nullopt; } void OperationName::UnregisteredOpModel::setInherentAttr(Operation *op, StringAttr name, Attribute value) { auto dict = dyn_cast_or_null(getPropertiesAsAttr(op)); assert(dict); NamedAttrList attrs(dict); attrs.set(name, value); *op->getPropertiesStorage().as() = attrs.getDictionary(op->getContext()); } void OperationName::UnregisteredOpModel::populateInherentAttrs( Operation *op, NamedAttrList &attrs) {} LogicalResult OperationName::UnregisteredOpModel::verifyInherentAttrs( OperationName opName, NamedAttrList &attributes, function_ref emitError) { return success(); } int OperationName::UnregisteredOpModel::getOpPropertyByteSize() { return sizeof(Attribute); } void OperationName::UnregisteredOpModel::initProperties( OperationName opName, OpaqueProperties storage, OpaqueProperties init) { new (storage.as()) Attribute(); } void OperationName::UnregisteredOpModel::deleteProperties( OpaqueProperties prop) { prop.as()->~Attribute(); } void OperationName::UnregisteredOpModel::populateDefaultProperties( OperationName opName, OpaqueProperties properties) {} LogicalResult OperationName::UnregisteredOpModel::setPropertiesFromAttr( OperationName opName, OpaqueProperties properties, Attribute attr, function_ref emitError) { *properties.as() = attr; return success(); } Attribute OperationName::UnregisteredOpModel::getPropertiesAsAttr(Operation *op) { return *op->getPropertiesStorage().as(); } void OperationName::UnregisteredOpModel::copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) { *lhs.as() = *rhs.as(); } bool OperationName::UnregisteredOpModel::compareProperties( OpaqueProperties lhs, OpaqueProperties rhs) { return *lhs.as() == *rhs.as(); } llvm::hash_code OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) { return llvm::hash_combine(*prop.as()); } //===----------------------------------------------------------------------===// // RegisteredOperationName //===----------------------------------------------------------------------===// std::optional RegisteredOperationName::lookup(TypeID typeID, MLIRContext *ctx) { auto &impl = ctx->getImpl(); auto it = impl.registeredOperations.find(typeID); if (it != impl.registeredOperations.end()) return it->second; return std::nullopt; } std::optional RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) { auto &impl = ctx->getImpl(); auto it = impl.registeredOperationsByName.find(name); if (it != impl.registeredOperationsByName.end()) return it->getValue(); return std::nullopt; } void RegisteredOperationName::insert( std::unique_ptr ownedImpl, ArrayRef attrNames) { RegisteredOperationName::Impl *impl = ownedImpl.get(); MLIRContext *ctx = impl->getDialect()->getContext(); auto &ctxImpl = ctx->getImpl(); assert(ctxImpl.multiThreadedExecutionContext == 0 && "registering a new operation kind while in a multi-threaded execution " "context"); // Register the attribute names of this operation. MutableArrayRef cachedAttrNames; if (!attrNames.empty()) { cachedAttrNames = MutableArrayRef( ctxImpl.abstractDialectSymbolAllocator.Allocate( attrNames.size()), attrNames.size()); for (unsigned i : llvm::seq(0, attrNames.size())) new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i])); impl->attributeNames = cachedAttrNames; } StringRef name = impl->getName().strref(); // Insert the operation info if it doesn't exist yet. ctxImpl.operations[name] = std::move(ownedImpl); // Update the registered info for this operation. auto emplaced = ctxImpl.registeredOperations.try_emplace( impl->getTypeID(), RegisteredOperationName(impl)); assert(emplaced.second && "operation name registration must be successful"); auto emplacedByName = ctxImpl.registeredOperationsByName.try_emplace( name, RegisteredOperationName(impl)); (void)emplacedByName; assert(emplacedByName.second && "operation name registration must be successful"); // Add emplaced operation name to the sorted operations container. RegisteredOperationName &value = emplaced.first->second; ctxImpl.sortedRegisteredOperations.insert( llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value, [](auto &lhs, auto &rhs) { return lhs.getIdentifier().compare( rhs.getIdentifier()); }), value); } //===----------------------------------------------------------------------===// // AbstractType //===----------------------------------------------------------------------===// const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { const AbstractType *type = lookupMutable(typeID, context); if (!type) llvm::report_fatal_error( "Trying to create a Type that was not registered in this MLIRContext."); return *type; } AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) { auto &impl = context->getImpl(); return impl.registeredTypes.lookup(typeID); } std::optional> AbstractType::lookup(StringRef name, MLIRContext *context) { MLIRContextImpl &impl = context->getImpl(); const AbstractType *type = impl.nameToType.lookup(name); if (!type) return std::nullopt; return {*type}; } //===----------------------------------------------------------------------===// // Type uniquing //===----------------------------------------------------------------------===// /// Returns the storage uniquer used for constructing type storage instances. /// This should not be used directly. StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) { return context->getImpl().f4E2M1FNTy; } Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) { return context->getImpl().f6E2M3FNTy; } Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) { return context->getImpl().f6E3M2FNTy; } Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) { return context->getImpl().f8E5M2Ty; } Float8E4M3Type Float8E4M3Type::get(MLIRContext *context) { return context->getImpl().f8E4M3Ty; } Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) { return context->getImpl().f8E4M3FNTy; } Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) { return context->getImpl().f8E5M2FNUZTy; } Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) { return context->getImpl().f8E4M3FNUZTy; } Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) { return context->getImpl().f8E4M3B11FNUZTy; } Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) { return context->getImpl().f8E3M4Ty; } Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) { return context->getImpl().f8E8M0FNUTy; } BFloat16Type BFloat16Type::get(MLIRContext *context) { return context->getImpl().bf16Ty; } Float16Type Float16Type::get(MLIRContext *context) { return context->getImpl().f16Ty; } FloatTF32Type FloatTF32Type::get(MLIRContext *context) { return context->getImpl().tf32Ty; } Float32Type Float32Type::get(MLIRContext *context) { return context->getImpl().f32Ty; } Float64Type Float64Type::get(MLIRContext *context) { return context->getImpl().f64Ty; } Float80Type Float80Type::get(MLIRContext *context) { return context->getImpl().f80Ty; } Float128Type Float128Type::get(MLIRContext *context) { return context->getImpl().f128Ty; } /// Get an instance of the IndexType. IndexType IndexType::get(MLIRContext *context) { return context->getImpl().indexTy; } /// Return an existing integer type instance if one is cached within the /// context. static IntegerType getCachedIntegerType(unsigned width, IntegerType::SignednessSemantics signedness, MLIRContext *context) { if (signedness != IntegerType::Signless) return IntegerType(); switch (width) { case 1: return context->getImpl().int1Ty; case 8: return context->getImpl().int8Ty; case 16: return context->getImpl().int16Ty; case 32: return context->getImpl().int32Ty; case 64: return context->getImpl().int64Ty; case 128: return context->getImpl().int128Ty; default: return IntegerType(); } } IntegerType IntegerType::get(MLIRContext *context, unsigned width, IntegerType::SignednessSemantics signedness) { if (auto cached = getCachedIntegerType(width, signedness, context)) return cached; return Base::get(context, width, signedness); } IntegerType IntegerType::getChecked(function_ref emitError, MLIRContext *context, unsigned width, SignednessSemantics signedness) { if (auto cached = getCachedIntegerType(width, signedness, context)) return cached; return Base::getChecked(emitError, context, width, signedness); } /// Get an instance of the NoneType. NoneType NoneType::get(MLIRContext *context) { if (NoneType cachedInst = context->getImpl().noneType) return cachedInst; // Note: May happen when initializing the singleton attributes of the builtin // dialect. return Base::get(context); } //===----------------------------------------------------------------------===// // Attribute uniquing //===----------------------------------------------------------------------===// /// Returns the storage uniquer used for constructing attribute storage /// instances. This should not be used directly. StorageUniquer &MLIRContext::getAttributeUniquer() { return getImpl().attributeUniquer; } /// Initialize the given attribute storage instance. void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, MLIRContext *ctx, TypeID attrID) { storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx)); } BoolAttr BoolAttr::get(MLIRContext *context, bool value) { return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; } UnitAttr UnitAttr::get(MLIRContext *context) { return context->getImpl().unitAttr; } UnknownLoc UnknownLoc::get(MLIRContext *context) { return context->getImpl().unknownLocAttr; } DistinctAttrStorage * detail::DistinctAttributeUniquer::allocateStorage(MLIRContext *context, Attribute referencedAttr) { return context->getImpl().distinctAttributeAllocator.allocate(referencedAttr); } /// Return empty dictionary. DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) { return context->getImpl().emptyDictionaryAttr; } void StringAttrStorage::initialize(MLIRContext *context) { // Check for a dialect namespace prefix, if there isn't one we don't need to // do any additional initialization. auto dialectNamePair = value.split('.'); if (dialectNamePair.first.empty() || dialectNamePair.second.empty()) return; // If one exists, we check to see if this dialect is loaded. If it is, we set // the dialect now, if it isn't we record this storage for initialization // later if the dialect ever gets loaded. if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first))) return; MLIRContextImpl &impl = context->getImpl(); llvm::sys::SmartScopedLock lock(impl.dialectRefStrAttrMutex); impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this); } /// Return an empty string. StringAttr StringAttr::get(MLIRContext *context) { return context->getImpl().emptyStringAttr; } //===----------------------------------------------------------------------===// // AffineMap uniquing //===----------------------------------------------------------------------===// StorageUniquer &MLIRContext::getAffineUniquer() { return getImpl().affineUniquer; } AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, ArrayRef results, MLIRContext *context) { auto &impl = context->getImpl(); auto *storage = impl.affineUniquer.get( [&](AffineMapStorage *storage) { storage->context = context; }, dimCount, symbolCount, results); return AffineMap(storage); } /// Check whether the arguments passed to the AffineMap::get() are consistent. /// This method checks whether the highest index of dimensional identifier /// present in result expressions is less than `dimCount` and the highest index /// of symbolic identifier present in result expressions is less than /// `symbolCount`. LLVM_ATTRIBUTE_UNUSED static bool willBeValidAffineMap(unsigned dimCount, unsigned symbolCount, ArrayRef results) { int64_t maxDimPosition = -1; int64_t maxSymbolPosition = -1; getMaxDimAndSymbol(ArrayRef>(results), maxDimPosition, maxSymbolPosition); if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) { LLVM_DEBUG( llvm::dbgs() << "maximum dimensional identifier position in result expression must " "be less than `dimCount` and maximum symbolic identifier position " "in result expression must be less than `symbolCount`\n"); return false; } return true; } AffineMap AffineMap::get(MLIRContext *context) { return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); } AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, MLIRContext *context) { return getImpl(dimCount, symbolCount, /*results=*/{}, context); } AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, AffineExpr result) { assert(willBeValidAffineMap(dimCount, symbolCount, {result})); return getImpl(dimCount, symbolCount, {result}, result.getContext()); } AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, ArrayRef results, MLIRContext *context) { assert(willBeValidAffineMap(dimCount, symbolCount, results)); return getImpl(dimCount, symbolCount, results, context); } //===----------------------------------------------------------------------===// // Integer Sets: these are allocated into the bump pointer, and are immutable. // Unlike AffineMap's, these are uniqued only if they are small. //===----------------------------------------------------------------------===// IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, ArrayRef constraints, ArrayRef eqFlags) { // The number of constraints can't be zero. assert(!constraints.empty()); assert(constraints.size() == eqFlags.size()); auto &impl = constraints[0].getContext()->getImpl(); auto *storage = impl.affineUniquer.get( [](IntegerSetStorage *) {}, dimCount, symbolCount, constraints, eqFlags); return IntegerSet(storage); } //===----------------------------------------------------------------------===// // StorageUniquerSupport //===----------------------------------------------------------------------===// /// Utility method to generate a callback that can be used to generate a /// diagnostic when checking the construction invariants of a storage object. /// This is defined out-of-line to avoid the need to include Location.h. llvm::unique_function mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) { return [ctx] { return emitError(UnknownLoc::get(ctx)); }; } llvm::unique_function mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) { return [=] { return emitError(loc); }; }