mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove autotune sharing.
xla_gpu_shard_autotuning can be used now instead and it is enabled by default. PiperOrigin-RevId: 705792463
This commit is contained in:
parent
d0f63da4b5
commit
a123d4e39e
@ -18,8 +18,6 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Callable
|
||||
import warnings
|
||||
@ -449,22 +447,6 @@ def compile_or_get_cached(
|
||||
cache_key,
|
||||
min_device_process_id
|
||||
)
|
||||
elif (
|
||||
config.share_autotune_config_between_hosts.value
|
||||
and is_multi_process
|
||||
and distributed.global_state.client is not None
|
||||
):
|
||||
log_persistent_cache_miss(module_name, cache_key)
|
||||
return _compile_and_write_autotune_config(
|
||||
backend,
|
||||
computation,
|
||||
compile_options,
|
||||
host_callbacks,
|
||||
distributed.global_state.client,
|
||||
module_name,
|
||||
cache_key,
|
||||
min_device_process_id
|
||||
)
|
||||
else:
|
||||
log_persistent_cache_miss(module_name, cache_key)
|
||||
return _compile_and_write_cache(
|
||||
@ -608,113 +590,6 @@ def _share_fdo_profiles(
|
||||
|
||||
_share_fdo_profiles.modules_profiles = {}
|
||||
|
||||
|
||||
# The process with the first_process_id should compile the module and write an
|
||||
# autotune config to the K-V storage.
|
||||
def _compile_and_write_autotune_config(
|
||||
backend: xc.Client,
|
||||
computation: ir.Module,
|
||||
compile_options: xc.CompileOptions,
|
||||
host_callbacks: Sequence[Any],
|
||||
global_client: lib.xla_extension.DistributedRuntimeClient,
|
||||
module_name: str,
|
||||
cache_key: str,
|
||||
first_process_id: int
|
||||
) -> xc.LoadedExecutable:
|
||||
share_timeout = config.share_binary_between_hosts_timeout_ms.value
|
||||
debug_options = compile_options.executable_build_options.debug_options
|
||||
|
||||
if _compile_and_write_autotune_config.autotune_configs_dir is None:
|
||||
_compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp()
|
||||
|
||||
autotune_tmp_file = os.path.join(
|
||||
_compile_and_write_autotune_config.autotune_configs_dir, cache_key
|
||||
)
|
||||
|
||||
if os.path.exists(autotune_tmp_file):
|
||||
logger.debug(
|
||||
"Compiling module: %s. Use existing autotune config file: %s",
|
||||
module_name,
|
||||
autotune_tmp_file,
|
||||
)
|
||||
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
|
||||
return _compile_and_write_cache(
|
||||
backend,
|
||||
computation,
|
||||
compile_options,
|
||||
host_callbacks,
|
||||
module_name,
|
||||
cache_key,
|
||||
)
|
||||
|
||||
if distributed.global_state.process_id == first_process_id:
|
||||
debug_options.xla_gpu_dump_autotune_results_to = autotune_tmp_file
|
||||
logger.debug("Process %d compiling and dumping autotune for module: %s",
|
||||
first_process_id, module_name)
|
||||
executable = _compile_and_write_cache(
|
||||
backend,
|
||||
computation,
|
||||
compile_options,
|
||||
host_callbacks,
|
||||
module_name,
|
||||
cache_key,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Writing autotune config for module %s to %s",
|
||||
module_name,
|
||||
autotune_tmp_file,
|
||||
)
|
||||
with open(autotune_tmp_file, "rb") as f:
|
||||
autotune_config = f.read()
|
||||
|
||||
autotune_config = compilation_cache.compress_executable(autotune_config)
|
||||
global_client.key_value_set_bytes(cache_key, autotune_config)
|
||||
logger.debug(
|
||||
"Autotune config for module %s with size %d shared by cache_key %s",
|
||||
module_name,
|
||||
len(autotune_config),
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Compiling module %s, waiting for config to be shared by cache_key %s"
|
||||
"from process %d",
|
||||
module_name,
|
||||
cache_key,
|
||||
first_process_id
|
||||
)
|
||||
autotune_config = global_client.blocking_key_value_get_bytes(
|
||||
cache_key, share_timeout
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Received autotune config for module %s of size %d",
|
||||
module_name,
|
||||
len(autotune_config),
|
||||
)
|
||||
autotune_config = compilation_cache.decompress_executable(autotune_config)
|
||||
with open(autotune_tmp_file, "wb") as f:
|
||||
f.write(autotune_config)
|
||||
|
||||
logger.debug(
|
||||
"Compiling module %s, using autotune config from %s",
|
||||
module_name,
|
||||
autotune_tmp_file,
|
||||
)
|
||||
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
|
||||
executable = _compile_and_write_cache(
|
||||
backend,
|
||||
computation,
|
||||
compile_options,
|
||||
host_callbacks,
|
||||
module_name,
|
||||
cache_key,
|
||||
)
|
||||
return executable
|
||||
|
||||
_compile_and_write_autotune_config.autotune_configs_dir = None
|
||||
|
||||
# The process with the first_process_id should compile the module and write it
|
||||
# to the K-V storage.
|
||||
def _compile_and_share_module(
|
||||
|
@ -1169,20 +1169,6 @@ traceback_in_locations_limit = int_state(
|
||||
),
|
||||
)
|
||||
|
||||
share_autotune_config_between_hosts = bool_state(
|
||||
name='jax_share_autotune_config_between_hosts',
|
||||
default=False,
|
||||
help=(
|
||||
'If set to True, the coordinator process will share autotune configs '
|
||||
'other participants. This will increase overall compilation time, but '
|
||||
'will lead to equal compiled modules in each process. '
|
||||
'If both jax_share_binary_between_hosts and '
|
||||
'jax_share_autotune_config_between_hosts are set, compiled HLO will be '
|
||||
"shared when it's possible and autotune config sharing will be used "
|
||||
'as a fallback.'
|
||||
),
|
||||
)
|
||||
|
||||
share_binary_between_hosts = bool_state(
|
||||
name='jax_share_binary_between_hosts',
|
||||
default=False,
|
||||
|
Loading…
x
Reference in New Issue
Block a user