Add new config jax_persistent_cache_min_compile_time_secs.

This replaces `jax_persistent_cache_min_instruction_count` introduced
in https://github.com/google/jax/pull/12798, since gating on the
compile time seems strictly better than gating on the instruction
count (except maybe that the instruction count is more deterministic,
but I don't think that's a big deal).

I defaulted to 1 second as the minimum threshold based on the same
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) numbers from

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new float config functionality.
This commit is contained in:
Skye Wanderman-Milne 2022-10-28 23:53:30 +00:00
parent f9e7629c3f
commit cc5171034f
4 changed files with 81 additions and 36 deletions

View File

@ -10,9 +10,10 @@ Remember to align the itemized text with the first line of an item within a list
* Changes * Changes
* JAX should be faster to import. We now import scipy lazily, which accounted * JAX should be faster to import. We now import scipy lazily, which accounted
for a significant fraction of JAX's import time. for a significant fraction of JAX's import time.
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_INSTRUCTION_COUNT=$N` can be * Setting the env var `JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N` can be
used to limit the number of cache entries written to the persistent used to limit the number of cache entries written to the persistent cache.
cache. By default, computations with 6 or more instructions will be cached. By default, computations that take 1 second or more to compile will be
cached.
* Added {func}`jax.scipy.stats.mode`. * Added {func}`jax.scipy.stats.mode`.
* The default device order used by `pmap` on TPU if no order is specified now * The default device order used by `pmap` on TPU if no order is specified now
matches `jax.devices()` for single-process jobs. Previously the matches `jax.devices()` for single-process jobs. Previously the

View File

@ -126,6 +126,10 @@ class Config:
update_hook = kwargs.pop("update_hook", None) update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, int, args, kwargs, update_hook=update_hook) self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
def DEFINE_float(self, name, default, *args, **kwargs):
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, float, args, kwargs, update_hook=update_hook)
def DEFINE_string(self, name, default, *args, **kwargs): def DEFINE_string(self, name, default, *args, **kwargs):
update_hook = kwargs.pop("update_hook", None) update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, str, args, kwargs, update_hook=update_hook) self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
@ -144,6 +148,7 @@ class Config:
self.absl_flags = absl_flags self.absl_flags = absl_flags
absl_defs = { bool: absl_flags.DEFINE_bool, absl_defs = { bool: absl_flags.DEFINE_bool,
int: absl_flags.DEFINE_integer, int: absl_flags.DEFINE_integer,
float: absl_flags.DEFINE_float,
str: absl_flags.DEFINE_string, str: absl_flags.DEFINE_string,
'enum': absl_flags.DEFINE_enum } 'enum': absl_flags.DEFINE_enum }
@ -327,6 +332,47 @@ class Config:
return _StateContextManager(name, help, update_thread_local_hook, validate) return _StateContextManager(name, help, update_thread_local_hook, validate)
def define_float_state(
self, name: str, default: Optional[float],
help: str, update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \
= None):
"""Set up thread-local state and return a contextmanager for managing it.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
enum_values: list of strings representing the possible values for the
option.
default: optional float, default value.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
See docstring for ``define_bool_state``.
"""
name = name.lower()
default_env = os.getenv(name.upper(), default)
if default_env is not None:
try:
default = float(default_env)
except ValueError:
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
self.DEFINE_float(name, default, help=help, update_hook=update_global_hook)
self._contextmanager_flags.add(name)
def get_state(self):
val = getattr(_thread_local_state, name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))
def validate(new_val):
if new_val is not None and not isinstance(new_val, (float, int)):
raise ValueError(f'new float config value must be None or of type float, '
f'got {new_val} of type {type(new_val)}')
return _StateContextManager(name, help, update_thread_local_hook, validate)
def define_string_state( def define_string_state(
self, name: str, default: Optional[str], help: str, self, name: str, default: Optional[str], help: str,
update_global_hook: Optional[Callable[[str], None]] = None, update_global_hook: Optional[Callable[[str], None]] = None,
@ -765,15 +811,12 @@ raise_persistent_cache_errors = config.define_bool_state(
'continue. Defaults to false so cache bugs or intermittent issues ' 'continue. Defaults to false so cache bugs or intermittent issues '
'are non-fatal.')) 'are non-fatal.'))
persistent_cache_min_instruction_count = config.define_int_state( persistent_cache_min_compile_time_secs = config.define_float_state(
name='jax_persistent_cache_min_instruction_count', name='jax_persistent_cache_min_compile_time_secs',
default=6, default=1,
help=('The minimum number of instructions a computation needs to have to ' help=('The minimum compile time of a computation to be written to the '
'be written to the persistent compilation cache. This threshold can ' 'persistent compilation cache. This threshold can be raised to '
'be raised to decrease the number of entries written to the cache. ' 'decrease the number of entries written to the cache.'))
'The (unoptimized) instruction count is meant to be a proxy for '
'compile time, so programs with longer compile times are still '
'cached.'))
hlo_source_file_canonicalization_regex = config.define_string_state( hlo_source_file_canonicalization_regex = config.define_string_state(
name='jax_hlo_source_file_canonicalization_regex', name='jax_hlo_source_file_canonicalization_regex',

View File

@ -1066,9 +1066,11 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
logger.info("Persistent compilation cache hit for '%s'", module_name) logger.info("Persistent compilation cache hit for '%s'", module_name)
return cached_executable return cached_executable
else: else:
start_time = time.monotonic()
compiled = backend_compile(backend, serialized_computation, compiled = backend_compile(backend, serialized_computation,
compile_options, host_callbacks) compile_options, host_callbacks)
_cache_write(computation, serialized_computation, module_name, compile_time = time.monotonic() - start_time
_cache_write(serialized_computation, compile_time, module_name,
compile_options, backend, compiled) compile_options, backend, compiled)
return compiled return compiled
@ -1094,27 +1096,27 @@ def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
return None return None
def _cache_write(computation: ir.Module, def _cache_write(serialized_computation: Union[str, bytes, ir.Module],
serialized_computation: Union[str, bytes, ir.Module], compile_time_secs: float,
module_name: str, compile_options: CompileOptions, module_name: str, compile_options: CompileOptions,
backend: Backend, compiled: XlaLoadedExecutable): backend: Backend, compiled: XlaLoadedExecutable):
"""Writes `serialized_computation` to the persistent compilation cache.""" """Writes `serialized_computation` to the persistent compilation cache."""
# Avoid import cycle between jax and jax.experimental # Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc from jax.experimental.compilation_cache import compilation_cache as cc
min_instr_count = config.jax_persistent_cache_min_instruction_count min_compile_time = config.jax_persistent_cache_min_compile_time_secs
if min_instr_count: if min_compile_time:
count = _instruction_count(computation, max_count=min_instr_count) if compile_time_secs < min_compile_time:
if count < min_instr_count:
logging.info( logging.info(
"Not writing persistent cache entry for '%s' because it has " "Not writing persistent cache entry for '%s' because it took < %.2f "
"fewer than %i instructions", module_name, min_instr_count) "seconds to compile (%.2fs)", module_name, min_compile_time,
compile_time_secs)
return return
else: else:
# Don't log `count` because it won't be more than max_count
logging.info( logging.info(
"'%s' has at least %i instructions, writing persistent cache entry", "'%s' took at least %.2f seconds to compile (%.2fs), writing "
module_name, min_instr_count) "persistent cache entry", module_name, min_compile_time,
compile_time_secs)
try: try:
cc.put_executable(module_name, serialized_computation, compile_options, cc.put_executable(module_name, serialized_computation, compile_options,

View File

@ -36,14 +36,14 @@ from jax._src.lib import xla_client
import numpy as np import numpy as np
from jax.config import config from jax.config import config
from jax._src.config import (persistent_cache_min_instruction_count, from jax._src.config import (persistent_cache_min_compile_time_secs,
raise_persistent_cache_errors) raise_persistent_cache_errors)
config.parse_flags_with_absl() config.parse_flags_with_absl()
FLAGS = config.FLAGS FLAGS = config.FLAGS
@jtu.with_config(jax_raise_persistent_cache_errors=True, @jtu.with_config(jax_raise_persistent_cache_errors=True,
jax_persistent_cache_min_instruction_count=0) jax_persistent_cache_min_compile_time_secs=0)
class CompilationCacheTest(jtu.JaxTestCase): class CompilationCacheTest(jtu.JaxTestCase):
def setUp(self): def setUp(self):
@ -335,21 +335,20 @@ class CompilationCacheTest(jtu.JaxTestCase):
"for 'jit__lambda_': RuntimeError: test error", "for 'jit__lambda_': RuntimeError: test error",
str(w[0].message)) str(w[0].message))
def test_min_instruction_count(self): def test_min_compile_time(self):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir, \
persistent_cache_min_compile_time_secs(2):
cc.initialize_cache(tmpdir) cc.initialize_cache(tmpdir)
with persistent_cache_min_instruction_count(20): # Mock time to progress in small intervals so compilation time is small.
# 2 instructions at time of writing with mock.patch("time.monotonic", side_effect=np.arange(0, 10, .1)):
jit(lambda x: x * x)(2) jit(lambda x: x + 1)(1)
files_in_cache = len(os.listdir(tmpdir)) files_in_cache = len(os.listdir(tmpdir))
self.assertEqual(files_in_cache, 0) self.assertEqual(files_in_cache, 0)
def f(xs): # Mock time to progress in large intervals so compilation time is large.
c, b = jax.lax.scan(lambda c, x: (c + x, c + x), 0, xs) with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
return c + 1, b jit(lambda x: x + 2)(1)
# 32 instructions at time of writing
jit(f)(jax.numpy.ones(8))
files_in_cache = len(os.listdir(tmpdir)) files_in_cache = len(os.listdir(tmpdir))
self.assertEqual(files_in_cache, 1) self.assertEqual(files_in_cache, 1)