Remove autotune sharing.

xla_gpu_shard_autotuning can be used now instead and it is enabled by default.

PiperOrigin-RevId: 705792463
This commit is contained in:
jax authors 2024-12-13 01:21:44 -08:00
parent d0f63da4b5
commit a123d4e39e
2 changed files with 0 additions and 139 deletions

View File

@ -18,8 +18,6 @@ from __future__ import annotations
from collections.abc import Sequence
import logging
import os
import tempfile
import time
from typing import Any, Callable
import warnings
@ -449,22 +447,6 @@ def compile_or_get_cached(
cache_key,
min_device_process_id
)
elif (
config.share_autotune_config_between_hosts.value
and is_multi_process
and distributed.global_state.client is not None
):
log_persistent_cache_miss(module_name, cache_key)
return _compile_and_write_autotune_config(
backend,
computation,
compile_options,
host_callbacks,
distributed.global_state.client,
module_name,
cache_key,
min_device_process_id
)
else:
log_persistent_cache_miss(module_name, cache_key)
return _compile_and_write_cache(
@ -608,113 +590,6 @@ def _share_fdo_profiles(
_share_fdo_profiles.modules_profiles = {}
# The process with the first_process_id should compile the module and write an
# autotune config to the K-V storage.
def _compile_and_write_autotune_config(
backend: xc.Client,
computation: ir.Module,
compile_options: xc.CompileOptions,
host_callbacks: Sequence[Any],
global_client: lib.xla_extension.DistributedRuntimeClient,
module_name: str,
cache_key: str,
first_process_id: int
) -> xc.LoadedExecutable:
share_timeout = config.share_binary_between_hosts_timeout_ms.value
debug_options = compile_options.executable_build_options.debug_options
if _compile_and_write_autotune_config.autotune_configs_dir is None:
_compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp()
autotune_tmp_file = os.path.join(
_compile_and_write_autotune_config.autotune_configs_dir, cache_key
)
if os.path.exists(autotune_tmp_file):
logger.debug(
"Compiling module: %s. Use existing autotune config file: %s",
module_name,
autotune_tmp_file,
)
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
return _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)
if distributed.global_state.process_id == first_process_id:
debug_options.xla_gpu_dump_autotune_results_to = autotune_tmp_file
logger.debug("Process %d compiling and dumping autotune for module: %s",
first_process_id, module_name)
executable = _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)
logger.debug(
"Writing autotune config for module %s to %s",
module_name,
autotune_tmp_file,
)
with open(autotune_tmp_file, "rb") as f:
autotune_config = f.read()
autotune_config = compilation_cache.compress_executable(autotune_config)
global_client.key_value_set_bytes(cache_key, autotune_config)
logger.debug(
"Autotune config for module %s with size %d shared by cache_key %s",
module_name,
len(autotune_config),
cache_key,
)
else:
logger.debug(
"Compiling module %s, waiting for config to be shared by cache_key %s"
"from process %d",
module_name,
cache_key,
first_process_id
)
autotune_config = global_client.blocking_key_value_get_bytes(
cache_key, share_timeout
)
logger.debug(
"Received autotune config for module %s of size %d",
module_name,
len(autotune_config),
)
autotune_config = compilation_cache.decompress_executable(autotune_config)
with open(autotune_tmp_file, "wb") as f:
f.write(autotune_config)
logger.debug(
"Compiling module %s, using autotune config from %s",
module_name,
autotune_tmp_file,
)
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
executable = _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)
return executable
_compile_and_write_autotune_config.autotune_configs_dir = None
# The process with the first_process_id should compile the module and write it
# to the K-V storage.
def _compile_and_share_module(

View File

@ -1169,20 +1169,6 @@ traceback_in_locations_limit = int_state(
),
)
share_autotune_config_between_hosts = bool_state(
name='jax_share_autotune_config_between_hosts',
default=False,
help=(
'If set to True, the coordinator process will share autotune configs '
'other participants. This will increase overall compilation time, but '
'will lead to equal compiled modules in each process. '
'If both jax_share_binary_between_hosts and '
'jax_share_autotune_config_between_hosts are set, compiled HLO will be '
"shared when it's possible and autotune config sharing will be used "
'as a fallback.'
),
)
share_binary_between_hosts = bool_state(
name='jax_share_binary_between_hosts',
default=False,