Changed the naming of internal config APIs

The new naming highlights that we have two kinds of configuration options:
flags, set at most once, and states, which can be changed locally per thread
via a context manager.

The renames are

* FlagHolder -> Flag
* DEFINE_<type> -> <type>_flag
* _StateContextManager -> State
* define_<type>_state -> <type>_state
This commit is contained in:
Sergei Lebedev 2024-04-15 10:35:50 +01:00
parent dfcfb36062
commit ce0d9e9b9f
11 changed files with 172 additions and 164 deletions

View File

@ -37,13 +37,13 @@ from jax._src.lib.mlir import ir
import numpy as np
_DISABLE_MOST_OPTIMIZATIONS = config.DEFINE_bool(
_DISABLE_MOST_OPTIMIZATIONS = config.bool_flag(
'jax_disable_most_optimizations',
config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')
_COMPILER_DETAILED_LOGGING_MIN_OPS = config.DEFINE_integer(
_COMPILER_DETAILED_LOGGING_MIN_OPS = config.int_flag(
"jax_compiler_detailed_logging_min_ops",
config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10),
help=(

View File

@ -14,7 +14,7 @@
from __future__ import annotations
from collections.abc import Hashable, Iterator
from collections.abc import Hashable, Iterator, Sequence
import contextlib
import functools
import itertools
@ -22,7 +22,9 @@ import logging
import os
import sys
import threading
from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast
from typing import (
Any, Callable, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast,
)
from jax._src import lib
from jax._src.lib import jax_jit
@ -60,24 +62,24 @@ def int_env(varname: str, default: int) -> int:
return int(os.getenv(varname, str(default)))
UPGRADE_BOOL_HELP = (
" This will be enabled by default in future versions of JAX, at which "
"point all uses of the flag will be considered deprecated (following "
"the `API compatibility policy "
"<https://jax.readthedocs.io/en/latest/api_compatibility.html>`_).")
class ValueHolder(Protocol[_T]):
"""A holder for a configuration value.
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
There are two kinds of value holders: ``Flag``, which is assigned exactly
once and never modified after; and ``State``, which can be changed locally
within a thread via a context manager.
"""
value: _T
def _set(self, value: _T) -> None: ...
class Config:
_HAS_DYNAMIC_ATTRIBUTES = True
def __init__(self):
# There are two kinds of value holders: FlagHolders, which hold global
# flags, and StateContextManagers, which hold state that can be changed
# locally within a thread. A value holder needs a `.value` property and a
# `._set()` method.
self._value_holders = {}
self._value_holders: dict[str, ValueHolder] = {}
self.meta = {}
self.use_absl = False
self._contextmanager_flags = set()
@ -113,7 +115,7 @@ class Config:
def config_with_absl(self):
"""Registers absl flags for the JAX configs.
E.g., for each JAX config defined using define_bool_state(), this method
E.g., for each JAX config defined using bool_state(), this method
registers an absl boolean flag, with the same name.
This is the recommended method to call if you use `app.run(main)` and you
@ -237,7 +239,8 @@ unset = _Unset()
_thread_local_state = threading.local()
class _StateContextManager(Generic[_T]):
class State(Generic[_T]):
__slots__ = (
'_name', '_value', '_update_thread_local_hook', '_update_global_hook',
'_validator', '_default_context_manager_value', '__doc__', '__name__',
@ -318,7 +321,16 @@ class _StateContextManager(Generic[_T]):
update_global_hook(self._value)
def define_bool_state(
UPGRADE_BOOL_HELP = (
" This will be enabled by default in future versions of JAX, at which "
"point all uses of the flag will be considered deprecated (following "
"the `API compatibility policy "
"<https://jax.readthedocs.io/en/latest/api_compatibility.html>`_).")
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
def bool_state(
name: str,
default: bool,
help: str,
@ -327,7 +339,7 @@ def define_bool_state(
update_thread_local_hook: Callable[[bool | None], None] | None = None,
upgrade: bool = False,
extra_description: str = '',
) -> _StateContextManager[bool]:
) -> State[bool]:
"""Set up thread-local state and return a contextmanager for managing it.
This function is a convenience wrapper. It defines a flag, environment
@ -360,7 +372,7 @@ def define_bool_state(
Example:
enable_foo = config.define_bool_state(
ENABLE_FOO = config.bool_state(
name='jax_enable_foo',
default=False,
help='Enable foo.')
@ -388,7 +400,7 @@ def define_bool_state(
extra_description += UPGRADE_BOOL_EXTRA_DESC
config._contextmanager_flags.add(name)
s = _StateContextManager[bool](
s = State[bool](
name, default, help, update_global_hook=update_global_hook,
update_thread_local_hook=update_thread_local_hook,
extra_description=extra_description, default_context_manager_value=True)
@ -397,18 +409,18 @@ def define_bool_state(
return s
def define_enum_state(
def enum_state(
name: str,
enum_values: list[str],
enum_values: Sequence[str],
default: str,
help: str,
*,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[str]:
) -> State[str]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``define_bool_state``.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
@ -437,7 +449,7 @@ def define_enum_state(
raise ValueError(f"new enum value must be in {enum_values}, "
f"got {new_val} of type {type(new_val)}.")
s = _StateContextManager[str](
s = State[str](
name,
default,
help,
@ -454,18 +466,18 @@ def define_enum_state(
return s
def define_optional_enum_state(
def optional_enum_state(
name: str,
enum_values: list[str],
enum_values: Sequence[str],
default: str | None,
help: str,
*,
update_global_hook: Callable[[str | None], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[str | None]:
) -> State[str | None]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``define_bool_state``.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
@ -495,7 +507,7 @@ def define_optional_enum_state(
raise ValueError(f"new enum value must be None or in {enum_values}, "
f"got {new_val} of type {type(new_val)}.")
s = _StateContextManager['str | None'](
s = State['str | None'](
name, default, help, update_global_hook, update_thread_local_hook,
validate
)
@ -508,17 +520,17 @@ def define_optional_enum_state(
return s
def define_int_state(
def int_state(
name: str,
default: int,
help: str,
*,
update_global_hook: Callable[[int], None] | None = None,
update_thread_local_hook: Callable[[int | None], None] | None = None,
) -> _StateContextManager[int]:
) -> State[int]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``define_bool_state``.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
@ -548,24 +560,24 @@ def define_int_state(
raise ValueError(f'new int config value must be None or of type int, '
f'got {new_val} of type {type(new_val)}')
s = _StateContextManager[int](name, default, help, update_global_hook,
update_thread_local_hook, validate)
s = State[int](name, default, help, update_global_hook,
update_thread_local_hook, validate)
config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help})
setattr(Config, name, property(lambda _: s.value))
return s
def define_float_state(
def float_state(
name: str,
default: float,
help: str,
*,
update_global_hook: Callable[[float], None] | None = None,
update_thread_local_hook: Callable[[float | None], None] | None = None,
) -> _StateContextManager[float]:
) -> State[float]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``define_bool_state``.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
@ -596,24 +608,24 @@ def define_float_state(
f'new float config value must be None or of type float, '
f'got {new_val} of type {type(new_val)}')
s = _StateContextManager[float](name, default, help, update_global_hook,
update_thread_local_hook, validate)
s = State[float](name, default, help, update_global_hook,
update_thread_local_hook, validate)
config.add_option(name, s, float, meta_args=[], meta_kwargs={"help": help})
setattr(Config, name, property(lambda _: s.value))
return s
def define_string_state(
def string_state(
name: str,
default: str,
help: str,
*,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[str]:
) -> State[str]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``define_bool_state``.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
@ -640,24 +652,24 @@ def define_string_state(
raise TypeError('new string config value must be of type str,'
f' got {new_val} of type {type(new_val)}.')
return define_string_or_object_state(
return string_or_object_state(
name, default, help,
update_global_hook=update_global_hook,
update_thread_local_hook=update_thread_local_hook,
validator=validator)
def define_optional_string_state(
def optional_string_state(
name: str,
default: str | None,
help: str,
*,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> _StateContextManager[str | None]:
) -> State[str | None]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``define_bool_state``.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
@ -684,13 +696,13 @@ def define_optional_string_state(
raise ValueError('new string config value must be None or of type str,'
f' got {new_val} of type {type(new_val)}.')
return define_string_or_object_state(
return string_or_object_state(
name, default, help,
update_global_hook=update_global_hook,
update_thread_local_hook=update_thread_local_hook,
validator=validator)
def define_string_or_object_state(
def string_or_object_state(
name: str,
default: Any,
help: str,
@ -698,10 +710,10 @@ def define_string_or_object_state(
update_global_hook: Callable[[Any], None] | None = None,
update_thread_local_hook: Callable[[Any], None] | None = None,
validator: Callable[[Any], None] | None = None,
) -> _StateContextManager[Any]:
) -> State[Any]:
"""Set up thread-local state and return a contextmanager for managing it.
Similar to ``define_string_state``, except the context manager will accept
Similar to ``string_state``, except the context manager will accept
any object, not just a string. Any value passed via command line flag or
environment variable will be treated as a string.
@ -728,7 +740,7 @@ def define_string_or_object_state(
default = os.getenv(name.upper(), default)
config._contextmanager_flags.add(name)
s = _StateContextManager[Any](
s = State[Any](
name, default, help, update_global_hook, update_thread_local_hook,
validator)
setattr(Config, name, property(lambda _: s.value))
@ -736,7 +748,8 @@ def define_string_or_object_state(
return s
class FlagHolder(Generic[_T]):
class Flag(Generic[_T]):
__slots__ = ("_name", "value", "_update_hook")
_name: str
@ -761,42 +774,37 @@ class FlagHolder(Generic[_T]):
self._update_hook(value)
def check_exists(name):
if name not in config._value_holders:
raise AttributeError(f"Unrecognized config option: {name}")
def DEFINE_bool(name, default, *args, **kwargs) -> FlagHolder[bool]:
def bool_flag(name, default, *args, **kwargs) -> Flag[bool]:
update_hook = kwargs.pop("update_hook", None)
holder = FlagHolder(name, default, update_hook)
holder = Flag(name, default, update_hook)
config.add_option(name, holder, bool, args, kwargs)
return holder
def DEFINE_integer(name, default, *args, **kwargs) -> FlagHolder[int]:
def int_flag(name, default, *args, **kwargs) -> Flag[int]:
update_hook = kwargs.pop("update_hook", None)
holder = FlagHolder(name, default, update_hook)
holder = Flag(name, default, update_hook)
config.add_option(name, holder, int, args, kwargs)
return holder
def DEFINE_float(name, default, *args, **kwargs) -> FlagHolder[float]:
def float_flag(name, default, *args, **kwargs) -> Flag[float]:
update_hook = kwargs.pop("update_hook", None)
holder = FlagHolder(name, default, update_hook)
holder = Flag(name, default, update_hook)
config.add_option(name, holder, float, args, kwargs)
return holder
def DEFINE_string(name, default, *args, **kwargs) -> FlagHolder[str]:
def string_flag(name, default, *args, **kwargs) -> Flag[str]:
update_hook = kwargs.pop("update_hook", None)
holder = FlagHolder(name, default, update_hook)
holder = Flag(name, default, update_hook)
config.add_option(name, holder, str, args, kwargs)
return holder
def DEFINE_enum(name, default, *args, **kwargs) -> FlagHolder[str]:
def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
update_hook = kwargs.pop("update_hook", None)
holder = FlagHolder(name, default, update_hook)
holder = Flag(name, default, update_hook)
config.add_option(name, holder, 'enum', args, kwargs)
return holder
@ -885,7 +893,7 @@ def update_thread_local_jit_state(**kw):
# TODO(b/214340779): remove flag when XLA:CPU is improved.
jax2tf_associative_scan_reductions = define_bool_state(
jax2tf_associative_scan_reductions = bool_state(
name='jax2tf_associative_scan_reductions',
default=False,
help=(
@ -900,7 +908,7 @@ jax2tf_associative_scan_reductions = define_bool_state(
)
)
jax2tf_default_native_serialization = define_bool_state(
jax2tf_default_native_serialization = bool_state(
name='jax2tf_default_native_serialization',
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True),
help=(
@ -910,7 +918,7 @@ jax2tf_default_native_serialization = define_bool_state(
)
)
jax_serialization_version = define_int_state(
jax_serialization_version = int_state(
name='jax_serialization_version',
default=int_env('JAX_SERIALIZATION_VERSION', 0), # We use 0 to detect default.
help=(
@ -918,7 +926,7 @@ jax_serialization_version = define_int_state(
)
)
jax_export_calling_convention_version = define_int_state(
jax_export_calling_convention_version = int_state(
name='jax_export_calling_convention_version',
# Note: bump the default calling convention version at least one month after
# we update XlaCallModule to support the new version, so that serialized
@ -933,7 +941,7 @@ jax_export_calling_convention_version = define_int_state(
)
)
jax_platforms = define_optional_string_state(
jax_platforms = optional_string_state(
name='jax_platforms',
default=None,
help=(
@ -949,18 +957,18 @@ jax_platforms = define_optional_string_state(
'otherwise.'
))
jax_pjrt_client_create_options = define_optional_string_state(
jax_pjrt_client_create_options = optional_string_state(
name='jax_pjrt_client_create_options',
default=None,
help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings '
'provided to a device platform pjrt client as extra arguments.'))
enable_checks = define_bool_state(
enable_checks = bool_state(
name='jax_enable_checks',
default=False,
help='Turn on invariant checking for JAX internals. Makes things slower.')
debug_key_reuse = define_bool_state(
debug_key_reuse = bool_state(
name='jax_debug_key_reuse',
default=False,
help=('Turn on experimental key reuse checking. With this configuration enabled,'
@ -969,7 +977,7 @@ debug_key_reuse = define_bool_state(
' an error. Currently enabling this leads to a small Python overhead on'
' every call to a JIT-compiled function with keys as inputs or outputs.'))
check_tracer_leaks = define_bool_state(
check_tracer_leaks = bool_state(
name='jax_check_tracer_leaks',
default=False,
help=('Turn on checking for leaked tracers as soon as a trace completes. '
@ -979,7 +987,7 @@ check_tracer_leaks = define_bool_state(
'to disable any debuggers while leak checking is enabled.'))
checking_leaks = functools.partial(check_tracer_leaks, True)
debug_nans = define_bool_state(
debug_nans = bool_state(
name='jax_debug_nans',
default=False,
help=('Add nan checks to every operation. When a nan is detected on the '
@ -987,7 +995,7 @@ debug_nans = define_bool_state(
'version in an attempt to more precisely identify the operation '
'which produced the nan.'))
debug_infs = define_bool_state(
debug_infs = bool_state(
name='jax_debug_infs',
default=False,
help=('Add inf checks to every operation. When an inf is detected on the '
@ -995,7 +1003,7 @@ debug_infs = define_bool_state(
'version in an attempt to more precisely identify the operation '
'which produced the inf.'))
log_compiles = define_bool_state(
log_compiles = bool_state(
name='jax_log_compiles',
default=False,
help=('Log a message each time `jit` or `pmap` compiles an XLA '
@ -1003,7 +1011,7 @@ log_compiles = define_bool_state(
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))
explain_cache_misses = define_bool_state(
explain_cache_misses = bool_state(
name='jax_explain_cache_misses',
default=False,
help=('Each time there is a miss on one of the main caches (e.g. the '
@ -1011,14 +1019,14 @@ explain_cache_misses = define_bool_state(
'`logging`. When this option is set, the log level is WARNING; '
'otherwise the level is DEBUG.'))
log_checkpoint_residuals = define_bool_state(
log_checkpoint_residuals = bool_state(
name='jax_log_checkpoint_residuals',
default=False,
help=('Log a message every time jax.checkpoint (aka jax.remat) is '
'partially evaluated (e.g. for autodiff), printing what residuals '
'are saved.'))
pmap_shmap_merge = define_bool_state(
pmap_shmap_merge = bool_state(
name='jax_pmap_shmap_merge',
default=False,
upgrade=True,
@ -1030,7 +1038,7 @@ def _update_jax_memories_global(val):
def _update_jax_memories_thread_local(val):
lib.jax_jit.thread_local_state().enable_memories = val
enable_memories = define_bool_state(
enable_memories = bool_state(
'jax_enable_memories',
default=False,
upgrade=True,
@ -1039,7 +1047,7 @@ enable_memories = define_bool_state(
help=("If True, will allow fetching memory kinds available on executable "
"and annotate Shardings with it."))
spmd_mode = define_enum_state(
spmd_mode = enum_state(
name='jax_spmd_mode',
enum_values=['allow_all', 'allow_jit'],
default='allow_jit',
@ -1052,14 +1060,14 @@ spmd_mode = define_enum_state(
" execute on non-fully addressable `jax.Array`s."))
distributed_debug = define_bool_state(
distributed_debug = bool_state(
name='jax_distributed_debug',
default=False,
help=('Enable logging useful for debugging multi-process distributed '
'computations. Logging is performed with `logging` at WARNING '
'level.'))
random_seed_offset = define_int_state(
random_seed_offset = int_state(
name='jax_random_seed_offset',
default=0,
help=('Offset to all random seeds (e.g. argument to jax.random.key()).'),
@ -1069,7 +1077,7 @@ random_seed_offset = define_int_state(
random_seed_offset=val)
)
legacy_prng_key = define_enum_state(
legacy_prng_key = enum_state(
name='jax_legacy_prng_key',
enum_values=['allow', 'warn', 'error'],
default='allow',
@ -1077,21 +1085,21 @@ legacy_prng_key = define_enum_state(
'jax.random APIs.')
)
enable_custom_prng = define_bool_state(
enable_custom_prng = bool_state(
name='jax_enable_custom_prng',
default=False,
upgrade=True,
help=('Enables an internal upgrade that allows one to define custom '
'pseudo-random number generator implementations.'))
default_prng_impl = define_enum_state(
default_prng_impl = enum_state(
name='jax_default_prng_impl',
enum_values=['threefry2x32', 'rbg', 'unsafe_rbg'],
default='threefry2x32',
help=('Select the default PRNG implementation, used when one is not '
'explicitly provided at seeding time.'))
threefry_partitionable = define_bool_state(
threefry_partitionable = bool_state(
name='jax_threefry_partitionable',
default=False,
upgrade=True,
@ -1106,7 +1114,7 @@ threefry_partitionable = define_bool_state(
update_thread_local_hook=lambda val: update_thread_local_jit_state(
threefry_partitionable=val))
threefry_gpu_kernel_lowering = define_bool_state(
threefry_gpu_kernel_lowering = bool_state(
name='jax_threefry_gpu_kernel_lowering',
default=False,
help=('On GPU, lower threefry PRNG operations to a kernel implementation. '
@ -1118,7 +1126,7 @@ threefry_gpu_kernel_lowering = define_bool_state(
threefry_gpu_kernel_lowering=val))
softmax_custom_jvp = define_bool_state(
softmax_custom_jvp = bool_state(
name='jax_softmax_custom_jvp',
default=False,
upgrade=True,
@ -1131,14 +1139,14 @@ softmax_custom_jvp = define_bool_state(
softmax_custom_jvp=val))
enable_custom_vjp_by_custom_transpose = define_bool_state(
enable_custom_vjp_by_custom_transpose = bool_state(
name='jax_enable_custom_vjp_by_custom_transpose',
default=False,
upgrade=True,
help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))
raise_persistent_cache_errors = define_bool_state(
raise_persistent_cache_errors = bool_state(
name='jax_raise_persistent_cache_errors',
default=False,
help=('If true, exceptions raised when reading or writing to the '
@ -1148,14 +1156,14 @@ raise_persistent_cache_errors = define_bool_state(
'continue. Defaults to false so cache bugs or intermittent issues '
'are non-fatal.'))
persistent_cache_min_compile_time_secs = define_float_state(
persistent_cache_min_compile_time_secs = float_state(
name='jax_persistent_cache_min_compile_time_secs',
default=1.,
help=('The minimum compile time of a computation to be written to the '
'persistent compilation cache. This threshold can be raised to '
'decrease the number of entries written to the cache.'))
persistent_cache_min_entry_size_bytes = define_int_state(
persistent_cache_min_entry_size_bytes = int_state(
name='jax_persistent_cache_min_entry_size_bytes',
default=0,
help=('The minimum size (in bytes) of an entry that will be cached in the '
@ -1166,7 +1174,7 @@ persistent_cache_min_entry_size_bytes = define_int_state(
' filesystem being used for the cache. '
'* > 0: the actual minimum size desired; no overrides.'))
compilation_cache_include_metadata_in_key = define_bool_state(
compilation_cache_include_metadata_in_key = bool_state(
name='jax_compilation_cache_include_metadata_in_key',
default=False,
help=(
@ -1178,7 +1186,7 @@ compilation_cache_include_metadata_in_key = define_bool_state(
),
)
hlo_source_file_canonicalization_regex = define_optional_string_state(
hlo_source_file_canonicalization_regex = optional_string_state(
name='jax_hlo_source_file_canonicalization_regex',
default=None,
help=('Used to canonicalize the source_path metadata of HLO instructions '
@ -1188,7 +1196,7 @@ hlo_source_file_canonicalization_regex = define_optional_string_state(
'persistent compilation cache, which includes HLO metadata in the '
'cache key.'))
include_full_tracebacks_in_locations = define_bool_state(
include_full_tracebacks_in_locations = bool_state(
name='jax_include_full_tracebacks_in_locations',
default=True,
help=(
@ -1196,7 +1204,7 @@ include_full_tracebacks_in_locations = define_bool_state(
),
)
traceback_in_locations_limit = define_int_state(
traceback_in_locations_limit = int_state(
name='jax_traceback_in_locations_limit',
default=10,
help=(
@ -1206,7 +1214,7 @@ traceback_in_locations_limit = define_int_state(
),
)
share_autotune_config_between_hosts = define_bool_state(
share_autotune_config_between_hosts = bool_state(
name='jax_share_autotune_config_between_hosts',
default=False,
help=(
@ -1220,7 +1228,7 @@ share_autotune_config_between_hosts = define_bool_state(
),
)
share_binary_between_hosts = define_bool_state(
share_binary_between_hosts = bool_state(
name='jax_share_binary_between_hosts',
default=False,
help=(
@ -1229,13 +1237,13 @@ share_binary_between_hosts = define_bool_state(
),
)
share_binary_between_hosts_timeout_ms = define_int_state(
share_binary_between_hosts_timeout_ms = int_state(
name='jax_share_binary_between_hosts_timeout_ms',
default=20 * 60 * 1000,
help='Timeout for the compiled module share.',
)
enable_pgle = define_bool_state(
enable_pgle = bool_state(
name='jax_enable_pgle',
default=False,
help=(
@ -1249,7 +1257,7 @@ enable_pgle = define_bool_state(
enable_pgle=val),
)
pgle_profiling_runs = define_int_state(
pgle_profiling_runs = int_state(
name='jax_pgle_profiling_runs',
default=3,
help=(
@ -1264,14 +1272,14 @@ pgle_profiling_runs = define_int_state(
),
)
pgle_aggregation_percentile = define_int_state(
pgle_aggregation_percentile = int_state(
name='jax_pgle_aggregation_percentile',
default=90,
help='Percentile used to aggregate performance data between devices when '
'PGLE is used.',
)
enable_compilation_cache = define_bool_state(
enable_compilation_cache = bool_state(
name='jax_enable_compilation_cache',
default=True,
help=('If set to False, the compilation cache will be disabled regardless '
@ -1280,7 +1288,7 @@ enable_compilation_cache = define_bool_state(
'set_cache_dir().'),
)
compilation_cache_dir = define_optional_string_state(
compilation_cache_dir = optional_string_state(
name='jax_compilation_cache_dir',
default=None,
help=('Path for the cache. '
@ -1289,7 +1297,7 @@ compilation_cache_dir = define_optional_string_state(
'2. The value of this flag set in the command line or by default.'),
)
compilation_cache_max_size = define_int_state(
compilation_cache_max_size = int_state(
name='jax_compilation_cache_max_size',
default=-1,
help=('The maximum size (in bytes) allowed for the persistent compilation '
@ -1301,7 +1309,7 @@ compilation_cache_max_size = define_int_state(
'size to grow indefinitely.'),
)
default_dtype_bits = define_enum_state(
default_dtype_bits = enum_state(
name='jax_default_dtype_bits',
enum_values=['32', '64'],
default='64',
@ -1309,7 +1317,7 @@ default_dtype_bits = define_enum_state(
'This is a temporary flag that will be used during the process '
'of deprecating the ``jax_enable_x64`` flag.'))
numpy_dtype_promotion = define_enum_state(
numpy_dtype_promotion = enum_state(
name='jax_numpy_dtype_promotion',
enum_values=['standard', 'strict'],
default='standard',
@ -1328,7 +1336,7 @@ def _update_x64_global(val):
def _update_x64_thread_local(val):
lib.jax_jit.thread_local_state().enable_x64 = val
enable_x64 = define_bool_state(
enable_x64 = bool_state(
name='jax_enable_x64',
default=False,
help='Enable 64-bit types to be used',
@ -1363,7 +1371,7 @@ def _validate_default_device(val):
# TODO(skye): default_device only accepts devices for now. Make it work with
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
default_device = define_string_or_object_state(
default_device = string_or_object_state(
name='jax_default_device',
default=None,
help=(
@ -1383,7 +1391,7 @@ def _update_disable_jit_global(val):
def _update_disable_jit_thread_local(val):
lib.jax_jit.thread_local_state().disable_jit = val
disable_jit = define_bool_state(
disable_jit = bool_state(
name='jax_disable_jit',
default=False,
help=('Disable JIT compilation and just call original Python.'),
@ -1391,7 +1399,7 @@ disable_jit = define_bool_state(
update_thread_local_hook=_update_disable_jit_thread_local)
numpy_rank_promotion = define_enum_state(
numpy_rank_promotion = enum_state(
name='jax_numpy_rank_promotion',
enum_values=['allow', 'warn', 'raise'],
default='allow',
@ -1402,7 +1410,7 @@ numpy_rank_promotion = define_enum_state(
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(numpy_rank_promotion=val))
default_matmul_precision = define_optional_enum_state(
default_matmul_precision = optional_enum_state(
name='jax_default_matmul_precision',
enum_values=['bfloat16', 'tensorfloat32', 'float32'],
default=None,
@ -1427,7 +1435,7 @@ default_matmul_precision = define_optional_enum_state(
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(default_matmul_precision=val))
traceback_filtering = define_enum_state(
traceback_filtering = enum_state(
name = 'jax_traceback_filtering',
enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames",
"auto"],
@ -1448,14 +1456,14 @@ traceback_filtering = define_enum_state(
# This flag is for internal use.
# TODO(tianjianlu): Removes once we always enable cusparse lowering.
# TODO(b/262050896): Set to true after bug is fixed
bcoo_cusparse_lowering = define_bool_state(
bcoo_cusparse_lowering = bool_state(
name='jax_bcoo_cusparse_lowering',
default=False,
help=('Enables lowering BCOO ops to cuSparse.'))
# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging
# if the intended backend can handle lowering the result
dynamic_shapes = define_bool_state(
dynamic_shapes = bool_state(
name='jax_dynamic_shapes',
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
help=('Enables experimental features for staging out computations with '
@ -1467,26 +1475,26 @@ dynamic_shapes = define_bool_state(
# This flag is temporary during rollout of the remat barrier.
# TODO(parkers): Remove if there are no complaints.
remat_opt_barrier = define_bool_state(
remat_opt_barrier = bool_state(
name='jax_remat_opt_barrier',
default=True,
help=('Enables using optimization-barrier op for lowering remat.'))
# TODO(sharadmv,mattjj): set default to True, then remove
eager_pmap = define_bool_state(
eager_pmap = bool_state(
name='jax_eager_pmap',
default=True,
upgrade=True,
help='Enable eager-mode pmap when jax_disable_jit is activated.')
# TODO(mattjj): remove once we land mutable array plumbing, or face great shame
custom_vjp_disable_shape_check = define_bool_state(
custom_vjp_disable_shape_check = bool_state(
name='jax_custom_vjp_disable_shape_check',
default=False,
upgrade=True,
help='Disable the check from #19009 to enable some custom_vjp hacks.')
xla_runtime_errors = define_bool_state(
xla_runtime_errors = bool_state(
name='jax_experimental_unsafe_xla_runtime_errors',
default=False,
help=('Enable XLA runtime errors for jax.experimental.checkify.checks '
@ -1496,7 +1504,7 @@ xla_runtime_errors = define_bool_state(
'work under pmap/pjit.')
)
jax_xla_profile_version = define_int_state(
jax_xla_profile_version = int_state(
name='jax_xla_profile_version',
default=0,
help=(
@ -1548,7 +1556,7 @@ def _update_transfer_guard(state, key, val):
else:
assert False, f'Invalid transfer guard level {val}'
transfer_guard_host_to_device = define_optional_enum_state(
transfer_guard_host_to_device = optional_enum_state(
name='jax_transfer_guard_host_to_device',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
@ -1563,7 +1571,7 @@ transfer_guard_host_to_device = define_optional_enum_state(
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'host_to_device', val))
transfer_guard_device_to_device = define_optional_enum_state(
transfer_guard_device_to_device = optional_enum_state(
name='jax_transfer_guard_device_to_device',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
@ -1578,7 +1586,7 @@ transfer_guard_device_to_device = define_optional_enum_state(
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'device_to_device', val))
transfer_guard_device_to_host = define_optional_enum_state(
transfer_guard_device_to_host = optional_enum_state(
name='jax_transfer_guard_device_to_host',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
@ -1599,7 +1607,7 @@ def _update_all_transfer_guard_global(val):
'jax_transfer_guard_device_to_host'):
config.update(name, val)
_transfer_guard = define_optional_enum_state(
_transfer_guard = optional_enum_state(
name='jax_transfer_guard',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
@ -1643,7 +1651,7 @@ def _update_debug_log_modules(module_names_str: str | None):
logging_config.enable_debug_logging(module_name)
# Don't define a context manager since this isn't threadsafe.
define_string_state(
string_state(
name='jax_debug_log_modules',
default='',
help=('Comma-separated list of module names (e.g. "jax" or '
@ -1651,7 +1659,7 @@ define_string_state(
'for.'),
update_global_hook=_update_debug_log_modules)
pmap_no_rank_reduction = define_bool_state(
pmap_no_rank_reduction = bool_state(
name='jax_pmap_no_rank_reduction',
default=False,
help=(

View File

@ -62,7 +62,7 @@ zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.DEFINE_integer(
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.int_flag(
'jax_tracer_error_num_traceback_frames',
config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
help='Set the number of stack frames in JAX tracer error messages.'

View File

@ -67,14 +67,14 @@ Value = Any # = ir.Value
# mypy implicitly sets this variable to true when type checking.
MYPY = False
_JAX_DUMP_IR_TO = config.DEFINE_string(
_JAX_DUMP_IR_TO = config.string_flag(
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
help="Path to which the IR that is emitted by JAX should be dumped as "
"text files. If omitted, JAX will not dump IR. "
"Supports the special value 'sponge' to pick the path from the "
"environment variable TEST_UNDECLARED_OUTPUTS_DIR.")
_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.DEFINE_string(
_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.string_flag(
'jax_include_debug_info_in_dumps',
os.getenv('JAX_INCLUDE_DEBUG_INFO_IN_DUMPS', "True"),
help="Determine whether or not to keep debug symbols and location information "

View File

@ -1850,7 +1850,7 @@ def _ensure_spmd_and(f):
return update
SPMD_LOWERING = config.define_bool_state(
SPMD_LOWERING = config.bool_state(
name="experimental_xmap_spmd_lowering",
default=False,
help=("When set, multi-device xmap computations will be compiled through "
@ -1858,7 +1858,7 @@ SPMD_LOWERING = config.define_bool_state(
"Not supported on CPU!"),
update_global_hook=_clear_compilation_cache,
update_thread_local_hook=_thread_local_flag_unsupported)
SPMD_LOWERING_MANUAL = config.define_bool_state(
SPMD_LOWERING_MANUAL = config.bool_state(
name="experimental_xmap_spmd_lowering_manual",
default=False,
help=("When set, multi-device xmap computations will be compiled using "
@ -1867,7 +1867,7 @@ SPMD_LOWERING_MANUAL = config.define_bool_state(
"Requires experimental_xmap_spmd_lowering!"),
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
update_thread_local_hook=_thread_local_flag_unsupported)
_ENSURE_FIXED_SHARDING = config.define_bool_state(
_ENSURE_FIXED_SHARDING = config.bool_state(
name="experimental_xmap_ensure_fixed_sharding",
default=False,
help=("When set and `experimental_xmap_spmd_lowering` is enabled, the lowering will "

View File

@ -925,7 +925,7 @@ def _extract_function_name(f: Callable, name: str | None) -> str:
return name
_PALLAS_USE_MOSAIC_GPU = config.DEFINE_bool(
_PALLAS_USE_MOSAIC_GPU = config.bool_flag(
"jax_pallas_use_mosaic_gpu",
default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False),
help=(

View File

@ -42,7 +42,7 @@ except ImportError:
colorama = None
_PPRINT_USE_COLOR = config.DEFINE_bool(
_PPRINT_USE_COLOR = config.bool_flag(
'jax_pprint_use_color',
config.bool_env('JAX_PPRINT_USE_COLOR', True),
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'

View File

@ -66,18 +66,18 @@ import numpy.random as npr
# jax.test_util. Functionality appearing here is for internal use only, and
# may be changed or removed at any time and without any deprecation cycle.
_TEST_DUT = config.DEFINE_string(
_TEST_DUT = config.string_flag(
'jax_test_dut', '',
help=
'Describes the device under test in case special consideration is required.'
)
NUM_GENERATED_CASES = config.DEFINE_integer(
NUM_GENERATED_CASES = config.int_flag(
'jax_num_generated_cases',
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
help='Number of generated cases to test')
_MAX_CASES_SAMPLING_RETRIES = config.DEFINE_integer(
_MAX_CASES_SAMPLING_RETRIES = config.int_flag(
'max_cases_sampling_retries',
int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
'Number of times a failed test sample should be retried. '
@ -85,23 +85,23 @@ _MAX_CASES_SAMPLING_RETRIES = config.DEFINE_integer(
'sampling process is terminated.'
)
_SKIP_SLOW_TESTS = config.DEFINE_bool(
_SKIP_SLOW_TESTS = config.bool_flag(
'jax_skip_slow_tests',
config.bool_env('JAX_SKIP_SLOW_TESTS', False),
help='Skip tests marked as slow (> 5 sec).'
)
_TEST_TARGETS = config.DEFINE_string(
_TEST_TARGETS = config.string_flag(
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
'Regular expression specifying which tests to run, called via re.search on '
'the test name. If empty or unspecified, run all tests.'
)
_EXCLUDE_TEST_TARGETS = config.DEFINE_string(
_EXCLUDE_TEST_TARGETS = config.string_flag(
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
'Regular expression specifying which tests NOT to run, called via re.search '
'on the test name. If empty or unspecified, run all tests.'
)
TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.DEFINE_bool(
TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag(
'jax_test_with_persistent_compilation_cache',
config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
help='If enabled, the persistent compilation cache will be enabled for all '

View File

@ -47,7 +47,7 @@ try:
except ImportError:
FLAGS = {}
_MOSAIC_USE_PYTHON_PIPELINE = config.define_bool_state(
_MOSAIC_USE_PYTHON_PIPELINE = config.bool_state(
name="mosaic_use_python_pipeline",
default=False,
help=(
@ -57,7 +57,7 @@ _MOSAIC_USE_PYTHON_PIPELINE = config.define_bool_state(
),
)
_MOSAIC_ALLOW_HLO = config.define_bool_state(
_MOSAIC_ALLOW_HLO = config.bool_state(
name="jax_mosaic_allow_hlo",
default=False,
help="Allow hlo dialects in Mosaic",

View File

@ -65,46 +65,46 @@ XlaBackend = xla_client.Client
MIN_COMPUTE_CAPABILITY = 52
# TODO(phawkins): Remove jax_xla_backend.
_XLA_BACKEND = config.DEFINE_string(
_XLA_BACKEND = config.string_flag(
'jax_xla_backend', '',
'Deprecated, please use --jax_platforms instead.')
BACKEND_TARGET = config.DEFINE_string(
BACKEND_TARGET = config.string_flag(
'jax_backend_target',
os.getenv('JAX_BACKEND_TARGET', '').lower(),
'Either "local" or "rpc:address" to connect to a remote service target.')
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
_PLATFORM_NAME = config.DEFINE_string(
_PLATFORM_NAME = config.string_flag(
'jax_platform_name',
os.getenv('JAX_PLATFORM_NAME', '').lower(),
'Deprecated, please use --jax_platforms instead.')
CUDA_VISIBLE_DEVICES = config.DEFINE_string(
CUDA_VISIBLE_DEVICES = config.string_flag(
'jax_cuda_visible_devices', 'all',
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.')
_ROCM_VISIBLE_DEVICES = config.DEFINE_string(
_ROCM_VISIBLE_DEVICES = config.string_flag(
'jax_rocm_visible_devices', 'all',
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.')
_USE_MOCK_GPU_CLIENT = config.DEFINE_bool(
_USE_MOCK_GPU_CLIENT = config.bool_flag(
name="use_mock_gpu_client",
default=False,
help="If True, use a mock GPU client instead of a real one.",
)
_MOCK_NUM_GPUS = config.DEFINE_integer(
_MOCK_NUM_GPUS = config.int_flag(
name="mock_num_gpus",
default=1,
help="Mock GPU client number of gpus.",
)
_CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool(
_CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag(
name="jax_cpu_enable_gloo_collectives",
default=False,
help="Deprecated, please use jax_cpu_collectives_implementation instead.",
)
_CPU_COLLECTIVES_IMPLEMENTATION = config.DEFINE_string(
_CPU_COLLECTIVES_IMPLEMENTATION = config.string_flag(
name='jax_cpu_collectives_implementation',
default='none',
help='Cross-process collective implementation used on CPU. Either "none", '
@ -113,7 +113,7 @@ _CPU_COLLECTIVES_IMPLEMENTATION = config.DEFINE_string(
# TODO(yueshengys): turn default back to True after resolving memory increase
# issue.
_CPU_ENABLE_ASYNC_DISPATCH = config.DEFINE_bool(
_CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
name="jax_cpu_enable_async_dispatch",
default=False,
help="Only applies to non-parallel computations. If False, run computations"
@ -417,7 +417,7 @@ def _check_cuda_versions(raise_on_first_error: bool = False,
def make_gpu_client(
*, platform_name: str, visible_devices_flag: config.FlagHolder[str]
*, platform_name: str, visible_devices_flag: config.Flag[str]
) -> xla_client.Client:
visible_devices = visible_devices_flag.value
allowed_devices = None

View File

@ -541,12 +541,12 @@ from jax._src.lib.mlir.dialects import hlo
import numpy as np
_HOST_CALLBACK_INLINE = config.DEFINE_bool(
_HOST_CALLBACK_INLINE = config.bool_flag(
'jax_host_callback_inline',
config.bool_env('JAX_HOST_CALLBACK_INLINE', False),
help='Inline the host_callback, if not in a staged context.'
)
_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.DEFINE_integer(
_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.int_flag(
'jax_host_callback_max_queue_byte_size',
config.int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)),
help=('The size in bytes of the buffer used to hold outfeeds from each '
@ -555,7 +555,7 @@ _HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.DEFINE_integer(
'until the Python callback consume more outfeeds.'),
lower_bound=int(16 * 1e6)
)
_HOST_CALLBACK_OUTFEED = config.DEFINE_bool(
_HOST_CALLBACK_OUTFEED = config.bool_flag(
'jax_host_callback_outfeed',
config.bool_env('JAX_HOST_CALLBACK_OUTFEED', False),
help=(
@ -564,7 +564,7 @@ _HOST_CALLBACK_OUTFEED = config.DEFINE_bool(
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
)
)
_HOST_CALLBACK_LEGACY = config.DEFINE_bool(
_HOST_CALLBACK_LEGACY = config.bool_flag(
'jax_host_callback_legacy',
config.bool_env('JAX_HOST_CALLBACK_LEGACY', True),
help=(