Fix the missing cache_misses metric when min compile time is set to zero.

Remove the code which checks if the min compile time is greater than zero. After this change, we can catch cache_misses when min compile time is zero.

Testing: revised unit test.
PiperOrigin-RevId: 579951415
This commit is contained in:
jax authors 2023-11-06 14:03:48 -08:00
parent 5f4d4797b2
commit 7e372944f9
2 changed files with 14 additions and 15 deletions

View File

@ -389,19 +389,17 @@ def _cache_write(cache_key: str,
return
min_compile_time = config.persistent_cache_min_compile_time_secs.value
if min_compile_time:
if compile_time_secs < min_compile_time:
logger.debug(
"Not writing persistent cache entry for '%s' because it took < %.2f "
"seconds to compile (%.2fs)", module_name, min_compile_time,
compile_time_secs)
return
else:
logger.debug(
"'%s' took at least %.2f seconds to compile (%.2fs), writing "
"persistent cache entry", module_name, min_compile_time,
compile_time_secs)
monitoring.record_event('/jax/compilation_cache/cache_misses')
if compile_time_secs < min_compile_time:
logger.debug(
"Not writing persistent cache entry for '%s' because it took < %.2f "
"seconds to compile (%.2fs)", module_name, min_compile_time,
compile_time_secs)
return
else:
logger.debug(
"'%s' took at least %.2f seconds to compile (%.2fs), writing persistent"
" cache entry", module_name, min_compile_time, compile_time_secs)
monitoring.record_event('/jax/compilation_cache/cache_misses')
try:
compilation_cache.put_executable_and_time(

View File

@ -385,11 +385,12 @@ class CompilationCacheTest(jtu.JaxTestCase):
- previous_counts["/jax/compilation_cache/compile_requests_use_cache"],
3)
def test_cache_misses_metric(self):
@parameterized.parameters(0, 2)
def test_cache_misses_metric(self, min_compile_time_secs):
previous_counts = Counter(_counts)
with (
tempfile.TemporaryDirectory() as tmpdir,
config.persistent_cache_min_compile_time_secs(2),
config.persistent_cache_min_compile_time_secs(min_compile_time_secs),
):
cc.initialize_cache(tmpdir)