mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Include compile time along with executable in cache entry.
In order to measure cache savings, we add compilation time to the cache entry along with the serialized executable. The compile time can then be retrieved on a cache hit. Testing: updated tests. PiperOrigin-RevId: 549439628
This commit is contained in:
parent
5ae3ac28cd
commit
0c4c020716
@ -71,49 +71,59 @@ def initialize_cache(path):
|
||||
logger.warning("Initialized persistent compilation cache at %s", path)
|
||||
|
||||
|
||||
def get_executable(
|
||||
def get_executable_and_time(
|
||||
cache_key: str, compile_options, backend
|
||||
) -> Optional[xla_client.LoadedExecutable]:
|
||||
"""Returns the cached executable if present, or None otherwise."""
|
||||
assert (
|
||||
_cache is not None
|
||||
), "initialize_cache must be called before you can call get_executable()"
|
||||
serialized_executable = _cache.get(cache_key)
|
||||
if not serialized_executable:
|
||||
return None
|
||||
) -> tuple[Optional[xla_client.LoadedExecutable], Optional[int]]:
|
||||
"""Returns the cached executable and its compilation time if present, or None
|
||||
otherwise.
|
||||
"""
|
||||
assert _cache is not None, (
|
||||
"initialize_cache must be called before you can call"
|
||||
" get_executable_and_time()"
|
||||
)
|
||||
executable_and_time = _cache.get(cache_key)
|
||||
if not executable_and_time:
|
||||
return None, None
|
||||
if zstandard:
|
||||
decompressor = zstandard.ZstdDecompressor()
|
||||
serialized_executable = decompressor.decompress(serialized_executable)
|
||||
executable_and_time = decompressor.decompress(executable_and_time)
|
||||
else:
|
||||
serialized_executable = zlib.decompress(serialized_executable)
|
||||
executable_and_time = zlib.decompress(executable_and_time)
|
||||
serialized_executable, compile_time = extract_executable_and_time(
|
||||
executable_and_time)
|
||||
xla_executable_deserialized = backend.deserialize_executable(
|
||||
serialized_executable, compile_options
|
||||
)
|
||||
return xla_executable_deserialized
|
||||
serialized_executable, compile_options)
|
||||
return xla_executable_deserialized, compile_time
|
||||
|
||||
|
||||
def put_executable(
|
||||
def put_executable_and_time(
|
||||
cache_key: str,
|
||||
module_name: str,
|
||||
executable: xla_client.LoadedExecutable,
|
||||
backend,
|
||||
compile_time: int
|
||||
) -> None:
|
||||
"""Adds 'executable' to the cache, possibly evicting older entries."""
|
||||
assert (
|
||||
_cache is not None
|
||||
), "initialize_cache must be called before you can call put_executable()"
|
||||
"""Adds the 'executable' and its compilation time to the cache repository,
|
||||
possibly evicting older entries.
|
||||
"""
|
||||
assert _cache is not None, (
|
||||
"initialize_cache must be called before you can call"
|
||||
"put_executable_and_time()"
|
||||
)
|
||||
logger.info(
|
||||
"Writing %s to persistent compilation cache with key %s.",
|
||||
module_name,
|
||||
cache_key,
|
||||
)
|
||||
serialized_executable = backend.serialize_executable(executable)
|
||||
executable_and_time = combine_executable_and_time(
|
||||
serialized_executable, compile_time)
|
||||
if zstandard:
|
||||
compressor = zstandard.ZstdCompressor()
|
||||
serialized_executable = compressor.compress(serialized_executable)
|
||||
executable_and_time = compressor.compress(executable_and_time)
|
||||
else:
|
||||
serialized_executable = zlib.compress(serialized_executable)
|
||||
_cache.put(cache_key, serialized_executable)
|
||||
executable_and_time = zlib.compress(executable_and_time)
|
||||
_cache.put(cache_key, executable_and_time)
|
||||
|
||||
|
||||
def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
|
||||
@ -375,3 +385,32 @@ def reset_cache():
|
||||
assert is_initialized()
|
||||
logger.info("Resetting cache at %s.", _cache._path)
|
||||
_cache = None
|
||||
|
||||
|
||||
def combine_executable_and_time(
|
||||
serialized_executable: bytes, compile_time: int
|
||||
) -> bytes:
|
||||
"""Given the serialized executable and the compilation time, produce a cache
|
||||
entry in the format shown below.
|
||||
|
||||
The cache entry is of the form:
|
||||
Byte: 0 1 2 3 4 ...
|
||||
Content: compilation time serialized executable
|
||||
(big-endian int)
|
||||
"""
|
||||
return int(compile_time).to_bytes(4, byteorder='big') + serialized_executable
|
||||
|
||||
|
||||
def extract_executable_and_time(
|
||||
exectuable_and_time: bytes
|
||||
) -> tuple[bytes, int]:
|
||||
"""Given the cache entry in the format shown below, extract the serialized
|
||||
executable and the compilation time.
|
||||
|
||||
The cache entry 'executable_and_time' is of the form:
|
||||
Byte: 0 1 2 3 4 ...
|
||||
Content: compilation time serialized executable
|
||||
(big-endian int)
|
||||
"""
|
||||
return exectuable_and_time[4:], int.from_bytes(
|
||||
exectuable_and_time[:4], byteorder='big')
|
||||
|
@ -500,11 +500,13 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
|
||||
cache_key = compilation_cache.get_cache_key(
|
||||
computation, devices, compile_options, backend)
|
||||
|
||||
cached_executable = _cache_read(module_name, cache_key, compile_options,
|
||||
backend)
|
||||
if cached_executable is not None:
|
||||
executable, compile_time_retrieved = _cache_read(
|
||||
module_name, cache_key, compile_options, backend)
|
||||
if executable is not None:
|
||||
# TODO(b/289098047): Will instrument a metric which uses the 'compile_time'
|
||||
# to measure the savings due to the cache hit.
|
||||
logger.info("Persistent compilation cache hit for '%s'", module_name)
|
||||
return cached_executable
|
||||
return executable
|
||||
else:
|
||||
start_time = time.monotonic()
|
||||
executable = backend_compile(backend, computation,
|
||||
@ -517,17 +519,20 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
|
||||
|
||||
def _cache_read(
|
||||
module_name: str, cache_key: str, compile_options, backend
|
||||
) -> Optional[xc.LoadedExecutable]:
|
||||
"""Looks up `computation` in the persistent compilation cache."""
|
||||
) -> tuple[Optional[xc.LoadedExecutable], Optional[int]]:
|
||||
"""Looks up the `computation` and it's compilation time in the persistent
|
||||
compilation cache repository.
|
||||
"""
|
||||
try:
|
||||
return compilation_cache.get_executable(cache_key, compile_options, backend)
|
||||
return compilation_cache.get_executable_and_time(
|
||||
cache_key, compile_options, backend)
|
||||
except Exception as ex:
|
||||
if config.jax_raise_persistent_cache_errors:
|
||||
raise
|
||||
warnings.warn(
|
||||
f"Error reading persistent compilation cache entry for "
|
||||
f"'{module_name}': {type(ex).__name__}: {ex}")
|
||||
return None
|
||||
return None, None
|
||||
|
||||
|
||||
def _cache_write(cache_key: str,
|
||||
@ -535,7 +540,9 @@ def _cache_write(cache_key: str,
|
||||
module_name: str,
|
||||
backend: Backend, executable: xc.LoadedExecutable,
|
||||
host_callbacks: list[Any]):
|
||||
"""Writes `serialized_computation` to the persistent compilation cache."""
|
||||
"""Writes the `serialized_computation` and its compilation time to the
|
||||
persistent compilation cache repository.
|
||||
"""
|
||||
if host_callbacks:
|
||||
logger.info(
|
||||
"Not writing persistent cache entry for '%s' because it uses host "
|
||||
@ -557,8 +564,8 @@ def _cache_write(cache_key: str,
|
||||
compile_time_secs)
|
||||
|
||||
try:
|
||||
compilation_cache.put_executable(cache_key, module_name, executable,
|
||||
backend)
|
||||
compilation_cache.put_executable_and_time(
|
||||
cache_key, module_name, executable, backend, int(compile_time_secs))
|
||||
except Exception as ex:
|
||||
if config.jax_raise_persistent_cache_errors:
|
||||
raise
|
||||
|
@ -46,6 +46,8 @@ import numpy as np
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
FAKE_COMPILE_TIME = 10
|
||||
|
||||
|
||||
@jtu.with_config(
|
||||
jax_raise_persistent_cache_errors=True,
|
||||
@ -272,9 +274,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
key = cc.get_cache_key(computation, devices, compile_options, backend)
|
||||
self.assertEqual(
|
||||
cc.get_executable(key, compile_options, backend), None
|
||||
)
|
||||
executable, compile_time = cc.get_executable_and_time(
|
||||
key, compile_options, backend)
|
||||
self.assertIsNone(executable)
|
||||
self.assertIsNone(compile_time)
|
||||
|
||||
def test_diff_executables(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@ -287,11 +290,13 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
backend = xla_bridge.get_backend()
|
||||
executable1 = backend.compile(computation1, compile_options)
|
||||
executable2 = backend.compile(computation2, compile_options)
|
||||
cc.put_executable("key1", "computation1", executable1, backend)
|
||||
cc.put_executable("key2", "computation2", executable2, backend)
|
||||
cc.put_executable_and_time(
|
||||
"key1", "computation1", executable1, backend, FAKE_COMPILE_TIME)
|
||||
cc.put_executable_and_time(
|
||||
"key2", "computation2", executable2, backend, FAKE_COMPILE_TIME)
|
||||
self.assertNotEqual(
|
||||
cc.get_executable("key1", compile_options, backend),
|
||||
cc.get_executable("key2", compile_options, backend),
|
||||
cc.get_executable_and_time("key1", compile_options, backend)[0],
|
||||
cc.get_executable_and_time("key2", compile_options, backend)[0]
|
||||
)
|
||||
|
||||
def test_put_executable(self):
|
||||
@ -309,8 +314,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
backend = xla_bridge.get_backend()
|
||||
executable = backend.compile(str(computation), compile_options)
|
||||
key = cc.get_cache_key(computation, devices, compile_options, backend)
|
||||
cc.put_executable(key, "alambda", executable, backend)
|
||||
deserialized_executable = cc.get_executable(key, compile_options, backend)
|
||||
cc.put_executable_and_time(
|
||||
key, "alambda", executable, backend, FAKE_COMPILE_TIME)
|
||||
executable_retrieved, compile_time_retrieved = cc.get_executable_and_time(
|
||||
key, compile_options, backend)
|
||||
inputs_to_executable = (
|
||||
np.array(1, dtype=np.int32),
|
||||
np.array(2, dtype=np.int32),
|
||||
@ -319,9 +326,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
executable, inputs_to_executable, backend
|
||||
)
|
||||
actual = xla_client.execute_with_python_values(
|
||||
deserialized_executable, inputs_to_executable, backend
|
||||
executable_retrieved, inputs_to_executable, backend
|
||||
)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved)
|
||||
|
||||
def test_pmap(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
Loading…
x
Reference in New Issue
Block a user