mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #22386 from rdyro:rdyro/explain_persistent_compilation
PiperOrigin-RevId: 652928829
This commit is contained in:
commit
4907c38742
@ -56,6 +56,7 @@ from jax._src.config import (
|
||||
debug_nans as debug_nans,
|
||||
debug_infs as debug_infs,
|
||||
log_compiles as log_compiles,
|
||||
explain_cache_misses as explain_cache_misses,
|
||||
default_device as default_device,
|
||||
default_matmul_precision as default_matmul_precision,
|
||||
default_prng_impl as default_prng_impl,
|
||||
|
@ -135,13 +135,20 @@ def _initialize_cache() -> None:
|
||||
_cache, path = cache_and_path
|
||||
logger.debug("Initialized persistent compilation cache at %s", path)
|
||||
|
||||
def is_persistent_cache_enabled() -> bool:
|
||||
return (config.compilation_cache_dir.value is not None
|
||||
and config.enable_compilation_cache.value)
|
||||
|
||||
|
||||
def _get_cache(backend) -> CacheInterface | None:
|
||||
# TODO(b/289098047): consider making this an API and changing the callers of
|
||||
# get_executable_and_time() and put_executable_and_time() to call get_cache()
|
||||
# and passing the result to them.
|
||||
if backend.runtime_type in _UNSUPPORTED_RUNTIMES:
|
||||
logger.debug("_get_cache: Unsupported runtime: %s", backend.runtime_type)
|
||||
log_priority = (logging.WARNING if is_persistent_cache_enabled()
|
||||
else logging.DEBUG)
|
||||
logger.log(log_priority, "_get_cache: Unsupported runtime: %s",
|
||||
backend.runtime_type)
|
||||
return None
|
||||
if _cache is None:
|
||||
_initialize_cache() # initialization is done at most once; see above
|
||||
@ -206,9 +213,15 @@ def put_executable_and_time(
|
||||
"""Adds the 'executable' and its compilation time to the cache, possibly
|
||||
evicting older entries.
|
||||
"""
|
||||
log_priority = (logging.WARNING
|
||||
if config.explain_cache_misses.value
|
||||
and is_persistent_cache_enabled()
|
||||
else logging.DEBUG)
|
||||
cache = _get_cache(backend)
|
||||
if cache is None:
|
||||
logger.debug("put_executable_and_time: cache is disabled/not initialized")
|
||||
logger.log(log_priority,
|
||||
"Not writing persistent cache entry with key %s"
|
||||
" since cache is disabled/not initialized", cache_key)
|
||||
return
|
||||
|
||||
serialized_executable = backend.serialize_executable(executable)
|
||||
@ -219,19 +232,14 @@ def put_executable_and_time(
|
||||
min_entry_size = config.persistent_cache_min_entry_size_bytes.value
|
||||
entry_size = len(executable_and_time)
|
||||
if entry_size < min_entry_size:
|
||||
logger.info(
|
||||
"Not writing cache entry with key %s since its size (%d bytes) "
|
||||
"is less than threshold (%d bytes)",
|
||||
cache_key,
|
||||
entry_size,
|
||||
min_entry_size,
|
||||
)
|
||||
logger.log(log_priority,
|
||||
"Not writing persistent cache entry with key %s since its size"
|
||||
" (%d bytes) is less than threshold (%d bytes)", cache_key, entry_size,
|
||||
min_entry_size)
|
||||
else:
|
||||
logger.info(
|
||||
"Writing %s to persistent compilation cache with key %s.",
|
||||
module_name,
|
||||
cache_key
|
||||
)
|
||||
logger.log(log_priority,
|
||||
"Writing %s to persistent compilation cache with key %s.",
|
||||
module_name, cache_key)
|
||||
monitoring.record_event('/jax/compilation_cache/cache_misses')
|
||||
cache.put(cache_key, executable_and_time)
|
||||
|
||||
|
@ -91,6 +91,23 @@ def use_detailed_logging(module: ir.Module) -> bool:
|
||||
return _walk_operations(module.operation, bound) < 0
|
||||
|
||||
|
||||
def log_persistent_cache_hit(module_name: str) -> None:
|
||||
hit_log_priority = (logging.WARNING if config.log_compiles.value
|
||||
else logging.DEBUG)
|
||||
logger.log(hit_log_priority, "Persistent compilation cache hit for '%s'",
|
||||
module_name)
|
||||
|
||||
|
||||
def log_persistent_cache_miss(module_name: str) -> None:
|
||||
miss_log_priority = (logging.WARNING
|
||||
if config.explain_cache_misses.value
|
||||
and compilation_cache.is_persistent_cache_enabled()
|
||||
else logging.DEBUG)
|
||||
# all caps to match the tracing cache "TRACING CACHE MISS"
|
||||
logger.log(miss_log_priority, "PERSISTENT COMPILATION CACHE MISS for '%s'",
|
||||
module_name)
|
||||
|
||||
|
||||
def get_compile_options(
|
||||
num_replicas: int,
|
||||
num_partitions: int,
|
||||
@ -330,7 +347,7 @@ def compile_or_get_cached(
|
||||
|
||||
if retrieved_executable is not None:
|
||||
assert retrieved_compile_time is not None
|
||||
logger.debug("Persistent compilation cache hit for '%s'", module_name)
|
||||
log_persistent_cache_hit(module_name)
|
||||
|
||||
monitoring.record_event('/jax/compilation_cache/cache_hits')
|
||||
monitoring.record_event_duration_secs(
|
||||
@ -349,6 +366,7 @@ def compile_or_get_cached(
|
||||
# them.
|
||||
and len(host_callbacks) == 0
|
||||
):
|
||||
log_persistent_cache_miss(module_name)
|
||||
return _compile_and_share_module(
|
||||
backend,
|
||||
computation,
|
||||
@ -364,6 +382,7 @@ def compile_or_get_cached(
|
||||
and is_multi_process
|
||||
and distributed.global_state.client is not None
|
||||
):
|
||||
log_persistent_cache_miss(module_name)
|
||||
return _compile_and_write_autotune_config(
|
||||
backend,
|
||||
computation,
|
||||
@ -375,6 +394,7 @@ def compile_or_get_cached(
|
||||
min_device_process_id
|
||||
)
|
||||
else:
|
||||
log_persistent_cache_miss(module_name)
|
||||
return _compile_and_write_cache(
|
||||
backend,
|
||||
computation,
|
||||
@ -655,19 +675,26 @@ def _cache_write(cache_key: str,
|
||||
"""
|
||||
# Only write cache entries from the first process. Otherwise we create
|
||||
# problems with contention for writes on some filesystems, e.g., GCS.
|
||||
log_priority = (logging.WARNING
|
||||
if config.explain_cache_misses.value
|
||||
and compilation_cache.is_persistent_cache_enabled()
|
||||
else logging.DEBUG)
|
||||
if distributed.global_state.process_id != 0:
|
||||
logger.debug("Not writing persistent cache entry since process_id != 0")
|
||||
logger.log(log_priority,
|
||||
"Not writing persistent cache entry since process_id != 0")
|
||||
return
|
||||
|
||||
if host_callbacks:
|
||||
logger.debug(
|
||||
logger.log(
|
||||
log_priority,
|
||||
"Not writing persistent cache entry for '%s' because it uses host "
|
||||
"callbacks (e.g. from jax.debug.print or breakpoint)", module_name)
|
||||
return
|
||||
|
||||
min_compile_time = config.persistent_cache_min_compile_time_secs.value
|
||||
if compile_time_secs < min_compile_time:
|
||||
logger.debug(
|
||||
logger.log(
|
||||
log_priority,
|
||||
"Not writing persistent cache entry for '%s' because it took < %.2f "
|
||||
"seconds to compile (%.2fs)", module_name, min_compile_time,
|
||||
compile_time_secs)
|
||||
|
@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from functools import partial
|
||||
import logging
|
||||
import math
|
||||
import platform
|
||||
import unittest
|
||||
@ -28,6 +29,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import pmap
|
||||
from jax._src import compilation_cache as cc
|
||||
from jax._src import compiler
|
||||
@ -62,6 +64,11 @@ def tearDownModule():
|
||||
def increment_event_count(event):
|
||||
_counts[event] += 1
|
||||
|
||||
def msg_exists_in_logs(msg: str, records: list[logging.LogRecord],
|
||||
level: int | None = None) -> bool:
|
||||
return any(msg in record.getMessage() for record in records
|
||||
if level is None or level == record.levelno)
|
||||
|
||||
|
||||
class InMemoryCache(CacheInterface):
|
||||
"""An in-memory cache for testing purposes."""
|
||||
@ -415,6 +422,109 @@ class CompilationCacheTest(CompilationCacheTestCase):
|
||||
- previous_counts["/jax/compilation_cache/cache_hits"],
|
||||
1)
|
||||
|
||||
def test_persistent_cache_hit_logging(self):
|
||||
jit(lambda x: x + 1)(1)
|
||||
msg = "Persistent compilation cache hit"
|
||||
|
||||
# cache hits with `log_compiles` on should be in WARNING when enabled
|
||||
with config.log_compiles(True):
|
||||
with self.assertLogs(level="WARNING") as log:
|
||||
jit(lambda x: x + 1)(1)
|
||||
self.assertTrue(msg_exists_in_logs(msg, log.records, logging.WARNING))
|
||||
|
||||
def test_persistent_cache_hit_no_logging(self):
|
||||
jit(lambda x: x + 1)(1)
|
||||
msg = "Persistent compilation cache hit"
|
||||
|
||||
# cache hits with `log_compiles` off should NOT be in WARNING
|
||||
with config.log_compiles(False):
|
||||
with self.assertLogs(level="DEBUG") as log:
|
||||
jit(lambda x: x + 1)(1)
|
||||
self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING))
|
||||
|
||||
def test_persistent_cache_miss_logging_with_explain(self):
|
||||
with (config.explain_cache_misses(True),
|
||||
config.compilation_cache_dir("jax-cache")):
|
||||
|
||||
# omitting writing to cache because compilation is too fast
|
||||
pure_fn = lambda a: jnp.array(1, dtype=jnp.int32)
|
||||
with config.persistent_cache_min_compile_time_secs(1e5):
|
||||
with self.assertLogs(level="DEBUG") as log:
|
||||
jit(lambda x: x +
|
||||
jax.pure_callback(pure_fn, jax.ShapeDtypeStruct((), jnp.int32), x)
|
||||
)(1)
|
||||
msg1 = "Not writing persistent cache entry"
|
||||
msg2 = "because it uses host callbacks"
|
||||
self.assertTrue(msg_exists_in_logs(msg1, log.records, logging.WARNING))
|
||||
self.assertTrue(msg_exists_in_logs(msg2, log.records, logging.WARNING))
|
||||
|
||||
# omitting writing to cache because host callback is present
|
||||
pure_fn = lambda a: jnp.array(1, dtype=jnp.int32)
|
||||
with self.assertLogs(level="DEBUG") as log:
|
||||
jit(lambda x: x +
|
||||
jax.pure_callback(pure_fn, jax.ShapeDtypeStruct((), jnp.int32), x)
|
||||
)(1)
|
||||
msg1 = "Not writing persistent cache entry"
|
||||
msg2 = "because it uses host callbacks"
|
||||
self.assertTrue(msg_exists_in_logs(msg1, log.records, logging.WARNING))
|
||||
self.assertTrue(msg_exists_in_logs(msg2, log.records, logging.WARNING))
|
||||
|
||||
# omitting writing to cache because binary is too small
|
||||
with config.persistent_cache_min_entry_size_bytes(int(1e9)):
|
||||
with self.assertLogs(level="DEBUG") as log:
|
||||
jit(lambda x: x + 2)(1)
|
||||
msg1 = "Not writing persistent cache entry"
|
||||
msg2 = "is less than threshold"
|
||||
self.assertTrue(msg_exists_in_logs(msg1, log.records, logging.WARNING))
|
||||
self.assertTrue(msg_exists_in_logs(msg2, log.records, logging.WARNING))
|
||||
|
||||
# successful cache write
|
||||
with config.persistent_cache_min_entry_size_bytes(1):
|
||||
with self.assertLogs(level="DEBUG") as log:
|
||||
jit(lambda x: x ** 2)(1)
|
||||
msg = "to persistent compilation cache with key"
|
||||
self.assertTrue(msg_exists_in_logs(msg, log.records, logging.WARNING))
|
||||
|
||||
def test_persistent_cache_miss_logging_with_no_explain(self):
|
||||
# test that cache failure messages do not get logged in WARNING
|
||||
with (config.explain_cache_misses(False),
|
||||
config.compilation_cache_dir("jax-cache")):
|
||||
# omitting writing to cache because compilation is too fast
|
||||
with config.persistent_cache_min_compile_time_secs(1e3):
|
||||
with self.assertLogs(level="DEBUG") as log:
|
||||
jit(lambda x: x + 1)(1)
|
||||
msg1, msg2 = "Not writing persistent cache entry", "because it took <"
|
||||
self.assertFalse(msg_exists_in_logs(msg1, log.records, logging.WARNING))
|
||||
self.assertFalse(msg_exists_in_logs(msg2, log.records, logging.WARNING))
|
||||
|
||||
# omitting writing to cache because host callback is present
|
||||
pure_fn = lambda a: jnp.array(1, dtype=jnp.int32)
|
||||
with self.assertLogs(level="DEBUG") as log:
|
||||
jit(lambda x: x +
|
||||
jax.pure_callback(pure_fn, jax.ShapeDtypeStruct((), jnp.int32), x)
|
||||
)(1)
|
||||
msg1 = "Not writing persistent cache entry"
|
||||
msg2 = "because it uses host callbacks"
|
||||
self.assertFalse(msg_exists_in_logs(msg1, log.records, logging.WARNING))
|
||||
self.assertFalse(msg_exists_in_logs(msg2, log.records, logging.WARNING))
|
||||
|
||||
# omitting writing to cache because binary is too small
|
||||
with config.persistent_cache_min_entry_size_bytes(int(1e9)):
|
||||
with self.assertLogs(level="DEBUG") as log:
|
||||
jit(lambda x: x + 2)(1)
|
||||
msg1 = "Not writing persistent cache entry"
|
||||
msg2 = "is less than threshold"
|
||||
self.assertFalse(msg_exists_in_logs(msg1, log.records, logging.WARNING))
|
||||
self.assertFalse(msg_exists_in_logs(msg2, log.records, logging.WARNING))
|
||||
|
||||
# successful cache write
|
||||
with config.persistent_cache_min_entry_size_bytes(1):
|
||||
with self.assertLogs(level="DEBUG") as log:
|
||||
jit(lambda x: x ** 2)(1)
|
||||
msg = "to persistent compilation cache with key"
|
||||
self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING))
|
||||
|
||||
|
||||
@parameterized.parameters(0, 1)
|
||||
def test_cache_write_with_process_restriction(self, process_id):
|
||||
with (
|
||||
|
Loading…
x
Reference in New Issue
Block a user