diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fa11c71c..eac3822d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,10 @@ Remember to align the itemized text with the first line of an item within a list * Changes * JAX should be faster to import. We now import scipy lazily, which accounted for a significant fraction of JAX's import time. - * Setting the env var `JAX_PERSISTENT_CACHE_MIN_INSTRUCTION_COUNT=$N` can be - used to limit the number of cache entries written to the persistent - cache. By default, computations with 6 or more instructions will be cached. + * Setting the env var `JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N` can be + used to limit the number of cache entries written to the persistent cache. + By default, computations that take 1 second or more to compile will be + cached. * Added {func}`jax.scipy.stats.mode`. * The default device order used by `pmap` on TPU if no order is specified now matches `jax.devices()` for single-process jobs. Previously the diff --git a/jax/_src/config.py b/jax/_src/config.py index 39b9fcc18..233290a59 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -126,6 +126,10 @@ class Config: update_hook = kwargs.pop("update_hook", None) self.add_option(name, default, int, args, kwargs, update_hook=update_hook) + def DEFINE_float(self, name, default, *args, **kwargs): + update_hook = kwargs.pop("update_hook", None) + self.add_option(name, default, float, args, kwargs, update_hook=update_hook) + def DEFINE_string(self, name, default, *args, **kwargs): update_hook = kwargs.pop("update_hook", None) self.add_option(name, default, str, args, kwargs, update_hook=update_hook) @@ -144,6 +148,7 @@ class Config: self.absl_flags = absl_flags absl_defs = { bool: absl_flags.DEFINE_bool, int: absl_flags.DEFINE_integer, + float: absl_flags.DEFINE_float, str: absl_flags.DEFINE_string, 'enum': absl_flags.DEFINE_enum } @@ -327,6 +332,47 @@ class Config: return _StateContextManager(name, help, update_thread_local_hook, validate) + def define_float_state( + self, name: str, default: Optional[float], + help: str, update_global_hook: Optional[Callable[[str], None]] = None, + update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \ + = None): + """Set up thread-local state and return a contextmanager for managing it. + Args: + name: string, converted to lowercase to define the name of the config + option (and absl flag). It is converted to uppercase to define the + corresponding shell environment variable. + enum_values: list of strings representing the possible values for the + option. + default: optional float, default value. + help: string, used to populate the flag help information as well as the + docstring of the returned context manager. + Returns: + A contextmanager to control the thread-local state value. + See docstring for ``define_bool_state``. + """ + name = name.lower() + default_env = os.getenv(name.upper(), default) + if default_env is not None: + try: + default = float(default_env) + except ValueError: + raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}") + self.DEFINE_float(name, default, help=help, update_hook=update_global_hook) + self._contextmanager_flags.add(name) + + def get_state(self): + val = getattr(_thread_local_state, name, unset) + return val if val is not unset else self._read(name) + setattr(Config, name, property(get_state)) + + def validate(new_val): + if new_val is not None and not isinstance(new_val, (float, int)): + raise ValueError(f'new float config value must be None or of type float, ' + f'got {new_val} of type {type(new_val)}') + + return _StateContextManager(name, help, update_thread_local_hook, validate) + def define_string_state( self, name: str, default: Optional[str], help: str, update_global_hook: Optional[Callable[[str], None]] = None, @@ -765,15 +811,12 @@ raise_persistent_cache_errors = config.define_bool_state( 'continue. Defaults to false so cache bugs or intermittent issues ' 'are non-fatal.')) -persistent_cache_min_instruction_count = config.define_int_state( - name='jax_persistent_cache_min_instruction_count', - default=6, - help=('The minimum number of instructions a computation needs to have to ' - 'be written to the persistent compilation cache. This threshold can ' - 'be raised to decrease the number of entries written to the cache. ' - 'The (unoptimized) instruction count is meant to be a proxy for ' - 'compile time, so programs with longer compile times are still ' - 'cached.')) +persistent_cache_min_compile_time_secs = config.define_float_state( + name='jax_persistent_cache_min_compile_time_secs', + default=1, + help=('The minimum compile time of a computation to be written to the ' + 'persistent compilation cache. This threshold can be raised to ' + 'decrease the number of entries written to the cache.')) hlo_source_file_canonicalization_regex = config.define_string_state( name='jax_hlo_source_file_canonicalization_regex', diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 6eca1183c..df57c62cd 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -1066,9 +1066,11 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options, logger.info("Persistent compilation cache hit for '%s'", module_name) return cached_executable else: + start_time = time.monotonic() compiled = backend_compile(backend, serialized_computation, compile_options, host_callbacks) - _cache_write(computation, serialized_computation, module_name, + compile_time = time.monotonic() - start_time + _cache_write(serialized_computation, compile_time, module_name, compile_options, backend, compiled) return compiled @@ -1094,27 +1096,27 @@ def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str, return None -def _cache_write(computation: ir.Module, - serialized_computation: Union[str, bytes, ir.Module], +def _cache_write(serialized_computation: Union[str, bytes, ir.Module], + compile_time_secs: float, module_name: str, compile_options: CompileOptions, backend: Backend, compiled: XlaLoadedExecutable): """Writes `serialized_computation` to the persistent compilation cache.""" # Avoid import cycle between jax and jax.experimental from jax.experimental.compilation_cache import compilation_cache as cc - min_instr_count = config.jax_persistent_cache_min_instruction_count - if min_instr_count: - count = _instruction_count(computation, max_count=min_instr_count) - if count < min_instr_count: + min_compile_time = config.jax_persistent_cache_min_compile_time_secs + if min_compile_time: + if compile_time_secs < min_compile_time: logging.info( - "Not writing persistent cache entry for '%s' because it has " - "fewer than %i instructions", module_name, min_instr_count) + "Not writing persistent cache entry for '%s' because it took < %.2f " + "seconds to compile (%.2fs)", module_name, min_compile_time, + compile_time_secs) return else: - # Don't log `count` because it won't be more than max_count logging.info( - "'%s' has at least %i instructions, writing persistent cache entry", - module_name, min_instr_count) + "'%s' took at least %.2f seconds to compile (%.2fs), writing " + "persistent cache entry", module_name, min_compile_time, + compile_time_secs) try: cc.put_executable(module_name, serialized_computation, compile_options, diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index e3248d781..9b7bbcdc7 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -36,14 +36,14 @@ from jax._src.lib import xla_client import numpy as np from jax.config import config -from jax._src.config import (persistent_cache_min_instruction_count, +from jax._src.config import (persistent_cache_min_compile_time_secs, raise_persistent_cache_errors) config.parse_flags_with_absl() FLAGS = config.FLAGS @jtu.with_config(jax_raise_persistent_cache_errors=True, - jax_persistent_cache_min_instruction_count=0) + jax_persistent_cache_min_compile_time_secs=0) class CompilationCacheTest(jtu.JaxTestCase): def setUp(self): @@ -335,21 +335,20 @@ class CompilationCacheTest(jtu.JaxTestCase): "for 'jit__lambda_': RuntimeError: test error", str(w[0].message)) - def test_min_instruction_count(self): - with tempfile.TemporaryDirectory() as tmpdir: + def test_min_compile_time(self): + with tempfile.TemporaryDirectory() as tmpdir, \ + persistent_cache_min_compile_time_secs(2): cc.initialize_cache(tmpdir) - with persistent_cache_min_instruction_count(20): - # 2 instructions at time of writing - jit(lambda x: x * x)(2) + # Mock time to progress in small intervals so compilation time is small. + with mock.patch("time.monotonic", side_effect=np.arange(0, 10, .1)): + jit(lambda x: x + 1)(1) files_in_cache = len(os.listdir(tmpdir)) self.assertEqual(files_in_cache, 0) - def f(xs): - c, b = jax.lax.scan(lambda c, x: (c + x, c + x), 0, xs) - return c + 1, b - # 32 instructions at time of writing - jit(f)(jax.numpy.ones(8)) + # Mock time to progress in large intervals so compilation time is large. + with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)): + jit(lambda x: x + 2)(1) files_in_cache = len(os.listdir(tmpdir)) self.assertEqual(files_in_cache, 1)