Remove autotune sharing.

xla_gpu_shard_autotuning can be used now instead and it is enabled by default.

PiperOrigin-RevId: 705792463
This commit is contained in:
jax authors 2024-12-13 01:21:44 -08:00
parent d0f63da4b5
commit a123d4e39e
2 changed files with 0 additions and 139 deletions

View File

@ -18,8 +18,6 @@ from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
import logging import logging
import os
import tempfile
import time import time
from typing import Any, Callable from typing import Any, Callable
import warnings import warnings
@ -449,22 +447,6 @@ def compile_or_get_cached(
cache_key, cache_key,
min_device_process_id min_device_process_id
) )
elif (
config.share_autotune_config_between_hosts.value
and is_multi_process
and distributed.global_state.client is not None
):
log_persistent_cache_miss(module_name, cache_key)
return _compile_and_write_autotune_config(
backend,
computation,
compile_options,
host_callbacks,
distributed.global_state.client,
module_name,
cache_key,
min_device_process_id
)
else: else:
log_persistent_cache_miss(module_name, cache_key) log_persistent_cache_miss(module_name, cache_key)
return _compile_and_write_cache( return _compile_and_write_cache(
@ -608,113 +590,6 @@ def _share_fdo_profiles(
_share_fdo_profiles.modules_profiles = {} _share_fdo_profiles.modules_profiles = {}
# The process with the first_process_id should compile the module and write an
# autotune config to the K-V storage.
def _compile_and_write_autotune_config(
backend: xc.Client,
computation: ir.Module,
compile_options: xc.CompileOptions,
host_callbacks: Sequence[Any],
global_client: lib.xla_extension.DistributedRuntimeClient,
module_name: str,
cache_key: str,
first_process_id: int
) -> xc.LoadedExecutable:
share_timeout = config.share_binary_between_hosts_timeout_ms.value
debug_options = compile_options.executable_build_options.debug_options
if _compile_and_write_autotune_config.autotune_configs_dir is None:
_compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp()
autotune_tmp_file = os.path.join(
_compile_and_write_autotune_config.autotune_configs_dir, cache_key
)
if os.path.exists(autotune_tmp_file):
logger.debug(
"Compiling module: %s. Use existing autotune config file: %s",
module_name,
autotune_tmp_file,
)
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
return _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)
if distributed.global_state.process_id == first_process_id:
debug_options.xla_gpu_dump_autotune_results_to = autotune_tmp_file
logger.debug("Process %d compiling and dumping autotune for module: %s",
first_process_id, module_name)
executable = _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)
logger.debug(
"Writing autotune config for module %s to %s",
module_name,
autotune_tmp_file,
)
with open(autotune_tmp_file, "rb") as f:
autotune_config = f.read()
autotune_config = compilation_cache.compress_executable(autotune_config)
global_client.key_value_set_bytes(cache_key, autotune_config)
logger.debug(
"Autotune config for module %s with size %d shared by cache_key %s",
module_name,
len(autotune_config),
cache_key,
)
else:
logger.debug(
"Compiling module %s, waiting for config to be shared by cache_key %s"
"from process %d",
module_name,
cache_key,
first_process_id
)
autotune_config = global_client.blocking_key_value_get_bytes(
cache_key, share_timeout
)
logger.debug(
"Received autotune config for module %s of size %d",
module_name,
len(autotune_config),
)
autotune_config = compilation_cache.decompress_executable(autotune_config)
with open(autotune_tmp_file, "wb") as f:
f.write(autotune_config)
logger.debug(
"Compiling module %s, using autotune config from %s",
module_name,
autotune_tmp_file,
)
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
executable = _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)
return executable
_compile_and_write_autotune_config.autotune_configs_dir = None
# The process with the first_process_id should compile the module and write it # The process with the first_process_id should compile the module and write it
# to the K-V storage. # to the K-V storage.
def _compile_and_share_module( def _compile_and_share_module(

View File

@ -1169,20 +1169,6 @@ traceback_in_locations_limit = int_state(
), ),
) )
share_autotune_config_between_hosts = bool_state(
name='jax_share_autotune_config_between_hosts',
default=False,
help=(
'If set to True, the coordinator process will share autotune configs '
'other participants. This will increase overall compilation time, but '
'will lead to equal compiled modules in each process. '
'If both jax_share_binary_between_hosts and '
'jax_share_autotune_config_between_hosts are set, compiled HLO will be '
"shared when it's possible and autotune config sharing will be used "
'as a fallback.'
),
)
share_binary_between_hosts = bool_state( share_binary_between_hosts = bool_state(
name='jax_share_binary_between_hosts', name='jax_share_binary_between_hosts',
default=False, default=False,