mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add new config jax_persistent_cache_min_compile_time_secs
.
This replaces `jax_persistent_cache_min_instruction_count` introduced in https://github.com/google/jax/pull/12798, since gating on the compile time seems strictly better than gating on the instruction count (except maybe that the instruction count is more deterministic, but I don't think that's a big deal). I defaulted to 1 second as the minimum threshold based on the same flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) numbers from name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new float config functionality.
This commit is contained in:
parent
f9e7629c3f
commit
cc5171034f
@ -10,9 +10,10 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* Changes
|
* Changes
|
||||||
* JAX should be faster to import. We now import scipy lazily, which accounted
|
* JAX should be faster to import. We now import scipy lazily, which accounted
|
||||||
for a significant fraction of JAX's import time.
|
for a significant fraction of JAX's import time.
|
||||||
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_INSTRUCTION_COUNT=$N` can be
|
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N` can be
|
||||||
used to limit the number of cache entries written to the persistent
|
used to limit the number of cache entries written to the persistent cache.
|
||||||
cache. By default, computations with 6 or more instructions will be cached.
|
By default, computations that take 1 second or more to compile will be
|
||||||
|
cached.
|
||||||
* Added {func}`jax.scipy.stats.mode`.
|
* Added {func}`jax.scipy.stats.mode`.
|
||||||
* The default device order used by `pmap` on TPU if no order is specified now
|
* The default device order used by `pmap` on TPU if no order is specified now
|
||||||
matches `jax.devices()` for single-process jobs. Previously the
|
matches `jax.devices()` for single-process jobs. Previously the
|
||||||
|
@ -126,6 +126,10 @@ class Config:
|
|||||||
update_hook = kwargs.pop("update_hook", None)
|
update_hook = kwargs.pop("update_hook", None)
|
||||||
self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
|
self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
|
||||||
|
|
||||||
|
def DEFINE_float(self, name, default, *args, **kwargs):
|
||||||
|
update_hook = kwargs.pop("update_hook", None)
|
||||||
|
self.add_option(name, default, float, args, kwargs, update_hook=update_hook)
|
||||||
|
|
||||||
def DEFINE_string(self, name, default, *args, **kwargs):
|
def DEFINE_string(self, name, default, *args, **kwargs):
|
||||||
update_hook = kwargs.pop("update_hook", None)
|
update_hook = kwargs.pop("update_hook", None)
|
||||||
self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
|
self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
|
||||||
@ -144,6 +148,7 @@ class Config:
|
|||||||
self.absl_flags = absl_flags
|
self.absl_flags = absl_flags
|
||||||
absl_defs = { bool: absl_flags.DEFINE_bool,
|
absl_defs = { bool: absl_flags.DEFINE_bool,
|
||||||
int: absl_flags.DEFINE_integer,
|
int: absl_flags.DEFINE_integer,
|
||||||
|
float: absl_flags.DEFINE_float,
|
||||||
str: absl_flags.DEFINE_string,
|
str: absl_flags.DEFINE_string,
|
||||||
'enum': absl_flags.DEFINE_enum }
|
'enum': absl_flags.DEFINE_enum }
|
||||||
|
|
||||||
@ -327,6 +332,47 @@ class Config:
|
|||||||
|
|
||||||
return _StateContextManager(name, help, update_thread_local_hook, validate)
|
return _StateContextManager(name, help, update_thread_local_hook, validate)
|
||||||
|
|
||||||
|
def define_float_state(
|
||||||
|
self, name: str, default: Optional[float],
|
||||||
|
help: str, update_global_hook: Optional[Callable[[str], None]] = None,
|
||||||
|
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \
|
||||||
|
= None):
|
||||||
|
"""Set up thread-local state and return a contextmanager for managing it.
|
||||||
|
Args:
|
||||||
|
name: string, converted to lowercase to define the name of the config
|
||||||
|
option (and absl flag). It is converted to uppercase to define the
|
||||||
|
corresponding shell environment variable.
|
||||||
|
enum_values: list of strings representing the possible values for the
|
||||||
|
option.
|
||||||
|
default: optional float, default value.
|
||||||
|
help: string, used to populate the flag help information as well as the
|
||||||
|
docstring of the returned context manager.
|
||||||
|
Returns:
|
||||||
|
A contextmanager to control the thread-local state value.
|
||||||
|
See docstring for ``define_bool_state``.
|
||||||
|
"""
|
||||||
|
name = name.lower()
|
||||||
|
default_env = os.getenv(name.upper(), default)
|
||||||
|
if default_env is not None:
|
||||||
|
try:
|
||||||
|
default = float(default_env)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
|
||||||
|
self.DEFINE_float(name, default, help=help, update_hook=update_global_hook)
|
||||||
|
self._contextmanager_flags.add(name)
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
val = getattr(_thread_local_state, name, unset)
|
||||||
|
return val if val is not unset else self._read(name)
|
||||||
|
setattr(Config, name, property(get_state))
|
||||||
|
|
||||||
|
def validate(new_val):
|
||||||
|
if new_val is not None and not isinstance(new_val, (float, int)):
|
||||||
|
raise ValueError(f'new float config value must be None or of type float, '
|
||||||
|
f'got {new_val} of type {type(new_val)}')
|
||||||
|
|
||||||
|
return _StateContextManager(name, help, update_thread_local_hook, validate)
|
||||||
|
|
||||||
def define_string_state(
|
def define_string_state(
|
||||||
self, name: str, default: Optional[str], help: str,
|
self, name: str, default: Optional[str], help: str,
|
||||||
update_global_hook: Optional[Callable[[str], None]] = None,
|
update_global_hook: Optional[Callable[[str], None]] = None,
|
||||||
@ -765,15 +811,12 @@ raise_persistent_cache_errors = config.define_bool_state(
|
|||||||
'continue. Defaults to false so cache bugs or intermittent issues '
|
'continue. Defaults to false so cache bugs or intermittent issues '
|
||||||
'are non-fatal.'))
|
'are non-fatal.'))
|
||||||
|
|
||||||
persistent_cache_min_instruction_count = config.define_int_state(
|
persistent_cache_min_compile_time_secs = config.define_float_state(
|
||||||
name='jax_persistent_cache_min_instruction_count',
|
name='jax_persistent_cache_min_compile_time_secs',
|
||||||
default=6,
|
default=1,
|
||||||
help=('The minimum number of instructions a computation needs to have to '
|
help=('The minimum compile time of a computation to be written to the '
|
||||||
'be written to the persistent compilation cache. This threshold can '
|
'persistent compilation cache. This threshold can be raised to '
|
||||||
'be raised to decrease the number of entries written to the cache. '
|
'decrease the number of entries written to the cache.'))
|
||||||
'The (unoptimized) instruction count is meant to be a proxy for '
|
|
||||||
'compile time, so programs with longer compile times are still '
|
|
||||||
'cached.'))
|
|
||||||
|
|
||||||
hlo_source_file_canonicalization_regex = config.define_string_state(
|
hlo_source_file_canonicalization_regex = config.define_string_state(
|
||||||
name='jax_hlo_source_file_canonicalization_regex',
|
name='jax_hlo_source_file_canonicalization_regex',
|
||||||
|
@ -1066,9 +1066,11 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
|||||||
logger.info("Persistent compilation cache hit for '%s'", module_name)
|
logger.info("Persistent compilation cache hit for '%s'", module_name)
|
||||||
return cached_executable
|
return cached_executable
|
||||||
else:
|
else:
|
||||||
|
start_time = time.monotonic()
|
||||||
compiled = backend_compile(backend, serialized_computation,
|
compiled = backend_compile(backend, serialized_computation,
|
||||||
compile_options, host_callbacks)
|
compile_options, host_callbacks)
|
||||||
_cache_write(computation, serialized_computation, module_name,
|
compile_time = time.monotonic() - start_time
|
||||||
|
_cache_write(serialized_computation, compile_time, module_name,
|
||||||
compile_options, backend, compiled)
|
compile_options, backend, compiled)
|
||||||
return compiled
|
return compiled
|
||||||
|
|
||||||
@ -1094,27 +1096,27 @@ def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _cache_write(computation: ir.Module,
|
def _cache_write(serialized_computation: Union[str, bytes, ir.Module],
|
||||||
serialized_computation: Union[str, bytes, ir.Module],
|
compile_time_secs: float,
|
||||||
module_name: str, compile_options: CompileOptions,
|
module_name: str, compile_options: CompileOptions,
|
||||||
backend: Backend, compiled: XlaLoadedExecutable):
|
backend: Backend, compiled: XlaLoadedExecutable):
|
||||||
"""Writes `serialized_computation` to the persistent compilation cache."""
|
"""Writes `serialized_computation` to the persistent compilation cache."""
|
||||||
# Avoid import cycle between jax and jax.experimental
|
# Avoid import cycle between jax and jax.experimental
|
||||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||||
|
|
||||||
min_instr_count = config.jax_persistent_cache_min_instruction_count
|
min_compile_time = config.jax_persistent_cache_min_compile_time_secs
|
||||||
if min_instr_count:
|
if min_compile_time:
|
||||||
count = _instruction_count(computation, max_count=min_instr_count)
|
if compile_time_secs < min_compile_time:
|
||||||
if count < min_instr_count:
|
|
||||||
logging.info(
|
logging.info(
|
||||||
"Not writing persistent cache entry for '%s' because it has "
|
"Not writing persistent cache entry for '%s' because it took < %.2f "
|
||||||
"fewer than %i instructions", module_name, min_instr_count)
|
"seconds to compile (%.2fs)", module_name, min_compile_time,
|
||||||
|
compile_time_secs)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# Don't log `count` because it won't be more than max_count
|
|
||||||
logging.info(
|
logging.info(
|
||||||
"'%s' has at least %i instructions, writing persistent cache entry",
|
"'%s' took at least %.2f seconds to compile (%.2fs), writing "
|
||||||
module_name, min_instr_count)
|
"persistent cache entry", module_name, min_compile_time,
|
||||||
|
compile_time_secs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cc.put_executable(module_name, serialized_computation, compile_options,
|
cc.put_executable(module_name, serialized_computation, compile_options,
|
||||||
|
@ -36,14 +36,14 @@ from jax._src.lib import xla_client
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from jax.config import config
|
from jax.config import config
|
||||||
from jax._src.config import (persistent_cache_min_instruction_count,
|
from jax._src.config import (persistent_cache_min_compile_time_secs,
|
||||||
raise_persistent_cache_errors)
|
raise_persistent_cache_errors)
|
||||||
|
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
FLAGS = config.FLAGS
|
FLAGS = config.FLAGS
|
||||||
|
|
||||||
@jtu.with_config(jax_raise_persistent_cache_errors=True,
|
@jtu.with_config(jax_raise_persistent_cache_errors=True,
|
||||||
jax_persistent_cache_min_instruction_count=0)
|
jax_persistent_cache_min_compile_time_secs=0)
|
||||||
class CompilationCacheTest(jtu.JaxTestCase):
|
class CompilationCacheTest(jtu.JaxTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -335,21 +335,20 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
|||||||
"for 'jit__lambda_': RuntimeError: test error",
|
"for 'jit__lambda_': RuntimeError: test error",
|
||||||
str(w[0].message))
|
str(w[0].message))
|
||||||
|
|
||||||
def test_min_instruction_count(self):
|
def test_min_compile_time(self):
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir, \
|
||||||
|
persistent_cache_min_compile_time_secs(2):
|
||||||
cc.initialize_cache(tmpdir)
|
cc.initialize_cache(tmpdir)
|
||||||
|
|
||||||
with persistent_cache_min_instruction_count(20):
|
# Mock time to progress in small intervals so compilation time is small.
|
||||||
# 2 instructions at time of writing
|
with mock.patch("time.monotonic", side_effect=np.arange(0, 10, .1)):
|
||||||
jit(lambda x: x * x)(2)
|
jit(lambda x: x + 1)(1)
|
||||||
files_in_cache = len(os.listdir(tmpdir))
|
files_in_cache = len(os.listdir(tmpdir))
|
||||||
self.assertEqual(files_in_cache, 0)
|
self.assertEqual(files_in_cache, 0)
|
||||||
|
|
||||||
def f(xs):
|
# Mock time to progress in large intervals so compilation time is large.
|
||||||
c, b = jax.lax.scan(lambda c, x: (c + x, c + x), 0, xs)
|
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
|
||||||
return c + 1, b
|
jit(lambda x: x + 2)(1)
|
||||||
# 32 instructions at time of writing
|
|
||||||
jit(f)(jax.numpy.ones(8))
|
|
||||||
files_in_cache = len(os.listdir(tmpdir))
|
files_in_cache = len(os.listdir(tmpdir))
|
||||||
self.assertEqual(files_in_cache, 1)
|
self.assertEqual(files_in_cache, 1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user