mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Cleanup deprecated compilation cache APIs.
Since the compilation cache is now initialized lazily, existing APIs initialize_cache() and is_initialized() are confusing. Deprecate these APIs. Introduce a new API set_cache_dir() to explicitly set the cache directory path in code. Testing: revised unit tests, test workload. PiperOrigin-RevId: 598073423
This commit is contained in:
parent
2fdaef4cee
commit
b8b119d9b9
@ -10,4 +10,5 @@ API
|
||||
|
||||
.. autofunction:: is_initialized
|
||||
.. autofunction:: initialize_cache
|
||||
.. autofunction:: set_cache_dir
|
||||
.. autofunction:: reset_cache
|
||||
|
@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import warnings
|
||||
import zlib
|
||||
|
||||
import numpy as np
|
||||
@ -66,11 +67,26 @@ def get_file_cache(path: str) -> CacheInterface:
|
||||
return GFileCache(path)
|
||||
|
||||
|
||||
def set_cache_dir(path) -> None:
|
||||
"""
|
||||
Sets the persistent compilation cache directory.
|
||||
|
||||
After calling this, jit-compiled functions are saved to `path`, so they
|
||||
do not need be recompiled if the process is restarted or otherwise run again.
|
||||
This also tells Jax where to look for compiled functions before compiling.
|
||||
"""
|
||||
config.config.update("jax_compilation_cache_dir", path)
|
||||
|
||||
|
||||
def initialize_cache(path) -> None:
|
||||
"""
|
||||
This API is deprecated; use set_cache_dir instead.
|
||||
|
||||
Set the path. To take effect, should be called prior to any calls to
|
||||
get_executable_and_time() and put_executable_and_time().
|
||||
"""
|
||||
warnings.warn("initialize_cache is deprecated; use set_cache_dir instead",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
config.config.update("jax_compilation_cache_dir", path)
|
||||
|
||||
|
||||
@ -207,10 +223,14 @@ def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""
|
||||
Deprecated.
|
||||
|
||||
Return whether the cache is enabled. Initialization can be deferred, so
|
||||
initialized status is not checked. The name is retained for backwards
|
||||
compatibility.
|
||||
"""
|
||||
warnings.warn("is_initialized is deprecated; do not use",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
return _is_cache_enabled()
|
||||
|
||||
|
||||
|
@ -254,7 +254,7 @@ def compile_or_get_cached(
|
||||
supported_platforms = ["tpu", "gpu"]
|
||||
if xla_extension_version >= 230:
|
||||
supported_platforms.append("cpu")
|
||||
use_compilation_cache = (compilation_cache.is_initialized() and
|
||||
use_compilation_cache = (config.enable_compilation_cache.value and
|
||||
backend.platform in supported_platforms)
|
||||
|
||||
if not use_compilation_cache:
|
||||
|
@ -1034,9 +1034,9 @@ enable_compilation_cache = define_bool_state(
|
||||
name='jax_enable_compilation_cache',
|
||||
default=True,
|
||||
help=('If set to False, the compilation cache will be disabled regardless '
|
||||
'of whether initialize_cache() was called. If set to True, the '
|
||||
'of whether set_cache_dir() was called. If set to True, the '
|
||||
'path could be set to a default value or via a call to '
|
||||
'initialize_cache().'),
|
||||
'set_cache_dir().'),
|
||||
)
|
||||
|
||||
compilation_cache_dir = define_string_state(
|
||||
@ -1044,7 +1044,7 @@ compilation_cache_dir = define_string_state(
|
||||
default=None,
|
||||
help=('Path for the cache. '
|
||||
'Precedence: '
|
||||
'1. A call to compilation_cache.initialize_cache(). '
|
||||
'1. A call to compilation_cache.set_cache_dir(). '
|
||||
'2. The value of this flag set in the command line or by default.'),
|
||||
)
|
||||
|
||||
|
@ -898,7 +898,7 @@ def promote_like_jnp(fun, inexact=False):
|
||||
"""Decorator that promotes the arguments of `fun` to `jnp.result_type(*args)`.
|
||||
|
||||
jnp and np have different type promotion semantics; this decorator allows
|
||||
tests make an np reference implementation act more like an jnp
|
||||
tests make an np reference implementation act more like a jnp
|
||||
implementation.
|
||||
"""
|
||||
_promote = promote_dtypes_inexact if inexact else promote_dtypes
|
||||
@ -955,9 +955,8 @@ class JaxTestCase(parameterized.TestCase):
|
||||
stack.enter_context(config.persistent_cache_min_entry_size_bytes(0))
|
||||
|
||||
tmp_dir = stack.enter_context(tempfile.TemporaryDirectory())
|
||||
compilation_cache.initialize_cache(tmp_dir)
|
||||
stack.callback(lambda: compilation_cache.reset_cache()
|
||||
if compilation_cache.is_initialized() else None)
|
||||
compilation_cache.set_cache_dir(tmp_dir)
|
||||
stack.callback(lambda: compilation_cache.reset_cache())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.compilation_cache import (
|
||||
is_initialized as is_initialized,
|
||||
initialize_cache as initialize_cache,
|
||||
is_initialized as is_initialized, # deprecated
|
||||
initialize_cache as initialize_cache, # deprecated; use set_cache_dir instead
|
||||
set_cache_dir as set_cache_dir,
|
||||
reset_cache as reset_cache,
|
||||
)
|
||||
|
@ -78,18 +78,15 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
"serialize executable only works on " + ",".join(supported_platforms)
|
||||
)
|
||||
|
||||
# Reset cache if already initialized by JaxTestCase
|
||||
if cc.is_initialized():
|
||||
cc.reset_cache()
|
||||
cc.reset_cache()
|
||||
|
||||
def tearDown(self):
|
||||
if cc.is_initialized():
|
||||
cc.reset_cache()
|
||||
cc.reset_cache()
|
||||
super().tearDown()
|
||||
|
||||
def test_get_no_executable(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = compiler.get_compile_options(
|
||||
@ -104,7 +101,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_diff_executables(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
|
||||
computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir())
|
||||
compile_options = compiler.get_compile_options(
|
||||
@ -124,7 +121,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_put_executable(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
computation = (
|
||||
jax.jit(lambda x, y: x + y)
|
||||
.lower(np.int32(1), np.int32(1))
|
||||
@ -156,7 +153,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_pmap(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i")
|
||||
x = np.arange(jax.device_count(), dtype=np.int64)
|
||||
f(x)
|
||||
@ -170,7 +167,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_jit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
f = jit(lambda x: x * x)
|
||||
f(1)
|
||||
files_in_directory = len(os.listdir(tmpdir))
|
||||
@ -183,7 +180,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
original_profile_version = config.jax_xla_profile_version.value
|
||||
with (tempfile.TemporaryDirectory() as tmpdir,
|
||||
config.jax_xla_profile_version(original_profile_version + 1)):
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
f = jit(lambda x: x * x)
|
||||
f(1)
|
||||
files_in_cache_directory = os.listdir(tmpdir)
|
||||
@ -200,7 +197,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
|
||||
@partial(pjit, in_shardings=(P("x"), P("x")), out_shardings=None)
|
||||
def f(x, y):
|
||||
@ -219,7 +216,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_xmap(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
|
||||
def f(x):
|
||||
return x * 2
|
||||
@ -242,7 +239,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_cache_write_warning(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
f = jit(lambda x: x * x)
|
||||
|
||||
with (
|
||||
@ -263,7 +260,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_cache_read_warning(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
f = jit(lambda x: x * x)
|
||||
|
||||
with (
|
||||
@ -290,7 +287,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
config.persistent_cache_min_compile_time_secs(0),
|
||||
config.persistent_cache_min_entry_size_bytes(1048576), # 1MiB
|
||||
):
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
|
||||
jit(lambda x: x + 1)(1)
|
||||
files_in_cache = len(os.listdir(tmpdir))
|
||||
@ -302,7 +299,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
config.persistent_cache_min_compile_time_secs(2),
|
||||
config.persistent_cache_min_entry_size_bytes(0),
|
||||
):
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
|
||||
# Mock time to progress in small intervals so compilation time is small.
|
||||
with mock.patch("time.monotonic", side_effect=np.arange(0, 10, 0.1)):
|
||||
@ -322,7 +319,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
config.persistent_cache_min_compile_time_secs(2),
|
||||
config.persistent_cache_min_entry_size_bytes(0),
|
||||
):
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
|
||||
durations = Counter() # Map metric name to time duration.
|
||||
def append_metric_duration(metric, duration):
|
||||
@ -354,7 +351,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_task_using_cache_metric(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
count_before_first_use = _counts[
|
||||
"/jax/compilation_cache/tasks_using_cache"]
|
||||
jit(lambda x: x + 1)(1)
|
||||
@ -371,7 +368,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
def test_compile_requests_use_cache_metric(self):
|
||||
previous_counts = Counter(_counts)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
|
||||
jit(lambda x: x + 1)(1)
|
||||
jit(lambda x: x + 2)(1)
|
||||
@ -390,7 +387,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
config.persistent_cache_min_compile_time_secs(2),
|
||||
config.persistent_cache_min_entry_size_bytes(min_entry_size),
|
||||
):
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
|
||||
# Mock time to create a long compilation time and make cache misses.
|
||||
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
|
||||
@ -415,7 +412,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
config.persistent_cache_min_compile_time_secs(2),
|
||||
config.persistent_cache_min_entry_size_bytes(0),
|
||||
):
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
|
||||
# Mock time to create a long compilation time, cache saved.
|
||||
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
|
||||
@ -438,17 +435,14 @@ class CompilationCacheDisabledTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# Reset cache if already initialized by JaxTestCase
|
||||
if cc.is_initialized():
|
||||
cc.reset_cache()
|
||||
cc.reset_cache()
|
||||
|
||||
def tearDown(self):
|
||||
if cc.is_initialized():
|
||||
cc.reset_cache()
|
||||
cc.reset_cache()
|
||||
super().tearDown()
|
||||
|
||||
# If the cache is disabled, there should be no files in the cache directory.
|
||||
# A call to initialize_cache() does not affect this.
|
||||
# A call to set_cache_dir() does not affect this.
|
||||
def test_jit(self):
|
||||
# Sequence of flag settings for config.jax_enable_compilation_cache:
|
||||
# 1. Flag is disabled by @jtu.with_config() above.
|
||||
@ -459,7 +453,7 @@ class CompilationCacheDisabledTest(jtu.JaxTestCase):
|
||||
tempfile.TemporaryDirectory() as tmpdir,
|
||||
config.enable_compilation_cache(False),
|
||||
):
|
||||
cc.initialize_cache(tmpdir)
|
||||
cc.set_cache_dir(tmpdir)
|
||||
f = jit(lambda x: x * x)
|
||||
f(1)
|
||||
files_in_directory = len(os.listdir(tmpdir))
|
||||
|
Loading…
x
Reference in New Issue
Block a user