Merge pull request #22386 from rdyro:rdyro/explain_persistent_compilation

PiperOrigin-RevId: 652928829
This commit is contained in:
jax authors 2024-07-16 12:00:37 -07:00
commit 4907c38742
4 changed files with 164 additions and 18 deletions

View File

@ -56,6 +56,7 @@ from jax._src.config import (
debug_nans as debug_nans,
debug_infs as debug_infs,
log_compiles as log_compiles,
explain_cache_misses as explain_cache_misses,
default_device as default_device,
default_matmul_precision as default_matmul_precision,
default_prng_impl as default_prng_impl,

View File

@ -135,13 +135,20 @@ def _initialize_cache() -> None:
_cache, path = cache_and_path
logger.debug("Initialized persistent compilation cache at %s", path)
def is_persistent_cache_enabled() -> bool:
return (config.compilation_cache_dir.value is not None
and config.enable_compilation_cache.value)
def _get_cache(backend) -> CacheInterface | None:
# TODO(b/289098047): consider making this an API and changing the callers of
# get_executable_and_time() and put_executable_and_time() to call get_cache()
# and passing the result to them.
if backend.runtime_type in _UNSUPPORTED_RUNTIMES:
logger.debug("_get_cache: Unsupported runtime: %s", backend.runtime_type)
log_priority = (logging.WARNING if is_persistent_cache_enabled()
else logging.DEBUG)
logger.log(log_priority, "_get_cache: Unsupported runtime: %s",
backend.runtime_type)
return None
if _cache is None:
_initialize_cache() # initialization is done at most once; see above
@ -206,9 +213,15 @@ def put_executable_and_time(
"""Adds the 'executable' and its compilation time to the cache, possibly
evicting older entries.
"""
log_priority = (logging.WARNING
if config.explain_cache_misses.value
and is_persistent_cache_enabled()
else logging.DEBUG)
cache = _get_cache(backend)
if cache is None:
logger.debug("put_executable_and_time: cache is disabled/not initialized")
logger.log(log_priority,
"Not writing persistent cache entry with key %s"
" since cache is disabled/not initialized", cache_key)
return
serialized_executable = backend.serialize_executable(executable)
@ -219,19 +232,14 @@ def put_executable_and_time(
min_entry_size = config.persistent_cache_min_entry_size_bytes.value
entry_size = len(executable_and_time)
if entry_size < min_entry_size:
logger.info(
"Not writing cache entry with key %s since its size (%d bytes) "
"is less than threshold (%d bytes)",
cache_key,
entry_size,
min_entry_size,
)
logger.log(log_priority,
"Not writing persistent cache entry with key %s since its size"
" (%d bytes) is less than threshold (%d bytes)", cache_key, entry_size,
min_entry_size)
else:
logger.info(
"Writing %s to persistent compilation cache with key %s.",
module_name,
cache_key
)
logger.log(log_priority,
"Writing %s to persistent compilation cache with key %s.",
module_name, cache_key)
monitoring.record_event('/jax/compilation_cache/cache_misses')
cache.put(cache_key, executable_and_time)

View File

@ -91,6 +91,23 @@ def use_detailed_logging(module: ir.Module) -> bool:
return _walk_operations(module.operation, bound) < 0
def log_persistent_cache_hit(module_name: str) -> None:
hit_log_priority = (logging.WARNING if config.log_compiles.value
else logging.DEBUG)
logger.log(hit_log_priority, "Persistent compilation cache hit for '%s'",
module_name)
def log_persistent_cache_miss(module_name: str) -> None:
miss_log_priority = (logging.WARNING
if config.explain_cache_misses.value
and compilation_cache.is_persistent_cache_enabled()
else logging.DEBUG)
# all caps to match the tracing cache "TRACING CACHE MISS"
logger.log(miss_log_priority, "PERSISTENT COMPILATION CACHE MISS for '%s'",
module_name)
def get_compile_options(
num_replicas: int,
num_partitions: int,
@ -330,7 +347,7 @@ def compile_or_get_cached(
if retrieved_executable is not None:
assert retrieved_compile_time is not None
logger.debug("Persistent compilation cache hit for '%s'", module_name)
log_persistent_cache_hit(module_name)
monitoring.record_event('/jax/compilation_cache/cache_hits')
monitoring.record_event_duration_secs(
@ -349,6 +366,7 @@ def compile_or_get_cached(
# them.
and len(host_callbacks) == 0
):
log_persistent_cache_miss(module_name)
return _compile_and_share_module(
backend,
computation,
@ -364,6 +382,7 @@ def compile_or_get_cached(
and is_multi_process
and distributed.global_state.client is not None
):
log_persistent_cache_miss(module_name)
return _compile_and_write_autotune_config(
backend,
computation,
@ -375,6 +394,7 @@ def compile_or_get_cached(
min_device_process_id
)
else:
log_persistent_cache_miss(module_name)
return _compile_and_write_cache(
backend,
computation,
@ -655,19 +675,26 @@ def _cache_write(cache_key: str,
"""
# Only write cache entries from the first process. Otherwise we create
# problems with contention for writes on some filesystems, e.g., GCS.
log_priority = (logging.WARNING
if config.explain_cache_misses.value
and compilation_cache.is_persistent_cache_enabled()
else logging.DEBUG)
if distributed.global_state.process_id != 0:
logger.debug("Not writing persistent cache entry since process_id != 0")
logger.log(log_priority,
"Not writing persistent cache entry since process_id != 0")
return
if host_callbacks:
logger.debug(
logger.log(
log_priority,
"Not writing persistent cache entry for '%s' because it uses host "
"callbacks (e.g. from jax.debug.print or breakpoint)", module_name)
return
min_compile_time = config.persistent_cache_min_compile_time_secs.value
if compile_time_secs < min_compile_time:
logger.debug(
logger.log(
log_priority,
"Not writing persistent cache entry for '%s' because it took < %.2f "
"seconds to compile (%.2fs)", module_name, min_compile_time,
compile_time_secs)

View File

@ -16,6 +16,7 @@ from __future__ import annotations
from collections import Counter
from functools import partial
import logging
import math
import platform
import unittest
@ -28,6 +29,7 @@ from absl.testing import parameterized
import jax
from jax import jit
from jax import lax
from jax import numpy as jnp
from jax import pmap
from jax._src import compilation_cache as cc
from jax._src import compiler
@ -62,6 +64,11 @@ def tearDownModule():
def increment_event_count(event):
_counts[event] += 1
def msg_exists_in_logs(msg: str, records: list[logging.LogRecord],
level: int | None = None) -> bool:
return any(msg in record.getMessage() for record in records
if level is None or level == record.levelno)
class InMemoryCache(CacheInterface):
"""An in-memory cache for testing purposes."""
@ -415,6 +422,109 @@ class CompilationCacheTest(CompilationCacheTestCase):
- previous_counts["/jax/compilation_cache/cache_hits"],
1)
def test_persistent_cache_hit_logging(self):
jit(lambda x: x + 1)(1)
msg = "Persistent compilation cache hit"
# cache hits with `log_compiles` on should be in WARNING when enabled
with config.log_compiles(True):
with self.assertLogs(level="WARNING") as log:
jit(lambda x: x + 1)(1)
self.assertTrue(msg_exists_in_logs(msg, log.records, logging.WARNING))
def test_persistent_cache_hit_no_logging(self):
jit(lambda x: x + 1)(1)
msg = "Persistent compilation cache hit"
# cache hits with `log_compiles` off should NOT be in WARNING
with config.log_compiles(False):
with self.assertLogs(level="DEBUG") as log:
jit(lambda x: x + 1)(1)
self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING))
def test_persistent_cache_miss_logging_with_explain(self):
with (config.explain_cache_misses(True),
config.compilation_cache_dir("jax-cache")):
# omitting writing to cache because compilation is too fast
pure_fn = lambda a: jnp.array(1, dtype=jnp.int32)
with config.persistent_cache_min_compile_time_secs(1e5):
with self.assertLogs(level="DEBUG") as log:
jit(lambda x: x +
jax.pure_callback(pure_fn, jax.ShapeDtypeStruct((), jnp.int32), x)
)(1)
msg1 = "Not writing persistent cache entry"
msg2 = "because it uses host callbacks"
self.assertTrue(msg_exists_in_logs(msg1, log.records, logging.WARNING))
self.assertTrue(msg_exists_in_logs(msg2, log.records, logging.WARNING))
# omitting writing to cache because host callback is present
pure_fn = lambda a: jnp.array(1, dtype=jnp.int32)
with self.assertLogs(level="DEBUG") as log:
jit(lambda x: x +
jax.pure_callback(pure_fn, jax.ShapeDtypeStruct((), jnp.int32), x)
)(1)
msg1 = "Not writing persistent cache entry"
msg2 = "because it uses host callbacks"
self.assertTrue(msg_exists_in_logs(msg1, log.records, logging.WARNING))
self.assertTrue(msg_exists_in_logs(msg2, log.records, logging.WARNING))
# omitting writing to cache because binary is too small
with config.persistent_cache_min_entry_size_bytes(int(1e9)):
with self.assertLogs(level="DEBUG") as log:
jit(lambda x: x + 2)(1)
msg1 = "Not writing persistent cache entry"
msg2 = "is less than threshold"
self.assertTrue(msg_exists_in_logs(msg1, log.records, logging.WARNING))
self.assertTrue(msg_exists_in_logs(msg2, log.records, logging.WARNING))
# successful cache write
with config.persistent_cache_min_entry_size_bytes(1):
with self.assertLogs(level="DEBUG") as log:
jit(lambda x: x ** 2)(1)
msg = "to persistent compilation cache with key"
self.assertTrue(msg_exists_in_logs(msg, log.records, logging.WARNING))
def test_persistent_cache_miss_logging_with_no_explain(self):
# test that cache failure messages do not get logged in WARNING
with (config.explain_cache_misses(False),
config.compilation_cache_dir("jax-cache")):
# omitting writing to cache because compilation is too fast
with config.persistent_cache_min_compile_time_secs(1e3):
with self.assertLogs(level="DEBUG") as log:
jit(lambda x: x + 1)(1)
msg1, msg2 = "Not writing persistent cache entry", "because it took <"
self.assertFalse(msg_exists_in_logs(msg1, log.records, logging.WARNING))
self.assertFalse(msg_exists_in_logs(msg2, log.records, logging.WARNING))
# omitting writing to cache because host callback is present
pure_fn = lambda a: jnp.array(1, dtype=jnp.int32)
with self.assertLogs(level="DEBUG") as log:
jit(lambda x: x +
jax.pure_callback(pure_fn, jax.ShapeDtypeStruct((), jnp.int32), x)
)(1)
msg1 = "Not writing persistent cache entry"
msg2 = "because it uses host callbacks"
self.assertFalse(msg_exists_in_logs(msg1, log.records, logging.WARNING))
self.assertFalse(msg_exists_in_logs(msg2, log.records, logging.WARNING))
# omitting writing to cache because binary is too small
with config.persistent_cache_min_entry_size_bytes(int(1e9)):
with self.assertLogs(level="DEBUG") as log:
jit(lambda x: x + 2)(1)
msg1 = "Not writing persistent cache entry"
msg2 = "is less than threshold"
self.assertFalse(msg_exists_in_logs(msg1, log.records, logging.WARNING))
self.assertFalse(msg_exists_in_logs(msg2, log.records, logging.WARNING))
# successful cache write
with config.persistent_cache_min_entry_size_bytes(1):
with self.assertLogs(level="DEBUG") as log:
jit(lambda x: x ** 2)(1)
msg = "to persistent compilation cache with key"
self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING))
@parameterized.parameters(0, 1)
def test_cache_write_with_process_restriction(self, process_id):
with (