From ce0d9e9b9faa7f2efa548720658c301a3377e465 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 15 Apr 2024 10:35:50 +0100 Subject: [PATCH] Changed the naming of internal config APIs The new naming highlights that we have two kinds of configuration options: flags, set at most once, and states, which can be changed locally per thread via a context manager. The renames are * FlagHolder -> Flag * DEFINE_ -> _flag * _StateContextManager -> State * define__state -> _state --- jax/_src/compiler.py | 4 +- jax/_src/config.py | 268 +++++++++++++++--------------- jax/_src/core.py | 2 +- jax/_src/interpreters/mlir.py | 4 +- jax/_src/maps.py | 6 +- jax/_src/pallas/pallas_call.py | 2 +- jax/_src/pretty_printer.py | 2 +- jax/_src/test_util.py | 14 +- jax/_src/tpu_custom_call.py | 4 +- jax/_src/xla_bridge.py | 22 +-- jax/experimental/host_callback.py | 8 +- 11 files changed, 172 insertions(+), 164 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 7bc1fd66f..1d08b8296 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -37,13 +37,13 @@ from jax._src.lib.mlir import ir import numpy as np -_DISABLE_MOST_OPTIMIZATIONS = config.DEFINE_bool( +_DISABLE_MOST_OPTIMIZATIONS = config.bool_flag( 'jax_disable_most_optimizations', config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False), 'Try not to do much optimization work. This can be useful if the cost of ' 'optimization is greater than that of running a less-optimized program.') -_COMPILER_DETAILED_LOGGING_MIN_OPS = config.DEFINE_integer( +_COMPILER_DETAILED_LOGGING_MIN_OPS = config.int_flag( "jax_compiler_detailed_logging_min_ops", config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10), help=( diff --git a/jax/_src/config.py b/jax/_src/config.py index e63349c12..eac4475e2 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Hashable, Iterator +from collections.abc import Hashable, Iterator, Sequence import contextlib import functools import itertools @@ -22,7 +22,9 @@ import logging import os import sys import threading -from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast +from typing import ( + Any, Callable, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast, +) from jax._src import lib from jax._src.lib import jax_jit @@ -60,24 +62,24 @@ def int_env(varname: str, default: int) -> int: return int(os.getenv(varname, str(default))) -UPGRADE_BOOL_HELP = ( - " This will be enabled by default in future versions of JAX, at which " - "point all uses of the flag will be considered deprecated (following " - "the `API compatibility policy " - "`_).") +class ValueHolder(Protocol[_T]): + """A holder for a configuration value. -UPGRADE_BOOL_EXTRA_DESC = " (transient)" + There are two kinds of value holders: ``Flag``, which is assigned exactly + once and never modified after; and ``State``, which can be changed locally + within a thread via a context manager. + """ + + value: _T + + def _set(self, value: _T) -> None: ... class Config: _HAS_DYNAMIC_ATTRIBUTES = True def __init__(self): - # There are two kinds of value holders: FlagHolders, which hold global - # flags, and StateContextManagers, which hold state that can be changed - # locally within a thread. A value holder needs a `.value` property and a - # `._set()` method. - self._value_holders = {} + self._value_holders: dict[str, ValueHolder] = {} self.meta = {} self.use_absl = False self._contextmanager_flags = set() @@ -113,7 +115,7 @@ class Config: def config_with_absl(self): """Registers absl flags for the JAX configs. - E.g., for each JAX config defined using define_bool_state(), this method + E.g., for each JAX config defined using bool_state(), this method registers an absl boolean flag, with the same name. This is the recommended method to call if you use `app.run(main)` and you @@ -237,7 +239,8 @@ unset = _Unset() _thread_local_state = threading.local() -class _StateContextManager(Generic[_T]): +class State(Generic[_T]): + __slots__ = ( '_name', '_value', '_update_thread_local_hook', '_update_global_hook', '_validator', '_default_context_manager_value', '__doc__', '__name__', @@ -318,7 +321,16 @@ class _StateContextManager(Generic[_T]): update_global_hook(self._value) -def define_bool_state( +UPGRADE_BOOL_HELP = ( + " This will be enabled by default in future versions of JAX, at which " + "point all uses of the flag will be considered deprecated (following " + "the `API compatibility policy " + "`_).") + +UPGRADE_BOOL_EXTRA_DESC = " (transient)" + + +def bool_state( name: str, default: bool, help: str, @@ -327,7 +339,7 @@ def define_bool_state( update_thread_local_hook: Callable[[bool | None], None] | None = None, upgrade: bool = False, extra_description: str = '', -) -> _StateContextManager[bool]: +) -> State[bool]: """Set up thread-local state and return a contextmanager for managing it. This function is a convenience wrapper. It defines a flag, environment @@ -360,7 +372,7 @@ def define_bool_state( Example: - enable_foo = config.define_bool_state( + ENABLE_FOO = config.bool_state( name='jax_enable_foo', default=False, help='Enable foo.') @@ -388,7 +400,7 @@ def define_bool_state( extra_description += UPGRADE_BOOL_EXTRA_DESC config._contextmanager_flags.add(name) - s = _StateContextManager[bool]( + s = State[bool]( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, extra_description=extra_description, default_context_manager_value=True) @@ -397,18 +409,18 @@ def define_bool_state( return s -def define_enum_state( +def enum_state( name: str, - enum_values: list[str], + enum_values: Sequence[str], default: str, help: str, *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str]: +) -> State[str]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -437,7 +449,7 @@ def define_enum_state( raise ValueError(f"new enum value must be in {enum_values}, " f"got {new_val} of type {type(new_val)}.") - s = _StateContextManager[str]( + s = State[str]( name, default, help, @@ -454,18 +466,18 @@ def define_enum_state( return s -def define_optional_enum_state( +def optional_enum_state( name: str, - enum_values: list[str], + enum_values: Sequence[str], default: str | None, help: str, *, update_global_hook: Callable[[str | None], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str | None]: +) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -495,7 +507,7 @@ def define_optional_enum_state( raise ValueError(f"new enum value must be None or in {enum_values}, " f"got {new_val} of type {type(new_val)}.") - s = _StateContextManager['str | None']( + s = State['str | None']( name, default, help, update_global_hook, update_thread_local_hook, validate ) @@ -508,17 +520,17 @@ def define_optional_enum_state( return s -def define_int_state( +def int_state( name: str, default: int, help: str, *, update_global_hook: Callable[[int], None] | None = None, update_thread_local_hook: Callable[[int | None], None] | None = None, -) -> _StateContextManager[int]: +) -> State[int]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -548,24 +560,24 @@ def define_int_state( raise ValueError(f'new int config value must be None or of type int, ' f'got {new_val} of type {type(new_val)}') - s = _StateContextManager[int](name, default, help, update_global_hook, - update_thread_local_hook, validate) + s = State[int](name, default, help, update_global_hook, + update_thread_local_hook, validate) config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s -def define_float_state( +def float_state( name: str, default: float, help: str, *, update_global_hook: Callable[[float], None] | None = None, update_thread_local_hook: Callable[[float | None], None] | None = None, -) -> _StateContextManager[float]: +) -> State[float]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -596,24 +608,24 @@ def define_float_state( f'new float config value must be None or of type float, ' f'got {new_val} of type {type(new_val)}') - s = _StateContextManager[float](name, default, help, update_global_hook, - update_thread_local_hook, validate) + s = State[float](name, default, help, update_global_hook, + update_thread_local_hook, validate) config.add_option(name, s, float, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s -def define_string_state( +def string_state( name: str, default: str, help: str, *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str]: +) -> State[str]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -640,24 +652,24 @@ def define_string_state( raise TypeError('new string config value must be of type str,' f' got {new_val} of type {type(new_val)}.') - return define_string_or_object_state( + return string_or_object_state( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, validator=validator) -def define_optional_string_state( +def optional_string_state( name: str, default: str | None, help: str, *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str | None]: +) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -684,13 +696,13 @@ def define_optional_string_state( raise ValueError('new string config value must be None or of type str,' f' got {new_val} of type {type(new_val)}.') - return define_string_or_object_state( + return string_or_object_state( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, validator=validator) -def define_string_or_object_state( +def string_or_object_state( name: str, default: Any, help: str, @@ -698,10 +710,10 @@ def define_string_or_object_state( update_global_hook: Callable[[Any], None] | None = None, update_thread_local_hook: Callable[[Any], None] | None = None, validator: Callable[[Any], None] | None = None, -) -> _StateContextManager[Any]: +) -> State[Any]: """Set up thread-local state and return a contextmanager for managing it. - Similar to ``define_string_state``, except the context manager will accept + Similar to ``string_state``, except the context manager will accept any object, not just a string. Any value passed via command line flag or environment variable will be treated as a string. @@ -728,7 +740,7 @@ def define_string_or_object_state( default = os.getenv(name.upper(), default) config._contextmanager_flags.add(name) - s = _StateContextManager[Any]( + s = State[Any]( name, default, help, update_global_hook, update_thread_local_hook, validator) setattr(Config, name, property(lambda _: s.value)) @@ -736,7 +748,8 @@ def define_string_or_object_state( return s -class FlagHolder(Generic[_T]): +class Flag(Generic[_T]): + __slots__ = ("_name", "value", "_update_hook") _name: str @@ -761,42 +774,37 @@ class FlagHolder(Generic[_T]): self._update_hook(value) -def check_exists(name): - if name not in config._value_holders: - raise AttributeError(f"Unrecognized config option: {name}") - - -def DEFINE_bool(name, default, *args, **kwargs) -> FlagHolder[bool]: +def bool_flag(name, default, *args, **kwargs) -> Flag[bool]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, bool, args, kwargs) return holder -def DEFINE_integer(name, default, *args, **kwargs) -> FlagHolder[int]: +def int_flag(name, default, *args, **kwargs) -> Flag[int]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, int, args, kwargs) return holder -def DEFINE_float(name, default, *args, **kwargs) -> FlagHolder[float]: +def float_flag(name, default, *args, **kwargs) -> Flag[float]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, float, args, kwargs) return holder -def DEFINE_string(name, default, *args, **kwargs) -> FlagHolder[str]: +def string_flag(name, default, *args, **kwargs) -> Flag[str]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, str, args, kwargs) return holder -def DEFINE_enum(name, default, *args, **kwargs) -> FlagHolder[str]: +def enum_flag(name, default, *args, **kwargs) -> Flag[str]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, 'enum', args, kwargs) return holder @@ -885,7 +893,7 @@ def update_thread_local_jit_state(**kw): # TODO(b/214340779): remove flag when XLA:CPU is improved. -jax2tf_associative_scan_reductions = define_bool_state( +jax2tf_associative_scan_reductions = bool_state( name='jax2tf_associative_scan_reductions', default=False, help=( @@ -900,7 +908,7 @@ jax2tf_associative_scan_reductions = define_bool_state( ) ) -jax2tf_default_native_serialization = define_bool_state( +jax2tf_default_native_serialization = bool_state( name='jax2tf_default_native_serialization', default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True), help=( @@ -910,7 +918,7 @@ jax2tf_default_native_serialization = define_bool_state( ) ) -jax_serialization_version = define_int_state( +jax_serialization_version = int_state( name='jax_serialization_version', default=int_env('JAX_SERIALIZATION_VERSION', 0), # We use 0 to detect default. help=( @@ -918,7 +926,7 @@ jax_serialization_version = define_int_state( ) ) -jax_export_calling_convention_version = define_int_state( +jax_export_calling_convention_version = int_state( name='jax_export_calling_convention_version', # Note: bump the default calling convention version at least one month after # we update XlaCallModule to support the new version, so that serialized @@ -933,7 +941,7 @@ jax_export_calling_convention_version = define_int_state( ) ) -jax_platforms = define_optional_string_state( +jax_platforms = optional_string_state( name='jax_platforms', default=None, help=( @@ -949,18 +957,18 @@ jax_platforms = define_optional_string_state( 'otherwise.' )) -jax_pjrt_client_create_options = define_optional_string_state( +jax_pjrt_client_create_options = optional_string_state( name='jax_pjrt_client_create_options', default=None, help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings ' 'provided to a device platform pjrt client as extra arguments.')) -enable_checks = define_bool_state( +enable_checks = bool_state( name='jax_enable_checks', default=False, help='Turn on invariant checking for JAX internals. Makes things slower.') -debug_key_reuse = define_bool_state( +debug_key_reuse = bool_state( name='jax_debug_key_reuse', default=False, help=('Turn on experimental key reuse checking. With this configuration enabled,' @@ -969,7 +977,7 @@ debug_key_reuse = define_bool_state( ' an error. Currently enabling this leads to a small Python overhead on' ' every call to a JIT-compiled function with keys as inputs or outputs.')) -check_tracer_leaks = define_bool_state( +check_tracer_leaks = bool_state( name='jax_check_tracer_leaks', default=False, help=('Turn on checking for leaked tracers as soon as a trace completes. ' @@ -979,7 +987,7 @@ check_tracer_leaks = define_bool_state( 'to disable any debuggers while leak checking is enabled.')) checking_leaks = functools.partial(check_tracer_leaks, True) -debug_nans = define_bool_state( +debug_nans = bool_state( name='jax_debug_nans', default=False, help=('Add nan checks to every operation. When a nan is detected on the ' @@ -987,7 +995,7 @@ debug_nans = define_bool_state( 'version in an attempt to more precisely identify the operation ' 'which produced the nan.')) -debug_infs = define_bool_state( +debug_infs = bool_state( name='jax_debug_infs', default=False, help=('Add inf checks to every operation. When an inf is detected on the ' @@ -995,7 +1003,7 @@ debug_infs = define_bool_state( 'version in an attempt to more precisely identify the operation ' 'which produced the inf.')) -log_compiles = define_bool_state( +log_compiles = bool_state( name='jax_log_compiles', default=False, help=('Log a message each time `jit` or `pmap` compiles an XLA ' @@ -1003,7 +1011,7 @@ log_compiles = define_bool_state( 'option is set, the log level is WARNING; otherwise the level is ' 'DEBUG.')) -explain_cache_misses = define_bool_state( +explain_cache_misses = bool_state( name='jax_explain_cache_misses', default=False, help=('Each time there is a miss on one of the main caches (e.g. the ' @@ -1011,14 +1019,14 @@ explain_cache_misses = define_bool_state( '`logging`. When this option is set, the log level is WARNING; ' 'otherwise the level is DEBUG.')) -log_checkpoint_residuals = define_bool_state( +log_checkpoint_residuals = bool_state( name='jax_log_checkpoint_residuals', default=False, help=('Log a message every time jax.checkpoint (aka jax.remat) is ' 'partially evaluated (e.g. for autodiff), printing what residuals ' 'are saved.')) -pmap_shmap_merge = define_bool_state( +pmap_shmap_merge = bool_state( name='jax_pmap_shmap_merge', default=False, upgrade=True, @@ -1030,7 +1038,7 @@ def _update_jax_memories_global(val): def _update_jax_memories_thread_local(val): lib.jax_jit.thread_local_state().enable_memories = val -enable_memories = define_bool_state( +enable_memories = bool_state( 'jax_enable_memories', default=False, upgrade=True, @@ -1039,7 +1047,7 @@ enable_memories = define_bool_state( help=("If True, will allow fetching memory kinds available on executable " "and annotate Shardings with it.")) -spmd_mode = define_enum_state( +spmd_mode = enum_state( name='jax_spmd_mode', enum_values=['allow_all', 'allow_jit'], default='allow_jit', @@ -1052,14 +1060,14 @@ spmd_mode = define_enum_state( " execute on non-fully addressable `jax.Array`s.")) -distributed_debug = define_bool_state( +distributed_debug = bool_state( name='jax_distributed_debug', default=False, help=('Enable logging useful for debugging multi-process distributed ' 'computations. Logging is performed with `logging` at WARNING ' 'level.')) -random_seed_offset = define_int_state( +random_seed_offset = int_state( name='jax_random_seed_offset', default=0, help=('Offset to all random seeds (e.g. argument to jax.random.key()).'), @@ -1069,7 +1077,7 @@ random_seed_offset = define_int_state( random_seed_offset=val) ) -legacy_prng_key = define_enum_state( +legacy_prng_key = enum_state( name='jax_legacy_prng_key', enum_values=['allow', 'warn', 'error'], default='allow', @@ -1077,21 +1085,21 @@ legacy_prng_key = define_enum_state( 'jax.random APIs.') ) -enable_custom_prng = define_bool_state( +enable_custom_prng = bool_state( name='jax_enable_custom_prng', default=False, upgrade=True, help=('Enables an internal upgrade that allows one to define custom ' 'pseudo-random number generator implementations.')) -default_prng_impl = define_enum_state( +default_prng_impl = enum_state( name='jax_default_prng_impl', enum_values=['threefry2x32', 'rbg', 'unsafe_rbg'], default='threefry2x32', help=('Select the default PRNG implementation, used when one is not ' 'explicitly provided at seeding time.')) -threefry_partitionable = define_bool_state( +threefry_partitionable = bool_state( name='jax_threefry_partitionable', default=False, upgrade=True, @@ -1106,7 +1114,7 @@ threefry_partitionable = define_bool_state( update_thread_local_hook=lambda val: update_thread_local_jit_state( threefry_partitionable=val)) -threefry_gpu_kernel_lowering = define_bool_state( +threefry_gpu_kernel_lowering = bool_state( name='jax_threefry_gpu_kernel_lowering', default=False, help=('On GPU, lower threefry PRNG operations to a kernel implementation. ' @@ -1118,7 +1126,7 @@ threefry_gpu_kernel_lowering = define_bool_state( threefry_gpu_kernel_lowering=val)) -softmax_custom_jvp = define_bool_state( +softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', default=False, upgrade=True, @@ -1131,14 +1139,14 @@ softmax_custom_jvp = define_bool_state( softmax_custom_jvp=val)) -enable_custom_vjp_by_custom_transpose = define_bool_state( +enable_custom_vjp_by_custom_transpose = bool_state( name='jax_enable_custom_vjp_by_custom_transpose', default=False, upgrade=True, help=('Enables an internal upgrade that implements `jax.custom_vjp` by ' 'reduction to `jax.custom_jvp` and `jax.custom_transpose`.')) -raise_persistent_cache_errors = define_bool_state( +raise_persistent_cache_errors = bool_state( name='jax_raise_persistent_cache_errors', default=False, help=('If true, exceptions raised when reading or writing to the ' @@ -1148,14 +1156,14 @@ raise_persistent_cache_errors = define_bool_state( 'continue. Defaults to false so cache bugs or intermittent issues ' 'are non-fatal.')) -persistent_cache_min_compile_time_secs = define_float_state( +persistent_cache_min_compile_time_secs = float_state( name='jax_persistent_cache_min_compile_time_secs', default=1., help=('The minimum compile time of a computation to be written to the ' 'persistent compilation cache. This threshold can be raised to ' 'decrease the number of entries written to the cache.')) -persistent_cache_min_entry_size_bytes = define_int_state( +persistent_cache_min_entry_size_bytes = int_state( name='jax_persistent_cache_min_entry_size_bytes', default=0, help=('The minimum size (in bytes) of an entry that will be cached in the ' @@ -1166,7 +1174,7 @@ persistent_cache_min_entry_size_bytes = define_int_state( ' filesystem being used for the cache. ' '* > 0: the actual minimum size desired; no overrides.')) -compilation_cache_include_metadata_in_key = define_bool_state( +compilation_cache_include_metadata_in_key = bool_state( name='jax_compilation_cache_include_metadata_in_key', default=False, help=( @@ -1178,7 +1186,7 @@ compilation_cache_include_metadata_in_key = define_bool_state( ), ) -hlo_source_file_canonicalization_regex = define_optional_string_state( +hlo_source_file_canonicalization_regex = optional_string_state( name='jax_hlo_source_file_canonicalization_regex', default=None, help=('Used to canonicalize the source_path metadata of HLO instructions ' @@ -1188,7 +1196,7 @@ hlo_source_file_canonicalization_regex = define_optional_string_state( 'persistent compilation cache, which includes HLO metadata in the ' 'cache key.')) -include_full_tracebacks_in_locations = define_bool_state( +include_full_tracebacks_in_locations = bool_state( name='jax_include_full_tracebacks_in_locations', default=True, help=( @@ -1196,7 +1204,7 @@ include_full_tracebacks_in_locations = define_bool_state( ), ) -traceback_in_locations_limit = define_int_state( +traceback_in_locations_limit = int_state( name='jax_traceback_in_locations_limit', default=10, help=( @@ -1206,7 +1214,7 @@ traceback_in_locations_limit = define_int_state( ), ) -share_autotune_config_between_hosts = define_bool_state( +share_autotune_config_between_hosts = bool_state( name='jax_share_autotune_config_between_hosts', default=False, help=( @@ -1220,7 +1228,7 @@ share_autotune_config_between_hosts = define_bool_state( ), ) -share_binary_between_hosts = define_bool_state( +share_binary_between_hosts = bool_state( name='jax_share_binary_between_hosts', default=False, help=( @@ -1229,13 +1237,13 @@ share_binary_between_hosts = define_bool_state( ), ) -share_binary_between_hosts_timeout_ms = define_int_state( +share_binary_between_hosts_timeout_ms = int_state( name='jax_share_binary_between_hosts_timeout_ms', default=20 * 60 * 1000, help='Timeout for the compiled module share.', ) -enable_pgle = define_bool_state( +enable_pgle = bool_state( name='jax_enable_pgle', default=False, help=( @@ -1249,7 +1257,7 @@ enable_pgle = define_bool_state( enable_pgle=val), ) -pgle_profiling_runs = define_int_state( +pgle_profiling_runs = int_state( name='jax_pgle_profiling_runs', default=3, help=( @@ -1264,14 +1272,14 @@ pgle_profiling_runs = define_int_state( ), ) -pgle_aggregation_percentile = define_int_state( +pgle_aggregation_percentile = int_state( name='jax_pgle_aggregation_percentile', default=90, help='Percentile used to aggregate performance data between devices when ' 'PGLE is used.', ) -enable_compilation_cache = define_bool_state( +enable_compilation_cache = bool_state( name='jax_enable_compilation_cache', default=True, help=('If set to False, the compilation cache will be disabled regardless ' @@ -1280,7 +1288,7 @@ enable_compilation_cache = define_bool_state( 'set_cache_dir().'), ) -compilation_cache_dir = define_optional_string_state( +compilation_cache_dir = optional_string_state( name='jax_compilation_cache_dir', default=None, help=('Path for the cache. ' @@ -1289,7 +1297,7 @@ compilation_cache_dir = define_optional_string_state( '2. The value of this flag set in the command line or by default.'), ) -compilation_cache_max_size = define_int_state( +compilation_cache_max_size = int_state( name='jax_compilation_cache_max_size', default=-1, help=('The maximum size (in bytes) allowed for the persistent compilation ' @@ -1301,7 +1309,7 @@ compilation_cache_max_size = define_int_state( 'size to grow indefinitely.'), ) -default_dtype_bits = define_enum_state( +default_dtype_bits = enum_state( name='jax_default_dtype_bits', enum_values=['32', '64'], default='64', @@ -1309,7 +1317,7 @@ default_dtype_bits = define_enum_state( 'This is a temporary flag that will be used during the process ' 'of deprecating the ``jax_enable_x64`` flag.')) -numpy_dtype_promotion = define_enum_state( +numpy_dtype_promotion = enum_state( name='jax_numpy_dtype_promotion', enum_values=['standard', 'strict'], default='standard', @@ -1328,7 +1336,7 @@ def _update_x64_global(val): def _update_x64_thread_local(val): lib.jax_jit.thread_local_state().enable_x64 = val -enable_x64 = define_bool_state( +enable_x64 = bool_state( name='jax_enable_x64', default=False, help='Enable 64-bit types to be used', @@ -1363,7 +1371,7 @@ def _validate_default_device(val): # TODO(skye): default_device only accepts devices for now. Make it work with # platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). -default_device = define_string_or_object_state( +default_device = string_or_object_state( name='jax_default_device', default=None, help=( @@ -1383,7 +1391,7 @@ def _update_disable_jit_global(val): def _update_disable_jit_thread_local(val): lib.jax_jit.thread_local_state().disable_jit = val -disable_jit = define_bool_state( +disable_jit = bool_state( name='jax_disable_jit', default=False, help=('Disable JIT compilation and just call original Python.'), @@ -1391,7 +1399,7 @@ disable_jit = define_bool_state( update_thread_local_hook=_update_disable_jit_thread_local) -numpy_rank_promotion = define_enum_state( +numpy_rank_promotion = enum_state( name='jax_numpy_rank_promotion', enum_values=['allow', 'warn', 'raise'], default='allow', @@ -1402,7 +1410,7 @@ numpy_rank_promotion = define_enum_state( update_thread_local_hook=lambda val: \ update_thread_local_jit_state(numpy_rank_promotion=val)) -default_matmul_precision = define_optional_enum_state( +default_matmul_precision = optional_enum_state( name='jax_default_matmul_precision', enum_values=['bfloat16', 'tensorfloat32', 'float32'], default=None, @@ -1427,7 +1435,7 @@ default_matmul_precision = define_optional_enum_state( update_thread_local_hook=lambda val: \ update_thread_local_jit_state(default_matmul_precision=val)) -traceback_filtering = define_enum_state( +traceback_filtering = enum_state( name = 'jax_traceback_filtering', enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames", "auto"], @@ -1448,14 +1456,14 @@ traceback_filtering = define_enum_state( # This flag is for internal use. # TODO(tianjianlu): Removes once we always enable cusparse lowering. # TODO(b/262050896): Set to true after bug is fixed -bcoo_cusparse_lowering = define_bool_state( +bcoo_cusparse_lowering = bool_state( name='jax_bcoo_cusparse_lowering', default=False, help=('Enables lowering BCOO ops to cuSparse.')) # TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging # if the intended backend can handle lowering the result -dynamic_shapes = define_bool_state( +dynamic_shapes = bool_state( name='jax_dynamic_shapes', default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')), help=('Enables experimental features for staging out computations with ' @@ -1467,26 +1475,26 @@ dynamic_shapes = define_bool_state( # This flag is temporary during rollout of the remat barrier. # TODO(parkers): Remove if there are no complaints. -remat_opt_barrier = define_bool_state( +remat_opt_barrier = bool_state( name='jax_remat_opt_barrier', default=True, help=('Enables using optimization-barrier op for lowering remat.')) # TODO(sharadmv,mattjj): set default to True, then remove -eager_pmap = define_bool_state( +eager_pmap = bool_state( name='jax_eager_pmap', default=True, upgrade=True, help='Enable eager-mode pmap when jax_disable_jit is activated.') # TODO(mattjj): remove once we land mutable array plumbing, or face great shame -custom_vjp_disable_shape_check = define_bool_state( +custom_vjp_disable_shape_check = bool_state( name='jax_custom_vjp_disable_shape_check', default=False, upgrade=True, help='Disable the check from #19009 to enable some custom_vjp hacks.') -xla_runtime_errors = define_bool_state( +xla_runtime_errors = bool_state( name='jax_experimental_unsafe_xla_runtime_errors', default=False, help=('Enable XLA runtime errors for jax.experimental.checkify.checks ' @@ -1496,7 +1504,7 @@ xla_runtime_errors = define_bool_state( 'work under pmap/pjit.') ) -jax_xla_profile_version = define_int_state( +jax_xla_profile_version = int_state( name='jax_xla_profile_version', default=0, help=( @@ -1548,7 +1556,7 @@ def _update_transfer_guard(state, key, val): else: assert False, f'Invalid transfer guard level {val}' -transfer_guard_host_to_device = define_optional_enum_state( +transfer_guard_host_to_device = optional_enum_state( name='jax_transfer_guard_host_to_device', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1563,7 +1571,7 @@ transfer_guard_host_to_device = define_optional_enum_state( update_thread_local_hook=lambda val: _update_transfer_guard( transfer_guard_lib.thread_local_state(), 'host_to_device', val)) -transfer_guard_device_to_device = define_optional_enum_state( +transfer_guard_device_to_device = optional_enum_state( name='jax_transfer_guard_device_to_device', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1578,7 +1586,7 @@ transfer_guard_device_to_device = define_optional_enum_state( update_thread_local_hook=lambda val: _update_transfer_guard( transfer_guard_lib.thread_local_state(), 'device_to_device', val)) -transfer_guard_device_to_host = define_optional_enum_state( +transfer_guard_device_to_host = optional_enum_state( name='jax_transfer_guard_device_to_host', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1599,7 +1607,7 @@ def _update_all_transfer_guard_global(val): 'jax_transfer_guard_device_to_host'): config.update(name, val) -_transfer_guard = define_optional_enum_state( +_transfer_guard = optional_enum_state( name='jax_transfer_guard', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1643,7 +1651,7 @@ def _update_debug_log_modules(module_names_str: str | None): logging_config.enable_debug_logging(module_name) # Don't define a context manager since this isn't threadsafe. -define_string_state( +string_state( name='jax_debug_log_modules', default='', help=('Comma-separated list of module names (e.g. "jax" or ' @@ -1651,7 +1659,7 @@ define_string_state( 'for.'), update_global_hook=_update_debug_log_modules) -pmap_no_rank_reduction = define_bool_state( +pmap_no_rank_reduction = bool_state( name='jax_pmap_no_rank_reduction', default=False, help=( diff --git a/jax/_src/core.py b/jax/_src/core.py index 78146cc2b..812da70fd 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -62,7 +62,7 @@ zip, unsafe_zip = safe_zip, zip map, unsafe_map = safe_map, map -_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.DEFINE_integer( +_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.int_flag( 'jax_tracer_error_num_traceback_frames', config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5), help='Set the number of stack frames in JAX tracer error messages.' diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a3e42bce3..3ba40907b 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -67,14 +67,14 @@ Value = Any # = ir.Value # mypy implicitly sets this variable to true when type checking. MYPY = False -_JAX_DUMP_IR_TO = config.DEFINE_string( +_JAX_DUMP_IR_TO = config.string_flag( 'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''), help="Path to which the IR that is emitted by JAX should be dumped as " "text files. If omitted, JAX will not dump IR. " "Supports the special value 'sponge' to pick the path from the " "environment variable TEST_UNDECLARED_OUTPUTS_DIR.") -_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.DEFINE_string( +_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.string_flag( 'jax_include_debug_info_in_dumps', os.getenv('JAX_INCLUDE_DEBUG_INFO_IN_DUMPS', "True"), help="Determine whether or not to keep debug symbols and location information " diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 10713b624..98dc6e855 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -1850,7 +1850,7 @@ def _ensure_spmd_and(f): return update -SPMD_LOWERING = config.define_bool_state( +SPMD_LOWERING = config.bool_state( name="experimental_xmap_spmd_lowering", default=False, help=("When set, multi-device xmap computations will be compiled through " @@ -1858,7 +1858,7 @@ SPMD_LOWERING = config.define_bool_state( "Not supported on CPU!"), update_global_hook=_clear_compilation_cache, update_thread_local_hook=_thread_local_flag_unsupported) -SPMD_LOWERING_MANUAL = config.define_bool_state( +SPMD_LOWERING_MANUAL = config.bool_state( name="experimental_xmap_spmd_lowering_manual", default=False, help=("When set, multi-device xmap computations will be compiled using " @@ -1867,7 +1867,7 @@ SPMD_LOWERING_MANUAL = config.define_bool_state( "Requires experimental_xmap_spmd_lowering!"), update_global_hook=_ensure_spmd_and(_clear_compilation_cache), update_thread_local_hook=_thread_local_flag_unsupported) -_ENSURE_FIXED_SHARDING = config.define_bool_state( +_ENSURE_FIXED_SHARDING = config.bool_state( name="experimental_xmap_ensure_fixed_sharding", default=False, help=("When set and `experimental_xmap_spmd_lowering` is enabled, the lowering will " diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 973ca4054..bc9d66f89 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -925,7 +925,7 @@ def _extract_function_name(f: Callable, name: str | None) -> str: return name -_PALLAS_USE_MOSAIC_GPU = config.DEFINE_bool( +_PALLAS_USE_MOSAIC_GPU = config.bool_flag( "jax_pallas_use_mosaic_gpu", default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False), help=( diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index 0614bb8a8..5c1e7e119 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -42,7 +42,7 @@ except ImportError: colorama = None -_PPRINT_USE_COLOR = config.DEFINE_bool( +_PPRINT_USE_COLOR = config.bool_flag( 'jax_pprint_use_color', config.bool_env('JAX_PPRINT_USE_COLOR', True), help='Enable jaxpr pretty-printing with colorful syntax highlighting.' diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index df8868929..f5a6208f8 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -66,18 +66,18 @@ import numpy.random as npr # jax.test_util. Functionality appearing here is for internal use only, and # may be changed or removed at any time and without any deprecation cycle. -_TEST_DUT = config.DEFINE_string( +_TEST_DUT = config.string_flag( 'jax_test_dut', '', help= 'Describes the device under test in case special consideration is required.' ) -NUM_GENERATED_CASES = config.DEFINE_integer( +NUM_GENERATED_CASES = config.int_flag( 'jax_num_generated_cases', int(os.getenv('JAX_NUM_GENERATED_CASES', '10')), help='Number of generated cases to test') -_MAX_CASES_SAMPLING_RETRIES = config.DEFINE_integer( +_MAX_CASES_SAMPLING_RETRIES = config.int_flag( 'max_cases_sampling_retries', int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')), 'Number of times a failed test sample should be retried. ' @@ -85,23 +85,23 @@ _MAX_CASES_SAMPLING_RETRIES = config.DEFINE_integer( 'sampling process is terminated.' ) -_SKIP_SLOW_TESTS = config.DEFINE_bool( +_SKIP_SLOW_TESTS = config.bool_flag( 'jax_skip_slow_tests', config.bool_env('JAX_SKIP_SLOW_TESTS', False), help='Skip tests marked as slow (> 5 sec).' ) -_TEST_TARGETS = config.DEFINE_string( +_TEST_TARGETS = config.string_flag( 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), 'Regular expression specifying which tests to run, called via re.search on ' 'the test name. If empty or unspecified, run all tests.' ) -_EXCLUDE_TEST_TARGETS = config.DEFINE_string( +_EXCLUDE_TEST_TARGETS = config.string_flag( 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), 'Regular expression specifying which tests NOT to run, called via re.search ' 'on the test name. If empty or unspecified, run all tests.' ) -TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.DEFINE_bool( +TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag( 'jax_test_with_persistent_compilation_cache', config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False), help='If enabled, the persistent compilation cache will be enabled for all ' diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index e3702d2b0..7b88d5314 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -47,7 +47,7 @@ try: except ImportError: FLAGS = {} -_MOSAIC_USE_PYTHON_PIPELINE = config.define_bool_state( +_MOSAIC_USE_PYTHON_PIPELINE = config.bool_state( name="mosaic_use_python_pipeline", default=False, help=( @@ -57,7 +57,7 @@ _MOSAIC_USE_PYTHON_PIPELINE = config.define_bool_state( ), ) -_MOSAIC_ALLOW_HLO = config.define_bool_state( +_MOSAIC_ALLOW_HLO = config.bool_state( name="jax_mosaic_allow_hlo", default=False, help="Allow hlo dialects in Mosaic", diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index d370a31b9..c6c03acd2 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -65,46 +65,46 @@ XlaBackend = xla_client.Client MIN_COMPUTE_CAPABILITY = 52 # TODO(phawkins): Remove jax_xla_backend. -_XLA_BACKEND = config.DEFINE_string( +_XLA_BACKEND = config.string_flag( 'jax_xla_backend', '', 'Deprecated, please use --jax_platforms instead.') -BACKEND_TARGET = config.DEFINE_string( +BACKEND_TARGET = config.string_flag( 'jax_backend_target', os.getenv('JAX_BACKEND_TARGET', '').lower(), 'Either "local" or "rpc:address" to connect to a remote service target.') # TODO(skye): warn when this is used once we test out --jax_platforms a bit -_PLATFORM_NAME = config.DEFINE_string( +_PLATFORM_NAME = config.string_flag( 'jax_platform_name', os.getenv('JAX_PLATFORM_NAME', '').lower(), 'Deprecated, please use --jax_platforms instead.') -CUDA_VISIBLE_DEVICES = config.DEFINE_string( +CUDA_VISIBLE_DEVICES = config.string_flag( 'jax_cuda_visible_devices', 'all', 'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') -_ROCM_VISIBLE_DEVICES = config.DEFINE_string( +_ROCM_VISIBLE_DEVICES = config.string_flag( 'jax_rocm_visible_devices', 'all', 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') -_USE_MOCK_GPU_CLIENT = config.DEFINE_bool( +_USE_MOCK_GPU_CLIENT = config.bool_flag( name="use_mock_gpu_client", default=False, help="If True, use a mock GPU client instead of a real one.", ) -_MOCK_NUM_GPUS = config.DEFINE_integer( +_MOCK_NUM_GPUS = config.int_flag( name="mock_num_gpus", default=1, help="Mock GPU client number of gpus.", ) -_CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool( +_CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag( name="jax_cpu_enable_gloo_collectives", default=False, help="Deprecated, please use jax_cpu_collectives_implementation instead.", ) -_CPU_COLLECTIVES_IMPLEMENTATION = config.DEFINE_string( +_CPU_COLLECTIVES_IMPLEMENTATION = config.string_flag( name='jax_cpu_collectives_implementation', default='none', help='Cross-process collective implementation used on CPU. Either "none", ' @@ -113,7 +113,7 @@ _CPU_COLLECTIVES_IMPLEMENTATION = config.DEFINE_string( # TODO(yueshengys): turn default back to True after resolving memory increase # issue. -_CPU_ENABLE_ASYNC_DISPATCH = config.DEFINE_bool( +_CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag( name="jax_cpu_enable_async_dispatch", default=False, help="Only applies to non-parallel computations. If False, run computations" @@ -417,7 +417,7 @@ def _check_cuda_versions(raise_on_first_error: bool = False, def make_gpu_client( - *, platform_name: str, visible_devices_flag: config.FlagHolder[str] + *, platform_name: str, visible_devices_flag: config.Flag[str] ) -> xla_client.Client: visible_devices = visible_devices_flag.value allowed_devices = None diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index b56bd7cec..a38261d57 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -541,12 +541,12 @@ from jax._src.lib.mlir.dialects import hlo import numpy as np -_HOST_CALLBACK_INLINE = config.DEFINE_bool( +_HOST_CALLBACK_INLINE = config.bool_flag( 'jax_host_callback_inline', config.bool_env('JAX_HOST_CALLBACK_INLINE', False), help='Inline the host_callback, if not in a staged context.' ) -_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.DEFINE_integer( +_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.int_flag( 'jax_host_callback_max_queue_byte_size', config.int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)), help=('The size in bytes of the buffer used to hold outfeeds from each ' @@ -555,7 +555,7 @@ _HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.DEFINE_integer( 'until the Python callback consume more outfeeds.'), lower_bound=int(16 * 1e6) ) -_HOST_CALLBACK_OUTFEED = config.DEFINE_bool( +_HOST_CALLBACK_OUTFEED = config.bool_flag( 'jax_host_callback_outfeed', config.bool_env('JAX_HOST_CALLBACK_OUTFEED', False), help=( @@ -564,7 +564,7 @@ _HOST_CALLBACK_OUTFEED = config.DEFINE_bool( 'Has no effect on TPU, since only the outfeed mechanism is implemented.' ) ) -_HOST_CALLBACK_LEGACY = config.DEFINE_bool( +_HOST_CALLBACK_LEGACY = config.bool_flag( 'jax_host_callback_legacy', config.bool_env('JAX_HOST_CALLBACK_LEGACY', True), help=(