mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 02:16:05 +00:00
Revert "Added free-threading CPython mode support in MLIR Python bindings (#107103)"
Breaks on 3.8, rolling back to avoid breakage while fixing. This reverts commit 9dee7c44491635ec9037b90050bcdbd3d5291e38.
This commit is contained in:
parent
1d2eea962a
commit
3f1486f08e
@ -668,31 +668,12 @@ function(add_mlir_python_extension libname extname)
|
||||
elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
|
||||
nanobind_add_module(${libname}
|
||||
NB_DOMAIN mlir
|
||||
FREE_THREADED
|
||||
${ARG_SOURCES}
|
||||
)
|
||||
|
||||
if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
|
||||
# Avoids warnings from upstream nanobind.
|
||||
set(nanobind_target "nanobind-static")
|
||||
if (NOT TARGET ${nanobind_target})
|
||||
# Get correct nanobind target name: nanobind-static-ft or something else
|
||||
# It is set by nanobind_add_module function according to the passed options
|
||||
get_property(all_targets DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY BUILDSYSTEM_TARGETS)
|
||||
|
||||
# Iterate over the list of targets
|
||||
foreach(target ${all_targets})
|
||||
# Check if the target name matches the given string
|
||||
if("${target}" MATCHES "nanobind-")
|
||||
set(nanobind_target "${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if (NOT TARGET ${nanobind_target})
|
||||
message(FATAL_ERROR "Could not find nanobind target to set compile options to")
|
||||
endif()
|
||||
endif()
|
||||
target_compile_options(${nanobind_target}
|
||||
target_compile_options(nanobind-static
|
||||
PRIVATE
|
||||
-Wno-cast-qual
|
||||
-Wno-zero-length-array
|
||||
|
@ -1187,43 +1187,3 @@ or nanobind and
|
||||
utilities to connect to the rest of Python API. The bindings can be located in a
|
||||
separate module or in the same module as attributes and types, and
|
||||
loaded along with the dialect.
|
||||
|
||||
## Free-threading (No-GIL) support
|
||||
|
||||
Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and this [Python free-threading guide](https://py-free-threading.github.io/).
|
||||
|
||||
MLIR Python bindings are free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts. Below we show an example code of safe usage:
|
||||
|
||||
```python
|
||||
# python3.13t example.py
|
||||
import concurrent.futures
|
||||
|
||||
import mlir.dialects.arith as arith
|
||||
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
|
||||
|
||||
|
||||
def func(py_value):
|
||||
with Context() as ctx:
|
||||
module = Module.create(loc=Location.file("foo.txt", 0, 0))
|
||||
|
||||
dtype = IntegerType.get_signless(64)
|
||||
with InsertionPoint(module.body), Location.name("a"):
|
||||
arith.constant(dtype, py_value)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
num_workers = 8
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = []
|
||||
for i in range(num_workers):
|
||||
futures.append(executor.submit(func, i))
|
||||
assert len(list(f.result() for f in futures)) == num_workers
|
||||
```
|
||||
|
||||
The exceptions to the free-threading compatibility:
|
||||
- IR printing is unsafe, e.g. when using `PassManager` with `PassManager.enable_ir_printing()` which calls thread-unsafe `llvm::raw_ostream`.
|
||||
- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
|
||||
- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
|
||||
- Usage of `mlir.dialects.transform.interpreter` is unsafe.
|
||||
- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
|
@ -24,7 +24,6 @@ namespace mlir {
|
||||
namespace python {
|
||||
|
||||
/// Globals that are always accessible once the extension has been initialized.
|
||||
/// Methods of this class are thread-safe.
|
||||
class PyGlobals {
|
||||
public:
|
||||
PyGlobals();
|
||||
@ -38,18 +37,12 @@ public:
|
||||
|
||||
/// Get and set the list of parent modules to search for dialect
|
||||
/// implementation classes.
|
||||
std::vector<std::string> getDialectSearchPrefixes() {
|
||||
nanobind::ft_lock_guard lock(mutex);
|
||||
std::vector<std::string> &getDialectSearchPrefixes() {
|
||||
return dialectSearchPrefixes;
|
||||
}
|
||||
void setDialectSearchPrefixes(std::vector<std::string> newValues) {
|
||||
nanobind::ft_lock_guard lock(mutex);
|
||||
dialectSearchPrefixes.swap(newValues);
|
||||
}
|
||||
void addDialectSearchPrefix(std::string value) {
|
||||
nanobind::ft_lock_guard lock(mutex);
|
||||
dialectSearchPrefixes.push_back(std::move(value));
|
||||
}
|
||||
|
||||
/// Loads a python module corresponding to the given dialect namespace.
|
||||
/// No-ops if the module has already been loaded or is not found. Raises
|
||||
@ -116,9 +109,6 @@ public:
|
||||
|
||||
private:
|
||||
static PyGlobals *instance;
|
||||
|
||||
nanobind::ft_mutex mutex;
|
||||
|
||||
/// Module name prefixes to search under for dialect implementation modules.
|
||||
std::vector<std::string> dialectSearchPrefixes;
|
||||
/// Map of dialect namespace to external dialect class object.
|
||||
|
@ -243,15 +243,9 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,
|
||||
|
||||
/// Wrapper for the global LLVM debugging flag.
|
||||
struct PyGlobalDebugFlag {
|
||||
static void set(nb::object &o, bool enable) {
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
mlirEnableGlobalDebug(enable);
|
||||
}
|
||||
static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
|
||||
|
||||
static bool get(const nb::object &) {
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
return mlirIsGlobalDebugEnabled();
|
||||
}
|
||||
static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
|
||||
|
||||
static void bind(nb::module_ &m) {
|
||||
// Debug flags.
|
||||
@ -261,7 +255,6 @@ struct PyGlobalDebugFlag {
|
||||
.def_static(
|
||||
"set_types",
|
||||
[](const std::string &type) {
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
mlirSetGlobalDebugType(type.c_str());
|
||||
},
|
||||
"types"_a, "Sets specific debug types to be produced by LLVM")
|
||||
@ -270,17 +263,11 @@ struct PyGlobalDebugFlag {
|
||||
pointers.reserve(types.size());
|
||||
for (const std::string &str : types)
|
||||
pointers.push_back(str.c_str());
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
static nb::ft_mutex mutex;
|
||||
};
|
||||
|
||||
nb::ft_mutex PyGlobalDebugFlag::mutex;
|
||||
|
||||
struct PyAttrBuilderMap {
|
||||
static bool dunderContains(const std::string &attributeKind) {
|
||||
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
|
||||
@ -619,7 +606,6 @@ private:
|
||||
|
||||
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
|
||||
nb::gil_scoped_acquire acquire;
|
||||
nb::ft_lock_guard lock(live_contexts_mutex);
|
||||
auto &liveContexts = getLiveContexts();
|
||||
liveContexts[context.ptr] = this;
|
||||
}
|
||||
@ -629,10 +615,7 @@ PyMlirContext::~PyMlirContext() {
|
||||
// forContext method, which always puts the associated handle into
|
||||
// liveContexts.
|
||||
nb::gil_scoped_acquire acquire;
|
||||
{
|
||||
nb::ft_lock_guard lock(live_contexts_mutex);
|
||||
getLiveContexts().erase(context.ptr);
|
||||
}
|
||||
getLiveContexts().erase(context.ptr);
|
||||
mlirContextDestroy(context);
|
||||
}
|
||||
|
||||
@ -649,7 +632,6 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
|
||||
|
||||
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
|
||||
nb::gil_scoped_acquire acquire;
|
||||
nb::ft_lock_guard lock(live_contexts_mutex);
|
||||
auto &liveContexts = getLiveContexts();
|
||||
auto it = liveContexts.find(context.ptr);
|
||||
if (it == liveContexts.end()) {
|
||||
@ -665,17 +647,12 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
|
||||
return PyMlirContextRef(it->second, std::move(pyRef));
|
||||
}
|
||||
|
||||
nb::ft_mutex PyMlirContext::live_contexts_mutex;
|
||||
|
||||
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
|
||||
static LiveContextMap liveContexts;
|
||||
return liveContexts;
|
||||
}
|
||||
|
||||
size_t PyMlirContext::getLiveCount() {
|
||||
nb::ft_lock_guard lock(live_contexts_mutex);
|
||||
return getLiveContexts().size();
|
||||
}
|
||||
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
|
||||
|
||||
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
|
||||
|
||||
|
@ -38,11 +38,8 @@ PyGlobals::PyGlobals() {
|
||||
PyGlobals::~PyGlobals() { instance = nullptr; }
|
||||
|
||||
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
|
||||
{
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
if (loadedDialectModules.contains(dialectNamespace))
|
||||
return true;
|
||||
}
|
||||
if (loadedDialectModules.contains(dialectNamespace))
|
||||
return true;
|
||||
// Since re-entrancy is possible, make a copy of the search prefixes.
|
||||
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
|
||||
nb::object loaded = nb::none();
|
||||
@ -65,14 +62,12 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
|
||||
return false;
|
||||
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
|
||||
// may have occurred, which may do anything.
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
loadedDialectModules.insert(dialectNamespace);
|
||||
return true;
|
||||
}
|
||||
|
||||
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
|
||||
nb::callable pyFunc, bool replace) {
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
nb::object &found = attributeBuilderMap[attributeKind];
|
||||
if (found && !replace) {
|
||||
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
|
||||
@ -86,7 +81,6 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
|
||||
|
||||
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
|
||||
nb::callable typeCaster, bool replace) {
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
nb::object &found = typeCasterMap[mlirTypeID];
|
||||
if (found && !replace)
|
||||
throw std::runtime_error("Type caster is already registered with caster: " +
|
||||
@ -96,7 +90,6 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
|
||||
|
||||
void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
|
||||
nb::callable valueCaster, bool replace) {
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
nb::object &found = valueCasterMap[mlirTypeID];
|
||||
if (found && !replace)
|
||||
throw std::runtime_error("Value caster is already registered: " +
|
||||
@ -106,7 +99,6 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
|
||||
|
||||
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
|
||||
nb::object pyClass) {
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
nb::object &found = dialectClassMap[dialectNamespace];
|
||||
if (found) {
|
||||
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
|
||||
@ -118,7 +110,6 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
|
||||
|
||||
void PyGlobals::registerOperationImpl(const std::string &operationName,
|
||||
nb::object pyClass, bool replace) {
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
nb::object &found = operationClassMap[operationName];
|
||||
if (found && !replace) {
|
||||
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
|
||||
@ -130,7 +121,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
|
||||
|
||||
std::optional<nb::callable>
|
||||
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
const auto foundIt = attributeBuilderMap.find(attributeKind);
|
||||
if (foundIt != attributeBuilderMap.end()) {
|
||||
assert(foundIt->second && "attribute builder is defined");
|
||||
@ -143,7 +133,6 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
|
||||
MlirDialect dialect) {
|
||||
// Try to load dialect module.
|
||||
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
const auto foundIt = typeCasterMap.find(mlirTypeID);
|
||||
if (foundIt != typeCasterMap.end()) {
|
||||
assert(foundIt->second && "type caster is defined");
|
||||
@ -156,7 +145,6 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
|
||||
MlirDialect dialect) {
|
||||
// Try to load dialect module.
|
||||
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
const auto foundIt = valueCasterMap.find(mlirTypeID);
|
||||
if (foundIt != valueCasterMap.end()) {
|
||||
assert(foundIt->second && "value caster is defined");
|
||||
@ -170,7 +158,6 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
|
||||
// Make sure dialect module is loaded.
|
||||
if (!loadDialectModule(dialectNamespace))
|
||||
return std::nullopt;
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
const auto foundIt = dialectClassMap.find(dialectNamespace);
|
||||
if (foundIt != dialectClassMap.end()) {
|
||||
assert(foundIt->second && "dialect class is defined");
|
||||
@ -188,7 +175,6 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
|
||||
if (!loadDialectModule(dialectNamespace))
|
||||
return std::nullopt;
|
||||
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
auto foundIt = operationClassMap.find(operationName);
|
||||
if (foundIt != operationClassMap.end()) {
|
||||
assert(foundIt->second && "OpView is defined");
|
||||
|
@ -260,7 +260,6 @@ private:
|
||||
// Note that this holds a handle, which does not imply ownership.
|
||||
// Mappings will be removed when the context is destructed.
|
||||
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
|
||||
static nanobind::ft_mutex live_contexts_mutex;
|
||||
static LiveContextMap &getLiveContexts();
|
||||
|
||||
// Interns all live modules associated with this context. Modules tracked
|
||||
|
@ -30,8 +30,12 @@ NB_MODULE(_mlir, m) {
|
||||
.def_prop_rw("dialect_search_modules",
|
||||
&PyGlobals::getDialectSearchPrefixes,
|
||||
&PyGlobals::setDialectSearchPrefixes)
|
||||
.def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
|
||||
"module_name"_a)
|
||||
.def(
|
||||
"append_dialect_search_prefix",
|
||||
[](PyGlobals &self, std::string moduleName) {
|
||||
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
|
||||
},
|
||||
"module_name"_a)
|
||||
.def(
|
||||
"_check_dialect_module_loaded",
|
||||
[](PyGlobals &self, const std::string &dialectNamespace) {
|
||||
@ -72,6 +76,7 @@ NB_MODULE(_mlir, m) {
|
||||
nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
|
||||
PyGlobals::get().registerOperationImpl(operationName, opClass,
|
||||
replace);
|
||||
|
||||
// Dict-stuff the new opClass by name onto the dialect class.
|
||||
nb::object opClassName = opClass.attr("__name__");
|
||||
dialectClass.attr(opClassName) = opClass;
|
||||
|
@ -2,4 +2,4 @@ nanobind>=2.4, <3.0
|
||||
numpy>=1.19.5, <=2.1.2
|
||||
pybind11>=2.10.0, <=2.13.6
|
||||
PyYAML>=5.4.0, <=6.0.1
|
||||
ml_dtypes>=0.5.0, <=0.6.0 # provides several NumPy dtype extensions, including the bf16
|
||||
ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16
|
||||
|
@ -1,531 +0,0 @@
|
||||
# RUN: %PYTHON %s
|
||||
"""
|
||||
This script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN.
|
||||
Tests can be run using pytest:
|
||||
```bash
|
||||
python3.13t -mpytest -vvv multithreaded_tests.py
|
||||
```
|
||||
|
||||
IMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context
|
||||
and passing if no warnings reported by TSAN and failing otherwise.
|
||||
|
||||
|
||||
Details on the generated tests and execution:
|
||||
1) Multi-threaded execution: all generated tests are executed independently by
|
||||
a pool of threads, running each test multiple times, see @multi_threaded for details
|
||||
|
||||
2) Tests generation: we use existing tests: test/python/ir/*.py,
|
||||
test/python/dialects/*.py, etc to generate multi-threaded tests.
|
||||
In details, we perform the following:
|
||||
a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`.
|
||||
b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method.
|
||||
c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py.
|
||||
In order to import the test file as python module, we remove all executing functions, like
|
||||
`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details.
|
||||
|
||||
|
||||
Observed warnings reported by TSAN.
|
||||
|
||||
CPython and free-threading known data-races:
|
||||
1) ctypes related races: https://github.com/python/cpython/issues/127945
|
||||
2) LLVM related data-races, llvm::raw_ostream is not thread-safe
|
||||
- mlir pass manager
|
||||
- dialects/transform_interpreter.py
|
||||
- ir/diagnostic_handler.py
|
||||
- ir/module.py
|
||||
3) Dialect gpu module-to-binary method is unsafe
|
||||
"""
|
||||
import concurrent.futures
|
||||
import gc
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlir.dialects.arith as arith
|
||||
from mlir.dialects import transform
|
||||
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
|
||||
|
||||
|
||||
def import_from_path(module_name: str, file_path: Path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def copy_and_update(src_filepath: Path, dst_filepath: Path):
|
||||
# We should remove all calls like `run(testMethod)`
|
||||
with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer:
|
||||
while True:
|
||||
src_line = reader.readline()
|
||||
if len(src_line) == 0:
|
||||
break
|
||||
skip_lines = [
|
||||
"run(",
|
||||
"@run",
|
||||
"@constructAndPrintInModule",
|
||||
"run_apply_patterns(",
|
||||
"@run_apply_patterns",
|
||||
"@test_in_context",
|
||||
"@construct_and_print_in_module",
|
||||
]
|
||||
if any(src_line.startswith(line) for line in skip_lines):
|
||||
continue
|
||||
writer.write(src_line)
|
||||
|
||||
|
||||
# Helper run functions
|
||||
# They are copied from the test modules (e.g. run function in execution_engine.py)
|
||||
def run(test_function):
|
||||
# Generic run tests function used by dialects and ir test modules
|
||||
test_function()
|
||||
|
||||
|
||||
def run_with_context_and_location(test_function):
|
||||
# run tests function with a context and a location
|
||||
# used by the following test modules:
|
||||
# - dialects/transform_gpu_ext,
|
||||
# - dialects/vector
|
||||
# - dialects/gpu/*
|
||||
with Context(), Location.unknown():
|
||||
test_function()
|
||||
return test_function
|
||||
|
||||
|
||||
def run_with_insertion_point_and_context_arg(test_function):
|
||||
# run tests function used by dialects/index_dialect test module
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
test_function(ctx)
|
||||
|
||||
|
||||
def run_with_insertion_point(test_function):
|
||||
# Used by a lot of dialects test modules
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
test_function()
|
||||
return test_function
|
||||
|
||||
|
||||
def run_with_insertion_point_and_module_arg(test_function):
|
||||
# Used by dialects/transform test module
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
test_function(module)
|
||||
return test_function
|
||||
|
||||
|
||||
def run_with_insertion_point_all_unreg_dialects(test_function):
|
||||
# Used by dialects/cf test module
|
||||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
test_function()
|
||||
return test_function
|
||||
|
||||
|
||||
def run_apply_patterns(test_function):
|
||||
# Used by dialects/transform_tensor_ext test module
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
apply = transform.ApplyPatternsOp(sequence.bodyTarget)
|
||||
with InsertionPoint(apply.patterns):
|
||||
test_function()
|
||||
transform.YieldOp()
|
||||
print(module)
|
||||
return test_function
|
||||
|
||||
|
||||
def run_transform_tensor_ext(test_function):
|
||||
# Used by test modules:
|
||||
# - dialects/transform_gpu_ext
|
||||
# - dialects/transform_sparse_tensor_ext
|
||||
# - dialects/transform_tensor_ext
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
test_function(sequence.bodyTarget)
|
||||
transform.YieldOp()
|
||||
print(module)
|
||||
return test_function
|
||||
|
||||
|
||||
def run_transform_structured_ext(test_function):
|
||||
# Used by dialects/transform_structured_ext test module
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
test_function()
|
||||
module.operation.verify()
|
||||
print(module)
|
||||
return test_function
|
||||
|
||||
|
||||
def run_construct_and_print_in_module(test_function):
|
||||
# Used by test modules:
|
||||
# - integration/dialects/pdl
|
||||
# - integration/dialects/transform
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
module = test_function(module)
|
||||
if module is not None:
|
||||
print(module)
|
||||
return test_function
|
||||
|
||||
|
||||
TEST_MODULES = [
|
||||
("execution_engine", run),
|
||||
("pass_manager", run),
|
||||
("dialects/affine", run_with_insertion_point),
|
||||
("dialects/func", run_with_insertion_point),
|
||||
("dialects/arith_dialect", run),
|
||||
("dialects/arith_llvm", run),
|
||||
("dialects/async_dialect", run),
|
||||
("dialects/builtin", run),
|
||||
("dialects/cf", run_with_insertion_point_all_unreg_dialects),
|
||||
("dialects/complex_dialect", run),
|
||||
("dialects/func", run_with_insertion_point),
|
||||
("dialects/index_dialect", run_with_insertion_point_and_context_arg),
|
||||
("dialects/llvm", run_with_insertion_point),
|
||||
("dialects/math_dialect", run),
|
||||
("dialects/memref", run),
|
||||
("dialects/ml_program", run_with_insertion_point),
|
||||
("dialects/nvgpu", run_with_insertion_point),
|
||||
("dialects/nvvm", run_with_insertion_point),
|
||||
("dialects/ods_helpers", run),
|
||||
("dialects/openmp_ops", run_with_insertion_point),
|
||||
("dialects/pdl_ops", run_with_insertion_point),
|
||||
# ("dialects/python_test", run), # TODO: Need to pass pybind11 or nanobind argv
|
||||
("dialects/quant", run),
|
||||
("dialects/rocdl", run_with_insertion_point),
|
||||
("dialects/scf", run_with_insertion_point),
|
||||
("dialects/shape", run),
|
||||
("dialects/spirv_dialect", run),
|
||||
("dialects/tensor", run),
|
||||
# ("dialects/tosa", ), # Nothing to test
|
||||
("dialects/transform_bufferization_ext", run_with_insertion_point),
|
||||
# ("dialects/transform_extras", ), # Needs a more complicated execution schema
|
||||
("dialects/transform_gpu_ext", run_transform_tensor_ext),
|
||||
(
|
||||
"dialects/transform_interpreter",
|
||||
run_with_context_and_location,
|
||||
["print_", "transform_options", "failed", "include"],
|
||||
),
|
||||
(
|
||||
"dialects/transform_loop_ext",
|
||||
run_with_insertion_point,
|
||||
["loopOutline"],
|
||||
),
|
||||
("dialects/transform_memref_ext", run_with_insertion_point),
|
||||
("dialects/transform_nvgpu_ext", run_with_insertion_point),
|
||||
("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext),
|
||||
("dialects/transform_structured_ext", run_transform_structured_ext),
|
||||
("dialects/transform_tensor_ext", run_transform_tensor_ext),
|
||||
(
|
||||
"dialects/transform_vector_ext",
|
||||
run_apply_patterns,
|
||||
["configurable_patterns"],
|
||||
),
|
||||
("dialects/transform", run_with_insertion_point_and_module_arg),
|
||||
("dialects/vector", run_with_context_and_location),
|
||||
("dialects/gpu/dialect", run_with_context_and_location),
|
||||
("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location),
|
||||
("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location),
|
||||
("dialects/linalg/ops", run),
|
||||
# TO ADD: No proper tests in this dialects/linalg/opsdsl/*
|
||||
# ("dialects/linalg/opsdsl/*", ...),
|
||||
("dialects/sparse_tensor/dialect", run),
|
||||
("dialects/sparse_tensor/passes", run),
|
||||
("integration/dialects/pdl", run_construct_and_print_in_module),
|
||||
("integration/dialects/transform", run_construct_and_print_in_module),
|
||||
("integration/dialects/linalg/opsrun", run),
|
||||
("ir/affine_expr", run),
|
||||
("ir/affine_map", run),
|
||||
("ir/array_attributes", run),
|
||||
("ir/attributes", run),
|
||||
("ir/blocks", run),
|
||||
("ir/builtin_types", run),
|
||||
("ir/context_managers", run),
|
||||
("ir/debug", run),
|
||||
("ir/diagnostic_handler", run),
|
||||
("ir/dialects", run),
|
||||
("ir/exception", run),
|
||||
("ir/insertion_point", run),
|
||||
("ir/integer_set", run),
|
||||
("ir/location", run),
|
||||
("ir/module", run),
|
||||
("ir/operation", run),
|
||||
("ir/symbol_table", run),
|
||||
("ir/value", run),
|
||||
]
|
||||
|
||||
TESTS_TO_SKIP = [
|
||||
"test_execution_engine__testNanoTime_multi_threaded", # testNanoTime can't run in multiple threads, even with GIL
|
||||
"test_execution_engine__testSharedLibLoad_multi_threaded", # testSharedLibLoad can't run in multiple threads, even with GIL
|
||||
"test_dialects_arith_dialect__testArithValue_multi_threaded", # RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>, even with GIL
|
||||
"test_ir_dialects__testAppendPrefixSearchPath_multi_threaded", # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals
|
||||
"test_ir_value__testValueCasters_multi_threaded", # RuntimeError: Value caster is already registered: <function testValueCasters.<locals>.dont_cast_int, even with GIL
|
||||
# tests indirectly calling thread-unsafe llvm::raw_ostream
|
||||
"test_execution_engine__testInvalidModule_multi_threaded", # mlirExecutionEngineCreate calls thread-unsafe llvm::raw_ostream
|
||||
"test_pass_manager__testPrintIrAfterAll_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
|
||||
"test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded", # IRPrinterInstrumentation::runBeforePass calls thread-unsafe llvm::raw_ostream
|
||||
"test_pass_manager__testPrintIrLargeLimitElements_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
|
||||
"test_pass_manager__testPrintIrTree_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
|
||||
"test_pass_manager__testRunPipeline_multi_threaded", # PrintOpStatsPass::printSummary calls thread-unsafe llvm::raw_ostream
|
||||
"test_dialects_transform_interpreter__include_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
|
||||
"test_dialects_transform_interpreter__transform_options_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
|
||||
"test_dialects_transform_interpreter__print_self_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) call thread-unsafe llvm::raw_ostream
|
||||
"test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded", # mlirEmitError calls thread-unsafe llvm::raw_ostream
|
||||
"test_ir_module__testParseSuccess_multi_threaded", # mlirOperationDump calls thread-unsafe llvm::raw_ostream
|
||||
# False-positive TSAN detected race in llvm::RuntimeDyldELF::registerEHFrames()
|
||||
# Details: https://github.com/llvm/llvm-project/pull/107103/files#r1905726947
|
||||
"test_execution_engine__testCapsule_multi_threaded",
|
||||
"test_execution_engine__testDumpToObjectFile_multi_threaded",
|
||||
]
|
||||
|
||||
TESTS_TO_XFAIL = [
|
||||
# execution_engine tests:
|
||||
# - ctypes related data-races: https://github.com/python/cpython/issues/127945
|
||||
"test_execution_engine__testBF16Memref_multi_threaded",
|
||||
"test_execution_engine__testBasicCallback_multi_threaded",
|
||||
"test_execution_engine__testComplexMemrefAdd_multi_threaded",
|
||||
"test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded",
|
||||
"test_execution_engine__testDynamicMemrefAdd2D_multi_threaded",
|
||||
"test_execution_engine__testF16MemrefAdd_multi_threaded",
|
||||
"test_execution_engine__testF8E5M2Memref_multi_threaded",
|
||||
"test_execution_engine__testInvokeFloatAdd_multi_threaded",
|
||||
"test_execution_engine__testInvokeVoid_multi_threaded", # a ctypes race
|
||||
"test_execution_engine__testMemrefAdd_multi_threaded",
|
||||
"test_execution_engine__testRankedMemRefCallback_multi_threaded",
|
||||
"test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded",
|
||||
"test_execution_engine__testUnrankedMemRefCallback_multi_threaded",
|
||||
"test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded",
|
||||
# dialects tests
|
||||
"test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded", # Related to ctypes data races
|
||||
"test_dialects_transform_interpreter__print_other_multi_threaded", # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe
|
||||
"test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded", # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation
|
||||
"test_dialects_gpu_module-to-binary-nvvm__testGPUToASMBin_multi_threaded",
|
||||
"test_dialects_gpu_module-to-binary-nvvm__testGPUToLLVMBin_multi_threaded",
|
||||
"test_dialects_gpu_module-to-binary-rocdl__testGPUToLLVMBin_multi_threaded",
|
||||
# integration tests
|
||||
"test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded", # Related to ctypes data races
|
||||
"test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded", # Related to ctypes data races
|
||||
"test_integration_dialects_linalg_opsrun__test_fill_builtin_multi_threaded", # ctypes
|
||||
"test_integration_dialects_linalg_opsrun__test_fill_generic_multi_threaded", # ctypes
|
||||
"test_integration_dialects_linalg_opsrun__test_fill_rng_builtin_multi_threaded", # ctypes
|
||||
"test_integration_dialects_linalg_opsrun__test_fill_rng_generic_multi_threaded", # ctypes
|
||||
"test_integration_dialects_linalg_opsrun__test_max_pooling_builtin_multi_threaded", # ctypes
|
||||
"test_integration_dialects_linalg_opsrun__test_max_pooling_generic_multi_threaded", # ctypes
|
||||
"test_integration_dialects_linalg_opsrun__test_min_pooling_builtin_multi_threaded", # ctypes
|
||||
"test_integration_dialects_linalg_opsrun__test_min_pooling_generic_multi_threaded", # ctypes
|
||||
]
|
||||
|
||||
|
||||
def add_existing_tests(test_modules, test_prefix: str = "_original_test"):
|
||||
def decorator(test_cls):
|
||||
this_folder = Path(__file__).parent.absolute()
|
||||
test_cls.output_folder = tempfile.TemporaryDirectory()
|
||||
output_folder = Path(test_cls.output_folder.name)
|
||||
|
||||
for test_mod_info in test_modules:
|
||||
# test_mod_info is a tuple of size 2 or 3:
|
||||
# (test_module_str, run_test_function) or (test_module_str, run_test_function, test_name_patterns_list)
|
||||
# For example:
|
||||
# - ("ir/value", run) or
|
||||
# - ("dialects/transform_loop_ext", run_with_insertion_point, ["loopOutline"])
|
||||
assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3)
|
||||
if len(test_mod_info) == 2:
|
||||
test_module_name, exec_fn = test_mod_info
|
||||
test_pattern = None
|
||||
else:
|
||||
test_module_name, exec_fn, test_pattern = test_mod_info
|
||||
|
||||
src_filepath = this_folder / f"{test_module_name}.py"
|
||||
dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
|
||||
if not dst_filepath.parent.exists():
|
||||
dst_filepath.parent.mkdir(parents=True)
|
||||
copy_and_update(src_filepath, dst_filepath)
|
||||
test_mod = import_from_path(test_module_name, dst_filepath)
|
||||
for attr_name in dir(test_mod):
|
||||
is_test_fn = test_pattern is None and attr_name.startswith("test")
|
||||
is_test_fn |= test_pattern is not None and any(
|
||||
[p in attr_name for p in test_pattern]
|
||||
)
|
||||
if is_test_fn:
|
||||
obj = getattr(test_mod, attr_name)
|
||||
if callable(obj):
|
||||
test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
|
||||
|
||||
def wrapped_test_fn(
|
||||
self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs
|
||||
):
|
||||
__exec_fn__(__test_fn__)
|
||||
|
||||
setattr(test_cls, test_name, wrapped_test_fn)
|
||||
return test_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _capture_output(fp):
|
||||
# Inspired from jax test_utils.py capture_stderr method
|
||||
# ``None`` means nothing has not been captured yet.
|
||||
captured = None
|
||||
|
||||
def get_output() -> str:
|
||||
if captured is None:
|
||||
raise ValueError("get_output() called while the context is active.")
|
||||
return captured
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f:
|
||||
original_fd = os.dup(fp.fileno())
|
||||
os.dup2(f.fileno(), fp.fileno())
|
||||
try:
|
||||
yield get_output
|
||||
finally:
|
||||
# Python also has its own buffers, make sure everything is flushed.
|
||||
fp.flush()
|
||||
os.fsync(fp.fileno())
|
||||
f.seek(0)
|
||||
captured = f.read()
|
||||
os.dup2(original_fd, fp.fileno())
|
||||
|
||||
|
||||
capture_stdout = partial(_capture_output, sys.stdout)
|
||||
capture_stderr = partial(_capture_output, sys.stderr)
|
||||
|
||||
|
||||
def multi_threaded(
|
||||
num_workers: int,
|
||||
num_runs: int = 5,
|
||||
skip_tests: Optional[list[str]] = None,
|
||||
xfail_tests: Optional[list[str]] = None,
|
||||
test_prefix: str = "_original_test",
|
||||
multithreaded_test_postfix: str = "_multi_threaded",
|
||||
):
|
||||
"""Decorator that runs a test in a multi-threaded environment."""
|
||||
|
||||
def decorator(test_cls):
|
||||
for name, test_fn in test_cls.__dict__.copy().items():
|
||||
if not (name.startswith(test_prefix) and callable(test_fn)):
|
||||
continue
|
||||
|
||||
name = f"test{name[len(test_prefix):]}"
|
||||
if skip_tests is not None:
|
||||
if any(
|
||||
test_name.replace(multithreaded_test_postfix, "") in name
|
||||
for test_name in skip_tests
|
||||
):
|
||||
continue
|
||||
|
||||
def multi_threaded_test_fn(self, *args, __test_fn__=test_fn, **kwargs):
|
||||
with capture_stdout(), capture_stderr() as get_output:
|
||||
barrier = threading.Barrier(num_workers)
|
||||
|
||||
def closure():
|
||||
barrier.wait()
|
||||
for _ in range(num_runs):
|
||||
__test_fn__(self, *args, **kwargs)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=num_workers
|
||||
) as executor:
|
||||
futures = []
|
||||
for _ in range(num_workers):
|
||||
futures.append(executor.submit(closure))
|
||||
# We should call future.result() to re-raise an exception if test has
|
||||
# failed
|
||||
assert len(list(f.result() for f in futures)) == num_workers
|
||||
|
||||
gc.collect()
|
||||
assert Context._get_live_count() == 0
|
||||
|
||||
captured = get_output()
|
||||
if len(captured) > 0 and "ThreadSanitizer" in captured:
|
||||
raise RuntimeError(
|
||||
f"ThreadSanitizer reported warnings:\n{captured}"
|
||||
)
|
||||
|
||||
test_new_name = f"{name}{multithreaded_test_postfix}"
|
||||
if xfail_tests is not None and test_new_name in xfail_tests:
|
||||
multi_threaded_test_fn = unittest.expectedFailure(
|
||||
multi_threaded_test_fn
|
||||
)
|
||||
|
||||
setattr(test_cls, test_new_name, multi_threaded_test_fn)
|
||||
|
||||
return test_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@multi_threaded(
|
||||
num_workers=10,
|
||||
num_runs=20,
|
||||
skip_tests=TESTS_TO_SKIP,
|
||||
xfail_tests=TESTS_TO_XFAIL,
|
||||
)
|
||||
@add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test")
|
||||
class TestAllMultiThreaded(unittest.TestCase):
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if hasattr(cls, "output_folder"):
|
||||
cls.output_folder.cleanup()
|
||||
|
||||
def _original_test_create_context(self):
|
||||
with Context() as ctx:
|
||||
print(ctx._get_live_count())
|
||||
print(ctx._get_live_module_count())
|
||||
print(ctx._get_live_operation_count())
|
||||
print(ctx._get_live_operation_objects())
|
||||
print(ctx._get_context_again() is ctx)
|
||||
print(ctx._clear_live_operations())
|
||||
|
||||
def _original_test_create_module_with_consts(self):
|
||||
py_values = [123, 234, 345]
|
||||
with Context() as ctx:
|
||||
module = Module.create(loc=Location.file("foo.txt", 0, 0))
|
||||
|
||||
dtype = IntegerType.get_signless(64)
|
||||
with InsertionPoint(module.body), Location.name("a"):
|
||||
arith.constant(dtype, py_values[0])
|
||||
|
||||
with InsertionPoint(module.body), Location.name("b"):
|
||||
arith.constant(dtype, py_values[1])
|
||||
|
||||
with InsertionPoint(module.body), Location.name("c"):
|
||||
arith.constant(dtype, py_values[2])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Do not run the tests on CPython with GIL
|
||||
if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled():
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user