mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #25889 from Stella-S-Yan:cache_reset
PiperOrigin-RevId: 718537398
This commit is contained in:
commit
f6243ff8e1
@ -131,13 +131,19 @@ def _initialize_cache() -> None:
|
||||
with _cache_initialized_mutex:
|
||||
if _cache_initialized:
|
||||
return
|
||||
_cache_initialized = True
|
||||
|
||||
path: str | None = config.compilation_cache_dir.value
|
||||
# If the path is not set, the cache will not be built.
|
||||
if not path:
|
||||
return
|
||||
|
||||
# Nothing to do if the cache is disabled.
|
||||
if not _is_cache_enabled():
|
||||
logger.debug("_initialize_cache: cache is disabled!")
|
||||
return
|
||||
|
||||
_cache_initialized = True
|
||||
|
||||
# Set the minimum cache size entry only if the flag
|
||||
# --jax_persistent_cache_min_entry_size_bytes has not been set.
|
||||
if config.persistent_cache_min_entry_size_bytes.value == 0:
|
||||
@ -146,10 +152,6 @@ def _initialize_cache() -> None:
|
||||
|
||||
global _cache
|
||||
assert _cache is None, "The cache has already been initialized!"
|
||||
path: str | None = config.compilation_cache_dir.value
|
||||
# If the path is not set, the cache will not be enabled.
|
||||
if not path:
|
||||
return
|
||||
|
||||
cache_and_path = get_file_cache(path)
|
||||
if cache_and_path is None:
|
||||
|
@ -23,6 +23,7 @@ import platform
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from unittest import SkipTest
|
||||
import tempfile
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -198,6 +199,46 @@ class CompilationCacheTest(CompilationCacheTestCase):
|
||||
f(1.0)
|
||||
self.assertEqual(count_cache_items(), 2)
|
||||
|
||||
def test_set_cache_dir_after_backends_init(self):
|
||||
# This a regression test for #25768
|
||||
with config.compilation_cache_dir(None):
|
||||
cc.reset_cache()
|
||||
backend = xla_bridge.get_backend()
|
||||
|
||||
a = jnp.zeros((2,3))
|
||||
self.assertFalse(cc.is_persistent_cache_enabled())
|
||||
cache = cc._get_cache(backend)
|
||||
self.assertIsNone(cache) # Not able to create cache
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_cache_dir:
|
||||
with config.compilation_cache_dir(tmp_cache_dir):
|
||||
f = jit(lambda x: x + 1)
|
||||
f(a) # Compile and cache
|
||||
self.assertTrue(cc.is_persistent_cache_enabled())
|
||||
cache = cc._get_cache(backend)
|
||||
self.assertIsNotNone(cache) # Cache is created
|
||||
|
||||
def test_enable_compilation_cache(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_cache_dir:
|
||||
with (
|
||||
config.enable_compilation_cache(False),
|
||||
config.compilation_cache_dir(tmp_cache_dir)
|
||||
):
|
||||
cc.reset_cache() # reset cache before testing
|
||||
backend = xla_bridge.get_backend()
|
||||
f = jit(lambda x: x + 1)
|
||||
f(1) # Compile and cache
|
||||
cache = cc._get_cache(backend)
|
||||
self.assertIsNone(cache) # Cache should not exist
|
||||
|
||||
with config.enable_compilation_cache(True):
|
||||
cc.reset_cache()
|
||||
backend = xla_bridge.get_backend()
|
||||
g = jit(lambda x: x * 3)
|
||||
g(2)
|
||||
cache = cc._get_cache(backend)
|
||||
self.assertIsNotNone(cache) # Cache should be initalized
|
||||
|
||||
def test_xla_autofdo_profile_version(self):
|
||||
original_profile_version = config.jax_xla_profile_version.value
|
||||
with config.jax_xla_profile_version(original_profile_version + 1):
|
||||
|
Loading…
x
Reference in New Issue
Block a user