Revert "Added free-threading CPython mode support in MLIR Python bindings (#107103)"

Breaks on 3.8, rolling back to avoid breakage while fixing.

This reverts commit 9dee7c44491635ec9037b90050bcdbd3d5291e38.
This commit is contained in:
Jacques Pienaar 2025-01-12 18:30:42 +00:00
parent 1d2eea962a
commit 3f1486f08e
9 changed files with 16 additions and 649 deletions

View File

@ -668,31 +668,12 @@ function(add_mlir_python_extension libname extname)
elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
nanobind_add_module(${libname}
NB_DOMAIN mlir
FREE_THREADED
${ARG_SOURCES}
)
if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
# Avoids warnings from upstream nanobind.
set(nanobind_target "nanobind-static")
if (NOT TARGET ${nanobind_target})
# Get correct nanobind target name: nanobind-static-ft or something else
# It is set by nanobind_add_module function according to the passed options
get_property(all_targets DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY BUILDSYSTEM_TARGETS)
# Iterate over the list of targets
foreach(target ${all_targets})
# Check if the target name matches the given string
if("${target}" MATCHES "nanobind-")
set(nanobind_target "${target}")
endif()
endforeach()
if (NOT TARGET ${nanobind_target})
message(FATAL_ERROR "Could not find nanobind target to set compile options to")
endif()
endif()
target_compile_options(${nanobind_target}
target_compile_options(nanobind-static
PRIVATE
-Wno-cast-qual
-Wno-zero-length-array

View File

@ -1187,43 +1187,3 @@ or nanobind and
utilities to connect to the rest of Python API. The bindings can be located in a
separate module or in the same module as attributes and types, and
loaded along with the dialect.
## Free-threading (No-GIL) support
Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and this [Python free-threading guide](https://py-free-threading.github.io/).
MLIR Python bindings are free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts. Below we show an example code of safe usage:
```python
# python3.13t example.py
import concurrent.futures
import mlir.dialects.arith as arith
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
def func(py_value):
with Context() as ctx:
module = Module.create(loc=Location.file("foo.txt", 0, 0))
dtype = IntegerType.get_signless(64)
with InsertionPoint(module.body), Location.name("a"):
arith.constant(dtype, py_value)
return module
num_workers = 8
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for i in range(num_workers):
futures.append(executor.submit(func, i))
assert len(list(f.result() for f in futures)) == num_workers
```
The exceptions to the free-threading compatibility:
- IR printing is unsafe, e.g. when using `PassManager` with `PassManager.enable_ir_printing()` which calls thread-unsafe `llvm::raw_ostream`.
- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `mlir.dialects.transform.interpreter` is unsafe.
- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.

View File

@ -24,7 +24,6 @@ namespace mlir {
namespace python {
/// Globals that are always accessible once the extension has been initialized.
/// Methods of this class are thread-safe.
class PyGlobals {
public:
PyGlobals();
@ -38,18 +37,12 @@ public:
/// Get and set the list of parent modules to search for dialect
/// implementation classes.
std::vector<std::string> getDialectSearchPrefixes() {
nanobind::ft_lock_guard lock(mutex);
std::vector<std::string> &getDialectSearchPrefixes() {
return dialectSearchPrefixes;
}
void setDialectSearchPrefixes(std::vector<std::string> newValues) {
nanobind::ft_lock_guard lock(mutex);
dialectSearchPrefixes.swap(newValues);
}
void addDialectSearchPrefix(std::string value) {
nanobind::ft_lock_guard lock(mutex);
dialectSearchPrefixes.push_back(std::move(value));
}
/// Loads a python module corresponding to the given dialect namespace.
/// No-ops if the module has already been loaded or is not found. Raises
@ -116,9 +109,6 @@ public:
private:
static PyGlobals *instance;
nanobind::ft_mutex mutex;
/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to external dialect class object.

View File

@ -243,15 +243,9 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,
/// Wrapper for the global LLVM debugging flag.
struct PyGlobalDebugFlag {
static void set(nb::object &o, bool enable) {
nb::ft_lock_guard lock(mutex);
mlirEnableGlobalDebug(enable);
}
static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
static bool get(const nb::object &) {
nb::ft_lock_guard lock(mutex);
return mlirIsGlobalDebugEnabled();
}
static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
static void bind(nb::module_ &m) {
// Debug flags.
@ -261,7 +255,6 @@ struct PyGlobalDebugFlag {
.def_static(
"set_types",
[](const std::string &type) {
nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugType(type.c_str());
},
"types"_a, "Sets specific debug types to be produced by LLVM")
@ -270,17 +263,11 @@ struct PyGlobalDebugFlag {
pointers.reserve(types.size());
for (const std::string &str : types)
pointers.push_back(str.c_str());
nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
});
}
private:
static nb::ft_mutex mutex;
};
nb::ft_mutex PyGlobalDebugFlag::mutex;
struct PyAttrBuilderMap {
static bool dunderContains(const std::string &attributeKind) {
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
@ -619,7 +606,6 @@ private:
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
nb::gil_scoped_acquire acquire;
nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
liveContexts[context.ptr] = this;
}
@ -629,10 +615,7 @@ PyMlirContext::~PyMlirContext() {
// forContext method, which always puts the associated handle into
// liveContexts.
nb::gil_scoped_acquire acquire;
{
nb::ft_lock_guard lock(live_contexts_mutex);
getLiveContexts().erase(context.ptr);
}
getLiveContexts().erase(context.ptr);
mlirContextDestroy(context);
}
@ -649,7 +632,6 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
nb::gil_scoped_acquire acquire;
nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
auto it = liveContexts.find(context.ptr);
if (it == liveContexts.end()) {
@ -665,17 +647,12 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
return PyMlirContextRef(it->second, std::move(pyRef));
}
nb::ft_mutex PyMlirContext::live_contexts_mutex;
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
static LiveContextMap liveContexts;
return liveContexts;
}
size_t PyMlirContext::getLiveCount() {
nb::ft_lock_guard lock(live_contexts_mutex);
return getLiveContexts().size();
}
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }

View File

@ -38,11 +38,8 @@ PyGlobals::PyGlobals() {
PyGlobals::~PyGlobals() { instance = nullptr; }
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
{
nb::ft_lock_guard lock(mutex);
if (loadedDialectModules.contains(dialectNamespace))
return true;
}
if (loadedDialectModules.contains(dialectNamespace))
return true;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
nb::object loaded = nb::none();
@ -65,14 +62,12 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
return false;
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
nb::ft_lock_guard lock(mutex);
loadedDialectModules.insert(dialectNamespace);
return true;
}
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
nb::callable pyFunc, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = attributeBuilderMap[attributeKind];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
@ -86,7 +81,6 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
nb::callable typeCaster, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = typeCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Type caster is already registered with caster: " +
@ -96,7 +90,6 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
nb::callable valueCaster, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = valueCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Value caster is already registered: " +
@ -106,7 +99,6 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
nb::object pyClass) {
nb::ft_lock_guard lock(mutex);
nb::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
@ -118,7 +110,6 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
void PyGlobals::registerOperationImpl(const std::string &operationName,
nb::object pyClass, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = operationClassMap[operationName];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
@ -130,7 +121,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
std::optional<nb::callable>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
nb::ft_lock_guard lock(mutex);
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
assert(foundIt->second && "attribute builder is defined");
@ -143,7 +133,6 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
nb::ft_lock_guard lock(mutex);
const auto foundIt = typeCasterMap.find(mlirTypeID);
if (foundIt != typeCasterMap.end()) {
assert(foundIt->second && "type caster is defined");
@ -156,7 +145,6 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
nb::ft_lock_guard lock(mutex);
const auto foundIt = valueCasterMap.find(mlirTypeID);
if (foundIt != valueCasterMap.end()) {
assert(foundIt->second && "value caster is defined");
@ -170,7 +158,6 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
// Make sure dialect module is loaded.
if (!loadDialectModule(dialectNamespace))
return std::nullopt;
nb::ft_lock_guard lock(mutex);
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
assert(foundIt->second && "dialect class is defined");
@ -188,7 +175,6 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
if (!loadDialectModule(dialectNamespace))
return std::nullopt;
nb::ft_lock_guard lock(mutex);
auto foundIt = operationClassMap.find(operationName);
if (foundIt != operationClassMap.end()) {
assert(foundIt->second && "OpView is defined");

View File

@ -260,7 +260,6 @@ private:
// Note that this holds a handle, which does not imply ownership.
// Mappings will be removed when the context is destructed.
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();
// Interns all live modules associated with this context. Modules tracked

View File

@ -30,8 +30,12 @@ NB_MODULE(_mlir, m) {
.def_prop_rw("dialect_search_modules",
&PyGlobals::getDialectSearchPrefixes,
&PyGlobals::setDialectSearchPrefixes)
.def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
"module_name"_a)
.def(
"append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
},
"module_name"_a)
.def(
"_check_dialect_module_loaded",
[](PyGlobals &self, const std::string &dialectNamespace) {
@ -72,6 +76,7 @@ NB_MODULE(_mlir, m) {
nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
PyGlobals::get().registerOperationImpl(operationName, opClass,
replace);
// Dict-stuff the new opClass by name onto the dialect class.
nb::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;

View File

@ -2,4 +2,4 @@ nanobind>=2.4, <3.0
numpy>=1.19.5, <=2.1.2
pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1
ml_dtypes>=0.5.0, <=0.6.0 # provides several NumPy dtype extensions, including the bf16
ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16

View File

@ -1,531 +0,0 @@
# RUN: %PYTHON %s
"""
This script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN.
Tests can be run using pytest:
```bash
python3.13t -mpytest -vvv multithreaded_tests.py
```
IMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context
and passing if no warnings reported by TSAN and failing otherwise.
Details on the generated tests and execution:
1) Multi-threaded execution: all generated tests are executed independently by
a pool of threads, running each test multiple times, see @multi_threaded for details
2) Tests generation: we use existing tests: test/python/ir/*.py,
test/python/dialects/*.py, etc to generate multi-threaded tests.
In details, we perform the following:
a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`.
b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method.
c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py.
In order to import the test file as python module, we remove all executing functions, like
`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details.
Observed warnings reported by TSAN.
CPython and free-threading known data-races:
1) ctypes related races: https://github.com/python/cpython/issues/127945
2) LLVM related data-races, llvm::raw_ostream is not thread-safe
- mlir pass manager
- dialects/transform_interpreter.py
- ir/diagnostic_handler.py
- ir/module.py
3) Dialect gpu module-to-binary method is unsafe
"""
import concurrent.futures
import gc
import importlib.util
import os
import sys
import threading
import tempfile
import unittest
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Optional
import mlir.dialects.arith as arith
from mlir.dialects import transform
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
def import_from_path(module_name: str, file_path: Path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def copy_and_update(src_filepath: Path, dst_filepath: Path):
# We should remove all calls like `run(testMethod)`
with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer:
while True:
src_line = reader.readline()
if len(src_line) == 0:
break
skip_lines = [
"run(",
"@run",
"@constructAndPrintInModule",
"run_apply_patterns(",
"@run_apply_patterns",
"@test_in_context",
"@construct_and_print_in_module",
]
if any(src_line.startswith(line) for line in skip_lines):
continue
writer.write(src_line)
# Helper run functions
# They are copied from the test modules (e.g. run function in execution_engine.py)
def run(test_function):
# Generic run tests function used by dialects and ir test modules
test_function()
def run_with_context_and_location(test_function):
# run tests function with a context and a location
# used by the following test modules:
# - dialects/transform_gpu_ext,
# - dialects/vector
# - dialects/gpu/*
with Context(), Location.unknown():
test_function()
return test_function
def run_with_insertion_point_and_context_arg(test_function):
# run tests function used by dialects/index_dialect test module
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
test_function(ctx)
def run_with_insertion_point(test_function):
# Used by a lot of dialects test modules
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
test_function()
return test_function
def run_with_insertion_point_and_module_arg(test_function):
# Used by dialects/transform test module
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
test_function(module)
return test_function
def run_with_insertion_point_all_unreg_dialects(test_function):
# Used by dialects/cf test module
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
module = Module.create()
with InsertionPoint(module.body):
test_function()
return test_function
def run_apply_patterns(test_function):
# Used by dialects/transform_tensor_ext test module
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
apply = transform.ApplyPatternsOp(sequence.bodyTarget)
with InsertionPoint(apply.patterns):
test_function()
transform.YieldOp()
print(module)
return test_function
def run_transform_tensor_ext(test_function):
# Used by test modules:
# - dialects/transform_gpu_ext
# - dialects/transform_sparse_tensor_ext
# - dialects/transform_tensor_ext
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
test_function(sequence.bodyTarget)
transform.YieldOp()
print(module)
return test_function
def run_transform_structured_ext(test_function):
# Used by dialects/transform_structured_ext test module
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
test_function()
module.operation.verify()
print(module)
return test_function
def run_construct_and_print_in_module(test_function):
# Used by test modules:
# - integration/dialects/pdl
# - integration/dialects/transform
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
module = test_function(module)
if module is not None:
print(module)
return test_function
TEST_MODULES = [
("execution_engine", run),
("pass_manager", run),
("dialects/affine", run_with_insertion_point),
("dialects/func", run_with_insertion_point),
("dialects/arith_dialect", run),
("dialects/arith_llvm", run),
("dialects/async_dialect", run),
("dialects/builtin", run),
("dialects/cf", run_with_insertion_point_all_unreg_dialects),
("dialects/complex_dialect", run),
("dialects/func", run_with_insertion_point),
("dialects/index_dialect", run_with_insertion_point_and_context_arg),
("dialects/llvm", run_with_insertion_point),
("dialects/math_dialect", run),
("dialects/memref", run),
("dialects/ml_program", run_with_insertion_point),
("dialects/nvgpu", run_with_insertion_point),
("dialects/nvvm", run_with_insertion_point),
("dialects/ods_helpers", run),
("dialects/openmp_ops", run_with_insertion_point),
("dialects/pdl_ops", run_with_insertion_point),
# ("dialects/python_test", run), # TODO: Need to pass pybind11 or nanobind argv
("dialects/quant", run),
("dialects/rocdl", run_with_insertion_point),
("dialects/scf", run_with_insertion_point),
("dialects/shape", run),
("dialects/spirv_dialect", run),
("dialects/tensor", run),
# ("dialects/tosa", ), # Nothing to test
("dialects/transform_bufferization_ext", run_with_insertion_point),
# ("dialects/transform_extras", ), # Needs a more complicated execution schema
("dialects/transform_gpu_ext", run_transform_tensor_ext),
(
"dialects/transform_interpreter",
run_with_context_and_location,
["print_", "transform_options", "failed", "include"],
),
(
"dialects/transform_loop_ext",
run_with_insertion_point,
["loopOutline"],
),
("dialects/transform_memref_ext", run_with_insertion_point),
("dialects/transform_nvgpu_ext", run_with_insertion_point),
("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext),
("dialects/transform_structured_ext", run_transform_structured_ext),
("dialects/transform_tensor_ext", run_transform_tensor_ext),
(
"dialects/transform_vector_ext",
run_apply_patterns,
["configurable_patterns"],
),
("dialects/transform", run_with_insertion_point_and_module_arg),
("dialects/vector", run_with_context_and_location),
("dialects/gpu/dialect", run_with_context_and_location),
("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location),
("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location),
("dialects/linalg/ops", run),
# TO ADD: No proper tests in this dialects/linalg/opsdsl/*
# ("dialects/linalg/opsdsl/*", ...),
("dialects/sparse_tensor/dialect", run),
("dialects/sparse_tensor/passes", run),
("integration/dialects/pdl", run_construct_and_print_in_module),
("integration/dialects/transform", run_construct_and_print_in_module),
("integration/dialects/linalg/opsrun", run),
("ir/affine_expr", run),
("ir/affine_map", run),
("ir/array_attributes", run),
("ir/attributes", run),
("ir/blocks", run),
("ir/builtin_types", run),
("ir/context_managers", run),
("ir/debug", run),
("ir/diagnostic_handler", run),
("ir/dialects", run),
("ir/exception", run),
("ir/insertion_point", run),
("ir/integer_set", run),
("ir/location", run),
("ir/module", run),
("ir/operation", run),
("ir/symbol_table", run),
("ir/value", run),
]
TESTS_TO_SKIP = [
"test_execution_engine__testNanoTime_multi_threaded", # testNanoTime can't run in multiple threads, even with GIL
"test_execution_engine__testSharedLibLoad_multi_threaded", # testSharedLibLoad can't run in multiple threads, even with GIL
"test_dialects_arith_dialect__testArithValue_multi_threaded", # RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>, even with GIL
"test_ir_dialects__testAppendPrefixSearchPath_multi_threaded", # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals
"test_ir_value__testValueCasters_multi_threaded", # RuntimeError: Value caster is already registered: <function testValueCasters.<locals>.dont_cast_int, even with GIL
# tests indirectly calling thread-unsafe llvm::raw_ostream
"test_execution_engine__testInvalidModule_multi_threaded", # mlirExecutionEngineCreate calls thread-unsafe llvm::raw_ostream
"test_pass_manager__testPrintIrAfterAll_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
"test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded", # IRPrinterInstrumentation::runBeforePass calls thread-unsafe llvm::raw_ostream
"test_pass_manager__testPrintIrLargeLimitElements_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
"test_pass_manager__testPrintIrTree_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
"test_pass_manager__testRunPipeline_multi_threaded", # PrintOpStatsPass::printSummary calls thread-unsafe llvm::raw_ostream
"test_dialects_transform_interpreter__include_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
"test_dialects_transform_interpreter__transform_options_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
"test_dialects_transform_interpreter__print_self_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) call thread-unsafe llvm::raw_ostream
"test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded", # mlirEmitError calls thread-unsafe llvm::raw_ostream
"test_ir_module__testParseSuccess_multi_threaded", # mlirOperationDump calls thread-unsafe llvm::raw_ostream
# False-positive TSAN detected race in llvm::RuntimeDyldELF::registerEHFrames()
# Details: https://github.com/llvm/llvm-project/pull/107103/files#r1905726947
"test_execution_engine__testCapsule_multi_threaded",
"test_execution_engine__testDumpToObjectFile_multi_threaded",
]
TESTS_TO_XFAIL = [
# execution_engine tests:
# - ctypes related data-races: https://github.com/python/cpython/issues/127945
"test_execution_engine__testBF16Memref_multi_threaded",
"test_execution_engine__testBasicCallback_multi_threaded",
"test_execution_engine__testComplexMemrefAdd_multi_threaded",
"test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded",
"test_execution_engine__testDynamicMemrefAdd2D_multi_threaded",
"test_execution_engine__testF16MemrefAdd_multi_threaded",
"test_execution_engine__testF8E5M2Memref_multi_threaded",
"test_execution_engine__testInvokeFloatAdd_multi_threaded",
"test_execution_engine__testInvokeVoid_multi_threaded", # a ctypes race
"test_execution_engine__testMemrefAdd_multi_threaded",
"test_execution_engine__testRankedMemRefCallback_multi_threaded",
"test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded",
"test_execution_engine__testUnrankedMemRefCallback_multi_threaded",
"test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded",
# dialects tests
"test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded", # Related to ctypes data races
"test_dialects_transform_interpreter__print_other_multi_threaded", # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe
"test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded", # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation
"test_dialects_gpu_module-to-binary-nvvm__testGPUToASMBin_multi_threaded",
"test_dialects_gpu_module-to-binary-nvvm__testGPUToLLVMBin_multi_threaded",
"test_dialects_gpu_module-to-binary-rocdl__testGPUToLLVMBin_multi_threaded",
# integration tests
"test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded", # Related to ctypes data races
"test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded", # Related to ctypes data races
"test_integration_dialects_linalg_opsrun__test_fill_builtin_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_fill_generic_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_fill_rng_builtin_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_fill_rng_generic_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_max_pooling_builtin_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_max_pooling_generic_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_min_pooling_builtin_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_min_pooling_generic_multi_threaded", # ctypes
]
def add_existing_tests(test_modules, test_prefix: str = "_original_test"):
def decorator(test_cls):
this_folder = Path(__file__).parent.absolute()
test_cls.output_folder = tempfile.TemporaryDirectory()
output_folder = Path(test_cls.output_folder.name)
for test_mod_info in test_modules:
# test_mod_info is a tuple of size 2 or 3:
# (test_module_str, run_test_function) or (test_module_str, run_test_function, test_name_patterns_list)
# For example:
# - ("ir/value", run) or
# - ("dialects/transform_loop_ext", run_with_insertion_point, ["loopOutline"])
assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3)
if len(test_mod_info) == 2:
test_module_name, exec_fn = test_mod_info
test_pattern = None
else:
test_module_name, exec_fn, test_pattern = test_mod_info
src_filepath = this_folder / f"{test_module_name}.py"
dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
if not dst_filepath.parent.exists():
dst_filepath.parent.mkdir(parents=True)
copy_and_update(src_filepath, dst_filepath)
test_mod = import_from_path(test_module_name, dst_filepath)
for attr_name in dir(test_mod):
is_test_fn = test_pattern is None and attr_name.startswith("test")
is_test_fn |= test_pattern is not None and any(
[p in attr_name for p in test_pattern]
)
if is_test_fn:
obj = getattr(test_mod, attr_name)
if callable(obj):
test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
def wrapped_test_fn(
self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs
):
__exec_fn__(__test_fn__)
setattr(test_cls, test_name, wrapped_test_fn)
return test_cls
return decorator
@contextmanager
def _capture_output(fp):
# Inspired from jax test_utils.py capture_stderr method
# ``None`` means nothing has not been captured yet.
captured = None
def get_output() -> str:
if captured is None:
raise ValueError("get_output() called while the context is active.")
return captured
with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f:
original_fd = os.dup(fp.fileno())
os.dup2(f.fileno(), fp.fileno())
try:
yield get_output
finally:
# Python also has its own buffers, make sure everything is flushed.
fp.flush()
os.fsync(fp.fileno())
f.seek(0)
captured = f.read()
os.dup2(original_fd, fp.fileno())
capture_stdout = partial(_capture_output, sys.stdout)
capture_stderr = partial(_capture_output, sys.stderr)
def multi_threaded(
num_workers: int,
num_runs: int = 5,
skip_tests: Optional[list[str]] = None,
xfail_tests: Optional[list[str]] = None,
test_prefix: str = "_original_test",
multithreaded_test_postfix: str = "_multi_threaded",
):
"""Decorator that runs a test in a multi-threaded environment."""
def decorator(test_cls):
for name, test_fn in test_cls.__dict__.copy().items():
if not (name.startswith(test_prefix) and callable(test_fn)):
continue
name = f"test{name[len(test_prefix):]}"
if skip_tests is not None:
if any(
test_name.replace(multithreaded_test_postfix, "") in name
for test_name in skip_tests
):
continue
def multi_threaded_test_fn(self, *args, __test_fn__=test_fn, **kwargs):
with capture_stdout(), capture_stderr() as get_output:
barrier = threading.Barrier(num_workers)
def closure():
barrier.wait()
for _ in range(num_runs):
__test_fn__(self, *args, **kwargs)
with concurrent.futures.ThreadPoolExecutor(
max_workers=num_workers
) as executor:
futures = []
for _ in range(num_workers):
futures.append(executor.submit(closure))
# We should call future.result() to re-raise an exception if test has
# failed
assert len(list(f.result() for f in futures)) == num_workers
gc.collect()
assert Context._get_live_count() == 0
captured = get_output()
if len(captured) > 0 and "ThreadSanitizer" in captured:
raise RuntimeError(
f"ThreadSanitizer reported warnings:\n{captured}"
)
test_new_name = f"{name}{multithreaded_test_postfix}"
if xfail_tests is not None and test_new_name in xfail_tests:
multi_threaded_test_fn = unittest.expectedFailure(
multi_threaded_test_fn
)
setattr(test_cls, test_new_name, multi_threaded_test_fn)
return test_cls
return decorator
@multi_threaded(
num_workers=10,
num_runs=20,
skip_tests=TESTS_TO_SKIP,
xfail_tests=TESTS_TO_XFAIL,
)
@add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test")
class TestAllMultiThreaded(unittest.TestCase):
@classmethod
def tearDownClass(cls):
if hasattr(cls, "output_folder"):
cls.output_folder.cleanup()
def _original_test_create_context(self):
with Context() as ctx:
print(ctx._get_live_count())
print(ctx._get_live_module_count())
print(ctx._get_live_operation_count())
print(ctx._get_live_operation_objects())
print(ctx._get_context_again() is ctx)
print(ctx._clear_live_operations())
def _original_test_create_module_with_consts(self):
py_values = [123, 234, 345]
with Context() as ctx:
module = Module.create(loc=Location.file("foo.txt", 0, 0))
dtype = IntegerType.get_signless(64)
with InsertionPoint(module.body), Location.name("a"):
arith.constant(dtype, py_values[0])
with InsertionPoint(module.body), Location.name("b"):
arith.constant(dtype, py_values[1])
with InsertionPoint(module.body), Location.name("c"):
arith.constant(dtype, py_values[2])
if __name__ == "__main__":
# Do not run the tests on CPython with GIL
if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled():
unittest.main()