Add new config jax_persistent_cache_min_compile_time_secs.

This replaces `jax_persistent_cache_min_instruction_count` introduced
in https://github.com/google/jax/pull/12798, since gating on the
compile time seems strictly better than gating on the instruction
count (except maybe that the instruction count is more deterministic,
but I don't think that's a big deal).

I defaulted to 1 second as the minimum threshold based on the same
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) numbers from

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new float config functionality.
This commit is contained in:
Skye Wanderman-Milne 2022-10-28 23:53:30 +00:00
parent f9e7629c3f
commit cc5171034f
4 changed files with 81 additions and 36 deletions

View File

@ -10,9 +10,10 @@ Remember to align the itemized text with the first line of an item within a list
* Changes
* JAX should be faster to import. We now import scipy lazily, which accounted
for a significant fraction of JAX's import time.
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_INSTRUCTION_COUNT=$N` can be
used to limit the number of cache entries written to the persistent
cache. By default, computations with 6 or more instructions will be cached.
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N` can be
used to limit the number of cache entries written to the persistent cache.
By default, computations that take 1 second or more to compile will be
cached.
* Added {func}`jax.scipy.stats.mode`.
* The default device order used by `pmap` on TPU if no order is specified now
matches `jax.devices()` for single-process jobs. Previously the

View File

