1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Don't create temp directory when module is getting imported.

PiperOrigin-RevId: 630958402
This commit is contained in:
jax authors 2024-05-06 00:58:00 -07:00
parent 047ea210e8
commit 7681493760

@ -353,6 +353,10 @@ def _compile_and_write_autotune_config(
) -> xc.LoadedExecutable:
share_timeout = config.share_binary_between_hosts_timeout_ms.value
debug_options = compile_options.executable_build_options.debug_options
if _compile_and_write_autotune_config.autotune_configs_dir is None:
_compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp()
autotune_tmp_file = os.path.join(
_compile_and_write_autotune_config.autotune_configs_dir, cache_key
)
@ -436,7 +440,7 @@ def _compile_and_write_autotune_config(
)
return executable
_compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp()
_compile_and_write_autotune_config.autotune_configs_dir = None
# The process with id 0 should compile the module and write it to the K-V
# storage.