From f87c94db75c3074c7ef9a66a5045ce91564a78db Mon Sep 17 00:00:00 2001 From: Stella S Yan Date: Wed, 15 Jan 2025 00:02:03 +0000 Subject: [PATCH] Fix cache init when JAX Array is created early (#25768) --- jax/_src/compilation_cache.py | 12 ++++++---- tests/compilation_cache_test.py | 41 +++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index d8724e429..3b3b1bc88 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -131,13 +131,19 @@ def _initialize_cache() -> None: with _cache_initialized_mutex: if _cache_initialized: return - _cache_initialized = True + + path: str | None = config.compilation_cache_dir.value + # If the path is not set, the cache will not be built. + if not path: + return # Nothing to do if the cache is disabled. if not _is_cache_enabled(): logger.debug("_initialize_cache: cache is disabled!") return + _cache_initialized = True + # Set the minimum cache size entry only if the flag # --jax_persistent_cache_min_entry_size_bytes has not been set. if config.persistent_cache_min_entry_size_bytes.value == 0: @@ -146,10 +152,6 @@ def _initialize_cache() -> None: global _cache assert _cache is None, "The cache has already been initialized!" - path: str | None = config.compilation_cache_dir.value - # If the path is not set, the cache will not be enabled. - if not path: - return cache_and_path = get_file_cache(path) if cache_and_path is None: diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index ef245bc8d..3b071cfc0 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -23,6 +23,7 @@ import platform import unittest from unittest import mock from unittest import SkipTest +import tempfile from absl.testing import absltest from absl.testing import parameterized @@ -198,6 +199,46 @@ class CompilationCacheTest(CompilationCacheTestCase): f(1.0) self.assertEqual(count_cache_items(), 2) + def test_set_cache_dir_after_backends_init(self): + # This a regression test for #25768 + with config.compilation_cache_dir(None): + cc.reset_cache() + backend = xla_bridge.get_backend() + + a = jnp.zeros((2,3)) + self.assertFalse(cc.is_persistent_cache_enabled()) + cache = cc._get_cache(backend) + self.assertIsNone(cache) # Not able to create cache + + with tempfile.TemporaryDirectory() as tmp_cache_dir: + with config.compilation_cache_dir(tmp_cache_dir): + f = jit(lambda x: x + 1) + f(a) # Compile and cache + self.assertTrue(cc.is_persistent_cache_enabled()) + cache = cc._get_cache(backend) + self.assertIsNotNone(cache) # Cache is created + + def test_enable_compilation_cache(self): + with tempfile.TemporaryDirectory() as tmp_cache_dir: + with ( + config.enable_compilation_cache(False), + config.compilation_cache_dir(tmp_cache_dir) + ): + cc.reset_cache() # reset cache before testing + backend = xla_bridge.get_backend() + f = jit(lambda x: x + 1) + f(1) # Compile and cache + cache = cc._get_cache(backend) + self.assertIsNone(cache) # Cache should not exist + + with config.enable_compilation_cache(True): + cc.reset_cache() + backend = xla_bridge.get_backend() + g = jit(lambda x: x * 3) + g(2) + cache = cc._get_cache(backend) + self.assertIsNotNone(cache) # Cache should be initalized + def test_xla_autofdo_profile_version(self): original_profile_version = config.jax_xla_profile_version.value with config.jax_xla_profile_version(original_profile_version + 1):