Cleanup deprecated compilation cache APIs.

Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.

Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.

Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
This commit is contained in:
jax authors 2024-01-12 22:44:03 -08:00
parent 2fdaef4cee
commit b8b119d9b9
7 changed files with 54 additions and 39 deletions

View File

@ -10,4 +10,5 @@ API
.. autofunction:: is_initialized
.. autofunction:: initialize_cache
.. autofunction:: set_cache_dir
.. autofunction:: reset_cache

View File

@ -16,6 +16,7 @@ from __future__ import annotations
import logging
import threading
import warnings
import zlib
import numpy as np
@ -66,11 +67,26 @@ def get_file_cache(path: str) -> CacheInterface:
return GFileCache(path)
def set_cache_dir(path) -> None:
"""
Sets the persistent compilation cache directory.
After calling this, jit-compiled functions are saved to `path`, so they
do not need be recompiled if the process is restarted or otherwise run again.
This also tells Jax where to look for compiled functions before compiling.
"""
config.config.update("jax_compilation_cache_dir", path)
def initialize_cache(path) -> None:
"""
This API is deprecated; use set_cache_dir instead.
Set the path. To take effect, should be called prior to any calls to
get_executable_and_time() and put_executable_and_time().
"""
warnings.warn("initialize_cache is deprecated; use set_cache_dir instead",
DeprecationWarning, stacklevel=2)
config.config.update("jax_compilation_cache_dir", path)
@ -207,10 +223,14 @@ def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
def is_initialized() -> bool:
"""
Deprecated.
Return whether the cache is enabled. Initialization can be deferred, so
initialized status is not checked. The name is retained for backwards
compatibility.
"""
warnings.warn("is_initialized is deprecated; do not use",
DeprecationWarning, stacklevel=2)
return _is_cache_enabled()

View File

