rocm_jax/jax/_src/config.py
Sergei Lebedev ce0d9e9b9f Changed the naming of internal config APIs
The new naming highlights that we have two kinds of configuration options:
flags, set at most once, and states, which can be changed locally per thread
via a context manager.

The renames are

* FlagHolder -> Flag
* DEFINE_<type> -> <type>_flag
* _StateContextManager -> State
* define_<type>_state -> <type>_state
2024-06-18 11:48:57 +01:00

1669 lines
63 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Hashable, Iterator, Sequence
import contextlib
import functools
import itertools
import logging
import os
import sys
import threading
from typing import (
Any, Callable, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast,
)
from jax._src import lib
from jax._src.lib import jax_jit
from jax._src.lib import transfer_guard_lib
from jax._src.lib import xla_client
from jax._src import logging_config
logger = logging.getLogger(__name__)
_T = TypeVar('_T')
def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
false values are 'n', 'no', 'f', 'false', 'off', and '0'.
Args:
varname: the name of the variable
default: the default boolean value
Raises: ValueError if the environment variable is anything else.
"""
val = os.getenv(varname, str(default))
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return True
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return False
else:
raise ValueError(f"invalid truth value {val!r} for environment {varname!r}")
def int_env(varname: str, default: int) -> int:
"""Read an environment variable and interpret it as an integer."""
return int(os.getenv(varname, str(default)))
class ValueHolder(Protocol[_T]):
"""A holder for a configuration value.
There are two kinds of value holders: ``Flag``, which is assigned exactly
once and never modified after; and ``State``, which can be changed locally
within a thread via a context manager.
"""
value: _T
def _set(self, value: _T) -> None: ...
class Config:
_HAS_DYNAMIC_ATTRIBUTES = True
def __init__(self):
self._value_holders: dict[str, ValueHolder] = {}
self.meta = {}
self.use_absl = False
self._contextmanager_flags = set()
def update(self, name, val):
if name not in self._value_holders:
raise AttributeError(f"Unrecognized config option: {name}")
self._value_holders[name]._set(val)
def read(self, name):
if name in self._contextmanager_flags:
raise AttributeError(
"For flags with a corresponding contextmanager, read their value "
f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.")
return self._read(name)
def _read(self, name):
try:
return self._value_holders[name].value
except KeyError:
raise AttributeError(f"Unrecognized config option: {name}")
@property
def values(self):
return {name: holder.value for name, holder in self._value_holders.items()}
def add_option(self, name, holder, opt_type, meta_args, meta_kwargs):
if name in self._value_holders:
raise Exception(f"Config option {name} already defined")
self._value_holders[name] = holder
self.meta[name] = (opt_type, meta_args, meta_kwargs)
def config_with_absl(self):
"""Registers absl flags for the JAX configs.
E.g., for each JAX config defined using bool_state(), this method
registers an absl boolean flag, with the same name.
This is the recommended method to call if you use `app.run(main)` and you
need JAX flags. Example:
```python
from absl import app
import jax
...
if __name__ == '__main__':
jax.config.config_with_absl()
app.run(main)
```
"""
import absl.flags as absl_FLAGS # noqa: F401 # pytype: disable=import-error
from absl import app, flags as absl_flags # pytype: disable=import-error
self.use_absl = True
self.absl_flags = absl_flags
absl_defs = { bool: absl_flags.DEFINE_bool,
int: absl_flags.DEFINE_integer,
float: absl_flags.DEFINE_float,
str: absl_flags.DEFINE_string,
'enum': absl_flags.DEFINE_enum }
for name, (flag_type, meta_args, meta_kwargs) in self.meta.items():
holder = self._value_holders[name]
absl_defs[flag_type](name, holder.value, *meta_args, **meta_kwargs)
app.call_after_init(lambda: self.complete_absl_config(absl_flags))
def complete_absl_config(self, absl_flags):
# NOTE: avoid calling from outside this module. Instead, use
# `config_with_absl()`, and (in rare cases) `parse_flags_with_absl()`.
for name, holder in self._value_holders.items():
try:
flag = absl_flags.FLAGS[name]
except KeyError:
# This can happen if a new flag was added after config_with_absl() was
# called, but before complete_absl_config was run. We could in principle
# add code to DEFINE_... to register any newly added flags with ABSL
# if config_with_absl() has already been called, but arguably the user
# should have called config_with_absl() later.
continue
if flag.present:
holder._set(flag.value)
def parse_flags_with_absl(self):
"""Parses command-line args that start with --jax.
This method should be used only by advanced users. Most users should use
:meth:`config_with_absl` instead.
This method has serious limitations: e.g., although it parses only the
--jax* command-line args, it runs the validators of all registered absl
flags, even non-JAX ones that have not been set yet; as such, for the
non-JAX flags, the validators run on the default flag values, not on the
values indicated by the command-line args.
"""
global already_configured_with_absl
if not already_configured_with_absl:
# Extract just the --jax... flags (before the first --) from argv. In some
# environments (e.g. ipython/colab) argv might be a mess of things
# parseable by absl and other junk.
jax_argv = itertools.takewhile(lambda a: a != '--', sys.argv)
jax_argv = ['', *(a for a in jax_argv if a.startswith('--jax'))]
import absl.flags # pytype: disable=import-error
self.config_with_absl()
absl.flags.FLAGS(jax_argv, known_only=True)
self.complete_absl_config(absl.flags)
already_configured_with_absl = True
def trace_context():
"""Returns a tuple of configuration values that affect tracing.
These values are included in the cache key for linear_util.cache.
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately.
"""
tls = jax_jit.thread_local_state()
axis_env_state = ()
mesh_context_manager = ()
context: Any = tls.extra_jit_context
if context and context.axis_env_state is not None:
axis_env_state = context.axis_env_state
if context and context.mesh_context_manager:
mesh_context_manager = context.mesh_context_manager
return (axis_env_state, mesh_context_manager, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value, numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
debug_key_reuse.value,
jax_xla_profile_version.value,
# Technically this affects jaxpr->stablehlo lowering, not tracing.
hlo_source_file_canonicalization_regex.value,
pgle_profiling_runs.value,
enable_pgle.value)
config = Config()
_read = config._read
update = config.update
parse_flags_with_absl = config.parse_flags_with_absl
class NoDefault: pass
no_default = NoDefault()
class _Unset: pass
unset = _Unset()
_thread_local_state = threading.local()
class State(Generic[_T]):
__slots__ = (
'_name', '_value', '_update_thread_local_hook', '_update_global_hook',
'_validator', '_default_context_manager_value', '__doc__', '__name__',
)
def __init__(
self,
name: str,
default: _T,
help,
update_global_hook: Callable[[_T], None] | None = None,
update_thread_local_hook: Callable[[_T | None], None] | None = None,
validator: Callable[[Any], None] | None = None,
extra_description: str = '',
default_context_manager_value: Any = no_default,
):
self._name = name
self.__name__ = name[4:] if name.startswith('jax_') else name
self.__doc__ = (f"Context manager for `{name}` config option"
f"{extra_description}.\n\n{help}")
self._update_global_hook = update_global_hook
self._update_thread_local_hook = update_thread_local_hook
self._validator = validator
self._default_context_manager_value = default_context_manager_value
self._set(default)
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
def _set(self, value: _T) -> None:
self._value = value
if self._update_global_hook:
self._update_global_hook(value)
@property
def value(self) -> _T:
val = _thread_local_state.__dict__.get(self._name, unset)
return cast(_T, val) if val is not unset else self._value
@contextlib.contextmanager
def __call__(self, new_val: Any = no_default):
if new_val is no_default:
if self._default_context_manager_value is not no_default:
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
else:
# no default_value provided to constructor and no value provided as an
# argument, so we raise an error
raise TypeError(f"Context manager for {self.__name__} config option "
"requires an argument representing the new value for "
"the config option.")
if self._validator:
self._validator(new_val)
prev_val = getattr(_thread_local_state, self._name, unset)
setattr(_thread_local_state, self._name, new_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(new_val)
try:
yield
finally:
if prev_val is unset:
delattr(_thread_local_state, self._name)
if self._update_thread_local_hook:
self._update_thread_local_hook(None)
else:
setattr(_thread_local_state, self._name, prev_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(cast(_T, prev_val))
def _add_hooks(self, update_global_hook, update_thread_local_hook):
"""Private method that adds hooks to an existing context-manager.
Used to avoid cyclic import dependencies."""
self._update_thread_local_hook = update_thread_local_hook
self._update_global_hook = update_global_hook
update_global_hook(self._value)
UPGRADE_BOOL_HELP = (
" This will be enabled by default in future versions of JAX, at which "
"point all uses of the flag will be considered deprecated (following "
"the `API compatibility policy "
"<https://jax.readthedocs.io/en/latest/api_compatibility.html>`_).")
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
def bool_state(
name: str,
default: bool,
help: str,
*,
update_global_hook: Callable[[bool], None] | None = None,
update_thread_local_hook: Callable[[bool | None], None] | None = None,
upgrade: bool = False,
extra_description: str = '',
) -> State[bool]:
"""Set up thread-local state and return a contextmanager for managing it.
This function is a convenience wrapper. It defines a flag, environment
variable, and corresponding thread-local state, which can be managed via the
contextmanager it returns.
The thread-local state value can be read via the ``config.<option_name>``
attribute, where ``config`` is the singleton ``Config`` instance.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
default: boolean, a default value for the option.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
update_global_hook: a optional callback that is called with the updated
value of the global state when it is altered or set initially.
update_thread_local_hook: a optional callback that is called with the
updated value of the thread-local state when it is altered or set
initially.
upgrade: optional indicator that this flag controls a canonical feature
upgrade, so that it is `True` for the incoming functionality, `False`
for the outgoing functionality to be deprecated.
extra_description: string, optional: extra information to add to the
summary description.
Returns:
A contextmanager to control the thread-local state value.
Example:
ENABLE_FOO = config.bool_state(
name='jax_enable_foo',
default=False,
help='Enable foo.')
# Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo
# command-line flag can be used to control the process-level value of
# the configuration option, in addition to using e.g.
# ``config.update("jax_enable_foo", True)`` directly. We can also use a
# context manager:
with enable_foo(True):
...
The value of the thread-local state or flag can be accessed via
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
an error.
"""
if not isinstance(default, bool):
raise TypeError(f"Default value must be of type bool, got {default} "
f"of type {getattr(type(default), '__name__', type(default))}")
default = bool_env(name.upper(), default)
name = name.lower()
if upgrade:
help += ' ' + UPGRADE_BOOL_HELP
extra_description += UPGRADE_BOOL_EXTRA_DESC
config._contextmanager_flags.add(name)
s = State[bool](
name, default, help, update_global_hook=update_global_hook,
update_thread_local_hook=update_thread_local_hook,
extra_description=extra_description, default_context_manager_value=True)
config.add_option(name, s, bool, meta_args=[], meta_kwargs={"help": help})
setattr(Config, name, property(lambda _: s.value))
return s
def enum_state(
name: str,
enum_values: Sequence[str],
default: str,
help: str,
*,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> State[str]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
enum_values: list of strings representing the possible values for the
option.
default: string, default value.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
"""
if not isinstance(default, str):
raise TypeError(f"Default value must be of type str, got {default} "
f"of type {getattr(type(default), '__name__', type(default))}")
name = name.lower()
default = os.getenv(name.upper(), default)
if default not in enum_values:
raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}")
config._contextmanager_flags.add(name)
def validator(new_val):
if type(new_val) is not str or new_val not in enum_values:
raise ValueError(f"new enum value must be in {enum_values}, "
f"got {new_val} of type {type(new_val)}.")
s = State[str](
name,
default,
help,
update_global_hook=update_global_hook,
update_thread_local_hook=update_thread_local_hook,
validator=validator,
)
config.add_option(
name, s, 'enum',
meta_args=[],
meta_kwargs={"enum_values": enum_values, "help": help}
)
setattr(Config, name, property(lambda _: s.value))
return s
def optional_enum_state(
name: str,
enum_values: Sequence[str],
default: str | None,
help: str,
*,
update_global_hook: Callable[[str | None], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> State[str | None]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
enum_values: list of strings representing the possible values for the
option.
default: optional string, default value.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
"""
if default is not None and not isinstance(default, str):
raise TypeError(f"Default value must be of type str or None, got {default} "
f"of type {getattr(type(default), '__name__', type(default))}")
name = name.lower()
default = os.getenv(name.upper(), default)
if default is not None and default not in enum_values:
raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}")
config._contextmanager_flags.add(name)
def validate(new_val):
if (new_val is not None and
(type(new_val) is not str or new_val not in enum_values)):
raise ValueError(f"new enum value must be None or in {enum_values}, "
f"got {new_val} of type {type(new_val)}.")
s = State['str | None'](
name, default, help, update_global_hook, update_thread_local_hook,
validate
)
config.add_option(
name, s, 'enum',
meta_args=[],
meta_kwargs={"enum_values": enum_values, "help": help}
)
setattr(Config, name, property(lambda _: s.value))
return s
def int_state(
name: str,
default: int,
help: str,
*,
update_global_hook: Callable[[int], None] | None = None,
update_thread_local_hook: Callable[[int | None], None] | None = None,
) -> State[int]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
default: optional int, default value.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
"""
if not isinstance(default, int):
raise TypeError(f"Default value must be of type int, got {default} "
f"of type {getattr(type(default), '__name__', type(default))}")
name = name.lower()
default_env = os.getenv(name.upper())
if default_env is not None:
try:
default = int(default_env)
except ValueError:
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
config._contextmanager_flags.add(name)
def validate(new_val):
if new_val is not None and not isinstance(new_val, int):
raise ValueError(f'new int config value must be None or of type int, '
f'got {new_val} of type {type(new_val)}')
s = State[int](name, default, help, update_global_hook,
update_thread_local_hook, validate)
config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help})
setattr(Config, name, property(lambda _: s.value))
return s
def float_state(
name: str,
default: float,
help: str,
*,
update_global_hook: Callable[[float], None] | None = None,
update_thread_local_hook: Callable[[float | None], None] | None = None,
) -> State[float]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
default: default value.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
"""
if not isinstance(default, float):
raise TypeError(f"Default value must be of type float, got {default} "
f"of type {getattr(type(default), '__name__', type(default))}")
name = name.lower()
default_env = os.getenv(name.upper())
if default_env is not None:
try:
default = float(default_env)
except ValueError:
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
config._contextmanager_flags.add(name)
def validate(new_val):
if new_val is not None and not isinstance(new_val, (float, int)):
raise ValueError(
f'new float config value must be None or of type float, '
f'got {new_val} of type {type(new_val)}')
s = State[float](name, default, help, update_global_hook,
update_thread_local_hook, validate)
config.add_option(name, s, float, meta_args=[], meta_kwargs={"help": help})
setattr(Config, name, property(lambda _: s.value))
return s
def string_state(
name: str,
default: str,
help: str,
*,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> State[str]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
default: string, a default value for the option.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
update_global_hook: an optional callback that is called with the updated
value of the global state when it is altered or set initially.
update_thread_local_hook: an optional callback that is called with the
updated value of the thread-local state when it is altered or set
initially.
Returns:
A contextmanager to control the thread-local state value.
"""
if not isinstance(default, str):
raise TypeError(f"Default value must be of type str, got {default} "
f"of type {getattr(type(default), '__name__', type(default))}")
def validator(new_val):
if not isinstance(new_val, str):
raise TypeError('new string config value must be of type str,'
f' got {new_val} of type {type(new_val)}.')
return string_or_object_state(
name, default, help,
update_global_hook=update_global_hook,
update_thread_local_hook=update_thread_local_hook,
validator=validator)
def optional_string_state(
name: str,
default: str | None,
help: str,
*,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
) -> State[str | None]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
default: optional string, a default value for the option.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
update_global_hook: an optional callback that is called with the updated
value of the global state when it is altered or set initially.
update_thread_local_hook: an optional callback that is called with the
updated value of the thread-local state when it is altered or set
initially.
Returns:
A contextmanager to control the thread-local state value.
"""
if default is not None and not isinstance(default, str):
raise TypeError(f"Default value must be of type str or None, got {default} "
f"of type {getattr(type(default), '__name__', type(default))}")
def validator(new_val):
if new_val is not None and not isinstance(new_val, str):
raise ValueError('new string config value must be None or of type str,'
f' got {new_val} of type {type(new_val)}.')
return string_or_object_state(
name, default, help,
update_global_hook=update_global_hook,
update_thread_local_hook=update_thread_local_hook,
validator=validator)
def string_or_object_state(
name: str,
default: Any,
help: str,
*,
update_global_hook: Callable[[Any], None] | None = None,
update_thread_local_hook: Callable[[Any], None] | None = None,
validator: Callable[[Any], None] | None = None,
) -> State[Any]:
"""Set up thread-local state and return a contextmanager for managing it.
Similar to ``string_state``, except the context manager will accept
any object, not just a string. Any value passed via command line flag or
environment variable will be treated as a string.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
default: string, a default value for the option.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
update_global_hook: an optional callback that is called with the updated
value of the global state when it is altered or set initially.
update_thread_local_hook: an optional callback that is called with the
updated value of the thread-local state when it is altered or set
initially.
validator: an optional callback that is called with the new
value on any update, and should raise an error if the new value is
invalid.
Returns:
A contextmanager to control the thread-local state value.
"""
name = name.lower()
default = os.getenv(name.upper(), default)
config._contextmanager_flags.add(name)
s = State[Any](
name, default, help, update_global_hook, update_thread_local_hook,
validator)
setattr(Config, name, property(lambda _: s.value))
config.add_option(name, s, str, meta_args=[], meta_kwargs={"help": help})
return s
class Flag(Generic[_T]):
__slots__ = ("_name", "value", "_update_hook")
_name: str
value: _T
_update_hook: Callable[[Any], None] | None
def __init__(self, name: str, default: _T,
update_hook: Callable[[Any], None] | None = None):
self._name = name
self._update_hook = update_hook
self._set(default)
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
def _set(self, value: _T) -> None:
self.value = value
if self._update_hook is not None:
self._update_hook(value)
def bool_flag(name, default, *args, **kwargs) -> Flag[bool]:
update_hook = kwargs.pop("update_hook", None)
holder = Flag(name, default, update_hook)
config.add_option(name, holder, bool, args, kwargs)
return holder
def int_flag(name, default, *args, **kwargs) -> Flag[int]:
update_hook = kwargs.pop("update_hook", None)
holder = Flag(name, default, update_hook)
config.add_option(name, holder, int, args, kwargs)
return holder
def float_flag(name, default, *args, **kwargs) -> Flag[float]:
update_hook = kwargs.pop("update_hook", None)
holder = Flag(name, default, update_hook)
config.add_option(name, holder, float, args, kwargs)
return holder
def string_flag(name, default, *args, **kwargs) -> Flag[str]:
update_hook = kwargs.pop("update_hook", None)
holder = Flag(name, default, update_hook)
config.add_option(name, holder, str, args, kwargs)
return holder
def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
update_hook = kwargs.pop("update_hook", None)
holder = Flag(name, default, update_hook)
config.add_option(name, holder, 'enum', args, kwargs)
return holder
already_configured_with_absl = False
# The C++ JIT maintains its own copy of several configuration items as
# a global/thread-local state. These methods allow updates to part of the
# state when a configuration value changes.
class _GlobalExtraJitContext(NamedTuple):
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool = False
random_seed_offset: int = 0
threefry_partitionable: bool = False
threefry_gpu_kernel_lowering: bool = False
softmax_custom_jvp: bool = False
xla_profile_version: int = 0
pgle_profiling_runs: int = 0
enable_pgle: bool = False
def _update_global_jit_state(**kw):
gs = jax_jit.global_state()
context = gs.extra_jit_context or _GlobalExtraJitContext()
gs.extra_jit_context = context._replace(**kw)
class _ThreadLocalExtraJitContext(NamedTuple):
"""A namedtuple containing states to add to the cache key.
Just in time compilation (for jit, pmap, etc) behavior is configurable through
global and thread-local options, used in the cache key.
The initialization, which uses both config.py and core.py is done using
`_update_thread_local_jit_state` in core.py to prevent circular imports.
"""
dynamic_trace_state: Any | None = None
axis_env_state: Hashable = ()
mesh_context_manager: Hashable = ()
# Values set by _StateContextManager context managers.
# CAUTION: these must be initialized to `None`! The state context manager
# restores these to None on exit. If the object default is not `None`, the
# context manager is not a no-op, which leads to problems with stale state
# (e.g. spurious cache misses in tests).
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool | None = None
random_seed_offset: int | None = None
threefry_partitionable: bool | None = None
threefry_gpu_kernel_lowering: bool | None = None
softmax_custom_jvp: bool | None = None
xla_profile_version: int | None = None
pgle_profiling_runs: int | None = None
enable_pgle: bool | None = None
class _ThreadLocalStateCache(threading.local):
""""A thread local cache for _ThreadLocalExtraJitContext
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus
incurring dispatch overhead for comparing this python object during jit calls.
We want to duduplicate the objects that have the same hash/equality to also
have the same object ID, since the equality check is much faster if the object
IDs match.
"""
def __init__(self):
self.canonicalize = functools.lru_cache(128)(lambda x: x)
_thread_local_state_cache = _ThreadLocalStateCache()
def update_thread_local_jit_state(**kw):
tls = jax_jit.thread_local_state()
# After xla_client._version >= 70, the thread_local object will necessarily
# be initialized when accessed. The following line can be removed when the
# minimum jaxlib version is past version 70
context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
tmp = context._replace(**kw)
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
# TODO(b/214340779): remove flag when XLA:CPU is improved.
jax2tf_associative_scan_reductions = bool_state(
name='jax2tf_associative_scan_reductions',
default=False,
help=(
'JAX has two separate lowering rules for the cumulative reduction '
'primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses '
'a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. '
'The latter has a slow implementation on CPUs and GPUs. '
'By default, jax2tf uses the TPU lowering. Set this flag to True to '
'use the associative scan lowering usage, and only if it makes a difference '
'for your application. '
'See the jax2tf README.md for more details.'
)
)
jax2tf_default_native_serialization = bool_state(
name='jax2tf_default_native_serialization',
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True),
help=(
'Sets the default value of the native_serialization parameter to '
'jax2tf.convert. Prefer using the parameter instead of the flag, '
'the flag may be removed in the future.'
)
)
jax_serialization_version = int_state(
name='jax_serialization_version',
default=int_env('JAX_SERIALIZATION_VERSION', 0), # We use 0 to detect default.
help=(
'DEPRECATED: use jax_export_calling_convention_version.'
)
)
jax_export_calling_convention_version = int_state(
name='jax_export_calling_convention_version',
# Note: bump the default calling convention version at least one month after
# we update XlaCallModule to support the new version, so that serialized
# modules are forward compatible with deployed versions of XlaCallModule.
# Version 9 of XlaCallModule is supported since October 27th, 2023.
default=int_env('JAX_EXPORT_CALLING_CONVENTION_VERSION', 9),
help=(
'The calling convention version number to use for exporting. This must be '
'within the range of versions supported by the tf.XlaCallModule '
'used in your deployment environment. '
'See https://jax.readthedocs.io/en/latest/export/shape_poly.html#calling-convention-versions.'
)
)
jax_platforms = optional_string_state(
name='jax_platforms',
default=None,
help=(
'Comma-separated list of platform names specifying which platforms jax '
'should initialize. If any of the platforms in this list are not successfully '
'initialized, an exception will be raised and the program will be aborted. '
'The first platform in the list will be the default platform. '
'For example, config.jax_platforms=cpu,tpu means that CPU and TPU backends '
'will be initialized, and the CPU backend will be used unless otherwise '
'specified. If TPU initialization fails, it will raise an exception. '
'By default, jax will try to initialize all available '
'platforms and will default to GPU or TPU if available, and fallback to CPU '
'otherwise.'
))
jax_pjrt_client_create_options = optional_string_state(
name='jax_pjrt_client_create_options',
default=None,
help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings '
'provided to a device platform pjrt client as extra arguments.'))
enable_checks = bool_state(
name='jax_enable_checks',
default=False,
help='Turn on invariant checking for JAX internals. Makes things slower.')
debug_key_reuse = bool_state(
name='jax_debug_key_reuse',
default=False,
help=('Turn on experimental key reuse checking. With this configuration enabled,'
' typed PRNG keys (i.e. keys created with jax.random.key()) will have their'
' usage tracked, and incorrect reuse of a previously-used key will lead to'
' an error. Currently enabling this leads to a small Python overhead on'
' every call to a JIT-compiled function with keys as inputs or outputs.'))
check_tracer_leaks = bool_state(
name='jax_check_tracer_leaks',
default=False,
help=('Turn on checking for leaked tracers as soon as a trace completes. '
'Enabling leak checking may have performance impacts: some caching '
'is disabled, and other overheads may be added. Additionally, be aware '
'that some Python debuggers can cause false positives, so it is recommended '
'to disable any debuggers while leak checking is enabled.'))
checking_leaks = functools.partial(check_tracer_leaks, True)
debug_nans = bool_state(
name='jax_debug_nans',
default=False,
help=('Add nan checks to every operation. When a nan is detected on the '
'output of a jit-compiled computation, call into the un-compiled '
'version in an attempt to more precisely identify the operation '
'which produced the nan.'))
debug_infs = bool_state(
name='jax_debug_infs',
default=False,
help=('Add inf checks to every operation. When an inf is detected on the '
'output of a jit-compiled computation, call into the un-compiled '
'version in an attempt to more precisely identify the operation '
'which produced the inf.'))
log_compiles = bool_state(
name='jax_log_compiles',
default=False,
help=('Log a message each time `jit` or `pmap` compiles an XLA '
'computation. Logging is performed with `logging`. When this '
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))
explain_cache_misses = bool_state(
name='jax_explain_cache_misses',
default=False,
help=('Each time there is a miss on one of the main caches (e.g. the '
'tracing cache), log an explanation.. Logging is performed with '
'`logging`. When this option is set, the log level is WARNING; '
'otherwise the level is DEBUG.'))
log_checkpoint_residuals = bool_state(
name='jax_log_checkpoint_residuals',
default=False,
help=('Log a message every time jax.checkpoint (aka jax.remat) is '
'partially evaluated (e.g. for autodiff), printing what residuals '
'are saved.'))
pmap_shmap_merge = bool_state(
name='jax_pmap_shmap_merge',
default=False,
upgrade=True,
help='If True, pmap and shard_map API will be merged.')
def _update_jax_memories_global(val):
lib.jax_jit.global_state().enable_memories = val
def _update_jax_memories_thread_local(val):
lib.jax_jit.thread_local_state().enable_memories = val
enable_memories = bool_state(
'jax_enable_memories',
default=False,
upgrade=True,
update_global_hook=_update_jax_memories_global,
update_thread_local_hook=_update_jax_memories_thread_local,
help=("If True, will allow fetching memory kinds available on executable "
"and annotate Shardings with it."))
spmd_mode = enum_state(
name='jax_spmd_mode',
enum_values=['allow_all', 'allow_jit'],
default='allow_jit',
help=("Decides whether Math on `jax.Array`'s that are not fully addressable "
"(i.e. spans across multiple processes) is allowed. The options are: "
"* allow_jit: Default, `pjit` and `jax.jit` computations are allowed "
" to execute on non-fully addressable `jax.Array`s\n"
"* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, "
" `jax.jit` and all other operations are allowed to "
" execute on non-fully addressable `jax.Array`s."))
distributed_debug = bool_state(
name='jax_distributed_debug',
default=False,
help=('Enable logging useful for debugging multi-process distributed '
'computations. Logging is performed with `logging` at WARNING '
'level.'))
random_seed_offset = int_state(
name='jax_random_seed_offset',
default=0,
help=('Offset to all random seeds (e.g. argument to jax.random.key()).'),
update_global_hook=lambda val: _update_global_jit_state(
random_seed_offset=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
random_seed_offset=val)
)
legacy_prng_key = enum_state(
name='jax_legacy_prng_key',
enum_values=['allow', 'warn', 'error'],
default='allow',
help=('Specify the behavior when raw PRNG keys are passed to '
'jax.random APIs.')
)
enable_custom_prng = bool_state(
name='jax_enable_custom_prng',
default=False,
upgrade=True,
help=('Enables an internal upgrade that allows one to define custom '
'pseudo-random number generator implementations.'))
default_prng_impl = enum_state(
name='jax_default_prng_impl',
enum_values=['threefry2x32', 'rbg', 'unsafe_rbg'],
default='threefry2x32',
help=('Select the default PRNG implementation, used when one is not '
'explicitly provided at seeding time.'))
threefry_partitionable = bool_state(
name='jax_threefry_partitionable',
default=False,
upgrade=True,
help=('Enables internal threefry PRNG implementation changes that '
'render it automatically partitionable in some cases. Without this '
'flag, using the standard jax.random pseudo-random number generation '
'may result in extraneous communication and/or redundant distributed '
'computation. With this flag, the communication overheads disappear '
'in some cases.'),
update_global_hook=lambda val: _update_global_jit_state(
threefry_partitionable=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
threefry_partitionable=val))
threefry_gpu_kernel_lowering = bool_state(
name='jax_threefry_gpu_kernel_lowering',
default=False,
help=('On GPU, lower threefry PRNG operations to a kernel implementation. '
'This makes compile times faster at a potential runtime memory '
'cost.'),
update_global_hook=lambda val: _update_global_jit_state(
threefry_gpu_kernel_lowering=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
threefry_gpu_kernel_lowering=val))
softmax_custom_jvp = bool_state(
name='jax_softmax_custom_jvp',
default=False,
upgrade=True,
help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should '
'improve memory usage and stability. Set True to use new '
'behavior. See https://github.com/google/jax/pull/15677'),
update_global_hook=lambda val: _update_global_jit_state(
softmax_custom_jvp=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
softmax_custom_jvp=val))
enable_custom_vjp_by_custom_transpose = bool_state(
name='jax_enable_custom_vjp_by_custom_transpose',
default=False,
upgrade=True,
help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))
raise_persistent_cache_errors = bool_state(
name='jax_raise_persistent_cache_errors',
default=False,
help=('If true, exceptions raised when reading or writing to the '
'persistent compilation cache will be allowed through, halting '
'program execution if not manually caught. If false, exceptions are '
'caught and raised as warnings, allowing program execution to '
'continue. Defaults to false so cache bugs or intermittent issues '
'are non-fatal.'))
persistent_cache_min_compile_time_secs = float_state(
name='jax_persistent_cache_min_compile_time_secs',
default=1.,
help=('The minimum compile time of a computation to be written to the '
'persistent compilation cache. This threshold can be raised to '
'decrease the number of entries written to the cache.'))
persistent_cache_min_entry_size_bytes = int_state(
name='jax_persistent_cache_min_entry_size_bytes',
default=0,
help=('The minimum size (in bytes) of an entry that will be cached in the '
'persistent compilation cache: '
'* -1: disable the size restriction and prevent overrides. '
'* Leave at default (0) to allow for overrides. The override will '
' typically ensure that the minimum size is optimal for the '
' filesystem being used for the cache. '
'* > 0: the actual minimum size desired; no overrides.'))
compilation_cache_include_metadata_in_key = bool_state(
name='jax_compilation_cache_include_metadata_in_key',
default=False,
help=(
'Include metadata, such as file names and line numbers, in the'
' compilation cache key. If false, the cache will still get hits even'
' if functions or files are moved, etc. However, it means that'
' executables loaded from the cache may have stale metadata, which'
' may show up in, e.g., profiles.'
),
)
hlo_source_file_canonicalization_regex = optional_string_state(
name='jax_hlo_source_file_canonicalization_regex',
default=None,
help=('Used to canonicalize the source_path metadata of HLO instructions '
'by removing the given regex. If set, re.sub() is called on each '
'source_file with the given regex, and all matches are removed. '
'This can be used to avoid spurious cache misses when using the '
'persistent compilation cache, which includes HLO metadata in the '
'cache key.'))
include_full_tracebacks_in_locations = bool_state(
name='jax_include_full_tracebacks_in_locations',
default=True,
help=(
'Include Python tracebacks in MLIR locations in IR emitted by JAX.'
),
)
traceback_in_locations_limit = int_state(
name='jax_traceback_in_locations_limit',
default=10,
help=(
'Limit the number of frames at the Python traceback frames included in '
'MLIR locations. If set to the negative value, traceback will not be '
'limited.'
),
)
share_autotune_config_between_hosts = bool_state(
name='jax_share_autotune_config_between_hosts',
default=False,
help=(
'If set to True, the coordinator process will share autotune configs '
'other participants. This will increase overall compilation time, but '
'will lead to equal compiled modules in each process. '
'If both jax_share_binary_between_hosts and '
'jax_share_autotune_config_between_hosts are set, compiled HLO will be '
"shared when it's possible and autotune config sharing will be used "
'as a fallback.'
),
)
share_binary_between_hosts = bool_state(
name='jax_share_binary_between_hosts',
default=False,
help=(
'If set to True, the compiled module will be shared between hosts '
'directly.'
),
)
share_binary_between_hosts_timeout_ms = int_state(
name='jax_share_binary_between_hosts_timeout_ms',
default=20 * 60 * 1000,
help='Timeout for the compiled module share.',
)
enable_pgle = bool_state(
name='jax_enable_pgle',
default=False,
help=(
'If set to True and the property jax_pgle_profiling_runs is set to '
'greater than 0, the modules will be recompiled after running specified '
'number times with collected data provided to the profile guided latency '
'estimator.'
),
update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
enable_pgle=val),
)
pgle_profiling_runs = int_state(
name='jax_pgle_profiling_runs',
default=3,
help=(
'Amount of times module should be profiled before recompilation when '
'PGLE is used.'
),
update_global_hook=lambda val: _update_global_jit_state(
pgle_profiling_runs=val
),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
pgle_profiling_runs=val
),
)
pgle_aggregation_percentile = int_state(
name='jax_pgle_aggregation_percentile',
default=90,
help='Percentile used to aggregate performance data between devices when '
'PGLE is used.',
)
enable_compilation_cache = bool_state(
name='jax_enable_compilation_cache',
default=True,
help=('If set to False, the compilation cache will be disabled regardless '
'of whether set_cache_dir() was called. If set to True, the '
'path could be set to a default value or via a call to '
'set_cache_dir().'),
)
compilation_cache_dir = optional_string_state(
name='jax_compilation_cache_dir',
default=None,
help=('Path for the cache. '
'Precedence: '
'1. A call to compilation_cache.set_cache_dir(). '
'2. The value of this flag set in the command line or by default.'),
)
compilation_cache_max_size = int_state(
name='jax_compilation_cache_max_size',
default=-1,
help=('The maximum size (in bytes) allowed for the persistent compilation '
'cache. When set, the least recently accessed cache entry(s) '
'will be deleted once the total cache directory size '
'exceeds the specified limit. '
'Caching will be disabled if this value is set to 0. A '
'special value of -1 indicates no limit, allowing the cache '
'size to grow indefinitely.'),
)
default_dtype_bits = enum_state(
name='jax_default_dtype_bits',
enum_values=['32', '64'],
default='64',
help=('Specify bit width of default dtypes, either 32-bit or 64-bit. '
'This is a temporary flag that will be used during the process '
'of deprecating the ``jax_enable_x64`` flag.'))
numpy_dtype_promotion = enum_state(
name='jax_numpy_dtype_promotion',
enum_values=['standard', 'strict'],
default='standard',
help=('Specify the rules used for implicit type promotion in operations '
'between arrays. Options are "standard" or "strict"; in strict-mode, '
'binary operations between arrays of differing strongly-specified '
'dtypes will result in an error.'),
update_global_hook=lambda val: \
_update_global_jit_state(numpy_dtype_promotion=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(numpy_dtype_promotion=val))
def _update_x64_global(val):
lib.jax_jit.global_state().enable_x64 = val
def _update_x64_thread_local(val):
lib.jax_jit.thread_local_state().enable_x64 = val
enable_x64 = bool_state(
name='jax_enable_x64',
default=False,
help='Enable 64-bit types to be used',
update_global_hook=_update_x64_global,
update_thread_local_hook=_update_x64_thread_local)
# TODO(phawkins): remove after fixing users of FLAGS.x64_enabled.
config._contextmanager_flags.remove('jax_enable_x64')
setattr(Config, "x64_enabled", property(lambda _: enable_x64.value))
def _update_default_device_global(val):
lib.jax_jit.global_state().default_device = val
def _update_default_device_thread_local(val):
lib.jax_jit.thread_local_state().default_device = val
def _validate_default_device(val):
if val is not None and not isinstance(val, xla_client.Device):
# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
# all JAX backends use a single C++ device interface.
if 'Device' in str(type(val)):
logger.info(
'Allowing non-`xla_client.Device` default device: %s, type: %s',
repr(val), type(val))
return
raise ValueError('jax.default_device must be passed a Device object (e.g. '
f"`jax.devices('cpu')[0]`), got: {val!r}")
# TODO(skye): default_device only accepts devices for now. Make it work with
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
default_device = string_or_object_state(
name='jax_default_device',
default=None,
help=(
'Configure the default device for JAX operations. Set to a Device '
'object (e.g. ``jax.devices("cpu")[0]``) to use that Device as the '
'default device for JAX operations and jit\'d function calls (there is '
'no effect on multi-device computations, e.g. pmapped function calls). '
'Set to None to use the system default device. See '
':ref:`faq-data-placement` for more information on device placement.'),
update_global_hook=_update_default_device_global,
update_thread_local_hook=_update_default_device_thread_local,
validator=_validate_default_device)
def _update_disable_jit_global(val):
lib.jax_jit.global_state().disable_jit = val
def _update_disable_jit_thread_local(val):
lib.jax_jit.thread_local_state().disable_jit = val
disable_jit = bool_state(
name='jax_disable_jit',
default=False,
help=('Disable JIT compilation and just call original Python.'),
update_global_hook=_update_disable_jit_global,
update_thread_local_hook=_update_disable_jit_thread_local)
numpy_rank_promotion = enum_state(
name='jax_numpy_rank_promotion',
enum_values=['allow', 'warn', 'raise'],
default='allow',
help=('Control NumPy-style automatic rank promotion broadcasting '
'("allow", "warn", or "raise").'),
update_global_hook=lambda val: \
_update_global_jit_state(numpy_rank_promotion=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(numpy_rank_promotion=val))
default_matmul_precision = optional_enum_state(
name='jax_default_matmul_precision',
enum_values=['bfloat16', 'tensorfloat32', 'float32'],
default=None,
help=('Control the default matmul and conv precision for 32bit inputs.\n\n'
'Some platforms, like TPU, offer configurable precision levels for '
'matrix multiplication and convolution computations, trading off '
'accuracy for speed. The precision can be controlled for each '
'operation; for example, see the :func:`jax.lax.conv_general_dilated` '
'and :func:`jax.lax.dot` docstrings. But it can be useful to control '
'the default behavior obtained when an operation is not given a '
'specific precision.\n\n'
'This option can be used to control the default precision '
'level for computations involved in matrix multiplication and '
'convolution on 32bit inputs. The levels roughly describe the '
"precision at which scalar products are computed. The 'bfloat16' "
"option is the fastest and least precise; 'float32' is similar to "
"full float32 precision; 'tensorfloat32' is intermediate.\n\n"),
update_global_hook=lambda val: \
_update_global_jit_state(default_matmul_precision=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(default_matmul_precision=val))
traceback_filtering = enum_state(
name = 'jax_traceback_filtering',
enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames",
"auto"],
default="auto",
help="Controls how JAX filters internal frames out of tracebacks.\n\n"
"Valid values are:\n"
" * \"off\": disables traceback filtering.\n"
" * \"auto\": use \"tracebackhide\" if running under a sufficiently"
" new IPython, or \"remove_frames\" otherwise.\n"
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to"
" hidden stack frames, which some traceback printers support.\n"
" * \"remove_frames\": removes hidden frames from tracebacks, and adds"
" the unfiltered traceback as a __cause__ of the exception.\n"
" * \"quiet_remove_frames\": removes hidden frames from tracebacks, and adds"
" a brief message (to the __cause__ of the exception) describing that this has"
" happened.\n")
# This flag is for internal use.
# TODO(tianjianlu): Removes once we always enable cusparse lowering.
# TODO(b/262050896): Set to true after bug is fixed
bcoo_cusparse_lowering = bool_state(
name='jax_bcoo_cusparse_lowering',
default=False,
help=('Enables lowering BCOO ops to cuSparse.'))
# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging
# if the intended backend can handle lowering the result
dynamic_shapes = bool_state(
name='jax_dynamic_shapes',
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
help=('Enables experimental features for staging out computations with '
'dynamic shapes.'),
update_global_hook=lambda val: \
_update_global_jit_state(dynamic_shapes=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(dynamic_shapes=val))
# This flag is temporary during rollout of the remat barrier.
# TODO(parkers): Remove if there are no complaints.
remat_opt_barrier = bool_state(
name='jax_remat_opt_barrier',
default=True,
help=('Enables using optimization-barrier op for lowering remat.'))
# TODO(sharadmv,mattjj): set default to True, then remove
eager_pmap = bool_state(
name='jax_eager_pmap',
default=True,
upgrade=True,
help='Enable eager-mode pmap when jax_disable_jit is activated.')
# TODO(mattjj): remove once we land mutable array plumbing, or face great shame
custom_vjp_disable_shape_check = bool_state(
name='jax_custom_vjp_disable_shape_check',
default=False,
upgrade=True,
help='Disable the check from #19009 to enable some custom_vjp hacks.')
xla_runtime_errors = bool_state(
name='jax_experimental_unsafe_xla_runtime_errors',
default=False,
help=('Enable XLA runtime errors for jax.experimental.checkify.checks '
'on CPU and GPU. These errors are async, might get lost and are not '
'very readable. But, they crash the computation and enable you '
'to write jittable checks without needing to checkify. Does not '
'work under pmap/pjit.')
)
jax_xla_profile_version = int_state(
name='jax_xla_profile_version',
default=0,
help=(
'Optional profile version for XLA compilation. This is meaningful '
'only when XLA is configured to support the remote compilation '
'profile feature.'),
update_global_hook=lambda val: _update_global_jit_state(
xla_profile_version=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
xla_profile_version=val),
)
@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""
state = transfer_guard_lib.thread_local_state()
prev = state.explicit_device_put
state.explicit_device_put = True
try:
yield
finally:
state.explicit_device_put = prev
@contextlib.contextmanager
def explicit_device_get_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_get() call."""
state = transfer_guard_lib.thread_local_state()
prev = state.explicit_device_get
state.explicit_device_get = True
try:
yield
finally:
state.explicit_device_get = prev
def _update_transfer_guard(state, key, val):
"""Applies the transfer guard level within transfer_guard_lib."""
if val is None:
setattr(state, key, None)
elif val == 'allow':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.ALLOW)
elif val == 'log':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG)
elif val == 'disallow':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW)
elif val == 'log_explicit':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG_EXPLICIT)
elif val == 'disallow_explicit':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
else:
assert False, f'Invalid transfer guard level {val}'
transfer_guard_host_to_device = optional_enum_state(
name='jax_transfer_guard_host_to_device',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for host-to-device transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'host_to_device', val),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'host_to_device', val))
transfer_guard_device_to_device = optional_enum_state(
name='jax_transfer_guard_device_to_device',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for device-to-device transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'device_to_device', val),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'device_to_device', val))
transfer_guard_device_to_host = optional_enum_state(
name='jax_transfer_guard_device_to_host',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for device-to-host transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'device_to_host', val),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'device_to_host', val))
def _update_all_transfer_guard_global(val):
for name in ('jax_transfer_guard_host_to_device',
'jax_transfer_guard_device_to_device',
'jax_transfer_guard_device_to_host'):
config.update(name, val)
_transfer_guard = optional_enum_state(
name='jax_transfer_guard',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard_*.
default=None,
help=('Select the transfer guard level for all transfers. This option is '
'set-only; the transfer guard level for a specific direction should '
'be read using the per-transfer direction option. '
'Default is "allow".'),
update_global_hook=_update_all_transfer_guard_global)
@contextlib.contextmanager
def transfer_guard(new_val: str) -> Iterator[None]:
"""A contextmanager to control the transfer guard level for all transfers.
For more information, see
https://jax.readthedocs.io/en/latest/transfer_guard.html
Args:
new_val: The new thread-local transfer guard level for all transfers.
Yields:
None.
"""
with contextlib.ExitStack() as stack:
stack.enter_context(transfer_guard_host_to_device(new_val))
stack.enter_context(transfer_guard_device_to_device(new_val))
stack.enter_context(transfer_guard_device_to_host(new_val))
stack.enter_context(_transfer_guard(new_val))
yield
def _update_debug_log_modules(module_names_str: str | None):
logging_config.disable_all_debug_logging()
if not module_names_str:
return
module_names = module_names_str.split(',')
for module_name in module_names:
logging_config.enable_debug_logging(module_name)
# Don't define a context manager since this isn't threadsafe.
string_state(
name='jax_debug_log_modules',
default='',
help=('Comma-separated list of module names (e.g. "jax" or '
'"jax._src.xla_bridge,jax._src.dispatch") to enable debug logging '
'for.'),
update_global_hook=_update_debug_log_modules)
pmap_no_rank_reduction = bool_state(
name='jax_pmap_no_rank_reduction',
default=False,
help=(
"If True, pmap shards have a the same rank as their enclosing array."
)
)