mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

Since the compilation cache is now initialized lazily, existing APIs initialize_cache() and is_initialized() are confusing. Deprecate these APIs. Introduce a new API set_cache_dir() to explicitly set the cache directory path in code. Testing: revised unit tests, test workload. PiperOrigin-RevId: 598073423
277 lines
8.5 KiB
Python
277 lines
8.5 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
import warnings
|
|
import zlib
|
|
|
|
import numpy as np
|
|
|
|
# If zstandard is installed, we use zstd compression, otherwise we use zlib.
|
|
try:
|
|
import zstandard
|
|
except ImportError:
|
|
zstandard = None
|
|
|
|
from jax._src import cache_key
|
|
from jax._src.compilation_cache_interface import CacheInterface
|
|
from jax._src import config
|
|
from jax._src import monitoring
|
|
from jax._src.gfile_cache import GFileCache
|
|
from jax._src.lib import xla_client
|
|
from jax._src.lib.mlir import ir
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_cache: CacheInterface | None = None
|
|
|
|
_cache_initialized: bool = False
|
|
|
|
_cache_used: bool = False
|
|
|
|
# Mutex to protect _cache_initialized and _cache_used.
|
|
_cache_initialized_mutex = threading.Lock()
|
|
|
|
|
|
def set_once_cache_used(f) -> None:
|
|
"""One-time setting of _cache_used.
|
|
|
|
If _cache_used is False, set it to True and execute the provided function
|
|
f. No action if _cache_used is True. This provides a mechanism to execute f
|
|
once per task. Note that reset_cache() will reset _cache_used also.
|
|
"""
|
|
global _cache_used
|
|
with _cache_initialized_mutex:
|
|
if not _cache_used:
|
|
_cache_used = True
|
|
if f is not None:
|
|
f()
|
|
|
|
|
|
def get_file_cache(path: str) -> CacheInterface:
|
|
return GFileCache(path)
|
|
|
|
|
|
def set_cache_dir(path) -> None:
|
|
"""
|
|
Sets the persistent compilation cache directory.
|
|
|
|
After calling this, jit-compiled functions are saved to `path`, so they
|
|
do not need be recompiled if the process is restarted or otherwise run again.
|
|
This also tells Jax where to look for compiled functions before compiling.
|
|
"""
|
|
config.config.update("jax_compilation_cache_dir", path)
|
|
|
|
|
|
def initialize_cache(path) -> None:
|
|
"""
|
|
This API is deprecated; use set_cache_dir instead.
|
|
|
|
Set the path. To take effect, should be called prior to any calls to
|
|
get_executable_and_time() and put_executable_and_time().
|
|
"""
|
|
warnings.warn("initialize_cache is deprecated; use set_cache_dir instead",
|
|
DeprecationWarning, stacklevel=2)
|
|
config.config.update("jax_compilation_cache_dir", path)
|
|
|
|
|
|
def default_min_cache_entry_size() -> int:
|
|
"""Returns the minimum size below which the entry should not be cached."""
|
|
return 0
|
|
|
|
|
|
def _is_cache_enabled() -> bool:
|
|
return config.enable_compilation_cache.value
|
|
|
|
|
|
def _initialize_cache() -> None:
|
|
# Attempt to initialize the cache at most once.
|
|
global _cache_initialized
|
|
with _cache_initialized_mutex:
|
|
if _cache_initialized:
|
|
logger.debug("_initialize_cache: cache has already been initialized!")
|
|
return
|
|
_cache_initialized = True
|
|
|
|
# Nothing to do if the cache is disabled.
|
|
if not _is_cache_enabled():
|
|
logger.debug("_initialize_cache: cache is disabled!")
|
|
return
|
|
|
|
# Set the minimum cache size entry only if the flag
|
|
# --jax_persistent_cache_min_entry_size_bytes has not been set.
|
|
if config.persistent_cache_min_entry_size_bytes.value == 0:
|
|
config.config.update("jax_persistent_cache_min_entry_size_bytes",
|
|
default_min_cache_entry_size())
|
|
|
|
global _cache
|
|
assert _cache is None, "The cache has already been initialized!"
|
|
path: str = config.compilation_cache_dir.value
|
|
# If the path is not set, the cache will not be enabled.
|
|
if not path:
|
|
return
|
|
|
|
_cache = get_file_cache(path)
|
|
logger.debug("Initialized persistent compilation cache at %s", path)
|
|
|
|
|
|
def _get_cache() -> CacheInterface | None:
|
|
# TODO(b/289098047): consider making this an API and changing the callers of
|
|
# get_executable_and_time() and put_executable_and_time() to call get_cache()
|
|
# and passing the result to them.
|
|
if _cache is None:
|
|
_initialize_cache() # initialization is done at most once; see above
|
|
return _cache
|
|
|
|
|
|
def compress_executable(executable):
|
|
if zstandard:
|
|
compressor = zstandard.ZstdCompressor()
|
|
return compressor.compress(executable)
|
|
else:
|
|
return zlib.compress(executable)
|
|
|
|
def decompress_executable(executable):
|
|
if zstandard:
|
|
decompressor = zstandard.ZstdDecompressor()
|
|
return decompressor.decompress(executable)
|
|
else:
|
|
return zlib.decompress(executable)
|
|
|
|
def get_executable_and_time(
|
|
cache_key: str, compile_options, backend
|
|
) -> tuple[xla_client.LoadedExecutable | None, int | None]:
|
|
"""Returns the cached executable and its compilation time if present, or None
|
|
otherwise.
|
|
"""
|
|
cache = _get_cache()
|
|
if cache is None:
|
|
logger.debug("get_executable_and_time: cache is disabled/not initialized")
|
|
return None, None
|
|
executable_and_time = cache.get(cache_key)
|
|
if not executable_and_time:
|
|
return None, None
|
|
|
|
executable_and_time = decompress_executable(executable_and_time)
|
|
serialized_executable, compile_time = extract_executable_and_time(
|
|
executable_and_time)
|
|
xla_executable_deserialized = backend.deserialize_executable(
|
|
serialized_executable, compile_options)
|
|
return xla_executable_deserialized, compile_time
|
|
|
|
|
|
def put_executable_and_time(
|
|
cache_key: str,
|
|
module_name: str,
|
|
executable: xla_client.LoadedExecutable,
|
|
backend,
|
|
compile_time: int
|
|
) -> None:
|
|
"""Adds the 'executable' and its compilation time to the cache, possibly
|
|
evicting older entries.
|
|
"""
|
|
cache = _get_cache()
|
|
if cache is None:
|
|
logger.debug("put_executable_and_time: cache is disabled/not initialized")
|
|
return
|
|
|
|
serialized_executable = backend.serialize_executable(executable)
|
|
executable_and_time = combine_executable_and_time(
|
|
serialized_executable, compile_time)
|
|
executable_and_time = compress_executable(executable_and_time)
|
|
|
|
min_entry_size = config.persistent_cache_min_entry_size_bytes.value
|
|
entry_size = len(executable_and_time)
|
|
if entry_size < min_entry_size:
|
|
logger.info(
|
|
"Not writing cache entry with key %s since its size (%d bytes) "
|
|
"is less than threshold (%d bytes)",
|
|
cache_key,
|
|
entry_size,
|
|
min_entry_size,
|
|
)
|
|
else:
|
|
logger.debug(
|
|
"Writing %s to persistent compilation cache with key %s.",
|
|
module_name,
|
|
cache_key
|
|
)
|
|
monitoring.record_event('/jax/compilation_cache/cache_misses')
|
|
cache.put(cache_key, executable_and_time)
|
|
|
|
|
|
def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
|
|
backend) -> str:
|
|
return cache_key.get(module, devices, compile_options, backend,
|
|
"zstandard" if zstandard is not None else "zlib")
|
|
|
|
|
|
def is_initialized() -> bool:
|
|
"""
|
|
Deprecated.
|
|
|
|
Return whether the cache is enabled. Initialization can be deferred, so
|
|
initialized status is not checked. The name is retained for backwards
|
|
compatibility.
|
|
"""
|
|
warnings.warn("is_initialized is deprecated; do not use",
|
|
DeprecationWarning, stacklevel=2)
|
|
return _is_cache_enabled()
|
|
|
|
|
|
def reset_cache() -> None:
|
|
"""Get back to pristine, uninitialized state."""
|
|
global _cache
|
|
global _cache_initialized
|
|
global _cache_used
|
|
logger.debug("Resetting cache at %s.",
|
|
_cache._path if _cache is not None else "<empty>")
|
|
_cache = None
|
|
with _cache_initialized_mutex:
|
|
_cache_initialized = False
|
|
_cache_used = False
|
|
|
|
|
|
def combine_executable_and_time(
|
|
serialized_executable: bytes, compile_time: int
|
|
) -> bytes:
|
|
"""Given the serialized executable and the compilation time, produce a cache
|
|
entry in the format shown below.
|
|
|
|
The cache entry is of the form:
|
|
Byte: 0 1 2 3 4 ...
|
|
Content: compilation time serialized executable
|
|
(big-endian int)
|
|
"""
|
|
return int(compile_time).to_bytes(4, byteorder='big') + serialized_executable
|
|
|
|
|
|
def extract_executable_and_time(
|
|
exectuable_and_time: bytes
|
|
) -> tuple[bytes, int]:
|
|
"""Given the cache entry in the format shown below, extract the serialized
|
|
executable and the compilation time.
|
|
|
|
The cache entry 'executable_and_time' is of the form:
|
|
Byte: 0 1 2 3 4 ...
|
|
Content: compilation time serialized executable
|
|
(big-endian int)
|
|
"""
|
|
return exectuable_and_time[4:], int.from_bytes(
|
|
exectuable_and_time[:4], byteorder='big')
|