mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add new config jax_persistent_cache_min_compile_time_secs
.
This replaces `jax_persistent_cache_min_instruction_count` introduced in https://github.com/google/jax/pull/12798, since gating on the compile time seems strictly better than gating on the instruction count (except maybe that the instruction count is more deterministic, but I don't think that's a big deal). I defaulted to 1 second as the minimum threshold based on the same flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) numbers from name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new float config functionality.
This commit is contained in:
parent
f9e7629c3f
commit
cc5171034f
@ -10,9 +10,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Changes
|
||||
* JAX should be faster to import. We now import scipy lazily, which accounted
|
||||
for a significant fraction of JAX's import time.
|
||||
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_INSTRUCTION_COUNT=$N` can be
|
||||
used to limit the number of cache entries written to the persistent
|
||||
cache. By default, computations with 6 or more instructions will be cached.
|
||||
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N` can be
|
||||
used to limit the number of cache entries written to the persistent cache.
|
||||
By default, computations that take 1 second or more to compile will be
|
||||
cached.
|
||||
* Added {func}`jax.scipy.stats.mode`.
|
||||
* The default device order used by `pmap` on TPU if no order is specified now
|
||||
matches `jax.devices()` for single-process jobs. Previously the
|
||||
|
@ -126,6 +126,10 @@ class Config:
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
|
||||
|
||||
def DEFINE_float(self, name, default, *args, **kwargs):
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, float, args, kwargs, update_hook=update_hook)
|
||||
|
||||
def DEFINE_string(self, name, default, *args, **kwargs):
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
|
||||
@ -144,6 +148,7 @@ class Config:
|
||||
self.absl_flags = absl_flags
|
||||
absl_defs = { bool: absl_flags.DEFINE_bool,
|
||||
int: absl_flags.DEFINE_integer,
|
||||
float: absl_flags.DEFINE_float,
|
||||
str: absl_flags.DEFINE_string,
|
||||
'enum': absl_flags.DEFINE_enum }
|
||||
|
||||
@ -327,6 +332,47 @@ class Config:
|
||||
|
||||
return _StateContextManager(name, help, update_thread_local_hook, validate)
|
||||
|
||||
def define_float_state(
|
||||
self, name: str, default: Optional[float],
|
||||
help: str, update_global_hook: Optional[Callable[[str], None]] = None,
|
||||
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \
|
||||
= None):
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
Args:
|
||||
name: string, converted to lowercase to define the name of the config
|
||||
option (and absl flag). It is converted to uppercase to define the
|
||||
corresponding shell environment variable.
|
||||
enum_values: list of strings representing the possible values for the
|
||||
option.
|
||||
default: optional float, default value.
|
||||
help: string, used to populate the flag help information as well as the
|
||||
docstring of the returned context manager.
|
||||
Returns:
|
||||
A contextmanager to control the thread-local state value.
|
||||
See docstring for ``define_bool_state``.
|
||||
"""
|
||||
name = name.lower()
|
||||
default_env = os.getenv(name.upper(), default)
|
||||
if default_env is not None:
|
||||
try:
|
||||
default = float(default_env)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
|
||||
self.DEFINE_float(name, default, help=help, update_hook=update_global_hook)
|
||||
self._contextmanager_flags.add(name)
|
||||
|
||||
def get_state(self):
|
||||
val = getattr(_thread_local_state, name, unset)
|
||||
return val if val is not unset else self._read(name)
|
||||
setattr(Config, name, property(get_state))
|
||||
|
||||
def validate(new_val):
|
||||
if new_val is not None and not isinstance(new_val, (float, int)):
|
||||
raise ValueError(f'new float config value must be None or of type float, '
|
||||
f'got {new_val} of type {type(new_val)}')
|
||||
|
||||
return _StateContextManager(name, help, update_thread_local_hook, validate)
|
||||
|
||||
def define_string_state(
|
||||
self, name: str, default: Optional[str], help: str,
|
||||
update_global_hook: Optional[Callable[[str], None]] = None,
|
||||
@ -765,15 +811,12 @@ raise_persistent_cache_errors = config.define_bool_state(
|
||||
'continue. Defaults to false so cache bugs or intermittent issues '
|
||||
'are non-fatal.'))
|
||||
|
||||
persistent_cache_min_instruction_count = config.define_int_state(
|
||||
name='jax_persistent_cache_min_instruction_count',
|
||||
default=6,
|
||||
help=('The minimum number of instructions a computation needs to have to '
|
||||
'be written to the persistent compilation cache. This threshold can '
|
||||
'be raised to decrease the number of entries written to the cache. '
|
||||
'The (unoptimized) instruction count is meant to be a proxy for '
|
||||
'compile time, so programs with longer compile times are still '
|
||||
'cached.'))
|
||||
persistent_cache_min_compile_time_secs = config.define_float_state(
|
||||
name='jax_persistent_cache_min_compile_time_secs',
|
||||
default=1,
|
||||
help=('The minimum compile time of a computation to be written to the '
|
||||
'persistent compilation cache. This threshold can be raised to '
|
||||
'decrease the number of entries written to the cache.'))
|
||||
|
||||
hlo_source_file_canonicalization_regex = config.define_string_state(
|
||||
name='jax_hlo_source_file_canonicalization_regex',
|
||||
|
@ -1066,9 +1066,11 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
||||
logger.info("Persistent compilation cache hit for '%s'", module_name)
|
||||
return cached_executable
|
||||
else:
|
||||
start_time = time.monotonic()
|
||||
compiled = backend_compile(backend, serialized_computation,
|
||||
compile_options, host_callbacks)
|
||||
_cache_write(computation, serialized_computation, module_name,
|
||||
compile_time = time.monotonic() - start_time
|
||||
_cache_write(serialized_computation, compile_time, module_name,
|
||||
compile_options, backend, compiled)
|
||||
return compiled
|
||||
|
||||
@ -1094,27 +1096,27 @@ def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
|
||||
return None
|
||||
|
||||
|
||||
def _cache_write(computation: ir.Module,
|
||||
serialized_computation: Union[str, bytes, ir.Module],
|
||||
def _cache_write(serialized_computation: Union[str, bytes, ir.Module],
|
||||
compile_time_secs: float,
|
||||
module_name: str, compile_options: CompileOptions,
|
||||
backend: Backend, compiled: XlaLoadedExecutable):
|
||||
"""Writes `serialized_computation` to the persistent compilation cache."""
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
|
||||
min_instr_count = config.jax_persistent_cache_min_instruction_count
|
||||
if min_instr_count:
|
||||
count = _instruction_count(computation, max_count=min_instr_count)
|
||||
if count < min_instr_count:
|
||||
min_compile_time = config.jax_persistent_cache_min_compile_time_secs
|
||||
if min_compile_time:
|
||||
if compile_time_secs < min_compile_time:
|
||||
logging.info(
|
||||
"Not writing persistent cache entry for '%s' because it has "
|
||||
"fewer than %i instructions", module_name, min_instr_count)
|
||||
"Not writing persistent cache entry for '%s' because it took < %.2f "
|
||||
"seconds to compile (%.2fs)", module_name, min_compile_time,
|
||||
compile_time_secs)
|
||||
return
|
||||
else:
|
||||
# Don't log `count` because it won't be more than max_count
|
||||
logging.info(
|
||||
"'%s' has at least %i instructions, writing persistent cache entry",
|
||||
module_name, min_instr_count)
|
||||
"'%s' took at least %.2f seconds to compile (%.2fs), writing "
|
||||
"persistent cache entry", module_name, min_compile_time,
|
||||
compile_time_secs)
|
||||
|
||||
try:
|
||||
cc.put_executable(module_name, serialized_computation, compile_options,
|
||||
|
@ -36,14 +36,14 @@ from jax._src.lib import xla_client
|
||||
import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
from jax._src.config import (persistent_cache_min_instruction_count,
|
||||
from jax._src.config import (persistent_cache_min_compile_time_secs,
|
||||
raise_persistent_cache_errors)
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
@jtu.with_config(jax_raise_persistent_cache_errors=True,
|
||||
jax_persistent_cache_min_instruction_count=0)
|
||||
jax_persistent_cache_min_compile_time_secs=0)
|
||||
class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -335,21 +335,20 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
"for 'jit__lambda_': RuntimeError: test error",
|
||||
str(w[0].message))
|
||||
|
||||
def test_min_instruction_count(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
def test_min_compile_time(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir, \
|
||||
persistent_cache_min_compile_time_secs(2):
|
||||
cc.initialize_cache(tmpdir)
|
||||
|
||||
with persistent_cache_min_instruction_count(20):
|
||||
# 2 instructions at time of writing
|
||||
jit(lambda x: x * x)(2)
|
||||
# Mock time to progress in small intervals so compilation time is small.
|
||||
with mock.patch("time.monotonic", side_effect=np.arange(0, 10, .1)):
|
||||
jit(lambda x: x + 1)(1)
|
||||
files_in_cache = len(os.listdir(tmpdir))
|
||||
self.assertEqual(files_in_cache, 0)
|
||||
|
||||
def f(xs):
|
||||
c, b = jax.lax.scan(lambda c, x: (c + x, c + x), 0, xs)
|
||||
return c + 1, b
|
||||
# 32 instructions at time of writing
|
||||
jit(f)(jax.numpy.ones(8))
|
||||
# Mock time to progress in large intervals so compilation time is large.
|
||||
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
|
||||
jit(lambda x: x + 2)(1)
|
||||
files_in_cache = len(os.listdir(tmpdir))
|
||||
self.assertEqual(files_in_cache, 1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user