|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from collections.abc import Hashable, Iterator
|
|
|
|
|
from collections.abc import Hashable, Iterator, Sequence
|
|
|
|
|
import contextlib
|
|
|
|
|
import functools
|
|
|
|
|
import itertools
|
|
|
|
@ -22,7 +22,9 @@ import logging
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
import threading
|
|
|
|
|
from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast
|
|
|
|
|
from typing import (
|
|
|
|
|
Any, Callable, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from jax._src import lib
|
|
|
|
|
from jax._src.lib import jax_jit
|
|
|
|
@ -60,24 +62,24 @@ def int_env(varname: str, default: int) -> int:
|
|
|
|
|
return int(os.getenv(varname, str(default)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UPGRADE_BOOL_HELP = (
|
|
|
|
|
" This will be enabled by default in future versions of JAX, at which "
|
|
|
|
|
"point all uses of the flag will be considered deprecated (following "
|
|
|
|
|
"the `API compatibility policy "
|
|
|
|
|
"<https://jax.readthedocs.io/en/latest/api_compatibility.html>`_).")
|
|
|
|
|
class ValueHolder(Protocol[_T]):
|
|
|
|
|
"""A holder for a configuration value.
|
|
|
|
|
|
|
|
|
|
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
|
|
|
|
|
There are two kinds of value holders: ``Flag``, which is assigned exactly
|
|
|
|
|
once and never modified after; and ``State``, which can be changed locally
|
|
|
|
|
within a thread via a context manager.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
value: _T
|
|
|
|
|
|
|
|
|
|
def _set(self, value: _T) -> None: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
_HAS_DYNAMIC_ATTRIBUTES = True
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
# There are two kinds of value holders: FlagHolders, which hold global
|
|
|
|
|
# flags, and StateContextManagers, which hold state that can be changed
|
|
|
|
|
# locally within a thread. A value holder needs a `.value` property and a
|
|
|
|
|
# `._set()` method.
|
|
|
|
|
self._value_holders = {}
|
|
|
|
|
self._value_holders: dict[str, ValueHolder] = {}
|
|
|
|
|
self.meta = {}
|
|
|
|
|
self.use_absl = False
|
|
|
|
|
self._contextmanager_flags = set()
|
|
|
|
@ -113,7 +115,7 @@ class Config:
|
|
|
|
|
def config_with_absl(self):
|
|
|
|
|
"""Registers absl flags for the JAX configs.
|
|
|
|
|
|
|
|
|
|
E.g., for each JAX config defined using define_bool_state(), this method
|
|
|
|
|
E.g., for each JAX config defined using bool_state(), this method
|
|
|
|
|
registers an absl boolean flag, with the same name.
|
|
|
|
|
|
|
|
|
|
This is the recommended method to call if you use `app.run(main)` and you
|
|
|
|
@ -237,7 +239,8 @@ unset = _Unset()
|
|
|
|
|
|
|
|
|
|
_thread_local_state = threading.local()
|
|
|
|
|
|
|
|
|
|
class _StateContextManager(Generic[_T]):
|
|
|
|
|
class State(Generic[_T]):
|
|
|
|
|
|
|
|
|
|
__slots__ = (
|
|
|
|
|
'_name', '_value', '_update_thread_local_hook', '_update_global_hook',
|
|
|
|
|
'_validator', '_default_context_manager_value', '__doc__', '__name__',
|
|
|
|
@ -318,7 +321,16 @@ class _StateContextManager(Generic[_T]):
|
|
|
|
|
update_global_hook(self._value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def define_bool_state(
|
|
|
|
|
UPGRADE_BOOL_HELP = (
|
|
|
|
|
" This will be enabled by default in future versions of JAX, at which "
|
|
|
|
|
"point all uses of the flag will be considered deprecated (following "
|
|
|
|
|
"the `API compatibility policy "
|
|
|
|
|
"<https://jax.readthedocs.io/en/latest/api_compatibility.html>`_).")
|
|
|
|
|
|
|
|
|
|
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bool_state(
|
|
|
|
|
name: str,
|
|
|
|
|
default: bool,
|
|
|
|
|
help: str,
|
|
|
|
@ -327,7 +339,7 @@ def define_bool_state(
|
|
|
|
|
update_thread_local_hook: Callable[[bool | None], None] | None = None,
|
|
|
|
|
upgrade: bool = False,
|
|
|
|
|
extra_description: str = '',
|
|
|
|
|
) -> _StateContextManager[bool]:
|
|
|
|
|
) -> State[bool]:
|
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
|
|
This function is a convenience wrapper. It defines a flag, environment
|
|
|
|
@ -360,7 +372,7 @@ def define_bool_state(
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
|
|
enable_foo = config.define_bool_state(
|
|
|
|
|
ENABLE_FOO = config.bool_state(
|
|
|
|
|
name='jax_enable_foo',
|
|
|
|
|
default=False,
|
|
|
|
|
help='Enable foo.')
|
|
|
|
@ -388,7 +400,7 @@ def define_bool_state(
|
|
|
|
|
extra_description += UPGRADE_BOOL_EXTRA_DESC
|
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
|
|
|
|
|
s = _StateContextManager[bool](
|
|
|
|
|
s = State[bool](
|
|
|
|
|
name, default, help, update_global_hook=update_global_hook,
|
|
|
|
|
update_thread_local_hook=update_thread_local_hook,
|
|
|
|
|
extra_description=extra_description, default_context_manager_value=True)
|
|
|
|
@ -397,18 +409,18 @@ def define_bool_state(
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def define_enum_state(
|
|
|
|
|
def enum_state(
|
|
|
|
|
name: str,
|
|
|
|
|
enum_values: list[str],
|
|
|
|
|
enum_values: Sequence[str],
|
|
|
|
|
default: str,
|
|
|
|
|
help: str,
|
|
|
|
|
*,
|
|
|
|
|
update_global_hook: Callable[[str], None] | None = None,
|
|
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
|
|
|
|
) -> _StateContextManager[str]:
|
|
|
|
|
) -> State[str]:
|
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
|
|
See docstring for ``define_bool_state``.
|
|
|
|
|
See docstring for ``bool_state``.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
@ -437,7 +449,7 @@ def define_enum_state(
|
|
|
|
|
raise ValueError(f"new enum value must be in {enum_values}, "
|
|
|
|
|
f"got {new_val} of type {type(new_val)}.")
|
|
|
|
|
|
|
|
|
|
s = _StateContextManager[str](
|
|
|
|
|
s = State[str](
|
|
|
|
|
name,
|
|
|
|
|
default,
|
|
|
|
|
help,
|
|
|
|
@ -454,18 +466,18 @@ def define_enum_state(
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def define_optional_enum_state(
|
|
|
|
|
def optional_enum_state(
|
|
|
|
|
name: str,
|
|
|
|
|
enum_values: list[str],
|
|
|
|
|
enum_values: Sequence[str],
|
|
|
|
|
default: str | None,
|
|
|
|
|
help: str,
|
|
|
|
|
*,
|
|
|
|
|
update_global_hook: Callable[[str | None], None] | None = None,
|
|
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
|
|
|
|
) -> _StateContextManager[str | None]:
|
|
|
|
|
) -> State[str | None]:
|
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
|
|
See docstring for ``define_bool_state``.
|
|
|
|
|
See docstring for ``bool_state``.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
@ -495,7 +507,7 @@ def define_optional_enum_state(
|
|
|
|
|
raise ValueError(f"new enum value must be None or in {enum_values}, "
|
|
|
|
|
f"got {new_val} of type {type(new_val)}.")
|
|
|
|
|
|
|
|
|
|
s = _StateContextManager['str | None'](
|
|
|
|
|
s = State['str | None'](
|
|
|
|
|
name, default, help, update_global_hook, update_thread_local_hook,
|
|
|
|
|
validate
|
|
|
|
|
)
|
|
|
|
@ -508,17 +520,17 @@ def define_optional_enum_state(
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def define_int_state(
|
|
|
|
|
def int_state(
|
|
|
|
|
name: str,
|
|
|
|
|
default: int,
|
|
|
|
|
help: str,
|
|
|
|
|
*,
|
|
|
|
|
update_global_hook: Callable[[int], None] | None = None,
|
|
|
|
|
update_thread_local_hook: Callable[[int | None], None] | None = None,
|
|
|
|
|
) -> _StateContextManager[int]:
|
|
|
|
|
) -> State[int]:
|
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
|
|
See docstring for ``define_bool_state``.
|
|
|
|
|
See docstring for ``bool_state``.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
@ -548,24 +560,24 @@ def define_int_state(
|
|
|
|
|
raise ValueError(f'new int config value must be None or of type int, '
|
|
|
|
|
f'got {new_val} of type {type(new_val)}')
|
|
|
|
|
|
|
|
|
|
s = _StateContextManager[int](name, default, help, update_global_hook,
|
|
|
|
|
update_thread_local_hook, validate)
|
|
|
|
|
s = State[int](name, default, help, update_global_hook,
|
|
|
|
|
update_thread_local_hook, validate)
|
|
|
|
|
config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help})
|
|
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def define_float_state(
|
|
|
|
|
def float_state(
|
|
|
|
|
name: str,
|
|
|
|
|
default: float,
|
|
|
|
|
help: str,
|
|
|
|
|
*,
|
|
|
|
|
update_global_hook: Callable[[float], None] | None = None,
|
|
|
|
|
update_thread_local_hook: Callable[[float | None], None] | None = None,
|
|
|
|
|
) -> _StateContextManager[float]:
|
|
|
|
|
) -> State[float]:
|
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
|
|
See docstring for ``define_bool_state``.
|
|
|
|
|
See docstring for ``bool_state``.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
@ -596,24 +608,24 @@ def define_float_state(
|
|
|
|
|
f'new float config value must be None or of type float, '
|
|
|
|
|
f'got {new_val} of type {type(new_val)}')
|
|
|
|
|
|
|
|
|
|
s = _StateContextManager[float](name, default, help, update_global_hook,
|
|
|
|
|
update_thread_local_hook, validate)
|
|
|
|
|
s = State[float](name, default, help, update_global_hook,
|
|
|
|
|
update_thread_local_hook, validate)
|
|
|
|
|
config.add_option(name, s, float, meta_args=[], meta_kwargs={"help": help})
|
|
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def define_string_state(
|
|
|
|
|
def string_state(
|
|
|
|
|
name: str,
|
|
|
|
|
default: str,
|
|
|
|
|
help: str,
|
|
|
|
|
*,
|
|
|
|
|
update_global_hook: Callable[[str], None] | None = None,
|
|
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
|
|
|
|
) -> _StateContextManager[str]:
|
|
|
|
|
) -> State[str]:
|
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
|
|
See docstring for ``define_bool_state``.
|
|
|
|
|
See docstring for ``bool_state``.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
@ -640,24 +652,24 @@ def define_string_state(
|
|
|
|
|
raise TypeError('new string config value must be of type str,'
|
|
|
|
|
f' got {new_val} of type {type(new_val)}.')
|
|
|
|
|
|
|
|
|
|
return define_string_or_object_state(
|
|
|
|
|
return string_or_object_state(
|
|
|
|
|
name, default, help,
|
|
|
|
|
update_global_hook=update_global_hook,
|
|
|
|
|
update_thread_local_hook=update_thread_local_hook,
|
|
|
|
|
validator=validator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def define_optional_string_state(
|
|
|
|
|
def optional_string_state(
|
|
|
|
|
name: str,
|
|
|
|
|
default: str | None,
|
|
|
|
|
help: str,
|
|
|
|
|
*,
|
|
|
|
|
update_global_hook: Callable[[str], None] | None = None,
|
|
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
|
|
|
|
) -> _StateContextManager[str | None]:
|
|
|
|
|
) -> State[str | None]:
|
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
|
|
See docstring for ``define_bool_state``.
|
|
|
|
|
See docstring for ``bool_state``.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
@ -684,13 +696,13 @@ def define_optional_string_state(
|
|
|
|
|
raise ValueError('new string config value must be None or of type str,'
|
|
|
|
|
f' got {new_val} of type {type(new_val)}.')
|
|
|
|
|
|
|
|
|
|
return define_string_or_object_state(
|
|
|
|
|
return string_or_object_state(
|
|
|
|
|
name, default, help,
|
|
|
|
|
update_global_hook=update_global_hook,
|
|
|
|
|
update_thread_local_hook=update_thread_local_hook,
|
|
|
|
|
validator=validator)
|
|
|
|
|
|
|
|
|
|
def define_string_or_object_state(
|
|
|
|
|
def string_or_object_state(
|
|
|
|
|
name: str,
|
|
|
|
|
default: Any,
|
|
|
|
|
help: str,
|
|
|
|
@ -698,10 +710,10 @@ def define_string_or_object_state(
|
|
|
|
|
update_global_hook: Callable[[Any], None] | None = None,
|
|
|
|
|
update_thread_local_hook: Callable[[Any], None] | None = None,
|
|
|
|
|
validator: Callable[[Any], None] | None = None,
|
|
|
|
|
) -> _StateContextManager[Any]:
|
|
|
|
|
) -> State[Any]:
|
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
|
|
Similar to ``define_string_state``, except the context manager will accept
|
|
|
|
|
Similar to ``string_state``, except the context manager will accept
|
|
|
|
|
any object, not just a string. Any value passed via command line flag or
|
|
|
|
|
environment variable will be treated as a string.
|
|
|
|
|
|
|
|
|
@ -728,7 +740,7 @@ def define_string_or_object_state(
|
|
|
|
|
default = os.getenv(name.upper(), default)
|
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
|
|
|
|
|
s = _StateContextManager[Any](
|
|
|
|
|
s = State[Any](
|
|
|
|
|
name, default, help, update_global_hook, update_thread_local_hook,
|
|
|
|
|
validator)
|
|
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
@ -736,7 +748,8 @@ def define_string_or_object_state(
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlagHolder(Generic[_T]):
|
|
|
|
|
class Flag(Generic[_T]):
|
|
|
|
|
|
|
|
|
|
__slots__ = ("_name", "value", "_update_hook")
|
|
|
|
|
|
|
|
|
|
_name: str
|
|
|
|
@ -761,42 +774,37 @@ class FlagHolder(Generic[_T]):
|
|
|
|
|
self._update_hook(value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_exists(name):
|
|
|
|
|
if name not in config._value_holders:
|
|
|
|
|
raise AttributeError(f"Unrecognized config option: {name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def DEFINE_bool(name, default, *args, **kwargs) -> FlagHolder[bool]:
|
|
|
|
|
def bool_flag(name, default, *args, **kwargs) -> Flag[bool]:
|
|
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
|
|
|
|
holder = FlagHolder(name, default, update_hook)
|
|
|
|
|
holder = Flag(name, default, update_hook)
|
|
|
|
|
config.add_option(name, holder, bool, args, kwargs)
|
|
|
|
|
return holder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def DEFINE_integer(name, default, *args, **kwargs) -> FlagHolder[int]:
|
|
|
|
|
def int_flag(name, default, *args, **kwargs) -> Flag[int]:
|
|
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
|
|
|
|
holder = FlagHolder(name, default, update_hook)
|
|
|
|
|
holder = Flag(name, default, update_hook)
|
|
|
|
|
config.add_option(name, holder, int, args, kwargs)
|
|
|
|
|
return holder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def DEFINE_float(name, default, *args, **kwargs) -> FlagHolder[float]:
|
|
|
|
|
def float_flag(name, default, *args, **kwargs) -> Flag[float]:
|
|
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
|
|
|
|
holder = FlagHolder(name, default, update_hook)
|
|
|
|
|
holder = Flag(name, default, update_hook)
|
|
|
|
|
config.add_option(name, holder, float, args, kwargs)
|
|
|
|
|
return holder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def DEFINE_string(name, default, *args, **kwargs) -> FlagHolder[str]:
|
|
|
|
|
def string_flag(name, default, *args, **kwargs) -> Flag[str]:
|
|
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
|
|
|
|
holder = FlagHolder(name, default, update_hook)
|
|
|
|
|
holder = Flag(name, default, update_hook)
|
|
|
|
|
config.add_option(name, holder, str, args, kwargs)
|
|
|
|
|
return holder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def DEFINE_enum(name, default, *args, **kwargs) -> FlagHolder[str]:
|
|
|
|
|
def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
|
|
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
|
|
|
|
holder = FlagHolder(name, default, update_hook)
|
|
|
|
|
holder = Flag(name, default, update_hook)
|
|
|
|
|
config.add_option(name, holder, 'enum', args, kwargs)
|
|
|
|
|
return holder
|
|
|
|
|
|
|
|
|
@ -885,7 +893,7 @@ def update_thread_local_jit_state(**kw):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(b/214340779): remove flag when XLA:CPU is improved.
|
|
|
|
|
jax2tf_associative_scan_reductions = define_bool_state(
|
|
|
|
|
jax2tf_associative_scan_reductions = bool_state(
|
|
|
|
|
name='jax2tf_associative_scan_reductions',
|
|
|
|
|
default=False,
|
|
|
|
|
help=(
|
|
|
|
@ -900,7 +908,7 @@ jax2tf_associative_scan_reductions = define_bool_state(
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
jax2tf_default_native_serialization = define_bool_state(
|
|
|
|
|
jax2tf_default_native_serialization = bool_state(
|
|
|
|
|
name='jax2tf_default_native_serialization',
|
|
|
|
|
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True),
|
|
|
|
|
help=(
|
|
|
|
@ -910,7 +918,7 @@ jax2tf_default_native_serialization = define_bool_state(
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
jax_serialization_version = define_int_state(
|
|
|
|
|
jax_serialization_version = int_state(
|
|
|
|
|
name='jax_serialization_version',
|
|
|
|
|
default=int_env('JAX_SERIALIZATION_VERSION', 0), # We use 0 to detect default.
|
|
|
|
|
help=(
|
|
|
|
@ -918,7 +926,7 @@ jax_serialization_version = define_int_state(
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
jax_export_calling_convention_version = define_int_state(
|
|
|
|
|
jax_export_calling_convention_version = int_state(
|
|
|
|
|
name='jax_export_calling_convention_version',
|
|
|
|
|
# Note: bump the default calling convention version at least one month after
|
|
|
|
|
# we update XlaCallModule to support the new version, so that serialized
|
|
|
|
@ -933,7 +941,7 @@ jax_export_calling_convention_version = define_int_state(
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
jax_platforms = define_optional_string_state(
|
|
|
|
|
jax_platforms = optional_string_state(
|
|
|
|
|
name='jax_platforms',
|
|
|
|
|
default=None,
|
|
|
|
|
help=(
|
|
|
|
@ -949,18 +957,18 @@ jax_platforms = define_optional_string_state(
|
|
|
|
|
'otherwise.'
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
jax_pjrt_client_create_options = define_optional_string_state(
|
|
|
|
|
jax_pjrt_client_create_options = optional_string_state(
|
|
|
|
|
name='jax_pjrt_client_create_options',
|
|
|
|
|
default=None,
|
|
|
|
|
help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings '
|
|
|
|
|
'provided to a device platform pjrt client as extra arguments.'))
|
|
|
|
|
|
|
|
|
|
enable_checks = define_bool_state(
|
|
|
|
|
enable_checks = bool_state(
|
|
|
|
|
name='jax_enable_checks',
|
|
|
|
|
default=False,
|
|
|
|
|
help='Turn on invariant checking for JAX internals. Makes things slower.')
|
|
|
|
|
|
|
|
|
|
debug_key_reuse = define_bool_state(
|
|
|
|
|
debug_key_reuse = bool_state(
|
|
|
|
|
name='jax_debug_key_reuse',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('Turn on experimental key reuse checking. With this configuration enabled,'
|
|
|
|
@ -969,7 +977,7 @@ debug_key_reuse = define_bool_state(
|
|
|
|
|
' an error. Currently enabling this leads to a small Python overhead on'
|
|
|
|
|
' every call to a JIT-compiled function with keys as inputs or outputs.'))
|
|
|
|
|
|
|
|
|
|
check_tracer_leaks = define_bool_state(
|
|
|
|
|
check_tracer_leaks = bool_state(
|
|
|
|
|
name='jax_check_tracer_leaks',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('Turn on checking for leaked tracers as soon as a trace completes. '
|
|
|
|
@ -979,7 +987,7 @@ check_tracer_leaks = define_bool_state(
|
|
|
|
|
'to disable any debuggers while leak checking is enabled.'))
|
|
|
|
|
checking_leaks = functools.partial(check_tracer_leaks, True)
|
|
|
|
|
|
|
|
|
|
debug_nans = define_bool_state(
|
|
|
|
|
debug_nans = bool_state(
|
|
|
|
|
name='jax_debug_nans',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('Add nan checks to every operation. When a nan is detected on the '
|
|
|
|
@ -987,7 +995,7 @@ debug_nans = define_bool_state(
|
|
|
|
|
'version in an attempt to more precisely identify the operation '
|
|
|
|
|
'which produced the nan.'))
|
|
|
|
|
|
|
|
|
|
debug_infs = define_bool_state(
|
|
|
|
|
debug_infs = bool_state(
|
|
|
|
|
name='jax_debug_infs',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('Add inf checks to every operation. When an inf is detected on the '
|
|
|
|
@ -995,7 +1003,7 @@ debug_infs = define_bool_state(
|
|
|
|
|
'version in an attempt to more precisely identify the operation '
|
|
|
|
|
'which produced the inf.'))
|
|
|
|
|
|
|
|
|
|
log_compiles = define_bool_state(
|
|
|
|
|
log_compiles = bool_state(
|
|
|
|
|
name='jax_log_compiles',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('Log a message each time `jit` or `pmap` compiles an XLA '
|
|
|
|
@ -1003,7 +1011,7 @@ log_compiles = define_bool_state(
|
|
|
|
|
'option is set, the log level is WARNING; otherwise the level is '
|
|
|
|
|
'DEBUG.'))
|
|
|
|
|
|
|
|
|
|
explain_cache_misses = define_bool_state(
|
|
|
|
|
explain_cache_misses = bool_state(
|
|
|
|
|
name='jax_explain_cache_misses',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('Each time there is a miss on one of the main caches (e.g. the '
|
|
|
|
@ -1011,14 +1019,14 @@ explain_cache_misses = define_bool_state(
|
|
|
|
|
'`logging`. When this option is set, the log level is WARNING; '
|
|
|
|
|
'otherwise the level is DEBUG.'))
|
|
|
|
|
|
|
|
|
|
log_checkpoint_residuals = define_bool_state(
|
|
|
|
|
log_checkpoint_residuals = bool_state(
|
|
|
|
|
name='jax_log_checkpoint_residuals',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('Log a message every time jax.checkpoint (aka jax.remat) is '
|
|
|
|
|
'partially evaluated (e.g. for autodiff), printing what residuals '
|
|
|
|
|
'are saved.'))
|
|
|
|
|
|
|
|
|
|
pmap_shmap_merge = define_bool_state(
|
|
|
|
|
pmap_shmap_merge = bool_state(
|
|
|
|
|
name='jax_pmap_shmap_merge',
|
|
|
|
|
default=False,
|
|
|
|
|
upgrade=True,
|
|
|
|
@ -1030,7 +1038,7 @@ def _update_jax_memories_global(val):
|
|
|
|
|
def _update_jax_memories_thread_local(val):
|
|
|
|
|
lib.jax_jit.thread_local_state().enable_memories = val
|
|
|
|
|
|
|
|
|
|
enable_memories = define_bool_state(
|
|
|
|
|
enable_memories = bool_state(
|
|
|
|
|
'jax_enable_memories',
|
|
|
|
|
default=False,
|
|
|
|
|
upgrade=True,
|
|
|
|
@ -1039,7 +1047,7 @@ enable_memories = define_bool_state(
|
|
|
|
|
help=("If True, will allow fetching memory kinds available on executable "
|
|
|
|
|
"and annotate Shardings with it."))
|
|
|
|
|
|
|
|
|
|
spmd_mode = define_enum_state(
|
|
|
|
|
spmd_mode = enum_state(
|
|
|
|
|
name='jax_spmd_mode',
|
|
|
|
|
enum_values=['allow_all', 'allow_jit'],
|
|
|
|
|
default='allow_jit',
|
|
|
|
@ -1052,14 +1060,14 @@ spmd_mode = define_enum_state(
|
|
|
|
|
" execute on non-fully addressable `jax.Array`s."))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
distributed_debug = define_bool_state(
|
|
|
|
|
distributed_debug = bool_state(
|
|
|
|
|
name='jax_distributed_debug',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('Enable logging useful for debugging multi-process distributed '
|
|
|
|
|
'computations. Logging is performed with `logging` at WARNING '
|
|
|
|
|
'level.'))
|
|
|
|
|
|
|
|
|
|
random_seed_offset = define_int_state(
|
|
|
|
|
random_seed_offset = int_state(
|
|
|
|
|
name='jax_random_seed_offset',
|
|
|
|
|
default=0,
|
|
|
|
|
help=('Offset to all random seeds (e.g. argument to jax.random.key()).'),
|
|
|
|
@ -1069,7 +1077,7 @@ random_seed_offset = define_int_state(
|
|
|
|
|
random_seed_offset=val)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
legacy_prng_key = define_enum_state(
|
|
|
|
|
legacy_prng_key = enum_state(
|
|
|
|
|
name='jax_legacy_prng_key',
|
|
|
|
|
enum_values=['allow', 'warn', 'error'],
|
|
|
|
|
default='allow',
|
|
|
|
@ -1077,21 +1085,21 @@ legacy_prng_key = define_enum_state(
|
|
|
|
|
'jax.random APIs.')
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
enable_custom_prng = define_bool_state(
|
|
|
|
|
enable_custom_prng = bool_state(
|
|
|
|
|
name='jax_enable_custom_prng',
|
|
|
|
|
default=False,
|
|
|
|
|
upgrade=True,
|
|
|
|
|
help=('Enables an internal upgrade that allows one to define custom '
|
|
|
|
|
'pseudo-random number generator implementations.'))
|
|
|
|
|
|
|
|
|
|
default_prng_impl = define_enum_state(
|
|
|
|
|
default_prng_impl = enum_state(
|
|
|
|
|
name='jax_default_prng_impl',
|
|
|
|
|
enum_values=['threefry2x32', 'rbg', 'unsafe_rbg'],
|
|
|
|
|
default='threefry2x32',
|
|
|
|
|
help=('Select the default PRNG implementation, used when one is not '
|
|
|
|
|
'explicitly provided at seeding time.'))
|
|
|
|
|
|
|
|
|
|
threefry_partitionable = define_bool_state(
|
|
|
|
|
threefry_partitionable = bool_state(
|
|
|
|
|
name='jax_threefry_partitionable',
|
|
|
|
|
default=False,
|
|
|
|
|
upgrade=True,
|
|
|
|
@ -1106,7 +1114,7 @@ threefry_partitionable = define_bool_state(
|
|
|
|
|
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
|
|
|
|
threefry_partitionable=val))
|
|
|
|
|
|
|
|
|
|
threefry_gpu_kernel_lowering = define_bool_state(
|
|
|
|
|
threefry_gpu_kernel_lowering = bool_state(
|
|
|
|
|
name='jax_threefry_gpu_kernel_lowering',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('On GPU, lower threefry PRNG operations to a kernel implementation. '
|
|
|
|
@ -1118,7 +1126,7 @@ threefry_gpu_kernel_lowering = define_bool_state(
|
|
|
|
|
threefry_gpu_kernel_lowering=val))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
softmax_custom_jvp = define_bool_state(
|
|
|
|
|
softmax_custom_jvp = bool_state(
|
|
|
|
|
name='jax_softmax_custom_jvp',
|
|
|
|
|
default=False,
|
|
|
|
|
upgrade=True,
|
|
|
|
@ -1131,14 +1139,14 @@ softmax_custom_jvp = define_bool_state(
|
|
|
|
|
softmax_custom_jvp=val))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enable_custom_vjp_by_custom_transpose = define_bool_state(
|
|
|
|
|
enable_custom_vjp_by_custom_transpose = bool_state(
|
|
|
|
|
name='jax_enable_custom_vjp_by_custom_transpose',
|
|
|
|
|
default=False,
|
|
|
|
|
upgrade=True,
|
|
|
|
|
help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
|
|
|
|
|
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))
|
|
|
|
|
|
|
|
|
|
raise_persistent_cache_errors = define_bool_state(
|
|
|
|
|
raise_persistent_cache_errors = bool_state(
|
|
|
|
|
name='jax_raise_persistent_cache_errors',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('If true, exceptions raised when reading or writing to the '
|
|
|
|
@ -1148,14 +1156,14 @@ raise_persistent_cache_errors = define_bool_state(
|
|
|
|
|
'continue. Defaults to false so cache bugs or intermittent issues '
|
|
|
|
|
'are non-fatal.'))
|
|
|
|
|
|
|
|
|
|
persistent_cache_min_compile_time_secs = define_float_state(
|
|
|
|
|
persistent_cache_min_compile_time_secs = float_state(
|
|
|
|
|
name='jax_persistent_cache_min_compile_time_secs',
|
|
|
|
|
default=1.,
|
|
|
|
|
help=('The minimum compile time of a computation to be written to the '
|
|
|
|
|
'persistent compilation cache. This threshold can be raised to '
|
|
|
|
|
'decrease the number of entries written to the cache.'))
|
|
|
|
|
|
|
|
|
|
persistent_cache_min_entry_size_bytes = define_int_state(
|
|
|
|
|
persistent_cache_min_entry_size_bytes = int_state(
|
|
|
|
|
name='jax_persistent_cache_min_entry_size_bytes',
|
|
|
|
|
default=0,
|
|
|
|
|
help=('The minimum size (in bytes) of an entry that will be cached in the '
|
|
|
|
@ -1166,7 +1174,7 @@ persistent_cache_min_entry_size_bytes = define_int_state(
|
|
|
|
|
' filesystem being used for the cache. '
|
|
|
|
|
'* > 0: the actual minimum size desired; no overrides.'))
|
|
|
|
|
|
|
|
|
|
compilation_cache_include_metadata_in_key = define_bool_state(
|
|
|
|
|
compilation_cache_include_metadata_in_key = bool_state(
|
|
|
|
|
name='jax_compilation_cache_include_metadata_in_key',
|
|
|
|
|
default=False,
|
|
|
|
|
help=(
|
|
|
|
@ -1178,7 +1186,7 @@ compilation_cache_include_metadata_in_key = define_bool_state(
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hlo_source_file_canonicalization_regex = define_optional_string_state(
|
|
|
|
|
hlo_source_file_canonicalization_regex = optional_string_state(
|
|
|
|
|
name='jax_hlo_source_file_canonicalization_regex',
|
|
|
|
|
default=None,
|
|
|
|
|
help=('Used to canonicalize the source_path metadata of HLO instructions '
|
|
|
|
@ -1188,7 +1196,7 @@ hlo_source_file_canonicalization_regex = define_optional_string_state(
|
|
|
|
|
'persistent compilation cache, which includes HLO metadata in the '
|
|
|
|
|
'cache key.'))
|
|
|
|
|
|
|
|
|
|
include_full_tracebacks_in_locations = define_bool_state(
|
|
|
|
|
include_full_tracebacks_in_locations = bool_state(
|
|
|
|
|
name='jax_include_full_tracebacks_in_locations',
|
|
|
|
|
default=True,
|
|
|
|
|
help=(
|
|
|
|
@ -1196,7 +1204,7 @@ include_full_tracebacks_in_locations = define_bool_state(
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
traceback_in_locations_limit = define_int_state(
|
|
|
|
|
traceback_in_locations_limit = int_state(
|
|
|
|
|
name='jax_traceback_in_locations_limit',
|
|
|
|
|
default=10,
|
|
|
|
|
help=(
|
|
|
|
@ -1206,7 +1214,7 @@ traceback_in_locations_limit = define_int_state(
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
share_autotune_config_between_hosts = define_bool_state(
|
|
|
|
|
share_autotune_config_between_hosts = bool_state(
|
|
|
|
|
name='jax_share_autotune_config_between_hosts',
|
|
|
|
|
default=False,
|
|
|
|
|
help=(
|
|
|
|
@ -1220,7 +1228,7 @@ share_autotune_config_between_hosts = define_bool_state(
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
share_binary_between_hosts = define_bool_state(
|
|
|
|
|
share_binary_between_hosts = bool_state(
|
|
|
|
|
name='jax_share_binary_between_hosts',
|
|
|
|
|
default=False,
|
|
|
|
|
help=(
|
|
|
|
@ -1229,13 +1237,13 @@ share_binary_between_hosts = define_bool_state(
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
share_binary_between_hosts_timeout_ms = define_int_state(
|
|
|
|
|
share_binary_between_hosts_timeout_ms = int_state(
|
|
|
|
|
name='jax_share_binary_between_hosts_timeout_ms',
|
|
|
|
|
default=20 * 60 * 1000,
|
|
|
|
|
help='Timeout for the compiled module share.',
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
enable_pgle = define_bool_state(
|
|
|
|
|
enable_pgle = bool_state(
|
|
|
|
|
name='jax_enable_pgle',
|
|
|
|
|
default=False,
|
|
|
|
|
help=(
|
|
|
|
@ -1249,7 +1257,7 @@ enable_pgle = define_bool_state(
|
|
|
|
|
enable_pgle=val),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
pgle_profiling_runs = define_int_state(
|
|
|
|
|
pgle_profiling_runs = int_state(
|
|
|
|
|
name='jax_pgle_profiling_runs',
|
|
|
|
|
default=3,
|
|
|
|
|
help=(
|
|
|
|
@ -1264,14 +1272,14 @@ pgle_profiling_runs = define_int_state(
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
pgle_aggregation_percentile = define_int_state(
|
|
|
|
|
pgle_aggregation_percentile = int_state(
|
|
|
|
|
name='jax_pgle_aggregation_percentile',
|
|
|
|
|
default=90,
|
|
|
|
|
help='Percentile used to aggregate performance data between devices when '
|
|
|
|
|
'PGLE is used.',
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
enable_compilation_cache = define_bool_state(
|
|
|
|
|
enable_compilation_cache = bool_state(
|
|
|
|
|
name='jax_enable_compilation_cache',
|
|
|
|
|
default=True,
|
|
|
|
|
help=('If set to False, the compilation cache will be disabled regardless '
|
|
|
|
@ -1280,7 +1288,7 @@ enable_compilation_cache = define_bool_state(
|
|
|
|
|
'set_cache_dir().'),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
compilation_cache_dir = define_optional_string_state(
|
|
|
|
|
compilation_cache_dir = optional_string_state(
|
|
|
|
|
name='jax_compilation_cache_dir',
|
|
|
|
|
default=None,
|
|
|
|
|
help=('Path for the cache. '
|
|
|
|
@ -1289,7 +1297,7 @@ compilation_cache_dir = define_optional_string_state(
|
|
|
|
|
'2. The value of this flag set in the command line or by default.'),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
compilation_cache_max_size = define_int_state(
|
|
|
|
|
compilation_cache_max_size = int_state(
|
|
|
|
|
name='jax_compilation_cache_max_size',
|
|
|
|
|
default=-1,
|
|
|
|
|
help=('The maximum size (in bytes) allowed for the persistent compilation '
|
|
|
|
@ -1301,7 +1309,7 @@ compilation_cache_max_size = define_int_state(
|
|
|
|
|
'size to grow indefinitely.'),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
default_dtype_bits = define_enum_state(
|
|
|
|
|
default_dtype_bits = enum_state(
|
|
|
|
|
name='jax_default_dtype_bits',
|
|
|
|
|
enum_values=['32', '64'],
|
|
|
|
|
default='64',
|
|
|
|
@ -1309,7 +1317,7 @@ default_dtype_bits = define_enum_state(
|
|
|
|
|
'This is a temporary flag that will be used during the process '
|
|
|
|
|
'of deprecating the ``jax_enable_x64`` flag.'))
|
|
|
|
|
|
|
|
|
|
numpy_dtype_promotion = define_enum_state(
|
|
|
|
|
numpy_dtype_promotion = enum_state(
|
|
|
|
|
name='jax_numpy_dtype_promotion',
|
|
|
|
|
enum_values=['standard', 'strict'],
|
|
|
|
|
default='standard',
|
|
|
|
@ -1328,7 +1336,7 @@ def _update_x64_global(val):
|
|
|
|
|
def _update_x64_thread_local(val):
|
|
|
|
|
lib.jax_jit.thread_local_state().enable_x64 = val
|
|
|
|
|
|
|
|
|
|
enable_x64 = define_bool_state(
|
|
|
|
|
enable_x64 = bool_state(
|
|
|
|
|
name='jax_enable_x64',
|
|
|
|
|
default=False,
|
|
|
|
|
help='Enable 64-bit types to be used',
|
|
|
|
@ -1363,7 +1371,7 @@ def _validate_default_device(val):
|
|
|
|
|
|
|
|
|
|
# TODO(skye): default_device only accepts devices for now. Make it work with
|
|
|
|
|
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
|
|
|
|
|
default_device = define_string_or_object_state(
|
|
|
|
|
default_device = string_or_object_state(
|
|
|
|
|
name='jax_default_device',
|
|
|
|
|
default=None,
|
|
|
|
|
help=(
|
|
|
|
@ -1383,7 +1391,7 @@ def _update_disable_jit_global(val):
|
|
|
|
|
def _update_disable_jit_thread_local(val):
|
|
|
|
|
lib.jax_jit.thread_local_state().disable_jit = val
|
|
|
|
|
|
|
|
|
|
disable_jit = define_bool_state(
|
|
|
|
|
disable_jit = bool_state(
|
|
|
|
|
name='jax_disable_jit',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('Disable JIT compilation and just call original Python.'),
|
|
|
|
@ -1391,7 +1399,7 @@ disable_jit = define_bool_state(
|
|
|
|
|
update_thread_local_hook=_update_disable_jit_thread_local)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
numpy_rank_promotion = define_enum_state(
|
|
|
|
|
numpy_rank_promotion = enum_state(
|
|
|
|
|
name='jax_numpy_rank_promotion',
|
|
|
|
|
enum_values=['allow', 'warn', 'raise'],
|
|
|
|
|
default='allow',
|
|
|
|
@ -1402,7 +1410,7 @@ numpy_rank_promotion = define_enum_state(
|
|
|
|
|
update_thread_local_hook=lambda val: \
|
|
|
|
|
update_thread_local_jit_state(numpy_rank_promotion=val))
|
|
|
|
|
|
|
|
|
|
default_matmul_precision = define_optional_enum_state(
|
|
|
|
|
default_matmul_precision = optional_enum_state(
|
|
|
|
|
name='jax_default_matmul_precision',
|
|
|
|
|
enum_values=['bfloat16', 'tensorfloat32', 'float32'],
|
|
|
|
|
default=None,
|
|
|
|
@ -1427,7 +1435,7 @@ default_matmul_precision = define_optional_enum_state(
|
|
|
|
|
update_thread_local_hook=lambda val: \
|
|
|
|
|
update_thread_local_jit_state(default_matmul_precision=val))
|
|
|
|
|
|
|
|
|
|
traceback_filtering = define_enum_state(
|
|
|
|
|
traceback_filtering = enum_state(
|
|
|
|
|
name = 'jax_traceback_filtering',
|
|
|
|
|
enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames",
|
|
|
|
|
"auto"],
|
|
|
|
@ -1448,14 +1456,14 @@ traceback_filtering = define_enum_state(
|
|
|
|
|
# This flag is for internal use.
|
|
|
|
|
# TODO(tianjianlu): Removes once we always enable cusparse lowering.
|
|
|
|
|
# TODO(b/262050896): Set to true after bug is fixed
|
|
|
|
|
bcoo_cusparse_lowering = define_bool_state(
|
|
|
|
|
bcoo_cusparse_lowering = bool_state(
|
|
|
|
|
name='jax_bcoo_cusparse_lowering',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('Enables lowering BCOO ops to cuSparse.'))
|
|
|
|
|
|
|
|
|
|
# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging
|
|
|
|
|
# if the intended backend can handle lowering the result
|
|
|
|
|
dynamic_shapes = define_bool_state(
|
|
|
|
|
dynamic_shapes = bool_state(
|
|
|
|
|
name='jax_dynamic_shapes',
|
|
|
|
|
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
|
|
|
|
|
help=('Enables experimental features for staging out computations with '
|
|
|
|
@ -1467,26 +1475,26 @@ dynamic_shapes = define_bool_state(
|
|
|
|
|
|
|
|
|
|
# This flag is temporary during rollout of the remat barrier.
|
|
|
|
|
# TODO(parkers): Remove if there are no complaints.
|
|
|
|
|
remat_opt_barrier = define_bool_state(
|
|
|
|
|
remat_opt_barrier = bool_state(
|
|
|
|
|
name='jax_remat_opt_barrier',
|
|
|
|
|
default=True,
|
|
|
|
|
help=('Enables using optimization-barrier op for lowering remat.'))
|
|
|
|
|
|
|
|
|
|
# TODO(sharadmv,mattjj): set default to True, then remove
|
|
|
|
|
eager_pmap = define_bool_state(
|
|
|
|
|
eager_pmap = bool_state(
|
|
|
|
|
name='jax_eager_pmap',
|
|
|
|
|
default=True,
|
|
|
|
|
upgrade=True,
|
|
|
|
|
help='Enable eager-mode pmap when jax_disable_jit is activated.')
|
|
|
|
|
|
|
|
|
|
# TODO(mattjj): remove once we land mutable array plumbing, or face great shame
|
|
|
|
|
custom_vjp_disable_shape_check = define_bool_state(
|
|
|
|
|
custom_vjp_disable_shape_check = bool_state(
|
|
|
|
|
name='jax_custom_vjp_disable_shape_check',
|
|
|
|
|
default=False,
|
|
|
|
|
upgrade=True,
|
|
|
|
|
help='Disable the check from #19009 to enable some custom_vjp hacks.')
|
|
|
|
|
|
|
|
|
|
xla_runtime_errors = define_bool_state(
|
|
|
|
|
xla_runtime_errors = bool_state(
|
|
|
|
|
name='jax_experimental_unsafe_xla_runtime_errors',
|
|
|
|
|
default=False,
|
|
|
|
|
help=('Enable XLA runtime errors for jax.experimental.checkify.checks '
|
|
|
|
@ -1496,7 +1504,7 @@ xla_runtime_errors = define_bool_state(
|
|
|
|
|
'work under pmap/pjit.')
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
jax_xla_profile_version = define_int_state(
|
|
|
|
|
jax_xla_profile_version = int_state(
|
|
|
|
|
name='jax_xla_profile_version',
|
|
|
|
|
default=0,
|
|
|
|
|
help=(
|
|
|
|
@ -1548,7 +1556,7 @@ def _update_transfer_guard(state, key, val):
|
|
|
|
|
else:
|
|
|
|
|
assert False, f'Invalid transfer guard level {val}'
|
|
|
|
|
|
|
|
|
|
transfer_guard_host_to_device = define_optional_enum_state(
|
|
|
|
|
transfer_guard_host_to_device = optional_enum_state(
|
|
|
|
|
name='jax_transfer_guard_host_to_device',
|
|
|
|
|
enum_values=[
|
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
@ -1563,7 +1571,7 @@ transfer_guard_host_to_device = define_optional_enum_state(
|
|
|
|
|
update_thread_local_hook=lambda val: _update_transfer_guard(
|
|
|
|
|
transfer_guard_lib.thread_local_state(), 'host_to_device', val))
|
|
|
|
|
|
|
|
|
|
transfer_guard_device_to_device = define_optional_enum_state(
|
|
|
|
|
transfer_guard_device_to_device = optional_enum_state(
|
|
|
|
|
name='jax_transfer_guard_device_to_device',
|
|
|
|
|
enum_values=[
|
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
@ -1578,7 +1586,7 @@ transfer_guard_device_to_device = define_optional_enum_state(
|
|
|
|
|
update_thread_local_hook=lambda val: _update_transfer_guard(
|
|
|
|
|
transfer_guard_lib.thread_local_state(), 'device_to_device', val))
|
|
|
|
|
|
|
|
|
|
transfer_guard_device_to_host = define_optional_enum_state(
|
|
|
|
|
transfer_guard_device_to_host = optional_enum_state(
|
|
|
|
|
name='jax_transfer_guard_device_to_host',
|
|
|
|
|
enum_values=[
|
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
@ -1599,7 +1607,7 @@ def _update_all_transfer_guard_global(val):
|
|
|
|
|
'jax_transfer_guard_device_to_host'):
|
|
|
|
|
config.update(name, val)
|
|
|
|
|
|
|
|
|
|
_transfer_guard = define_optional_enum_state(
|
|
|
|
|
_transfer_guard = optional_enum_state(
|
|
|
|
|
name='jax_transfer_guard',
|
|
|
|
|
enum_values=[
|
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
@ -1643,7 +1651,7 @@ def _update_debug_log_modules(module_names_str: str | None):
|
|
|
|
|
logging_config.enable_debug_logging(module_name)
|
|
|
|
|
|
|
|
|
|
# Don't define a context manager since this isn't threadsafe.
|
|
|
|
|
define_string_state(
|
|
|
|
|
string_state(
|
|
|
|
|
name='jax_debug_log_modules',
|
|
|
|
|
default='',
|
|
|
|
|
help=('Comma-separated list of module names (e.g. "jax" or '
|
|
|
|
@ -1651,7 +1659,7 @@ define_string_state(
|
|
|
|
|
'for.'),
|
|
|
|
|
update_global_hook=_update_debug_log_modules)
|
|
|
|
|
|
|
|
|
|
pmap_no_rank_reduction = define_bool_state(
|
|
|
|
|
pmap_no_rank_reduction = bool_state(
|
|
|
|
|
name='jax_pmap_no_rank_reduction',
|
|
|
|
|
default=False,
|
|
|
|
|
help=(
|
|
|
|
|