mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Remove autotune sharing.
xla_gpu_shard_autotuning can be used now instead and it is enabled by default. PiperOrigin-RevId: 705792463
This commit is contained in:
parent
d0f63da4b5
commit
a123d4e39e
@ -18,8 +18,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
import warnings
|
import warnings
|
||||||
@ -449,22 +447,6 @@ def compile_or_get_cached(
|
|||||||
cache_key,
|
cache_key,
|
||||||
min_device_process_id
|
min_device_process_id
|
||||||
)
|
)
|
||||||
elif (
|
|
||||||
config.share_autotune_config_between_hosts.value
|
|
||||||
and is_multi_process
|
|
||||||
and distributed.global_state.client is not None
|
|
||||||
):
|
|
||||||
log_persistent_cache_miss(module_name, cache_key)
|
|
||||||
return _compile_and_write_autotune_config(
|
|
||||||
backend,
|
|
||||||
computation,
|
|
||||||
compile_options,
|
|
||||||
host_callbacks,
|
|
||||||
distributed.global_state.client,
|
|
||||||
module_name,
|
|
||||||
cache_key,
|
|
||||||
min_device_process_id
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
log_persistent_cache_miss(module_name, cache_key)
|
log_persistent_cache_miss(module_name, cache_key)
|
||||||
return _compile_and_write_cache(
|
return _compile_and_write_cache(
|
||||||
@ -608,113 +590,6 @@ def _share_fdo_profiles(
|
|||||||
|
|
||||||
_share_fdo_profiles.modules_profiles = {}
|
_share_fdo_profiles.modules_profiles = {}
|
||||||
|
|
||||||
|
|
||||||
# The process with the first_process_id should compile the module and write an
|
|
||||||
# autotune config to the K-V storage.
|
|
||||||
def _compile_and_write_autotune_config(
|
|
||||||
backend: xc.Client,
|
|
||||||
computation: ir.Module,
|
|
||||||
compile_options: xc.CompileOptions,
|
|
||||||
host_callbacks: Sequence[Any],
|
|
||||||
global_client: lib.xla_extension.DistributedRuntimeClient,
|
|
||||||
module_name: str,
|
|
||||||
cache_key: str,
|
|
||||||
first_process_id: int
|
|
||||||
) -> xc.LoadedExecutable:
|
|
||||||
share_timeout = config.share_binary_between_hosts_timeout_ms.value
|
|
||||||
debug_options = compile_options.executable_build_options.debug_options
|
|
||||||
|
|
||||||
if _compile_and_write_autotune_config.autotune_configs_dir is None:
|
|
||||||
_compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp()
|
|
||||||
|
|
||||||
autotune_tmp_file = os.path.join(
|
|
||||||
_compile_and_write_autotune_config.autotune_configs_dir, cache_key
|
|
||||||
)
|
|
||||||
|
|
||||||
if os.path.exists(autotune_tmp_file):
|
|
||||||
logger.debug(
|
|
||||||
"Compiling module: %s. Use existing autotune config file: %s",
|
|
||||||
module_name,
|
|
||||||
autotune_tmp_file,
|
|
||||||
)
|
|
||||||
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
|
|
||||||
return _compile_and_write_cache(
|
|
||||||
backend,
|
|
||||||
computation,
|
|
||||||
compile_options,
|
|
||||||
host_callbacks,
|
|
||||||
module_name,
|
|
||||||
cache_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
if distributed.global_state.process_id == first_process_id:
|
|
||||||
debug_options.xla_gpu_dump_autotune_results_to = autotune_tmp_file
|
|
||||||
logger.debug("Process %d compiling and dumping autotune for module: %s",
|
|
||||||
first_process_id, module_name)
|
|
||||||
executable = _compile_and_write_cache(
|
|
||||||
backend,
|
|
||||||
computation,
|
|
||||||
compile_options,
|
|
||||||
host_callbacks,
|
|
||||||
module_name,
|
|
||||||
cache_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Writing autotune config for module %s to %s",
|
|
||||||
module_name,
|
|
||||||
autotune_tmp_file,
|
|
||||||
)
|
|
||||||
with open(autotune_tmp_file, "rb") as f:
|
|
||||||
autotune_config = f.read()
|
|
||||||
|
|
||||||
autotune_config = compilation_cache.compress_executable(autotune_config)
|
|
||||||
global_client.key_value_set_bytes(cache_key, autotune_config)
|
|
||||||
logger.debug(
|
|
||||||
"Autotune config for module %s with size %d shared by cache_key %s",
|
|
||||||
module_name,
|
|
||||||
len(autotune_config),
|
|
||||||
cache_key,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"Compiling module %s, waiting for config to be shared by cache_key %s"
|
|
||||||
"from process %d",
|
|
||||||
module_name,
|
|
||||||
cache_key,
|
|
||||||
first_process_id
|
|
||||||
)
|
|
||||||
autotune_config = global_client.blocking_key_value_get_bytes(
|
|
||||||
cache_key, share_timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Received autotune config for module %s of size %d",
|
|
||||||
module_name,
|
|
||||||
len(autotune_config),
|
|
||||||
)
|
|
||||||
autotune_config = compilation_cache.decompress_executable(autotune_config)
|
|
||||||
with open(autotune_tmp_file, "wb") as f:
|
|
||||||
f.write(autotune_config)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Compiling module %s, using autotune config from %s",
|
|
||||||
module_name,
|
|
||||||
autotune_tmp_file,
|
|
||||||
)
|
|
||||||
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
|
|
||||||
executable = _compile_and_write_cache(
|
|
||||||
backend,
|
|
||||||
computation,
|
|
||||||
compile_options,
|
|
||||||
host_callbacks,
|
|
||||||
module_name,
|
|
||||||
cache_key,
|
|
||||||
)
|
|
||||||
return executable
|
|
||||||
|
|
||||||
_compile_and_write_autotune_config.autotune_configs_dir = None
|
|
||||||
|
|
||||||
# The process with the first_process_id should compile the module and write it
|
# The process with the first_process_id should compile the module and write it
|
||||||
# to the K-V storage.
|
# to the K-V storage.
|
||||||
def _compile_and_share_module(
|
def _compile_and_share_module(
|
||||||
|
@ -1169,20 +1169,6 @@ traceback_in_locations_limit = int_state(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
share_autotune_config_between_hosts = bool_state(
|
|
||||||
name='jax_share_autotune_config_between_hosts',
|
|
||||||
default=False,
|
|
||||||
help=(
|
|
||||||
'If set to True, the coordinator process will share autotune configs '
|
|
||||||
'other participants. This will increase overall compilation time, but '
|
|
||||||
'will lead to equal compiled modules in each process. '
|
|
||||||
'If both jax_share_binary_between_hosts and '
|
|
||||||
'jax_share_autotune_config_between_hosts are set, compiled HLO will be '
|
|
||||||
"shared when it's possible and autotune config sharing will be used "
|
|
||||||
'as a fallback.'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
share_binary_between_hosts = bool_state(
|
share_binary_between_hosts = bool_state(
|
||||||
name='jax_share_binary_between_hosts',
|
name='jax_share_binary_between_hosts',
|
||||||
default=False,
|
default=False,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user