Disable threading in MLIR contexts.

If threading is enabled, each MLIR context owns a threadpool of typically 3 threads. We might keep thousands of MLIR module objects alive in our caches, and we don't want to keep thousands of idle threads around. Disable the MLIR threading support. We don't do heavy compilation work from the Python-generated context anyway.

Fixes #16272
This commit is contained in:
Peter Hawkins 2023-06-06 19:46:30 -04:00
parent eea03ced0a
commit 08184cec86

View File

@ -360,6 +360,14 @@ def _source_info_to_location(
def make_ir_context() -> ir.Context:
"""Creates an MLIR context suitable for JAX IR."""
context = ir.Context()
# If threading is enabled, each MLIR context will keep alive a thread pool.
# Since we cache MLIR modules (and hence contexts), this means we might keep
# several threads alive for each cache entry. This is a terrible idea. However
# we don't do any heavy computation on MLIR modules from Python anyway, so we
# just disable threading.
context.enable_multithreading(False)
dialects.mhlo.register_mhlo_dialect(context)
dialects.chlo.register_dialect(context)
dialects.stablehlo.register_dialect(context)