mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add new config jax_persistent_cache_min_instruction_count
.
This can be used to limit the number of entries written to the persistent compilation cache. I defaulted to setting 6 as the minimum threshold based on running the flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) and logging the instruction counts and complilation time: name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new int config functionality. Fixes #12583
This commit is contained in:
parent
9589e5fca0
commit
81eb3fca55
@ -4,13 +4,14 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
|
||||
|
||||
<!--
|
||||
Remember to align the itemized text with the first line of an item within a list.
|
||||
|
||||
PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
-->
|
||||
|
||||
## jax 0.3.24
|
||||
* JAX should be faster to import. We now import scipy lazily, which accounted
|
||||
for a significant fraction of JAX's import time.
|
||||
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_INSTRUCTION_COUNT=$N` can be
|
||||
used to limit the number of cache entries written to the persistent
|
||||
cache. By default, computations with 6 or more instructions will be cached.
|
||||
|
||||
## jaxlib 0.3.24
|
||||
|
||||
|
@ -285,6 +285,47 @@ class Config:
|
||||
|
||||
return _StateContextManager(name, help, update_thread_local_hook, validate)
|
||||
|
||||
def define_int_state(
|
||||
self, name: str, default: Optional[int],
|
||||
help: str, update_global_hook: Optional[Callable[[str], None]] = None,
|
||||
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \
|
||||
= None):
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
Args:
|
||||
name: string, converted to lowercase to define the name of the config
|
||||
option (and absl flag). It is converted to uppercase to define the
|
||||
corresponding shell environment variable.
|
||||
enum_values: list of strings representing the possible values for the
|
||||
option.
|
||||
default: optional int, default value.
|
||||
help: string, used to populate the flag help information as well as the
|
||||
docstring of the returned context manager.
|
||||
Returns:
|
||||
A contextmanager to control the thread-local state value.
|
||||
See docstring for ``define_bool_state``.
|
||||
"""
|
||||
name = name.lower()
|
||||
default_env = os.getenv(name.upper(), default)
|
||||
if default_env is not None:
|
||||
try:
|
||||
default = int(default_env)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
|
||||
self.DEFINE_integer(name, default, help=help, update_hook=update_global_hook)
|
||||
self._contextmanager_flags.add(name)
|
||||
|
||||
def get_state(self):
|
||||
val = getattr(_thread_local_state, name, unset)
|
||||
return val if val is not unset else self._read(name)
|
||||
setattr(Config, name, property(get_state))
|
||||
|
||||
def validate(new_val):
|
||||
if new_val is not None and not isinstance(new_val, int):
|
||||
raise ValueError(f'new int config value must be None or of type int, '
|
||||
f'got {new_val} of type {type(new_val)}')
|
||||
|
||||
return _StateContextManager(name, help, update_thread_local_hook, validate)
|
||||
|
||||
def define_string_state(
|
||||
self, name: str, default: Optional[str], help: str,
|
||||
update_global_hook: Optional[Callable[[str], None]] = None,
|
||||
@ -709,6 +750,16 @@ raise_persistent_cache_errors = config.define_bool_state(
|
||||
'continue. Defaults to false so cache bugs or intermittent issues '
|
||||
'are non-fatal.'))
|
||||
|
||||
persistent_cache_min_instruction_count = config.define_int_state(
|
||||
name='jax_persistent_cache_min_instruction_count',
|
||||
default=6,
|
||||
help=('The minimum number of instructions a computation needs to have to '
|
||||
'be written to the persistent compilation cache. This threshold can '
|
||||
'be raised to decrease the number of entries written to the cache. '
|
||||
'The (unoptimized) instruction count is meant to be a proxy for '
|
||||
'compile time, so programs with longer compile times are still '
|
||||
'cached.'))
|
||||
|
||||
hlo_source_file_canonicalization_regex = config.define_string_state(
|
||||
name='jax_hlo_source_file_canonicalization_regex',
|
||||
default=None,
|
||||
|
@ -1046,8 +1046,8 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
||||
else:
|
||||
compiled = backend_compile(backend, serialized_computation,
|
||||
compile_options, host_callbacks)
|
||||
_cache_write(serialized_computation, module_name, compile_options,
|
||||
backend, compiled)
|
||||
_cache_write(computation, serialized_computation, module_name,
|
||||
compile_options, backend, compiled)
|
||||
return compiled
|
||||
|
||||
return backend_compile(backend, serialized_computation, compile_options,
|
||||
@ -1072,16 +1072,31 @@ def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
|
||||
return None
|
||||
|
||||
|
||||
def _cache_write(computation: Union[str, bytes, ir.Module], module_name: str,
|
||||
compile_options: CompileOptions, backend: Backend,
|
||||
compiled: XlaLoadedExecutable):
|
||||
"""Writes `computation` to the persistent compilation cache."""
|
||||
def _cache_write(computation: ir.Module,
|
||||
serialized_computation: Union[str, bytes, ir.Module],
|
||||
module_name: str, compile_options: CompileOptions,
|
||||
backend: Backend, compiled: XlaLoadedExecutable):
|
||||
"""Writes `serialized_computation` to the persistent compilation cache."""
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
|
||||
min_instr_count = config.jax_persistent_cache_min_instruction_count
|
||||
if min_instr_count:
|
||||
count = _instruction_count(computation, max_count=min_instr_count)
|
||||
if count < min_instr_count:
|
||||
logging.info(
|
||||
"Not writing persistent cache entry for '%s' because it has "
|
||||
"fewer than %i instructions", module_name, min_instr_count)
|
||||
return
|
||||
else:
|
||||
# Don't log `count` because it won't be more than max_count
|
||||
logging.info(
|
||||
"'%s' has at least %i instructions, writing persistent cache entry",
|
||||
module_name, min_instr_count)
|
||||
|
||||
try:
|
||||
cc.put_executable(module_name, computation, compile_options, compiled,
|
||||
backend)
|
||||
cc.put_executable(module_name, serialized_computation, compile_options,
|
||||
compiled, backend)
|
||||
except Exception as ex:
|
||||
if config.jax_raise_persistent_cache_errors:
|
||||
raise
|
||||
@ -1089,6 +1104,26 @@ def _cache_write(computation: Union[str, bytes, ir.Module], module_name: str,
|
||||
f"Error writing persistent compilation cache entry for "
|
||||
f"'{module_name}': {type(ex).__name__}: {ex}")
|
||||
|
||||
|
||||
def _instruction_count(module: ir.Module, max_count: Optional[int] = None):
|
||||
|
||||
def _blocks_count(blocks, count):
|
||||
for block in blocks:
|
||||
for op in block.operations:
|
||||
count += 1
|
||||
# Untested premature performance optimization
|
||||
if max_count is not None and count >= max_count:
|
||||
return max_count
|
||||
for region in op.regions:
|
||||
count = _blocks_count(region.blocks, count)
|
||||
return count
|
||||
|
||||
count = 0
|
||||
for func in module.body.operations:
|
||||
count = _blocks_count(func.body.blocks, count)
|
||||
return count
|
||||
|
||||
|
||||
def get_buffer_counts(out_avals, ordered_effects, has_unordered_effects):
|
||||
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]
|
||||
if ordered_effects or has_unordered_effects:
|
||||
|
@ -36,12 +36,14 @@ from jax._src.lib import xla_client
|
||||
import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
from jax._src.config import raise_persistent_cache_errors
|
||||
from jax._src.config import (persistent_cache_min_instruction_count,
|
||||
raise_persistent_cache_errors)
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
@jtu.with_config(jax_raise_persistent_cache_errors=True)
|
||||
@jtu.with_config(jax_raise_persistent_cache_errors=True,
|
||||
jax_persistent_cache_min_instruction_count=0)
|
||||
class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -331,6 +333,24 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
"for 'jit__lambda_': RuntimeError: test error",
|
||||
str(w[0].message))
|
||||
|
||||
def test_min_instruction_count(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
|
||||
with persistent_cache_min_instruction_count(20):
|
||||
# 2 instructions at time of writing
|
||||
jit(lambda x: x * x)(2)
|
||||
files_in_cache = len(os.listdir(tmpdir))
|
||||
self.assertEqual(files_in_cache, 0)
|
||||
|
||||
def f(xs):
|
||||
c, b = jax.lax.scan(lambda c, x: (c + x, c + x), 0, xs)
|
||||
return c + 1, b
|
||||
# 32 instructions at time of writing
|
||||
jit(f)(jax.numpy.ones(8))
|
||||
files_in_cache = len(os.listdir(tmpdir))
|
||||
self.assertEqual(files_in_cache, 1)
|
||||
|
||||
def create_new_debug_options(self, debug_options_obj):
|
||||
debug_options_obj.xla_cpu_enable_fast_math = False
|
||||
debug_options_obj.xla_cpu_fast_math_honor_infs = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user