mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Disable threading in MLIR contexts.
If threading is enabled, each MLIR context owns a threadpool of typically 3 threads. We might keep thousands of MLIR module objects alive in our caches, and we don't want to keep thousands of idle threads around. Disable the MLIR threading support. We don't do heavy compilation work from the Python-generated context anyway. Fixes #16272
This commit is contained in:
parent
eea03ced0a
commit
08184cec86
@ -360,6 +360,14 @@ def _source_info_to_location(
|
||||
def make_ir_context() -> ir.Context:
|
||||
"""Creates an MLIR context suitable for JAX IR."""
|
||||
context = ir.Context()
|
||||
|
||||
# If threading is enabled, each MLIR context will keep alive a thread pool.
|
||||
# Since we cache MLIR modules (and hence contexts), this means we might keep
|
||||
# several threads alive for each cache entry. This is a terrible idea. However
|
||||
# we don't do any heavy computation on MLIR modules from Python anyway, so we
|
||||
# just disable threading.
|
||||
context.enable_multithreading(False)
|
||||
|
||||
dialects.mhlo.register_mhlo_dialect(context)
|
||||
dialects.chlo.register_dialect(context)
|
||||
dialects.stablehlo.register_dialect(context)
|
||||
|
Loading…
x
Reference in New Issue
Block a user