mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Instrument metrics to track cache hit rate of original JAX compilation cache.
Metrics: 1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true. 2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation. 3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository. Created a context manager to register/unregister event listeners. PiperOrigin-RevId: 561771262
This commit is contained in:
parent
8a04dfd830
commit
80f6151110
@ -257,12 +257,14 @@ def compile_or_get_cached(
|
||||
host_callbacks)
|
||||
|
||||
# TODO(b/293308239) Instrument a metric to track the adoption of the new cache
|
||||
# key implementation once the enabling flag is added.
|
||||
# key implementation after it is enabled.
|
||||
global _cache_used
|
||||
if not _cache_used:
|
||||
_cache_used = True
|
||||
monitoring.record_event('/jax/compilation_cache/tasks_using_original_cache')
|
||||
|
||||
monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache')
|
||||
|
||||
cache_key = compilation_cache.get_cache_key(
|
||||
computation, devices, compile_options, backend,
|
||||
jax_config.config.jax_use_original_compilation_cache_key_generation,
|
||||
@ -276,15 +278,20 @@ def compile_or_get_cached(
|
||||
if retrieved_executable is not None:
|
||||
assert retrieved_compile_time is not None
|
||||
logger.info("Persistent compilation cache hit for '%s'", module_name)
|
||||
|
||||
# TODO(b/293308239) Instrument metrics for new cache savings and cache hit
|
||||
# rate after it is enabled.
|
||||
if jax_config.config.jax_use_original_compilation_cache_key_generation:
|
||||
# TODO(b/293308239) Remove metrics for the original cache after the new
|
||||
# compilation cache key implementation is fully rolled out.
|
||||
monitoring.record_event('/jax/compilation_cache/cache_hits_original')
|
||||
monitoring.record_event_duration_secs(
|
||||
"/jax/compilation_cache/original_compile_time_saved_sec",
|
||||
retrieved_compile_time - cache_retrieval_time)
|
||||
|
||||
monitoring.record_event_duration_secs(
|
||||
"/jax/compilation_cache/cache_retrieval_time_sec", cache_retrieval_time)
|
||||
# TODO(b/293308239) Instrument a metric for new cache savings once the
|
||||
# enabling flag is added.
|
||||
# TODO(b/293308239) Remove the metric for original cache savings after the
|
||||
# new compilation cache key implementation is fully rolled out.
|
||||
monitoring.record_event_duration_secs(
|
||||
"/jax/compilation_cache/original_compile_time_saved_sec",
|
||||
retrieved_compile_time - cache_retrieval_time)
|
||||
|
||||
return retrieved_executable
|
||||
else:
|
||||
start_time = time.monotonic()
|
||||
@ -342,6 +349,7 @@ def _cache_write(cache_key: str,
|
||||
"'%s' took at least %.2f seconds to compile (%.2fs), writing "
|
||||
"persistent cache entry", module_name, min_compile_time,
|
||||
compile_time_secs)
|
||||
monitoring.record_event('/jax/compilation_cache/cache_misses')
|
||||
|
||||
try:
|
||||
compilation_cache.put_executable_and_time(
|
||||
|
@ -34,6 +34,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.config import persistent_cache_min_compile_time_secs
|
||||
from jax._src.config import raise_persistent_cache_errors
|
||||
from jax._src.config import use_original_compilation_cache_key_generation
|
||||
from jax._src.lib import xla_client
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
@ -329,5 +330,51 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertEqual(
|
||||
_counts["/jax/compilation_cache/tasks_using_original_cache"], 1)
|
||||
|
||||
def test_compile_requests_use_cache_metric(self):
|
||||
previous_counts = Counter(_counts)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
|
||||
jit(lambda x: x + 1)(1)
|
||||
jit(lambda x: x + 2)(1)
|
||||
jit(lambda x: x + 1)(1)
|
||||
|
||||
self.assertEqual(
|
||||
_counts["/jax/compilation_cache/compile_requests_use_cache"]
|
||||
- previous_counts["/jax/compilation_cache/compile_requests_use_cache"],
|
||||
3)
|
||||
|
||||
def test_cache_misses_metric(self):
|
||||
previous_counts = Counter(_counts)
|
||||
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
|
||||
2):
|
||||
cc.initialize_cache(tmpdir)
|
||||
|
||||
# Mock time to create a long compilation time and make cache misses.
|
||||
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
|
||||
jit(lambda x: x + 1)(1)
|
||||
jit(lambda x: x + 2)(1)
|
||||
|
||||
self.assertEqual(
|
||||
_counts["/jax/compilation_cache/cache_misses"]
|
||||
- previous_counts["/jax/compilation_cache/cache_misses"],
|
||||
2)
|
||||
|
||||
def test_cache_hits_original_metric(self):
|
||||
previous_counts = Counter(_counts)
|
||||
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
|
||||
2), use_original_compilation_cache_key_generation(True):
|
||||
cc.initialize_cache(tmpdir)
|
||||
|
||||
# Mock time to create a long compilation time, cache saved.
|
||||
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
|
||||
jit(lambda x: x + 1)(1)
|
||||
jit(lambda x: x + 1)(1)
|
||||
|
||||
self.assertEqual(
|
||||
_counts["/jax/compilation_cache/cache_hits_original"]
|
||||
- previous_counts["/jax/compilation_cache/cache_hits_original"],
|
||||
1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user