Include compile time along with executable in cache entry.

In order to measure cache savings, we add compilation time to the cache entry along with the serialized executable. The compile time can then be retrieved on a cache hit.

Testing: updated tests.
PiperOrigin-RevId: 549439628
This commit is contained in:
jax authors 2023-07-19 15:17:04 -07:00
parent 5ae3ac28cd
commit 0c4c020716
3 changed files with 97 additions and 43 deletions

View File

@ -71,49 +71,59 @@ def initialize_cache(path):
logger.warning("Initialized persistent compilation cache at %s", path)
def get_executable(
def get_executable_and_time(
cache_key: str, compile_options, backend
) -> Optional[xla_client.LoadedExecutable]:
"""Returns the cached executable if present, or None otherwise."""
assert (
_cache is not None
), "initialize_cache must be called before you can call get_executable()"
serialized_executable = _cache.get(cache_key)
if not serialized_executable:
return None
) -> tuple[Optional[xla_client.LoadedExecutable], Optional[int]]:
"""Returns the cached executable and its compilation time if present, or None
otherwise.
"""
assert _cache is not None, (
"initialize_cache must be called before you can call"
" get_executable_and_time()"
)
executable_and_time = _cache.get(cache_key)
if not executable_and_time:
return None, None
if zstandard:
decompressor = zstandard.ZstdDecompressor()
serialized_executable = decompressor.decompress(serialized_executable)
executable_and_time = decompressor.decompress(executable_and_time)
else:
serialized_executable = zlib.decompress(serialized_executable)
executable_and_time = zlib.decompress(executable_and_time)
serialized_executable, compile_time = extract_executable_and_time(
executable_and_time)
xla_executable_deserialized = backend.deserialize_executable(
serialized_executable, compile_options
)
return xla_executable_deserialized
serialized_executable, compile_options)
return xla_executable_deserialized, compile_time
def put_executable(
def put_executable_and_time(
cache_key: str,
module_name: str,
executable: xla_client.LoadedExecutable,
backend,
compile_time: int
) -> None:
"""Adds 'executable' to the cache, possibly evicting older entries."""
assert (
_cache is not None
), "initialize_cache must be called before you can call put_executable()"
"""Adds the 'executable' and its compilation time to the cache repository,
possibly evicting older entries.
"""
assert _cache is not None, (
"initialize_cache must be called before you can call"
"put_executable_and_time()"
)
logger.info(
"Writing %s to persistent compilation cache with key %s.",
module_name,
cache_key,
)
serialized_executable = backend.serialize_executable(executable)
executable_and_time = combine_executable_and_time(
serialized_executable, compile_time)
if zstandard:
compressor = zstandard.ZstdCompressor()
serialized_executable = compressor.compress(serialized_executable)
executable_and_time = compressor.compress(executable_and_time)
else:
serialized_executable = zlib.compress(serialized_executable)
_cache.put(cache_key, serialized_executable)
executable_and_time = zlib.compress(executable_and_time)
_cache.put(cache_key, executable_and_time)
def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
@ -375,3 +385,32 @@ def reset_cache():
assert is_initialized()
logger.info("Resetting cache at %s.", _cache._path)
_cache = None
def combine_executable_and_time(
serialized_executable: bytes, compile_time: int
) -> bytes:
"""Given the serialized executable and the compilation time, produce a cache
entry in the format shown below.
The cache entry is of the form:
Byte: 0 1 2 3 4 ...
Content: compilation time serialized executable
(big-endian int)
"""
return int(compile_time).to_bytes(4, byteorder='big') + serialized_executable
def extract_executable_and_time(
exectuable_and_time: bytes
) -> tuple[bytes, int]:
"""Given the cache entry in the format shown below, extract the serialized
executable and the compilation time.
The cache entry 'executable_and_time' is of the form:
Byte: 0 1 2 3 4 ...
Content: compilation time serialized executable
(big-endian int)
"""
return exectuable_and_time[4:], int.from_bytes(
exectuable_and_time[:4], byteorder='big')

View File

@ -500,11 +500,13 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend)
cached_executable = _cache_read(module_name, cache_key, compile_options,
backend)
if cached_executable is not None:
executable, compile_time_retrieved = _cache_read(
module_name, cache_key, compile_options, backend)
if executable is not None:
# TODO(b/289098047): Will instrument a metric which uses the 'compile_time'
# to measure the savings due to the cache hit.
logger.info("Persistent compilation cache hit for '%s'", module_name)
return cached_executable
return executable
else:
start_time = time.monotonic()
executable = backend_compile(backend, computation,
@ -517,17 +519,20 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
def _cache_read(
module_name: str, cache_key: str, compile_options, backend
) -> Optional[xc.LoadedExecutable]:
"""Looks up `computation` in the persistent compilation cache."""
) -> tuple[Optional[xc.LoadedExecutable], Optional[int]]:
"""Looks up the `computation` and it's compilation time in the persistent
compilation cache repository.
"""
try:
return compilation_cache.get_executable(cache_key, compile_options, backend)
return compilation_cache.get_executable_and_time(
cache_key, compile_options, backend)
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
warnings.warn(
f"Error reading persistent compilation cache entry for "
f"'{module_name}': {type(ex).__name__}: {ex}")
return None
return None, None
def _cache_write(cache_key: str,
@ -535,7 +540,9 @@ def _cache_write(cache_key: str,
module_name: str,
backend: Backend, executable: xc.LoadedExecutable,
host_callbacks: list[Any]):
"""Writes `serialized_computation` to the persistent compilation cache."""
"""Writes the `serialized_computation` and its compilation time to the
persistent compilation cache repository.
"""
if host_callbacks:
logger.info(
"Not writing persistent cache entry for '%s' because it uses host "
@ -557,8 +564,8 @@ def _cache_write(cache_key: str,
compile_time_secs)
try:
compilation_cache.put_executable(cache_key, module_name, executable,
backend)
compilation_cache.put_executable_and_time(
cache_key, module_name, executable, backend, int(compile_time_secs))
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise

View File

@ -46,6 +46,8 @@ import numpy as np
config.parse_flags_with_absl()
FLAGS = config.FLAGS
FAKE_COMPILE_TIME = 10
@jtu.with_config(
jax_raise_persistent_cache_errors=True,
@ -272,9 +274,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
)
backend = xla_bridge.get_backend()
key = cc.get_cache_key(computation, devices, compile_options, backend)
self.assertEqual(
cc.get_executable(key, compile_options, backend), None
)
executable, compile_time = cc.get_executable_and_time(
key, compile_options, backend)
self.assertIsNone(executable)
self.assertIsNone(compile_time)
def test_diff_executables(self):
with tempfile.TemporaryDirectory() as tmpdir:
@ -287,11 +290,13 @@ class CompilationCacheTest(jtu.JaxTestCase):
backend = xla_bridge.get_backend()
executable1 = backend.compile(computation1, compile_options)
executable2 = backend.compile(computation2, compile_options)
cc.put_executable("key1", "computation1", executable1, backend)
cc.put_executable("key2", "computation2", executable2, backend)
cc.put_executable_and_time(
"key1", "computation1", executable1, backend, FAKE_COMPILE_TIME)
cc.put_executable_and_time(
"key2", "computation2", executable2, backend, FAKE_COMPILE_TIME)
self.assertNotEqual(
cc.get_executable("key1", compile_options, backend),
cc.get_executable("key2", compile_options, backend),
cc.get_executable_and_time("key1", compile_options, backend)[0],
cc.get_executable_and_time("key2", compile_options, backend)[0]
)
def test_put_executable(self):
@ -309,8 +314,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
backend = xla_bridge.get_backend()
executable = backend.compile(str(computation), compile_options)
key = cc.get_cache_key(computation, devices, compile_options, backend)
cc.put_executable(key, "alambda", executable, backend)
deserialized_executable = cc.get_executable(key, compile_options, backend)
cc.put_executable_and_time(
key, "alambda", executable, backend, FAKE_COMPILE_TIME)
executable_retrieved, compile_time_retrieved = cc.get_executable_and_time(
key, compile_options, backend)
inputs_to_executable = (
np.array(1, dtype=np.int32),
np.array(2, dtype=np.int32),
@ -319,9 +326,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
executable, inputs_to_executable, backend
)
actual = xla_client.execute_with_python_values(
deserialized_executable, inputs_to_executable, backend
executable_retrieved, inputs_to_executable, backend
)
self.assertEqual(expected, actual)
self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved)
def test_pmap(self):
with tempfile.TemporaryDirectory() as tmpdir: