Add new config jax_persistent_cache_min_instruction_count.

This can be used to limit the number of entries written to the
persistent compilation cache.

I defaulted to setting 6 as the minimum threshold based on running the
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) and logging
the instruction counts and complilation time:

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new int config functionality.

Fixes #12583
This commit is contained in:
Skye Wanderman-Milne 2022-10-13 23:14:49 +00:00
parent 9589e5fca0
commit 81eb3fca55
4 changed files with 119 additions and 12 deletions

View File

@ -4,13 +4,14 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
<!--
Remember to align the itemized text with the first line of an item within a list.
PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
-->
## jax 0.3.24
* JAX should be faster to import. We now import scipy lazily, which accounted
for a significant fraction of JAX's import time.
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_INSTRUCTION_COUNT=$N` can be
used to limit the number of cache entries written to the persistent
cache. By default, computations with 6 or more instructions will be cached.
## jaxlib 0.3.24

View File

@ -285,6 +285,47 @@ class Config:
return _StateContextManager(name, help, update_thread_local_hook, validate)
def define_int_state(
self, name: str, default: Optional[int],
help: str, update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \
= None):
"""Set up thread-local state and return a contextmanager for managing it.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
enum_values: list of strings representing the possible values for the
option.
default: optional int, default value.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
See docstring for ``define_bool_state``.
"""
name = name.lower()
default_env = os.getenv(name.upper(), default)
if default_env is not None:
try:
default = int(default_env)
except ValueError:
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
self.DEFINE_integer(name, default, help=help, update_hook=update_global_hook)
self._contextmanager_flags.add(name)
def get_state(self):
val = getattr(_thread_local_state, name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))
def validate(new_val):
if new_val is not None and not isinstance(new_val, int):
raise ValueError(f'new int config value must be None or of type int, '
f'got {new_val} of type {type(new_val)}')
return _StateContextManager(name, help, update_thread_local_hook, validate)
def define_string_state(
self, name: str, default: Optional[str], help: str,
update_global_hook: Optional[Callable[[str], None]] = None,
@ -709,6 +750,16 @@ raise_persistent_cache_errors = config.define_bool_state(
'continue. Defaults to false so cache bugs or intermittent issues '
'are non-fatal.'))
persistent_cache_min_instruction_count = config.define_int_state(
name='jax_persistent_cache_min_instruction_count',
default=6,
help=('The minimum number of instructions a computation needs to have to '
'be written to the persistent compilation cache. This threshold can '
'be raised to decrease the number of entries written to the cache. '
'The (unoptimized) instruction count is meant to be a proxy for '
'compile time, so programs with longer compile times are still '
'cached.'))
hlo_source_file_canonicalization_regex = config.define_string_state(
name='jax_hlo_source_file_canonicalization_regex',
default=None,

View File

@ -1046,8 +1046,8 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
else:
compiled = backend_compile(backend, serialized_computation,
compile_options, host_callbacks)
_cache_write(serialized_computation, module_name, compile_options,
backend, compiled)
_cache_write(computation, serialized_computation, module_name,
compile_options, backend, compiled)
return compiled
return backend_compile(backend, serialized_computation, compile_options,
@ -1072,16 +1072,31 @@ def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
return None
def _cache_write(computation: Union[str, bytes, ir.Module], module_name: str,
compile_options: CompileOptions, backend: Backend,
compiled: XlaLoadedExecutable):
"""Writes `computation` to the persistent compilation cache."""
def _cache_write(computation: ir.Module,
serialized_computation: Union[str, bytes, ir.Module],
module_name: str, compile_options: CompileOptions,
backend: Backend, compiled: XlaLoadedExecutable):
"""Writes `serialized_computation` to the persistent compilation cache."""
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
min_instr_count = config.jax_persistent_cache_min_instruction_count
if min_instr_count:
count = _instruction_count(computation, max_count=min_instr_count)
if count < min_instr_count:
logging.info(
"Not writing persistent cache entry for '%s' because it has "
"fewer than %i instructions", module_name, min_instr_count)
return
else:
# Don't log `count` because it won't be more than max_count
logging.info(
"'%s' has at least %i instructions, writing persistent cache entry",
module_name, min_instr_count)
try:
cc.put_executable(module_name, computation, compile_options, compiled,
backend)
cc.put_executable(module_name, serialized_computation, compile_options,
compiled, backend)
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
@ -1089,6 +1104,26 @@ def _cache_write(computation: Union[str, bytes, ir.Module], module_name: str,
f"Error writing persistent compilation cache entry for "
f"'{module_name}': {type(ex).__name__}: {ex}")
def _instruction_count(module: ir.Module, max_count: Optional[int] = None):
def _blocks_count(blocks, count):
for block in blocks:
for op in block.operations:
count += 1
# Untested premature performance optimization
if max_count is not None and count >= max_count:
return max_count
for region in op.regions:
count = _blocks_count(region.blocks, count)
return count
count = 0
for func in module.body.operations:
count = _blocks_count(func.body.blocks, count)
return count
def get_buffer_counts(out_avals, ordered_effects, has_unordered_effects):
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]
if ordered_effects or has_unordered_effects:

View File

@ -36,12 +36,14 @@ from jax._src.lib import xla_client
import numpy as np
from jax.config import config
from jax._src.config import raise_persistent_cache_errors
from jax._src.config import (persistent_cache_min_instruction_count,
raise_persistent_cache_errors)
config.parse_flags_with_absl()
FLAGS = config.FLAGS
@jtu.with_config(jax_raise_persistent_cache_errors=True)
@jtu.with_config(jax_raise_persistent_cache_errors=True,
jax_persistent_cache_min_instruction_count=0)
class CompilationCacheTest(jtu.JaxTestCase):
def setUp(self):
@ -331,6 +333,24 @@ class CompilationCacheTest(jtu.JaxTestCase):
"for 'jit__lambda_': RuntimeError: test error",
str(w[0].message))
def test_min_instruction_count(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
with persistent_cache_min_instruction_count(20):
# 2 instructions at time of writing
jit(lambda x: x * x)(2)
files_in_cache = len(os.listdir(tmpdir))
self.assertEqual(files_in_cache, 0)
def f(xs):
c, b = jax.lax.scan(lambda c, x: (c + x, c + x), 0, xs)
return c + 1, b
# 32 instructions at time of writing
jit(f)(jax.numpy.ones(8))
files_in_cache = len(os.listdir(tmpdir))
self.assertEqual(files_in_cache, 1)
def create_new_debug_options(self, debug_options_obj):
debug_options_obj.xla_cpu_enable_fast_math = False
debug_options_obj.xla_cpu_fast_math_honor_infs = False