From 0c4c0207161c59784123f07532ac8d25c314e49e Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Jul 2023 15:17:04 -0700 Subject: [PATCH] Include compile time along with executable in cache entry. In order to measure cache savings, we add compilation time to the cache entry along with the serialized executable. The compile time can then be retrieved on a cache hit. Testing: updated tests. PiperOrigin-RevId: 549439628 --- jax/_src/compilation_cache.py | 83 ++++++++++++++++++++++++--------- jax/_src/dispatch.py | 29 +++++++----- tests/compilation_cache_test.py | 28 +++++++---- 3 files changed, 97 insertions(+), 43 deletions(-) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 87f0fc3ee..d11669e72 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -71,49 +71,59 @@ def initialize_cache(path): logger.warning("Initialized persistent compilation cache at %s", path) -def get_executable( +def get_executable_and_time( cache_key: str, compile_options, backend -) -> Optional[xla_client.LoadedExecutable]: - """Returns the cached executable if present, or None otherwise.""" - assert ( - _cache is not None - ), "initialize_cache must be called before you can call get_executable()" - serialized_executable = _cache.get(cache_key) - if not serialized_executable: - return None +) -> tuple[Optional[xla_client.LoadedExecutable], Optional[int]]: + """Returns the cached executable and its compilation time if present, or None + otherwise. + """ + assert _cache is not None, ( + "initialize_cache must be called before you can call" + " get_executable_and_time()" + ) + executable_and_time = _cache.get(cache_key) + if not executable_and_time: + return None, None if zstandard: decompressor = zstandard.ZstdDecompressor() - serialized_executable = decompressor.decompress(serialized_executable) + executable_and_time = decompressor.decompress(executable_and_time) else: - serialized_executable = zlib.decompress(serialized_executable) + executable_and_time = zlib.decompress(executable_and_time) + serialized_executable, compile_time = extract_executable_and_time( + executable_and_time) xla_executable_deserialized = backend.deserialize_executable( - serialized_executable, compile_options - ) - return xla_executable_deserialized + serialized_executable, compile_options) + return xla_executable_deserialized, compile_time -def put_executable( +def put_executable_and_time( cache_key: str, module_name: str, executable: xla_client.LoadedExecutable, backend, + compile_time: int ) -> None: - """Adds 'executable' to the cache, possibly evicting older entries.""" - assert ( - _cache is not None - ), "initialize_cache must be called before you can call put_executable()" + """Adds the 'executable' and its compilation time to the cache repository, + possibly evicting older entries. + """ + assert _cache is not None, ( + "initialize_cache must be called before you can call" + "put_executable_and_time()" + ) logger.info( "Writing %s to persistent compilation cache with key %s.", module_name, cache_key, ) serialized_executable = backend.serialize_executable(executable) + executable_and_time = combine_executable_and_time( + serialized_executable, compile_time) if zstandard: compressor = zstandard.ZstdCompressor() - serialized_executable = compressor.compress(serialized_executable) + executable_and_time = compressor.compress(executable_and_time) else: - serialized_executable = zlib.compress(serialized_executable) - _cache.put(cache_key, serialized_executable) + executable_and_time = zlib.compress(executable_and_time) + _cache.put(cache_key, executable_and_time) def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn): @@ -375,3 +385,32 @@ def reset_cache(): assert is_initialized() logger.info("Resetting cache at %s.", _cache._path) _cache = None + + +def combine_executable_and_time( + serialized_executable: bytes, compile_time: int +) -> bytes: + """Given the serialized executable and the compilation time, produce a cache + entry in the format shown below. + + The cache entry is of the form: + Byte: 0 1 2 3 4 ... + Content: compilation time serialized executable + (big-endian int) + """ + return int(compile_time).to_bytes(4, byteorder='big') + serialized_executable + + +def extract_executable_and_time( + exectuable_and_time: bytes +) -> tuple[bytes, int]: + """Given the cache entry in the format shown below, extract the serialized + executable and the compilation time. + + The cache entry 'executable_and_time' is of the form: + Byte: 0 1 2 3 4 ... + Content: compilation time serialized executable + (big-endian int) + """ + return exectuable_and_time[4:], int.from_bytes( + exectuable_and_time[:4], byteorder='big') diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 6a1363c96..fc6f27d7e 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -500,11 +500,13 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray, cache_key = compilation_cache.get_cache_key( computation, devices, compile_options, backend) - cached_executable = _cache_read(module_name, cache_key, compile_options, - backend) - if cached_executable is not None: + executable, compile_time_retrieved = _cache_read( + module_name, cache_key, compile_options, backend) + if executable is not None: + # TODO(b/289098047): Will instrument a metric which uses the 'compile_time' + # to measure the savings due to the cache hit. logger.info("Persistent compilation cache hit for '%s'", module_name) - return cached_executable + return executable else: start_time = time.monotonic() executable = backend_compile(backend, computation, @@ -517,17 +519,20 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray, def _cache_read( module_name: str, cache_key: str, compile_options, backend -) -> Optional[xc.LoadedExecutable]: - """Looks up `computation` in the persistent compilation cache.""" +) -> tuple[Optional[xc.LoadedExecutable], Optional[int]]: + """Looks up the `computation` and it's compilation time in the persistent + compilation cache repository. + """ try: - return compilation_cache.get_executable(cache_key, compile_options, backend) + return compilation_cache.get_executable_and_time( + cache_key, compile_options, backend) except Exception as ex: if config.jax_raise_persistent_cache_errors: raise warnings.warn( f"Error reading persistent compilation cache entry for " f"'{module_name}': {type(ex).__name__}: {ex}") - return None + return None, None def _cache_write(cache_key: str, @@ -535,7 +540,9 @@ def _cache_write(cache_key: str, module_name: str, backend: Backend, executable: xc.LoadedExecutable, host_callbacks: list[Any]): - """Writes `serialized_computation` to the persistent compilation cache.""" + """Writes the `serialized_computation` and its compilation time to the + persistent compilation cache repository. + """ if host_callbacks: logger.info( "Not writing persistent cache entry for '%s' because it uses host " @@ -557,8 +564,8 @@ def _cache_write(cache_key: str, compile_time_secs) try: - compilation_cache.put_executable(cache_key, module_name, executable, - backend) + compilation_cache.put_executable_and_time( + cache_key, module_name, executable, backend, int(compile_time_secs)) except Exception as ex: if config.jax_raise_persistent_cache_errors: raise diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 25b6e752c..57f109ae3 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -46,6 +46,8 @@ import numpy as np config.parse_flags_with_absl() FLAGS = config.FLAGS +FAKE_COMPILE_TIME = 10 + @jtu.with_config( jax_raise_persistent_cache_errors=True, @@ -272,9 +274,10 @@ class CompilationCacheTest(jtu.JaxTestCase): ) backend = xla_bridge.get_backend() key = cc.get_cache_key(computation, devices, compile_options, backend) - self.assertEqual( - cc.get_executable(key, compile_options, backend), None - ) + executable, compile_time = cc.get_executable_and_time( + key, compile_options, backend) + self.assertIsNone(executable) + self.assertIsNone(compile_time) def test_diff_executables(self): with tempfile.TemporaryDirectory() as tmpdir: @@ -287,11 +290,13 @@ class CompilationCacheTest(jtu.JaxTestCase): backend = xla_bridge.get_backend() executable1 = backend.compile(computation1, compile_options) executable2 = backend.compile(computation2, compile_options) - cc.put_executable("key1", "computation1", executable1, backend) - cc.put_executable("key2", "computation2", executable2, backend) + cc.put_executable_and_time( + "key1", "computation1", executable1, backend, FAKE_COMPILE_TIME) + cc.put_executable_and_time( + "key2", "computation2", executable2, backend, FAKE_COMPILE_TIME) self.assertNotEqual( - cc.get_executable("key1", compile_options, backend), - cc.get_executable("key2", compile_options, backend), + cc.get_executable_and_time("key1", compile_options, backend)[0], + cc.get_executable_and_time("key2", compile_options, backend)[0] ) def test_put_executable(self): @@ -309,8 +314,10 @@ class CompilationCacheTest(jtu.JaxTestCase): backend = xla_bridge.get_backend() executable = backend.compile(str(computation), compile_options) key = cc.get_cache_key(computation, devices, compile_options, backend) - cc.put_executable(key, "alambda", executable, backend) - deserialized_executable = cc.get_executable(key, compile_options, backend) + cc.put_executable_and_time( + key, "alambda", executable, backend, FAKE_COMPILE_TIME) + executable_retrieved, compile_time_retrieved = cc.get_executable_and_time( + key, compile_options, backend) inputs_to_executable = ( np.array(1, dtype=np.int32), np.array(2, dtype=np.int32), @@ -319,9 +326,10 @@ class CompilationCacheTest(jtu.JaxTestCase): executable, inputs_to_executable, backend ) actual = xla_client.execute_with_python_values( - deserialized_executable, inputs_to_executable, backend + executable_retrieved, inputs_to_executable, backend ) self.assertEqual(expected, actual) + self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved) def test_pmap(self): with tempfile.TemporaryDirectory() as tmpdir: