Instrument metrics to track cache hit rate of original JAX compilation cache.

Metrics:
1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of  the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true.
2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation.
3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository.

Created a context manager to register/unregister event listeners.

PiperOrigin-RevId: 561771262
This commit is contained in:
jax authors 2023-08-31 15:04:43 -07:00
parent 8a04dfd830
commit 80f6151110
2 changed files with 63 additions and 8 deletions

View File

@ -257,12 +257,14 @@ def compile_or_get_cached(
host_callbacks)
# TODO(b/293308239) Instrument a metric to track the adoption of the new cache
# key implementation once the enabling flag is added.
# key implementation after it is enabled.
global _cache_used
if not _cache_used:
_cache_used = True
monitoring.record_event('/jax/compilation_cache/tasks_using_original_cache')
monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache')
cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend,
jax_config.config.jax_use_original_compilation_cache_key_generation,
@ -276,15 +278,20 @@ def compile_or_get_cached(
if retrieved_executable is not None:
assert retrieved_compile_time is not None
logger.info("Persistent compilation cache hit for '%s'", module_name)
# TODO(b/293308239) Instrument metrics for new cache savings and cache hit
# rate after it is enabled.
if jax_config.config.jax_use_original_compilation_cache_key_generation:
# TODO(b/293308239) Remove metrics for the original cache after the new
# compilation cache key implementation is fully rolled out.
monitoring.record_event('/jax/compilation_cache/cache_hits_original')
monitoring.record_event_duration_secs(
"/jax/compilation_cache/original_compile_time_saved_sec",
retrieved_compile_time - cache_retrieval_time)
monitoring.record_event_duration_secs(
"/jax/compilation_cache/cache_retrieval_time_sec", cache_retrieval_time)
# TODO(b/293308239) Instrument a metric for new cache savings once the
# enabling flag is added.
# TODO(b/293308239) Remove the metric for original cache savings after the
# new compilation cache key implementation is fully rolled out.
monitoring.record_event_duration_secs(
"/jax/compilation_cache/original_compile_time_saved_sec",
retrieved_compile_time - cache_retrieval_time)
return retrieved_executable
else:
start_time = time.monotonic()
@ -342,6 +349,7 @@ def _cache_write(cache_key: str,
"'%s' took at least %.2f seconds to compile (%.2fs), writing "
"persistent cache entry", module_name, min_compile_time,
compile_time_secs)
monitoring.record_event('/jax/compilation_cache/cache_misses')
try:
compilation_cache.put_executable_and_time(

View File

@ -34,6 +34,7 @@ from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.config import persistent_cache_min_compile_time_secs
from jax._src.config import raise_persistent_cache_errors
from jax._src.config import use_original_compilation_cache_key_generation
from jax._src.lib import xla_client
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit
@ -329,5 +330,51 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertEqual(
_counts["/jax/compilation_cache/tasks_using_original_cache"], 1)
def test_compile_requests_use_cache_metric(self):
previous_counts = Counter(_counts)
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
jit(lambda x: x + 1)(1)
jit(lambda x: x + 2)(1)
jit(lambda x: x + 1)(1)
self.assertEqual(
_counts["/jax/compilation_cache/compile_requests_use_cache"]
- previous_counts["/jax/compilation_cache/compile_requests_use_cache"],
3)
def test_cache_misses_metric(self):
previous_counts = Counter(_counts)
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
2):
cc.initialize_cache(tmpdir)
# Mock time to create a long compilation time and make cache misses.
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
jit(lambda x: x + 1)(1)
jit(lambda x: x + 2)(1)
self.assertEqual(
_counts["/jax/compilation_cache/cache_misses"]
- previous_counts["/jax/compilation_cache/cache_misses"],
2)
def test_cache_hits_original_metric(self):
previous_counts = Counter(_counts)
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
2), use_original_compilation_cache_key_generation(True):
cc.initialize_cache(tmpdir)
# Mock time to create a long compilation time, cache saved.
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
jit(lambda x: x + 1)(1)
jit(lambda x: x + 1)(1)
self.assertEqual(
_counts["/jax/compilation_cache/cache_hits_original"]
- previous_counts["/jax/compilation_cache/cache_hits_original"],
1)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())