Merge pull request #25889 from Stella-S-Yan:cache_reset

PiperOrigin-RevId: 718537398
This commit is contained in:
jax authors 2025-01-22 14:52:05 -08:00
commit f6243ff8e1
2 changed files with 48 additions and 5 deletions

View File

@ -131,13 +131,19 @@ def _initialize_cache() -> None:
with _cache_initialized_mutex:
if _cache_initialized:
return
_cache_initialized = True
path: str | None = config.compilation_cache_dir.value
# If the path is not set, the cache will not be built.
if not path:
return
# Nothing to do if the cache is disabled.
if not _is_cache_enabled():
logger.debug("_initialize_cache: cache is disabled!")
return
_cache_initialized = True
# Set the minimum cache size entry only if the flag
# --jax_persistent_cache_min_entry_size_bytes has not been set.
if config.persistent_cache_min_entry_size_bytes.value == 0:
@ -146,10 +152,6 @@ def _initialize_cache() -> None:
global _cache
assert _cache is None, "The cache has already been initialized!"
path: str | None = config.compilation_cache_dir.value
# If the path is not set, the cache will not be enabled.
if not path:
return
cache_and_path = get_file_cache(path)
if cache_and_path is None:

View File

@ -23,6 +23,7 @@ import platform
import unittest
from unittest import mock
from unittest import SkipTest
import tempfile
from absl.testing import absltest
from absl.testing import parameterized
@ -198,6 +199,46 @@ class CompilationCacheTest(CompilationCacheTestCase):
f(1.0)
self.assertEqual(count_cache_items(), 2)
def test_set_cache_dir_after_backends_init(self):
# This a regression test for #25768
with config.compilation_cache_dir(None):
cc.reset_cache()
backend = xla_bridge.get_backend()
a = jnp.zeros((2,3))
self.assertFalse(cc.is_persistent_cache_enabled())
cache = cc._get_cache(backend)
self.assertIsNone(cache) # Not able to create cache
with tempfile.TemporaryDirectory() as tmp_cache_dir:
with config.compilation_cache_dir(tmp_cache_dir):
f = jit(lambda x: x + 1)
f(a) # Compile and cache
self.assertTrue(cc.is_persistent_cache_enabled())
cache = cc._get_cache(backend)
self.assertIsNotNone(cache) # Cache is created
def test_enable_compilation_cache(self):
with tempfile.TemporaryDirectory() as tmp_cache_dir:
with (
config.enable_compilation_cache(False),
config.compilation_cache_dir(tmp_cache_dir)
):
cc.reset_cache() # reset cache before testing
backend = xla_bridge.get_backend()
f = jit(lambda x: x + 1)
f(1) # Compile and cache
cache = cc._get_cache(backend)
self.assertIsNone(cache) # Cache should not exist
with config.enable_compilation_cache(True):
cc.reset_cache()
backend = xla_bridge.get_backend()
g = jit(lambda x: x * 3)
g(2)
cache = cc._get_cache(backend)
self.assertIsNotNone(cache) # Cache should be initalized
def test_xla_autofdo_profile_version(self):
original_profile_version = config.jax_xla_profile_version.value
with config.jax_xla_profile_version(original_profile_version + 1):