@ -254,7 +254,7 @@ def compile_or_get_cached(
supported_platforms = ["tpu", "gpu"]
if xla_extension_version >= 230:
supported_platforms.append("cpu")
use_compilation_cache = (compilation_cache.is_initialized() and
use_compilation_cache = (config.enable_compilation_cache.value and
backend.platform in supported_platforms)
if not use_compilation_cache:

View File

@ -1034,9 +1034,9 @@ enable_compilation_cache = define_bool_state(
name='jax_enable_compilation_cache',
default=True,
help=('If set to False, the compilation cache will be disabled regardless '
'of whether initialize_cache() was called. If set to True, the '
'of whether set_cache_dir() was called. If set to True, the '
'path could be set to a default value or via a call to '
'initialize_cache().'),
'set_cache_dir().'),
)
compilation_cache_dir = define_string_state(
@ -1044,7 +1044,7 @@ compilation_cache_dir = define_string_state(
default=None,
help=('Path for the cache. '
'Precedence: '
'1. A call to compilation_cache.initialize_cache(). '
'1. A call to compilation_cache.set_cache_dir(). '
'2. The value of this flag set in the command line or by default.'),
)

View File

@ -898,7 +898,7 @@ def promote_like_jnp(fun, inexact=False):
"""Decorator that promotes the arguments of `fun` to `jnp.result_type(*args)`.
jnp and np have different type promotion semantics; this decorator allows
tests make an np reference implementation act more like an jnp
tests make an np reference implementation act more like a jnp
implementation.
"""
_promote = promote_dtypes_inexact if inexact else promote_dtypes
@ -955,9 +955,8 @@ class JaxTestCase(parameterized.TestCase):
stack.enter_context(config.persistent_cache_min_entry_size_bytes(0))
tmp_dir = stack.enter_context(tempfile.TemporaryDirectory())
compilation_cache.initialize_cache(tmp_dir)
stack.callback(lambda: compilation_cache.reset_cache()
if compilation_cache.is_initialized() else None)
compilation_cache.set_cache_dir(tmp_dir)
stack.callback(lambda: compilation_cache.reset_cache())
@classmethod
def tearDownClass(cls):

View File

@ -13,7 +13,8 @@
# limitations under the License.
from jax._src.compilation_cache import (
is_initialized as is_initialized,
initialize_cache as initialize_cache,
is_initialized as is_initialized, # deprecated
initialize_cache as initialize_cache, # deprecated; use set_cache_dir instead
set_cache_dir as set_cache_dir,
reset_cache as reset_cache,
)

View File

@ -78,18 +78,15 @@ class CompilationCacheTest(jtu.JaxTestCase):
"serialize executable only works on " + ",".join(supported_platforms)
)
# Reset cache if already initialized by JaxTestCase
if cc.is_initialized():
cc.reset_cache()
cc.reset_cache()
def tearDown(self):
if cc.is_initialized():
cc.reset_cache()
cc.reset_cache()
super().tearDown()
def test_get_no_executable(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = compiler.get_compile_options(
@ -104,7 +101,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_diff_executables(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir())
compile_options = compiler.get_compile_options(
@ -124,7 +121,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_put_executable(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
computation = (
jax.jit(lambda x, y: x + y)
.lower(np.int32(1), np.int32(1))
@ -156,7 +153,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_pmap(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i")
x = np.arange(jax.device_count(), dtype=np.int64)
f(x)
@ -170,7 +167,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_jit(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
f = jit(lambda x: x * x)
f(1)
files_in_directory = len(os.listdir(tmpdir))
@ -183,7 +180,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
original_profile_version = config.jax_xla_profile_version.value
with (tempfile.TemporaryDirectory() as tmpdir,
config.jax_xla_profile_version(original_profile_version + 1)):
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
f = jit(lambda x: x * x)
f(1)
files_in_cache_directory = os.listdir(tmpdir)
@ -200,7 +197,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
@jtu.with_mesh([("x", 2)])
def test_pjit(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
@partial(pjit, in_shardings=(P("x"), P("x")), out_shardings=None)
def f(x, y):
@ -219,7 +216,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
@jtu.with_mesh([("x", 2)])
def test_xmap(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
def f(x):
return x * 2
@ -242,7 +239,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_cache_write_warning(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
f = jit(lambda x: x * x)
with (
@ -263,7 +260,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_cache_read_warning(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
f = jit(lambda x: x * x)
with (
@ -290,7 +287,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
config.persistent_cache_min_compile_time_secs(0),
config.persistent_cache_min_entry_size_bytes(1048576), # 1MiB
):
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
jit(lambda x: x + 1)(1)
files_in_cache = len(os.listdir(tmpdir))
@ -302,7 +299,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
config.persistent_cache_min_compile_time_secs(2),
config.persistent_cache_min_entry_size_bytes(0),
):
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
# Mock time to progress in small intervals so compilation time is small.
with mock.patch("time.monotonic", side_effect=np.arange(0, 10, 0.1)):
@ -322,7 +319,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
config.persistent_cache_min_compile_time_secs(2),
config.persistent_cache_min_entry_size_bytes(0),
):
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
durations = Counter() # Map metric name to time duration.
def append_metric_duration(metric, duration):
@ -354,7 +351,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_task_using_cache_metric(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
count_before_first_use = _counts[
"/jax/compilation_cache/tasks_using_cache"]
jit(lambda x: x + 1)(1)
@ -371,7 +368,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_compile_requests_use_cache_metric(self):
previous_counts = Counter(_counts)
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
jit(lambda x: x + 1)(1)
jit(lambda x: x + 2)(1)
@ -390,7 +387,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
config.persistent_cache_min_compile_time_secs(2),
config.persistent_cache_min_entry_size_bytes(min_entry_size),
):
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
# Mock time to create a long compilation time and make cache misses.
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
@ -415,7 +412,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
config.persistent_cache_min_compile_time_secs(2),
config.persistent_cache_min_entry_size_bytes(0),
):
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
# Mock time to create a long compilation time, cache saved.
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
@ -438,17 +435,14 @@ class CompilationCacheDisabledTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
# Reset cache if already initialized by JaxTestCase
if cc.is_initialized():
cc.reset_cache()
cc.reset_cache()
def tearDown(self):
if cc.is_initialized():
cc.reset_cache()
cc.reset_cache()
super().tearDown()
# If the cache is disabled, there should be no files in the cache directory.
# A call to initialize_cache() does not affect this.
# A call to set_cache_dir() does not affect this.
def test_jit(self):
# Sequence of flag settings for config.jax_enable_compilation_cache:
# 1. Flag is disabled by @jtu.with_config() above.
@ -459,7 +453,7 @@ class CompilationCacheDisabledTest(jtu.JaxTestCase):
tempfile.TemporaryDirectory() as tmpdir,
config.enable_compilation_cache(False),
):
cc.initialize_cache(tmpdir)
cc.set_cache_dir(tmpdir)
f = jit(lambda x: x * x)
f(1)
files_in_directory = len(os.listdir(tmpdir))