From b8b119d9b9df4d4d3ca5522a2063e90dcffccb34 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 12 Jan 2024 22:44:03 -0800 Subject: [PATCH] Cleanup deprecated compilation cache APIs. Since the compilation cache is now initialized lazily, existing APIs initialize_cache() and is_initialized() are confusing. Deprecate these APIs. Introduce a new API set_cache_dir() to explicitly set the cache directory path in code. Testing: revised unit tests, test workload. PiperOrigin-RevId: 598073423 --- docs/jax.experimental.compilation_cache.rst | 1 + jax/_src/compilation_cache.py | 20 +++++++ jax/_src/compiler.py | 2 +- jax/_src/config.py | 6 +-- jax/_src/test_util.py | 7 ++- .../compilation_cache/compilation_cache.py | 5 +- tests/compilation_cache_test.py | 52 ++++++++----------- 7 files changed, 54 insertions(+), 39 deletions(-) diff --git a/docs/jax.experimental.compilation_cache.rst b/docs/jax.experimental.compilation_cache.rst index 60213fe0c..8196b1e9c 100644 --- a/docs/jax.experimental.compilation_cache.rst +++ b/docs/jax.experimental.compilation_cache.rst @@ -10,4 +10,5 @@ API .. autofunction:: is_initialized .. autofunction:: initialize_cache +.. autofunction:: set_cache_dir .. autofunction:: reset_cache diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 80911cca7..65b69e92e 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -16,6 +16,7 @@ from __future__ import annotations import logging import threading +import warnings import zlib import numpy as np @@ -66,11 +67,26 @@ def get_file_cache(path: str) -> CacheInterface: return GFileCache(path) +def set_cache_dir(path) -> None: + """ + Sets the persistent compilation cache directory. + + After calling this, jit-compiled functions are saved to `path`, so they + do not need be recompiled if the process is restarted or otherwise run again. + This also tells Jax where to look for compiled functions before compiling. + """ + config.config.update("jax_compilation_cache_dir", path) + + def initialize_cache(path) -> None: """ + This API is deprecated; use set_cache_dir instead. + Set the path. To take effect, should be called prior to any calls to get_executable_and_time() and put_executable_and_time(). """ + warnings.warn("initialize_cache is deprecated; use set_cache_dir instead", + DeprecationWarning, stacklevel=2) config.config.update("jax_compilation_cache_dir", path) @@ -207,10 +223,14 @@ def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options, def is_initialized() -> bool: """ + Deprecated. + Return whether the cache is enabled. Initialization can be deferred, so initialized status is not checked. The name is retained for backwards compatibility. """ + warnings.warn("is_initialized is deprecated; do not use", + DeprecationWarning, stacklevel=2) return _is_cache_enabled() diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index a86afe6c6..4122303ca 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -254,7 +254,7 @@ def compile_or_get_cached( supported_platforms = ["tpu", "gpu"] if xla_extension_version >= 230: supported_platforms.append("cpu") - use_compilation_cache = (compilation_cache.is_initialized() and + use_compilation_cache = (config.enable_compilation_cache.value and backend.platform in supported_platforms) if not use_compilation_cache: diff --git a/jax/_src/config.py b/jax/_src/config.py index e804e372c..6ff73af0c 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1034,9 +1034,9 @@ enable_compilation_cache = define_bool_state( name='jax_enable_compilation_cache', default=True, help=('If set to False, the compilation cache will be disabled regardless ' - 'of whether initialize_cache() was called. If set to True, the ' + 'of whether set_cache_dir() was called. If set to True, the ' 'path could be set to a default value or via a call to ' - 'initialize_cache().'), + 'set_cache_dir().'), ) compilation_cache_dir = define_string_state( @@ -1044,7 +1044,7 @@ compilation_cache_dir = define_string_state( default=None, help=('Path for the cache. ' 'Precedence: ' - '1. A call to compilation_cache.initialize_cache(). ' + '1. A call to compilation_cache.set_cache_dir(). ' '2. The value of this flag set in the command line or by default.'), ) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index a20c6732e..d9517648e 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -898,7 +898,7 @@ def promote_like_jnp(fun, inexact=False): """Decorator that promotes the arguments of `fun` to `jnp.result_type(*args)`. jnp and np have different type promotion semantics; this decorator allows - tests make an np reference implementation act more like an jnp + tests make an np reference implementation act more like a jnp implementation. """ _promote = promote_dtypes_inexact if inexact else promote_dtypes @@ -955,9 +955,8 @@ class JaxTestCase(parameterized.TestCase): stack.enter_context(config.persistent_cache_min_entry_size_bytes(0)) tmp_dir = stack.enter_context(tempfile.TemporaryDirectory()) - compilation_cache.initialize_cache(tmp_dir) - stack.callback(lambda: compilation_cache.reset_cache() - if compilation_cache.is_initialized() else None) + compilation_cache.set_cache_dir(tmp_dir) + stack.callback(lambda: compilation_cache.reset_cache()) @classmethod def tearDownClass(cls): diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py index 7cf81d76a..990dd1742 100644 --- a/jax/experimental/compilation_cache/compilation_cache.py +++ b/jax/experimental/compilation_cache/compilation_cache.py @@ -13,7 +13,8 @@ # limitations under the License. from jax._src.compilation_cache import ( - is_initialized as is_initialized, - initialize_cache as initialize_cache, + is_initialized as is_initialized, # deprecated + initialize_cache as initialize_cache, # deprecated; use set_cache_dir instead + set_cache_dir as set_cache_dir, reset_cache as reset_cache, ) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 3a861d119..d34a49fc7 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -78,18 +78,15 @@ class CompilationCacheTest(jtu.JaxTestCase): "serialize executable only works on " + ",".join(supported_platforms) ) - # Reset cache if already initialized by JaxTestCase - if cc.is_initialized(): - cc.reset_cache() + cc.reset_cache() def tearDown(self): - if cc.is_initialized(): - cc.reset_cache() + cc.reset_cache() super().tearDown() def test_get_no_executable(self): with tempfile.TemporaryDirectory() as tmpdir: - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() devices = np.array([[jax.local_devices()[0]]]) compile_options = compiler.get_compile_options( @@ -104,7 +101,7 @@ class CompilationCacheTest(jtu.JaxTestCase): def test_diff_executables(self): with tempfile.TemporaryDirectory() as tmpdir: - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()) computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()) compile_options = compiler.get_compile_options( @@ -124,7 +121,7 @@ class CompilationCacheTest(jtu.JaxTestCase): def test_put_executable(self): with tempfile.TemporaryDirectory() as tmpdir: - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) computation = ( jax.jit(lambda x, y: x + y) .lower(np.int32(1), np.int32(1)) @@ -156,7 +153,7 @@ class CompilationCacheTest(jtu.JaxTestCase): def test_pmap(self): with tempfile.TemporaryDirectory() as tmpdir: - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i") x = np.arange(jax.device_count(), dtype=np.int64) f(x) @@ -170,7 +167,7 @@ class CompilationCacheTest(jtu.JaxTestCase): def test_jit(self): with tempfile.TemporaryDirectory() as tmpdir: - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) f = jit(lambda x: x * x) f(1) files_in_directory = len(os.listdir(tmpdir)) @@ -183,7 +180,7 @@ class CompilationCacheTest(jtu.JaxTestCase): original_profile_version = config.jax_xla_profile_version.value with (tempfile.TemporaryDirectory() as tmpdir, config.jax_xla_profile_version(original_profile_version + 1)): - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) f = jit(lambda x: x * x) f(1) files_in_cache_directory = os.listdir(tmpdir) @@ -200,7 +197,7 @@ class CompilationCacheTest(jtu.JaxTestCase): @jtu.with_mesh([("x", 2)]) def test_pjit(self): with tempfile.TemporaryDirectory() as tmpdir: - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) @partial(pjit, in_shardings=(P("x"), P("x")), out_shardings=None) def f(x, y): @@ -219,7 +216,7 @@ class CompilationCacheTest(jtu.JaxTestCase): @jtu.with_mesh([("x", 2)]) def test_xmap(self): with tempfile.TemporaryDirectory() as tmpdir: - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) def f(x): return x * 2 @@ -242,7 +239,7 @@ class CompilationCacheTest(jtu.JaxTestCase): def test_cache_write_warning(self): with tempfile.TemporaryDirectory() as tmpdir: - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) f = jit(lambda x: x * x) with ( @@ -263,7 +260,7 @@ class CompilationCacheTest(jtu.JaxTestCase): def test_cache_read_warning(self): with tempfile.TemporaryDirectory() as tmpdir: - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) f = jit(lambda x: x * x) with ( @@ -290,7 +287,7 @@ class CompilationCacheTest(jtu.JaxTestCase): config.persistent_cache_min_compile_time_secs(0), config.persistent_cache_min_entry_size_bytes(1048576), # 1MiB ): - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) jit(lambda x: x + 1)(1) files_in_cache = len(os.listdir(tmpdir)) @@ -302,7 +299,7 @@ class CompilationCacheTest(jtu.JaxTestCase): config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(0), ): - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) # Mock time to progress in small intervals so compilation time is small. with mock.patch("time.monotonic", side_effect=np.arange(0, 10, 0.1)): @@ -322,7 +319,7 @@ class CompilationCacheTest(jtu.JaxTestCase): config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(0), ): - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) durations = Counter() # Map metric name to time duration. def append_metric_duration(metric, duration): @@ -354,7 +351,7 @@ class CompilationCacheTest(jtu.JaxTestCase): def test_task_using_cache_metric(self): with tempfile.TemporaryDirectory() as tmpdir: - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) count_before_first_use = _counts[ "/jax/compilation_cache/tasks_using_cache"] jit(lambda x: x + 1)(1) @@ -371,7 +368,7 @@ class CompilationCacheTest(jtu.JaxTestCase): def test_compile_requests_use_cache_metric(self): previous_counts = Counter(_counts) with tempfile.TemporaryDirectory() as tmpdir: - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) jit(lambda x: x + 1)(1) jit(lambda x: x + 2)(1) @@ -390,7 +387,7 @@ class CompilationCacheTest(jtu.JaxTestCase): config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(min_entry_size), ): - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) # Mock time to create a long compilation time and make cache misses. with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)): @@ -415,7 +412,7 @@ class CompilationCacheTest(jtu.JaxTestCase): config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(0), ): - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) # Mock time to create a long compilation time, cache saved. with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)): @@ -438,17 +435,14 @@ class CompilationCacheDisabledTest(jtu.JaxTestCase): def setUp(self): super().setUp() - # Reset cache if already initialized by JaxTestCase - if cc.is_initialized(): - cc.reset_cache() + cc.reset_cache() def tearDown(self): - if cc.is_initialized(): - cc.reset_cache() + cc.reset_cache() super().tearDown() # If the cache is disabled, there should be no files in the cache directory. - # A call to initialize_cache() does not affect this. + # A call to set_cache_dir() does not affect this. def test_jit(self): # Sequence of flag settings for config.jax_enable_compilation_cache: # 1. Flag is disabled by @jtu.with_config() above. @@ -459,7 +453,7 @@ class CompilationCacheDisabledTest(jtu.JaxTestCase): tempfile.TemporaryDirectory() as tmpdir, config.enable_compilation_cache(False), ): - cc.initialize_cache(tmpdir) + cc.set_cache_dir(tmpdir) f = jit(lambda x: x * x) f(1) files_in_directory = len(os.listdir(tmpdir))