@ -126,6 +126,10 @@ class Config:
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
def DEFINE_float(self, name, default, *args, **kwargs):
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, float, args, kwargs, update_hook=update_hook)
def DEFINE_string(self, name, default, *args, **kwargs):
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
@ -144,6 +148,7 @@ class Config:
self.absl_flags = absl_flags
absl_defs = { bool: absl_flags.DEFINE_bool,
int: absl_flags.DEFINE_integer,
float: absl_flags.DEFINE_float,
str: absl_flags.DEFINE_string,
'enum': absl_flags.DEFINE_enum }
@ -327,6 +332,47 @@ class Config:
return _StateContextManager(name, help, update_thread_local_hook, validate)
def define_float_state(
self, name: str, default: Optional[float],
help: str, update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \
= None):
"""Set up thread-local state and return a contextmanager for managing it.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
enum_values: list of strings representing the possible values for the
option.
default: optional float, default value.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
See docstring for ``define_bool_state``.
"""
name = name.lower()
default_env = os.getenv(name.upper(), default)
if default_env is not None:
try:
default = float(default_env)
except ValueError:
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
self.DEFINE_float(name, default, help=help, update_hook=update_global_hook)
self._contextmanager_flags.add(name)
def get_state(self):
val = getattr(_thread_local_state, name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))
def validate(new_val):
if new_val is not None and not isinstance(new_val, (float, int)):
raise ValueError(f'new float config value must be None or of type float, '
f'got {new_val} of type {type(new_val)}')
return _StateContextManager(name, help, update_thread_local_hook, validate)
def define_string_state(
self, name: str, default: Optional[str], help: str,
update_global_hook: Optional[Callable[[str], None]] = None,
@ -765,15 +811,12 @@ raise_persistent_cache_errors = config.define_bool_state(
'continue. Defaults to false so cache bugs or intermittent issues '
'are non-fatal.'))
persistent_cache_min_instruction_count = config.define_int_state(
name='jax_persistent_cache_min_instruction_count',
default=6,
help=('The minimum number of instructions a computation needs to have to '
'be written to the persistent compilation cache. This threshold can '
'be raised to decrease the number of entries written to the cache. '
'The (unoptimized) instruction count is meant to be a proxy for '
'compile time, so programs with longer compile times are still '
'cached.'))
persistent_cache_min_compile_time_secs = config.define_float_state(
name='jax_persistent_cache_min_compile_time_secs',
default=1,
help=('The minimum compile time of a computation to be written to the '
'persistent compilation cache. This threshold can be raised to '
'decrease the number of entries written to the cache.'))
hlo_source_file_canonicalization_regex = config.define_string_state(
name='jax_hlo_source_file_canonicalization_regex',

View File

@ -1066,9 +1066,11 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
logger.info("Persistent compilation cache hit for '%s'", module_name)
return cached_executable
else:
start_time = time.monotonic()
compiled = backend_compile(backend, serialized_computation,
compile_options, host_callbacks)
_cache_write(computation, serialized_computation, module_name,
compile_time = time.monotonic() - start_time
_cache_write(serialized_computation, compile_time, module_name,
compile_options, backend, compiled)
return compiled
@ -1094,27 +1096,27 @@ def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
return None
def _cache_write(computation: ir.Module,
serialized_computation: Union[str, bytes, ir.Module],
def _cache_write(serialized_computation: Union[str, bytes, ir.Module],
compile_time_secs: float,
module_name: str, compile_options: CompileOptions,
backend: Backend, compiled: XlaLoadedExecutable):
"""Writes `serialized_computation` to the persistent compilation cache."""
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
min_instr_count = config.jax_persistent_cache_min_instruction_count
if min_instr_count:
count = _instruction_count(computation, max_count=min_instr_count)
if count < min_instr_count:
min_compile_time = config.jax_persistent_cache_min_compile_time_secs
if min_compile_time:
if compile_time_secs < min_compile_time:
logging.info(
"Not writing persistent cache entry for '%s' because it has "
"fewer than %i instructions", module_name, min_instr_count)
"Not writing persistent cache entry for '%s' because it took < %.2f "
"seconds to compile (%.2fs)", module_name, min_compile_time,
compile_time_secs)
return
else:
# Don't log `count` because it won't be more than max_count
logging.info(
"'%s' has at least %i instructions, writing persistent cache entry",
module_name, min_instr_count)
"'%s' took at least %.2f seconds to compile (%.2fs), writing "
"persistent cache entry", module_name, min_compile_time,
compile_time_secs)
try:
cc.put_executable(module_name, serialized_computation, compile_options,

View File

@ -36,14 +36,14 @@ from jax._src.lib import xla_client
import numpy as np
from jax.config import config
from jax._src.config import (persistent_cache_min_instruction_count,
from jax._src.config import (persistent_cache_min_compile_time_secs,
raise_persistent_cache_errors)
config.parse_flags_with_absl()
FLAGS = config.FLAGS
@jtu.with_config(jax_raise_persistent_cache_errors=True,
jax_persistent_cache_min_instruction_count=0)
jax_persistent_cache_min_compile_time_secs=0)
class CompilationCacheTest(jtu.JaxTestCase):
def setUp(self):
@ -335,21 +335,20 @@ class CompilationCacheTest(jtu.JaxTestCase):
"for 'jit__lambda_': RuntimeError: test error",
str(w[0].message))
def test_min_instruction_count(self):
with tempfile.TemporaryDirectory() as tmpdir:
def test_min_compile_time(self):
with tempfile.TemporaryDirectory() as tmpdir, \
persistent_cache_min_compile_time_secs(2):
cc.initialize_cache(tmpdir)
with persistent_cache_min_instruction_count(20):
# 2 instructions at time of writing
jit(lambda x: x * x)(2)
# Mock time to progress in small intervals so compilation time is small.
with mock.patch("time.monotonic", side_effect=np.arange(0, 10, .1)):
jit(lambda x: x + 1)(1)
files_in_cache = len(os.listdir(tmpdir))
self.assertEqual(files_in_cache, 0)
def f(xs):
c, b = jax.lax.scan(lambda c, x: (c + x, c + x), 0, xs)
return c + 1, b
# 32 instructions at time of writing
jit(f)(jax.numpy.ones(8))
# Mock time to progress in large intervals so compilation time is large.
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
jit(lambda x: x + 2)(1)
files_in_cache = len(os.listdir(tmpdir))
self.assertEqual(files_in_cache, 1)