2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2021-04-19 08:52:48 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-07-27 12:15:16 -07:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
from collections.abc import Hashable, Iterator
|
2021-04-19 08:52:48 -07:00
|
|
|
import contextlib
|
|
|
|
import functools
|
2021-08-06 07:05:43 -07:00
|
|
|
import itertools
|
2022-10-13 17:06:22 +02:00
|
|
|
import logging
|
2021-04-19 08:52:48 -07:00
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import threading
|
2023-12-08 12:09:04 +00:00
|
|
|
from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar
|
2023-10-25 12:38:19 +01:00
|
|
|
import warnings
|
2022-06-02 10:33:53 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
from jax._src import lib
|
|
|
|
from jax._src.lib import jax_jit
|
2022-04-11 14:59:04 +00:00
|
|
|
from jax._src.lib import transfer_guard_lib
|
2021-12-21 20:55:03 +00:00
|
|
|
from jax._src.lib import xla_client
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
from jax._src import logging_config
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2022-10-13 17:06:22 +02:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2023-07-27 12:15:16 -07:00
|
|
|
_T = TypeVar('_T')
|
|
|
|
|
2022-10-13 17:06:22 +02:00
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
def bool_env(varname: str, default: bool) -> bool:
|
|
|
|
"""Read an environment variable and interpret it as a boolean.
|
|
|
|
|
|
|
|
True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
|
|
|
|
false values are 'n', 'no', 'f', 'false', 'off', and '0'.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
varname: the name of the variable
|
|
|
|
default: the default boolean value
|
|
|
|
Raises: ValueError if the environment variable is anything else.
|
|
|
|
"""
|
|
|
|
val = os.getenv(varname, str(default))
|
|
|
|
val = val.lower()
|
|
|
|
if val in ('y', 'yes', 't', 'true', 'on', '1'):
|
|
|
|
return True
|
|
|
|
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
|
|
|
|
return False
|
|
|
|
else:
|
2022-05-12 19:13:00 +01:00
|
|
|
raise ValueError(f"invalid truth value {val!r} for environment {varname!r}")
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
def int_env(varname: str, default: int) -> int:
|
|
|
|
"""Read an environment variable and interpret it as an integer."""
|
2021-10-04 17:54:18 -07:00
|
|
|
return int(os.getenv(varname, str(default)))
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
|
2022-03-25 18:37:15 -07:00
|
|
|
UPGRADE_BOOL_HELP = (
|
|
|
|
" This will be enabled by default in future versions of JAX, at which "
|
|
|
|
"point all uses of the flag will be considered deprecated (following "
|
2022-03-28 17:17:33 -07:00
|
|
|
"the `API compatibility policy "
|
|
|
|
"<https://jax.readthedocs.io/en/latest/api_compatibility.html>`_).")
|
2022-03-25 18:37:15 -07:00
|
|
|
|
|
|
|
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
|
|
|
|
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
_CONFIG_DEPRECATIONS = {
|
|
|
|
# Added October 26, 2023:
|
|
|
|
"check_exists",
|
|
|
|
"DEFINE_bool",
|
|
|
|
"DEFINE_integer",
|
|
|
|
"DEFINE_float",
|
|
|
|
"DEFINE_string",
|
|
|
|
"DEFINE_enum",
|
|
|
|
"define_bool_state",
|
|
|
|
"define_enum_state",
|
|
|
|
"define_int_state",
|
|
|
|
"define_float_state",
|
|
|
|
"define_string_state",
|
|
|
|
"define_string_or_object_state",
|
|
|
|
}
|
2023-07-27 12:15:16 -07:00
|
|
|
|
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
class Config:
|
|
|
|
_HAS_DYNAMIC_ATTRIBUTES = True
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.values = {}
|
|
|
|
self.meta = {}
|
|
|
|
self.use_absl = False
|
|
|
|
self._contextmanager_flags = set()
|
|
|
|
self._update_hooks = {}
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
def __getattr__(self, name):
|
|
|
|
fn = None
|
|
|
|
if name in _CONFIG_DEPRECATIONS:
|
|
|
|
fn = globals().get(name, None)
|
|
|
|
if fn is None:
|
|
|
|
raise AttributeError(
|
|
|
|
f"'{type(self).__name__!r} object has no attribute {name!r}")
|
|
|
|
message = (
|
|
|
|
f"jax.config.{name} is deprecated. Please use other libraries "
|
|
|
|
"for configuration instead."
|
|
|
|
)
|
|
|
|
warnings.warn(message, DeprecationWarning, stacklevel=2)
|
|
|
|
return fn
|
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
def update(self, name, val):
|
2022-12-02 11:36:51 -08:00
|
|
|
if name not in self.values:
|
2023-10-25 12:38:19 +01:00
|
|
|
raise AttributeError(f"Unrecognized config option: {name}")
|
2022-12-02 11:36:51 -08:00
|
|
|
self.values[name] = val
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
hook = self._update_hooks.get(name, None)
|
|
|
|
if hook:
|
|
|
|
hook(val)
|
|
|
|
|
|
|
|
def read(self, name):
|
|
|
|
if name in self._contextmanager_flags:
|
|
|
|
raise AttributeError(
|
|
|
|
"For flags with a corresponding contextmanager, read their value "
|
|
|
|
f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.")
|
|
|
|
return self._read(name)
|
|
|
|
|
|
|
|
def _read(self, name):
|
2022-12-02 11:36:51 -08:00
|
|
|
try:
|
2021-04-19 08:52:48 -07:00
|
|
|
return self.values[name]
|
2022-12-02 11:36:51 -08:00
|
|
|
except KeyError:
|
|
|
|
raise AttributeError(f"Unrecognized config option: {name}")
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
def add_option(self, name, default, opt_type, meta_args, meta_kwargs,
|
2023-12-08 12:09:04 +00:00
|
|
|
update_hook: Callable[[Any], None] | None = None):
|
2021-04-19 08:52:48 -07:00
|
|
|
if name in self.values:
|
2022-05-12 19:13:00 +01:00
|
|
|
raise Exception(f"Config option {name} already defined")
|
2021-04-19 08:52:48 -07:00
|
|
|
self.values[name] = default
|
|
|
|
self.meta[name] = (opt_type, meta_args, meta_kwargs)
|
|
|
|
if update_hook:
|
|
|
|
self._update_hooks[name] = update_hook
|
|
|
|
update_hook(default)
|
|
|
|
|
|
|
|
def config_with_absl(self):
|
|
|
|
# Run this before calling `app.run(main)` etc
|
2023-03-09 05:26:58 -08:00
|
|
|
import absl.flags as absl_FLAGS # noqa: F401 # pytype: disable=import-error
|
2023-02-27 10:39:11 -08:00
|
|
|
from absl import app, flags as absl_flags # pytype: disable=import-error
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
self.use_absl = True
|
|
|
|
self.absl_flags = absl_flags
|
|
|
|
absl_defs = { bool: absl_flags.DEFINE_bool,
|
|
|
|
int: absl_flags.DEFINE_integer,
|
2022-10-28 23:53:30 +00:00
|
|
|
float: absl_flags.DEFINE_float,
|
2021-04-19 08:52:48 -07:00
|
|
|
str: absl_flags.DEFINE_string,
|
|
|
|
'enum': absl_flags.DEFINE_enum }
|
|
|
|
|
|
|
|
for name, val in self.values.items():
|
|
|
|
flag_type, meta_args, meta_kwargs = self.meta[name]
|
|
|
|
absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)
|
|
|
|
app.call_after_init(lambda: self.complete_absl_config(absl_flags))
|
|
|
|
|
|
|
|
def complete_absl_config(self, absl_flags):
|
|
|
|
for name, _ in self.values.items():
|
2023-07-31 13:57:43 -07:00
|
|
|
try:
|
|
|
|
flag = absl_flags.FLAGS[name]
|
|
|
|
except KeyError:
|
|
|
|
# This can happen if a new flag was added after config_with_absl() was
|
|
|
|
# called, but before complete_absl_config was run. We could in principle
|
|
|
|
# add code to DEFINE_... to register any newly added flags with ABSL
|
|
|
|
# if config_with_absl() has already been called, but arguably the user
|
|
|
|
# should have called config_with_absl() later.
|
|
|
|
continue
|
2022-12-02 11:36:51 -08:00
|
|
|
if flag.present:
|
|
|
|
self.update(name, flag.value)
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
def parse_flags_with_absl(self):
|
|
|
|
global already_configured_with_absl
|
|
|
|
if not already_configured_with_absl:
|
2021-08-06 07:05:43 -07:00
|
|
|
# Extract just the --jax... flags (before the first --) from argv. In some
|
|
|
|
# environments (e.g. ipython/colab) argv might be a mess of things
|
|
|
|
# parseable by absl and other junk.
|
|
|
|
jax_argv = itertools.takewhile(lambda a: a != '--', sys.argv)
|
|
|
|
jax_argv = ['', *(a for a in jax_argv if a.startswith('--jax'))]
|
|
|
|
|
2023-03-09 05:26:58 -08:00
|
|
|
import absl.flags # pytype: disable=import-error
|
2021-04-19 08:52:48 -07:00
|
|
|
self.config_with_absl()
|
2021-08-06 07:05:43 -07:00
|
|
|
absl.flags.FLAGS(jax_argv, known_only=True)
|
2021-04-19 08:52:48 -07:00
|
|
|
self.complete_absl_config(absl.flags)
|
|
|
|
already_configured_with_absl = True
|
|
|
|
|
2021-04-21 06:36:08 -07:00
|
|
|
def _trace_context(self):
|
|
|
|
"""Returns a tuple of configuration values that affect tracing.
|
|
|
|
|
|
|
|
These values are included in the cache key for linear_util.cache.
|
|
|
|
|
|
|
|
Values included in this set should also most likely be included in
|
2023-10-27 12:52:24 -07:00
|
|
|
the C++ JIT state, which is handled separately.
|
|
|
|
"""
|
2022-06-28 13:50:49 -07:00
|
|
|
tls = jax_jit.thread_local_state()
|
|
|
|
axis_env_state = ()
|
2023-03-06 10:45:02 -08:00
|
|
|
mesh_context_manager = ()
|
2022-06-28 13:50:49 -07:00
|
|
|
context = tls.extra_jit_context
|
|
|
|
if context and context.axis_env_state is not None:
|
|
|
|
axis_env_state = context.axis_env_state
|
2023-03-06 10:45:02 -08:00
|
|
|
if context and context.mesh_context_manager:
|
|
|
|
mesh_context_manager = context.mesh_context_manager
|
|
|
|
return (axis_env_state, mesh_context_manager, self.x64_enabled,
|
|
|
|
self.jax_numpy_rank_promotion, self.jax_default_matmul_precision,
|
|
|
|
self.jax_dynamic_shapes, self.jax_numpy_dtype_promotion,
|
2023-07-18 06:15:24 -07:00
|
|
|
self.jax_default_device,
|
2023-12-12 18:31:07 -08:00
|
|
|
self.jax_random_seed_offset,
|
2023-03-06 10:45:02 -08:00
|
|
|
self.jax_threefry_partitionable,
|
2023-04-19 18:11:35 -07:00
|
|
|
self.jax_softmax_custom_jvp,
|
2023-08-24 22:23:13 -07:00
|
|
|
self.jax_enable_memories,
|
2023-08-31 15:17:57 -07:00
|
|
|
self.jax_disable_jit,
|
2023-10-27 12:52:24 -07:00
|
|
|
self.jax_xla_profile_version,
|
2024-01-15 02:12:52 -08:00
|
|
|
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
2023-02-15 17:25:29 -08:00
|
|
|
self.jax_hlo_source_file_canonicalization_regex)
|
2021-04-21 06:36:08 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
|
|
|
|
config = Config()
|
|
|
|
|
|
|
|
_read = config._read
|
|
|
|
update = config.update
|
|
|
|
parse_flags_with_absl = config.parse_flags_with_absl
|
|
|
|
|
|
|
|
|
2022-04-12 15:05:53 -07:00
|
|
|
class NoDefault: pass
|
|
|
|
no_default = NoDefault()
|
|
|
|
|
2023-10-04 09:44:12 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
class _Unset: pass
|
|
|
|
unset = _Unset()
|
|
|
|
|
|
|
|
_thread_local_state = threading.local()
|
|
|
|
|
|
|
|
|
2023-10-04 09:44:12 -07:00
|
|
|
class _StateContextManager(Generic[_T]):
|
2021-07-07 17:45:45 -07:00
|
|
|
def __init__(self, name, help, update_thread_local_hook,
|
2023-12-08 12:09:04 +00:00
|
|
|
validate_new_val_hook: Callable[[Any], None] | None = None,
|
2022-04-12 15:05:53 -07:00
|
|
|
extra_description: str = "", default_value: Any = no_default):
|
2021-04-19 08:52:48 -07:00
|
|
|
self._name = name
|
|
|
|
self.__name__ = name[4:] if name.startswith('jax_') else name
|
2022-04-12 15:05:53 -07:00
|
|
|
self.__doc__ = (f"Context manager for `{name}` config option"
|
|
|
|
f"{extra_description}.\n\n{help}")
|
2021-07-07 17:45:45 -07:00
|
|
|
self._update_thread_local_hook = update_thread_local_hook
|
|
|
|
self._validate_new_val_hook = validate_new_val_hook
|
2022-04-12 15:05:53 -07:00
|
|
|
self._default_value = default_value
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-09 07:28:18 -07:00
|
|
|
def __bool__(self) -> NoReturn:
|
|
|
|
raise TypeError(
|
|
|
|
"bool() not supported for instances of type '{0}' "
|
|
|
|
"(did you mean to use '{0}.value' instead?)".format(
|
|
|
|
type(self).__name__))
|
|
|
|
|
2023-10-04 09:44:12 -07:00
|
|
|
@property
|
|
|
|
def value(self) -> _T:
|
|
|
|
val = _thread_local_state.__dict__.get(self._name, unset)
|
|
|
|
return val if val is not unset else config._read(self._name)
|
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
@contextlib.contextmanager
|
2022-12-06 18:39:54 -05:00
|
|
|
def __call__(self, new_val: Any = no_default):
|
2022-04-12 15:05:53 -07:00
|
|
|
if new_val is no_default:
|
|
|
|
if self._default_value is not no_default:
|
|
|
|
new_val = self._default_value # default_value provided to constructor
|
|
|
|
else:
|
|
|
|
# no default_value provided to constructor and no value provided as an
|
|
|
|
# argument, so we raise an error
|
|
|
|
raise TypeError(f"Context manager for {self.__name__} config option "
|
|
|
|
"requires an argument representing the new value for "
|
|
|
|
"the config option.")
|
2021-07-07 17:45:45 -07:00
|
|
|
if self._validate_new_val_hook:
|
|
|
|
self._validate_new_val_hook(new_val)
|
2021-04-19 08:52:48 -07:00
|
|
|
prev_val = getattr(_thread_local_state, self._name, unset)
|
|
|
|
setattr(_thread_local_state, self._name, new_val)
|
2021-07-07 17:45:45 -07:00
|
|
|
if self._update_thread_local_hook:
|
|
|
|
self._update_thread_local_hook(new_val)
|
2021-04-19 08:52:48 -07:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
if prev_val is unset:
|
|
|
|
delattr(_thread_local_state, self._name)
|
2021-07-07 17:45:45 -07:00
|
|
|
if self._update_thread_local_hook:
|
|
|
|
self._update_thread_local_hook(None)
|
2021-04-19 08:52:48 -07:00
|
|
|
else:
|
|
|
|
setattr(_thread_local_state, self._name, prev_val)
|
2021-07-07 17:45:45 -07:00
|
|
|
if self._update_thread_local_hook:
|
|
|
|
self._update_thread_local_hook(prev_val)
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
def _add_hooks(self, update_global_hook, update_thread_local_hook):
|
|
|
|
"""Private method that adds hooks to an existing context-manager.
|
|
|
|
|
|
|
|
Used to avoid cyclic import dependencies."""
|
2021-07-07 17:45:45 -07:00
|
|
|
self._update_thread_local_hook = update_thread_local_hook
|
2021-04-19 08:52:48 -07:00
|
|
|
config._update_hooks[self._name] = update_global_hook
|
|
|
|
update_global_hook(config._read(self._name))
|
|
|
|
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
def define_bool_state(
|
|
|
|
name: str,
|
|
|
|
default: bool,
|
|
|
|
help: str,
|
|
|
|
*,
|
2023-12-08 12:09:04 +00:00
|
|
|
update_global_hook: Callable[[bool], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[bool | None], None] | None = None,
|
2023-10-25 12:38:19 +01:00
|
|
|
upgrade: bool = False,
|
|
|
|
extra_description: str = '',
|
|
|
|
) -> _StateContextManager[bool]:
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
This function is a convenience wrapper. It defines a flag, environment
|
|
|
|
variable, and corresponding thread-local state, which can be managed via the
|
|
|
|
contextmanager it returns.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
The thread-local state value can be read via the ``config.<option_name>``
|
|
|
|
attribute, where ``config`` is the singleton ``Config`` instance.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
default: boolean, a default value for the option.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
update_global_hook: a optional callback that is called with the updated
|
|
|
|
value of the global state when it is altered or set initially.
|
|
|
|
update_thread_local_hook: a optional callback that is called with the
|
|
|
|
updated value of the thread-local state when it is altered or set
|
|
|
|
initially.
|
|
|
|
upgrade: optional indicator that this flag controls a canonical feature
|
|
|
|
upgrade, so that it is `True` for the incoming functionality, `False`
|
|
|
|
for the outgoing functionality to be deprecated.
|
|
|
|
extra_description: string, optional: extra information to add to the
|
|
|
|
summary description.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
enable_foo = config.define_bool_state(
|
|
|
|
name='jax_enable_foo',
|
|
|
|
default=False,
|
|
|
|
help='Enable foo.')
|
|
|
|
|
|
|
|
# Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo
|
|
|
|
# command-line flag can be used to control the process-level value of
|
|
|
|
# the configuration option, in addition to using e.g.
|
|
|
|
# ``config.update("jax_enable_foo", True)`` directly. We can also use a
|
|
|
|
# context manager:
|
|
|
|
|
|
|
|
with enable_foo(True):
|
|
|
|
...
|
|
|
|
|
|
|
|
The value of the thread-local state or flag can be accessed via
|
|
|
|
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
|
|
|
|
an error.
|
|
|
|
"""
|
|
|
|
name = name.lower()
|
|
|
|
if upgrade:
|
|
|
|
help += ' ' + UPGRADE_BOOL_HELP
|
|
|
|
extra_description += UPGRADE_BOOL_EXTRA_DESC
|
|
|
|
DEFINE_bool(name, bool_env(name.upper(), default), help,
|
|
|
|
update_hook=update_global_hook)
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
|
|
|
s = _StateContextManager[bool](
|
|
|
|
name, help, update_thread_local_hook,
|
|
|
|
extra_description=extra_description, default_value=True)
|
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
def define_enum_state(
|
|
|
|
name: str,
|
|
|
|
enum_values: list[str],
|
2023-12-08 12:09:04 +00:00
|
|
|
default: str | None,
|
2023-10-25 12:38:19 +01:00
|
|
|
help: str,
|
|
|
|
*,
|
2023-12-08 12:09:04 +00:00
|
|
|
update_global_hook: Callable[[str], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
2023-10-25 12:38:19 +01:00
|
|
|
) -> _StateContextManager[str]:
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
See docstring for ``define_bool_state``.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
enum_values: list of strings representing the possible values for the
|
|
|
|
option.
|
|
|
|
default: optional string, default value.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
|
|
|
name = name.lower()
|
|
|
|
default = os.getenv(name.upper(), default)
|
|
|
|
if default is not None and default not in enum_values:
|
|
|
|
raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}")
|
|
|
|
DEFINE_enum(name, default,
|
|
|
|
enum_values=enum_values, help=help,
|
|
|
|
update_hook=update_global_hook)
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
|
|
|
def validate(new_val):
|
|
|
|
if (new_val is not None and
|
|
|
|
(type(new_val) is not str or new_val not in enum_values)):
|
|
|
|
raise ValueError(f"new enum value must be None or in {enum_values}, "
|
|
|
|
f"got {new_val} of type {type(new_val)}.")
|
|
|
|
|
|
|
|
s = _StateContextManager[str](name, help, update_thread_local_hook, validate)
|
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
def define_int_state(
|
|
|
|
name: str,
|
2023-12-08 12:09:04 +00:00
|
|
|
default: int | None,
|
2023-10-25 12:38:19 +01:00
|
|
|
help: str,
|
|
|
|
*,
|
2023-12-08 12:09:04 +00:00
|
|
|
update_global_hook: Callable[[str], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
2023-10-25 12:38:19 +01:00
|
|
|
) -> _StateContextManager[int]:
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
See docstring for ``define_bool_state``.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
default: optional int, default value.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
|
|
|
name = name.lower()
|
|
|
|
default_env = os.getenv(name.upper(), default)
|
|
|
|
if default_env is not None:
|
|
|
|
try:
|
|
|
|
default = int(default_env)
|
|
|
|
except ValueError:
|
|
|
|
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
|
|
|
|
DEFINE_integer(name, default, help=help, update_hook=update_global_hook)
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
|
|
|
def validate(new_val):
|
|
|
|
if new_val is not None and not isinstance(new_val, int):
|
|
|
|
raise ValueError(f'new int config value must be None or of type int, '
|
|
|
|
f'got {new_val} of type {type(new_val)}')
|
|
|
|
|
|
|
|
s = _StateContextManager[int](name, help, update_thread_local_hook, validate)
|
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
def define_float_state(
|
|
|
|
name: str,
|
2023-12-08 12:09:04 +00:00
|
|
|
default: float | None,
|
2023-10-25 12:38:19 +01:00
|
|
|
help: str,
|
|
|
|
*,
|
2023-12-08 12:09:04 +00:00
|
|
|
update_global_hook: Callable[[str], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
2023-10-25 12:38:19 +01:00
|
|
|
) -> _StateContextManager[float]:
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
See docstring for ``define_bool_state``.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
default: optional float, default value.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
|
|
|
name = name.lower()
|
|
|
|
default_env = os.getenv(name.upper(), default)
|
|
|
|
if default_env is not None:
|
|
|
|
try:
|
|
|
|
default = float(default_env)
|
|
|
|
except ValueError:
|
|
|
|
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
|
|
|
|
DEFINE_float(name, default, help=help, update_hook=update_global_hook)
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
|
|
|
def validate(new_val):
|
|
|
|
if new_val is not None and not isinstance(new_val, (float, int)):
|
|
|
|
raise ValueError(
|
|
|
|
f'new float config value must be None or of type float, '
|
|
|
|
f'got {new_val} of type {type(new_val)}')
|
|
|
|
|
|
|
|
s = _StateContextManager[float](name, help, update_thread_local_hook,
|
|
|
|
validate)
|
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
def define_string_state(
|
|
|
|
name: str,
|
2023-12-08 12:09:04 +00:00
|
|
|
default: str | None,
|
2023-10-25 12:38:19 +01:00
|
|
|
help: str,
|
|
|
|
*,
|
2023-12-08 12:09:04 +00:00
|
|
|
update_global_hook: Callable[[str], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
2023-10-25 12:38:19 +01:00
|
|
|
) -> _StateContextManager[str]:
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
See docstring for ``define_bool_state``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
default: string, a default value for the option.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
update_global_hook: an optional callback that is called with the updated
|
|
|
|
value of the global state when it is altered or set initially.
|
|
|
|
update_thread_local_hook: an optional callback that is called with the
|
|
|
|
updated value of the thread-local state when it is altered or set
|
|
|
|
initially.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def validate(new_val):
|
|
|
|
if new_val is not None and not isinstance(new_val, str):
|
|
|
|
raise ValueError(f'new string config value must be None or of type str,'
|
|
|
|
f' got {new_val} of type {type(new_val)}.')
|
|
|
|
|
|
|
|
return define_string_or_object_state(
|
|
|
|
name, default, help,
|
|
|
|
update_global_hook=update_global_hook,
|
|
|
|
update_thread_local_hook=update_thread_local_hook,
|
|
|
|
validate_new_val_hook=validate)
|
|
|
|
|
|
|
|
|
|
|
|
def define_string_or_object_state(
|
|
|
|
name: str,
|
|
|
|
default: Any,
|
|
|
|
help: str,
|
|
|
|
*,
|
2023-12-08 12:09:04 +00:00
|
|
|
update_global_hook: Callable[[Any], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[Any], None] | None = None,
|
|
|
|
validate_new_val_hook: Callable[[Any], None] | None = None,
|
2023-10-25 12:38:19 +01:00
|
|
|
) -> _StateContextManager[Any]:
|
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
|
|
|
Similar to ``define_string_state``, except the context manager will accept
|
|
|
|
any object, not just a string. Any value passed via commandline flag or
|
|
|
|
environment variable will be treated as a string.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
default: string, a default value for the option.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
update_global_hook: an optional callback that is called with the updated
|
|
|
|
value of the global state when it is altered or set initially.
|
|
|
|
update_thread_local_hook: an optional callback that is called with the
|
|
|
|
updated value of the thread-local state when it is altered or set
|
|
|
|
initially.
|
|
|
|
validate_new_val_hook: an optional callback that is called with the new
|
|
|
|
value on any update, and should raise an error if the new value is
|
|
|
|
invalid.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
|
|
|
name = name.lower()
|
|
|
|
default = os.getenv(name.upper(), default)
|
|
|
|
DEFINE_string(name, default, help=help, update_hook=update_global_hook)
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
|
|
|
s = _StateContextManager[Any](
|
|
|
|
name, help, update_thread_local_hook, validate_new_val_hook)
|
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
return s
|
2023-10-12 13:15:22 +01:00
|
|
|
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
class FlagHolder(Generic[_T]):
|
|
|
|
def __init__(self, name: str):
|
|
|
|
self._name = name
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
def __bool__(self) -> NoReturn:
|
|
|
|
raise TypeError(
|
|
|
|
"bool() not supported for instances of type '{0}' "
|
|
|
|
"(did you mean to use '{0}.value' instead?)".format(
|
|
|
|
type(self).__name__))
|
2023-07-27 12:15:16 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
@property
|
|
|
|
def value(self) -> _T:
|
|
|
|
return config.read(self._name)
|
2023-07-27 12:15:16 -07:00
|
|
|
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
def check_exists(name):
|
|
|
|
if name not in config.values:
|
|
|
|
raise AttributeError(f"Unrecognized config option: {name}")
|
2023-07-27 12:15:16 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
|
|
|
|
def DEFINE_bool(name, default, *args, **kwargs) -> FlagHolder[bool]:
|
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
|
|
|
config.add_option(name, default, bool, args, kwargs, update_hook=update_hook)
|
|
|
|
return FlagHolder(name)
|
|
|
|
|
|
|
|
|
|
|
|
def DEFINE_integer(name, default, *args, **kwargs) -> FlagHolder[int]:
|
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
|
|
|
config.add_option(name, default, int, args, kwargs, update_hook=update_hook)
|
|
|
|
return FlagHolder(name)
|
|
|
|
|
|
|
|
|
|
|
|
def DEFINE_float(name, default, *args, **kwargs) -> FlagHolder[float]:
|
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
|
|
|
config.add_option(name, default, float, args, kwargs,
|
|
|
|
update_hook=update_hook)
|
|
|
|
return FlagHolder(name)
|
|
|
|
|
|
|
|
|
|
|
|
def DEFINE_string(name, default, *args, **kwargs) -> FlagHolder[str]:
|
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
|
|
|
config.add_option(name, default, str, args, kwargs, update_hook=update_hook)
|
|
|
|
return FlagHolder(name)
|
|
|
|
|
|
|
|
|
|
|
|
def DEFINE_enum(name, default, *args, **kwargs) -> FlagHolder[str]:
|
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
|
|
|
config.add_option(name, default, 'enum', args, kwargs,
|
|
|
|
update_hook=update_hook)
|
|
|
|
return FlagHolder(name)
|
2023-07-27 12:15:16 -07:00
|
|
|
|
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
already_configured_with_absl = False
|
|
|
|
|
|
|
|
|
2021-04-21 06:36:08 -07:00
|
|
|
# The C++ JIT maintains its own copy of several configuration items as
|
|
|
|
# a global/thread-local state. These methods allow updates to part of the
|
|
|
|
# state when a configuration value changes.
|
2022-05-31 12:57:02 -07:00
|
|
|
class _GlobalExtraJitContext(NamedTuple):
|
2023-12-08 12:09:04 +00:00
|
|
|
numpy_rank_promotion: str | None = None
|
|
|
|
numpy_dtype_promotion: str | None = None
|
|
|
|
default_matmul_precision: Any | None = None
|
2022-03-30 17:52:55 -07:00
|
|
|
dynamic_shapes: bool = False
|
2023-12-12 18:31:07 -08:00
|
|
|
random_seed_offset: int = 0
|
2022-11-22 13:58:59 -08:00
|
|
|
threefry_partitionable: bool = False
|
2023-05-23 11:56:50 -07:00
|
|
|
softmax_custom_jvp: bool = False
|
2023-10-27 12:52:24 -07:00
|
|
|
xla_profile_version: int = 0
|
2021-04-21 06:36:08 -07:00
|
|
|
|
|
|
|
|
2022-05-31 12:57:02 -07:00
|
|
|
def _update_global_jit_state(**kw):
|
2021-04-21 06:36:08 -07:00
|
|
|
gs = jax_jit.global_state()
|
2022-05-31 12:57:02 -07:00
|
|
|
context = gs.extra_jit_context or _GlobalExtraJitContext()
|
2021-04-21 06:36:08 -07:00
|
|
|
gs.extra_jit_context = context._replace(**kw)
|
|
|
|
|
|
|
|
|
2022-05-31 12:57:02 -07:00
|
|
|
class _ThreadLocalExtraJitContext(NamedTuple):
|
2023-10-25 12:38:19 +01:00
|
|
|
"""A namedtuple containing states to add to the cache key.
|
2022-05-31 12:57:02 -07:00
|
|
|
|
|
|
|
Just in time compilation (for jit, pmap, etc) behavior is configurable through
|
|
|
|
global and thread-local options, used in the cache key.
|
|
|
|
|
|
|
|
The initialization, which uses both config.py and core.py is done using
|
|
|
|
`_update_thread_local_jit_state` in core.py to prevent circular imports.
|
|
|
|
"""
|
2023-12-08 12:09:04 +00:00
|
|
|
dynamic_trace_state: Any | None = None
|
2022-08-15 14:26:39 -07:00
|
|
|
axis_env_state: Hashable = ()
|
2023-03-06 10:45:02 -08:00
|
|
|
mesh_context_manager: Hashable = ()
|
2023-12-08 12:09:04 +00:00
|
|
|
numpy_rank_promotion: str | None = None
|
|
|
|
numpy_dtype_promotion: str | None = None
|
|
|
|
default_matmul_precision: Any | None = None
|
2022-03-30 17:52:55 -07:00
|
|
|
dynamic_shapes: bool = False
|
2023-12-12 18:31:07 -08:00
|
|
|
random_seed_offset: int = 0
|
2023-10-31 13:41:41 -07:00
|
|
|
threefry_partitionable: bool = False
|
2023-05-23 11:56:50 -07:00
|
|
|
softmax_custom_jvp: bool = False
|
2023-10-27 12:52:24 -07:00
|
|
|
xla_profile_version: int = 0
|
2021-04-21 06:36:08 -07:00
|
|
|
|
|
|
|
|
2022-06-22 20:04:05 -07:00
|
|
|
class _ThreadLocalStateCache(threading.local):
|
|
|
|
""""A thread local cache for _ThreadLocalExtraJitContext
|
|
|
|
|
|
|
|
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus
|
|
|
|
incurring dispatch overhead for comparing this python object during jit calls.
|
|
|
|
We want to duduplicate the objects that have the same hash/equality to also
|
|
|
|
have the same object ID, since the equality check is much faster if the object
|
|
|
|
IDs match.
|
|
|
|
"""
|
|
|
|
def __init__(self):
|
|
|
|
self.canonicalize = functools.lru_cache(128)(lambda x: x)
|
|
|
|
|
|
|
|
|
|
|
|
_thread_local_state_cache = _ThreadLocalStateCache()
|
|
|
|
|
|
|
|
|
2021-04-21 06:36:08 -07:00
|
|
|
def update_thread_local_jit_state(**kw):
|
|
|
|
tls = jax_jit.thread_local_state()
|
2022-05-31 12:57:02 -07:00
|
|
|
# After xla_client._version >= 70, the thread_local object will necessarily
|
|
|
|
# be initialized when accessed. The following line can be removed when the
|
|
|
|
# minimum jaxlib version is past version 70
|
|
|
|
context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
|
2022-06-22 20:04:05 -07:00
|
|
|
tmp = context._replace(**kw)
|
|
|
|
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
|
2021-04-21 06:36:08 -07:00
|
|
|
|
|
|
|
|
2022-01-13 12:28:25 +02:00
|
|
|
# TODO(b/214340779): remove flag when XLA:CPU is improved.
|
2023-10-25 12:38:19 +01:00
|
|
|
jax2tf_associative_scan_reductions = define_bool_state(
|
2022-01-13 12:28:25 +02:00
|
|
|
name='jax2tf_associative_scan_reductions',
|
|
|
|
default=False,
|
|
|
|
help=(
|
|
|
|
'JAX has two separate lowering rules for the cumulative reduction '
|
|
|
|
'primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses '
|
|
|
|
'a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. '
|
|
|
|
'The latter has a slow implementation on CPUs and GPUs. '
|
|
|
|
'By default, jax2tf uses the TPU lowering. Set this flag to True to '
|
|
|
|
'use the associative scan lowering usage, and only if it makes a difference '
|
|
|
|
'for your application. '
|
|
|
|
'See the jax2tf README.md for more details.'
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
jax2tf_default_native_serialization = define_bool_state(
|
2023-03-15 10:30:52 -07:00
|
|
|
name='jax2tf_default_native_serialization',
|
2023-08-07 01:03:15 -07:00
|
|
|
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True),
|
2023-03-15 10:30:52 -07:00
|
|
|
help=(
|
|
|
|
'Sets the default value of the native_serialization parameter to '
|
2023-03-28 10:17:19 -07:00
|
|
|
'jax2tf.convert. Prefer using the parameter instead of the flag, '
|
|
|
|
'the flag may be removed in the future.'
|
2023-03-15 10:30:52 -07:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
jax_serialization_version = define_int_state(
|
2023-07-16 09:26:27 -07:00
|
|
|
name='jax_serialization_version',
|
|
|
|
# Note: bump the default serialization version at least one month after
|
|
|
|
# we update XlaCallModule to support the new version, so that serialized
|
|
|
|
# modules are forward compatible with deployed versions of XlaCallModule.
|
2023-10-20 21:00:23 -07:00
|
|
|
# Version 8 of XlaCallModule is supported since July 21th, 2023.
|
|
|
|
default=int_env('JAX_SERIALIZATION_VERSION', 8),
|
2023-07-16 09:26:27 -07:00
|
|
|
help=(
|
|
|
|
'The version number to use for native serialization. This must be '
|
|
|
|
'within the range of versions supported by the tf.XlaCallModule '
|
|
|
|
'used in your deployment environment. '
|
2023-07-20 11:58:47 +03:00
|
|
|
'See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.'
|
2023-07-16 09:26:27 -07:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
jax_platforms = define_string_state(
|
2022-06-17 16:38:56 +00:00
|
|
|
name='jax_platforms',
|
2022-12-21 09:05:36 -08:00
|
|
|
default=None,
|
2022-06-17 16:38:56 +00:00
|
|
|
help=(
|
|
|
|
'Comma-separated list of platform names specifying which platforms jax '
|
|
|
|
'should initialize. If any of the platforms in this list are not successfully '
|
|
|
|
'initialized, an exception will be raised and the program will be aborted. '
|
|
|
|
'The first platform in the list will be the default platform. '
|
|
|
|
'For example, config.jax_platforms=cpu,tpu means that CPU and TPU backends '
|
|
|
|
'will be initialized, and the CPU backend will be used unless otherwise '
|
|
|
|
'specified. If TPU initialization fails, it will raise an exception. '
|
|
|
|
'By default, jax will try to initialize all available '
|
|
|
|
'platforms and will default to GPU or TPU if available, and fallback to CPU '
|
|
|
|
'otherwise.'
|
|
|
|
))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
enable_checks = define_bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_enable_checks',
|
|
|
|
default=False,
|
|
|
|
help='Turn on invariant checking for JAX internals. Makes things slower.')
|
|
|
|
|
2023-12-11 12:03:48 -08:00
|
|
|
enable_key_reuse_checks = define_bool_state(
|
|
|
|
name='jax_enable_key_reuse_checks',
|
|
|
|
default=False,
|
|
|
|
help="Turn on experimental key reuse checking."
|
|
|
|
)
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
check_tracer_leaks = define_bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_check_tracer_leaks',
|
|
|
|
default=False,
|
|
|
|
help=('Turn on checking for leaked tracers as soon as a trace completes. '
|
|
|
|
'Enabling leak checking may have performance impacts: some caching '
|
2021-11-08 09:21:18 -08:00
|
|
|
'is disabled, and other overheads may be added. Additionally, be aware '
|
|
|
|
'that some Python debuggers can cause false positives, so it is recommended '
|
|
|
|
'to disable any debuggers while leak checking is enabled.'))
|
2021-04-19 08:52:48 -07:00
|
|
|
checking_leaks = functools.partial(check_tracer_leaks, True)
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
debug_nans = define_bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_debug_nans',
|
|
|
|
default=False,
|
|
|
|
help=('Add nan checks to every operation. When a nan is detected on the '
|
|
|
|
'output of a jit-compiled computation, call into the un-compiled '
|
|
|
|
'version in an attempt to more precisely identify the operation '
|
|
|
|
'which produced the nan.'))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
debug_infs = define_bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_debug_infs',
|
|
|
|
default=False,
|
|
|
|
help=('Add inf checks to every operation. When an inf is detected on the '
|
|
|
|
'output of a jit-compiled computation, call into the un-compiled '
|
|
|
|
'version in an attempt to more precisely identify the operation '
|
|
|
|
'which produced the inf.'))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
log_compiles = define_bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_log_compiles',
|
|
|
|
default=False,
|
|
|
|
help=('Log a message each time every time `jit` or `pmap` compiles an XLA '
|
2022-10-13 17:06:22 +02:00
|
|
|
'computation. Logging is performed with `logging`. When this '
|
2021-04-19 08:52:48 -07:00
|
|
|
'option is set, the log level is WARNING; otherwise the level is '
|
|
|
|
'DEBUG.'))
|
|
|
|
|
2023-06-09 14:43:42 -07:00
|
|
|
explain_cache_misses = define_bool_state(
|
|
|
|
name='jax_explain_cache_misses',
|
|
|
|
default=False,
|
|
|
|
help=('Each time there is a miss on one of the main caches (e.g. the '
|
|
|
|
'tracing cache), log an explanation.. Logging is performed with '
|
|
|
|
'`logging`. When this option is set, the log level is WARNING; '
|
|
|
|
'otherwise the level is DEBUG.'))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
log_checkpoint_residuals = define_bool_state(
|
2023-03-22 13:37:40 -07:00
|
|
|
name='jax_log_checkpoint_residuals',
|
|
|
|
default=False,
|
|
|
|
help=('Log a message every time jax.checkpoint (aka jax.remat) is '
|
|
|
|
'partially evaluated (e.g. for autodiff), printing what residuals '
|
|
|
|
'are saved.'))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
parallel_functions_output_gda = define_bool_state(
|
2021-12-08 22:04:13 -08:00
|
|
|
name='jax_parallel_functions_output_gda',
|
2021-11-12 22:41:42 -08:00
|
|
|
default=False,
|
2022-05-17 15:10:46 -07:00
|
|
|
help='If True, pjit will output GDAs.')
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
pmap_shmap_merge = define_bool_state(
|
2023-04-20 21:22:16 -07:00
|
|
|
name='jax_pmap_shmap_merge',
|
|
|
|
default=False,
|
|
|
|
upgrade=True,
|
|
|
|
help='If True, pmap and shard_map API will be merged.')
|
|
|
|
|
2023-08-24 22:23:13 -07:00
|
|
|
def _update_jax_memories_global(val):
|
2023-11-17 09:37:45 -08:00
|
|
|
lib.jax_jit.global_state().enable_memories = val
|
2023-08-24 22:23:13 -07:00
|
|
|
|
|
|
|
def _update_jax_memories_thread_local(val):
|
2023-11-17 09:37:45 -08:00
|
|
|
lib.jax_jit.thread_local_state().enable_memories = val
|
2023-08-24 22:23:13 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
enable_memories = define_bool_state(
|
2023-08-24 22:23:13 -07:00
|
|
|
'jax_enable_memories',
|
2023-11-09 10:53:33 -08:00
|
|
|
default=False,
|
2023-08-24 22:23:13 -07:00
|
|
|
upgrade=True,
|
|
|
|
update_global_hook=_update_jax_memories_global,
|
|
|
|
update_thread_local_hook=_update_jax_memories_thread_local,
|
|
|
|
help=("If True, will allow fetching memory kinds available on executable "
|
|
|
|
"and annotate Shardings with it."))
|
2023-04-20 21:22:16 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
spmd_mode = define_enum_state(
|
2022-10-31 09:46:46 -07:00
|
|
|
name='jax_spmd_mode',
|
2023-09-11 17:07:51 -07:00
|
|
|
enum_values=['allow_all', 'allow_jit'],
|
2023-04-26 15:54:50 -07:00
|
|
|
default='allow_jit',
|
2022-10-31 09:46:46 -07:00
|
|
|
help=("Decides whether Math on `jax.Array`'s that are not fully addressable "
|
|
|
|
"(i.e. spans across multiple processes) is allowed. The options are: "
|
2023-09-11 17:07:51 -07:00
|
|
|
"* allow_jit: Default, `pjit` and `jax.jit` computations are allowed "
|
|
|
|
" to execute on non-fully addressable `jax.Array`s\n"
|
2022-10-31 09:46:46 -07:00
|
|
|
"* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, "
|
2023-09-11 17:07:51 -07:00
|
|
|
" `jax.jit` and all other operations are allowed to "
|
2023-09-22 14:54:31 -07:00
|
|
|
" execute on non-fully addressable `jax.Array`s."))
|
2022-10-31 09:46:46 -07:00
|
|
|
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
distributed_debug = define_bool_state(
|
2021-07-30 12:37:21 -07:00
|
|
|
name='jax_distributed_debug',
|
Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.
This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.
Example output:
```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
process_index: 0
device_count: 8
local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
devices: None
abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
dtype=object), ('x',))
abstract args: []
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
[TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
dtype=object), ('x', 'y'))
abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-19 12:45:17 -07:00
|
|
|
default=False,
|
|
|
|
help=('Enable logging useful for debugging multi-process distributed '
|
2022-10-13 17:06:22 +02:00
|
|
|
'computations. Logging is performed with `logging` at WARNING '
|
Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.
This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.
Example output:
```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
process_index: 0
device_count: 8
local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
devices: None
abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
dtype=object), ('x',))
abstract args: []
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
[TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
dtype=object), ('x', 'y'))
abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-19 12:45:17 -07:00
|
|
|
'level.'))
|
|
|
|
|
2023-12-12 18:31:07 -08:00
|
|
|
random_seed_offset = define_int_state(
|
|
|
|
name='jax_random_seed_offset',
|
|
|
|
default=0,
|
|
|
|
help=('Offset to all random seeds (e.g. argument to jax.random.key()).'),
|
|
|
|
update_global_hook=lambda val: _update_global_jit_state(
|
|
|
|
random_seed_offset=val),
|
|
|
|
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
|
|
|
random_seed_offset=val)
|
|
|
|
)
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
legacy_prng_key = define_enum_state(
|
2023-08-22 15:08:51 -07:00
|
|
|
name='jax_legacy_prng_key',
|
|
|
|
enum_values=['allow', 'warn', 'error'],
|
|
|
|
default='allow',
|
|
|
|
help=('Specify the behavior when raw PRNG keys are passed to '
|
|
|
|
'jax.random APIs.')
|
|
|
|
)
|
2022-03-25 18:37:15 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
enable_custom_prng = define_bool_state(
|
2021-08-15 08:09:30 -07:00
|
|
|
name='jax_enable_custom_prng',
|
|
|
|
default=False,
|
2022-03-28 17:17:33 -07:00
|
|
|
upgrade=True,
|
2021-08-15 08:09:30 -07:00
|
|
|
help=('Enables an internal upgrade that allows one to define custom '
|
2022-03-25 18:37:15 -07:00
|
|
|
'pseudo-random number generator implementations.'))
|
2021-08-15 08:09:30 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
default_prng_impl = define_enum_state(
|
2021-10-07 19:15:43 -07:00
|
|
|
name='jax_default_prng_impl',
|
|
|
|
enum_values=['threefry2x32', 'rbg', 'unsafe_rbg'],
|
|
|
|
default='threefry2x32',
|
|
|
|
help=('Select the default PRNG implementation, used when one is not '
|
|
|
|
'explicitly provided at seeding time.'))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
threefry_partitionable = define_bool_state(
|
2022-10-25 08:13:55 -07:00
|
|
|
name='jax_threefry_partitionable',
|
|
|
|
default=False,
|
|
|
|
upgrade=True,
|
|
|
|
help=('Enables internal threefry PRNG implementation changes that '
|
2023-07-18 06:15:24 -07:00
|
|
|
'render it automatically partitionable in some cases. Without this '
|
|
|
|
'flag, using the standard jax.random pseudo-random number generation '
|
|
|
|
'may result in extraneous communication and/or redundant distributed '
|
2022-10-25 08:13:55 -07:00
|
|
|
'computation. With this flag, the communication overheads disappear '
|
2022-11-22 13:58:59 -08:00
|
|
|
'in some cases.'),
|
|
|
|
update_global_hook=lambda val: _update_global_jit_state(
|
|
|
|
threefry_partitionable=val),
|
|
|
|
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
|
|
|
threefry_partitionable=val))
|
2022-10-25 08:13:55 -07:00
|
|
|
|
2023-04-19 18:11:35 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
softmax_custom_jvp = define_bool_state(
|
2023-04-19 18:11:35 -07:00
|
|
|
name='jax_softmax_custom_jvp',
|
2023-05-23 11:56:50 -07:00
|
|
|
default=False,
|
2023-04-19 18:11:35 -07:00
|
|
|
upgrade=True,
|
|
|
|
help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should '
|
2023-05-23 11:56:50 -07:00
|
|
|
'improve memory usage and stability. Set True to use new '
|
2023-04-19 18:11:35 -07:00
|
|
|
'behavior. See https://github.com/google/jax/pull/15677'),
|
|
|
|
update_global_hook=lambda val: _update_global_jit_state(
|
|
|
|
softmax_custom_jvp=val),
|
|
|
|
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
|
|
|
softmax_custom_jvp=val))
|
|
|
|
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
enable_custom_vjp_by_custom_transpose = define_bool_state(
|
2022-03-25 18:37:15 -07:00
|
|
|
name='jax_enable_custom_vjp_by_custom_transpose',
|
|
|
|
default=False,
|
2022-03-28 17:17:33 -07:00
|
|
|
upgrade=True,
|
2022-03-25 18:37:15 -07:00
|
|
|
help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
|
|
|
|
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
raise_persistent_cache_errors = define_bool_state(
|
2022-09-27 20:59:08 +00:00
|
|
|
name='jax_raise_persistent_cache_errors',
|
|
|
|
default=False,
|
|
|
|
help=('If true, exceptions raised when reading or writing to the '
|
|
|
|
'persistent compilation cache will be allowed through, halting '
|
|
|
|
'program execution if not manually caught. If false, exceptions are '
|
|
|
|
'caught and raised as warnings, allowing program execution to '
|
|
|
|
'continue. Defaults to false so cache bugs or intermittent issues '
|
|
|
|
'are non-fatal.'))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
persistent_cache_min_compile_time_secs = define_float_state(
|
2022-10-28 23:53:30 +00:00
|
|
|
name='jax_persistent_cache_min_compile_time_secs',
|
|
|
|
default=1,
|
|
|
|
help=('The minimum compile time of a computation to be written to the '
|
|
|
|
'persistent compilation cache. This threshold can be raised to '
|
|
|
|
'decrease the number of entries written to the cache.'))
|
2022-09-27 20:59:08 +00:00
|
|
|
|
2024-01-04 15:16:25 -08:00
|
|
|
persistent_cache_min_entry_size_bytes = define_int_state(
|
|
|
|
name='jax_persistent_cache_min_entry_size_bytes',
|
|
|
|
default=0,
|
|
|
|
help=('The minimum size (in bytes) of an entry that will be cached in the '
|
|
|
|
'persistent compilation cache: '
|
|
|
|
'* -1: disable the size restriction and prevent overrides. '
|
|
|
|
'* Leave at default (0) to allow for overrides. The override will '
|
|
|
|
' typically ensure that the minimum size is optimal for the '
|
|
|
|
' filesystem being used for the cache. '
|
|
|
|
'* > 0: the actual minimum size desired; no overrides.'))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
compilation_cache_include_metadata_in_key = define_bool_state(
|
2023-04-19 13:26:24 -07:00
|
|
|
name='jax_compilation_cache_include_metadata_in_key',
|
|
|
|
default=False,
|
|
|
|
help=(
|
|
|
|
'Include metadata, such as file names and line numbers, in the'
|
|
|
|
' compilation cache key. If false, the cache will still get hits even'
|
|
|
|
' if functions or files are moved, etc. However, it means that'
|
|
|
|
' executables loaded from the cache may have stale metadata, which'
|
|
|
|
' may show up in, e.g., profiles.'
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
hlo_source_file_canonicalization_regex = define_string_state(
|
2021-07-30 12:37:21 -07:00
|
|
|
name='jax_hlo_source_file_canonicalization_regex',
|
|
|
|
default=None,
|
|
|
|
help=('Used to canonicalize the source_path metadata of HLO instructions '
|
|
|
|
'by removing the given regex. If set, re.sub() is called on each '
|
|
|
|
'source_file with the given regex, and all matches are removed. '
|
|
|
|
'This can be used to avoid spurious cache misses when using the '
|
|
|
|
'persistent compilation cache, which includes HLO metadata in the '
|
|
|
|
'cache key.'))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
include_full_tracebacks_in_locations = define_bool_state(
|
2023-05-22 07:40:52 -07:00
|
|
|
name='jax_include_full_tracebacks_in_locations',
|
2023-12-17 21:55:34 -08:00
|
|
|
default=True,
|
2023-05-22 07:40:52 -07:00
|
|
|
help=(
|
2024-01-03 23:26:22 -08:00
|
|
|
'Include Python tracebacks in MLIR locations in IR emitted by JAX.'
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
traceback_in_locations_limit = define_int_state(
|
|
|
|
name='jax_traceback_in_locations_limit',
|
|
|
|
default=10,
|
|
|
|
help=(
|
|
|
|
'Limit the number of frames at the Python traceback frames included in '
|
|
|
|
'MLIR locations. If set to the negative value, traceback will not be '
|
|
|
|
'limited.'
|
2023-05-22 07:40:52 -07:00
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2024-01-11 23:37:22 -08:00
|
|
|
share_binary_between_hosts = define_bool_state(
|
|
|
|
name='jax_share_binary_between_hosts',
|
|
|
|
default=False,
|
|
|
|
help=(
|
|
|
|
'If set to True, the compiled module will be shared between hosts '
|
|
|
|
'directly.'
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
share_binary_between_hosts_timeout_ms = define_int_state(
|
|
|
|
name='jax_share_binary_between_hosts_timeout_ms',
|
|
|
|
default=10 * 60 * 1000,
|
|
|
|
help='Timeout for the compiled module share.',
|
|
|
|
)
|
|
|
|
|
2023-11-27 14:52:22 -08:00
|
|
|
enable_compilation_cache = define_bool_state(
|
|
|
|
name='jax_enable_compilation_cache',
|
|
|
|
default=True,
|
|
|
|
help=('If set to False, the compilation cache will be disabled regardless '
|
2024-01-12 22:44:03 -08:00
|
|
|
'of whether set_cache_dir() was called. If set to True, the '
|
2023-11-27 14:52:22 -08:00
|
|
|
'path could be set to a default value or via a call to '
|
2024-01-12 22:44:03 -08:00
|
|
|
'set_cache_dir().'),
|
2023-11-27 14:52:22 -08:00
|
|
|
)
|
|
|
|
|
|
|
|
compilation_cache_dir = define_string_state(
|
|
|
|
name='jax_compilation_cache_dir',
|
|
|
|
default=None,
|
|
|
|
help=('Path for the cache. '
|
|
|
|
'Precedence: '
|
2024-01-12 22:44:03 -08:00
|
|
|
'1. A call to compilation_cache.set_cache_dir(). '
|
2023-11-27 14:52:22 -08:00
|
|
|
'2. The value of this flag set in the command line or by default.'),
|
|
|
|
)
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
default_dtype_bits = define_enum_state(
|
2021-12-07 08:23:37 -08:00
|
|
|
name='jax_default_dtype_bits',
|
|
|
|
enum_values=['32', '64'],
|
|
|
|
default='64',
|
|
|
|
help=('Specify bit width of default dtypes, either 32-bit or 64-bit. '
|
|
|
|
'This is a temporary flag that will be used during the process '
|
|
|
|
'of deprecating the ``jax_enable_x64`` flag.'))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
numpy_dtype_promotion = define_enum_state(
|
2022-05-26 10:56:09 -07:00
|
|
|
name='jax_numpy_dtype_promotion',
|
|
|
|
enum_values=['standard', 'strict'],
|
|
|
|
default='standard',
|
|
|
|
help=('Specify the rules used for implicit type promotion in operations '
|
|
|
|
'between arrays. Options are "standard" or "strict"; in strict-mode, '
|
|
|
|
'binary operations between arrays of differing strongly-specified '
|
|
|
|
'dtypes will result in an error.'),
|
|
|
|
update_global_hook=lambda val: \
|
2022-05-31 12:57:02 -07:00
|
|
|
_update_global_jit_state(numpy_dtype_promotion=val),
|
2022-05-26 10:56:09 -07:00
|
|
|
update_thread_local_hook=lambda val: \
|
|
|
|
update_thread_local_jit_state(numpy_dtype_promotion=val))
|
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
def _update_x64_global(val):
|
|
|
|
lib.jax_jit.global_state().enable_x64 = val
|
|
|
|
|
|
|
|
def _update_x64_thread_local(val):
|
|
|
|
lib.jax_jit.thread_local_state().enable_x64 = val
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
enable_x64 = define_bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_enable_x64',
|
|
|
|
default=False,
|
|
|
|
help='Enable 64-bit types to be used',
|
|
|
|
update_global_hook=_update_x64_global,
|
|
|
|
update_thread_local_hook=_update_x64_thread_local)
|
|
|
|
|
|
|
|
# TODO(phawkins): remove after fixing users of FLAGS.x64_enabled.
|
2022-06-02 10:33:53 -07:00
|
|
|
config._contextmanager_flags.remove('jax_enable_x64')
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
Config.x64_enabled = Config.jax_enable_x64 # type: ignore
|
|
|
|
|
2021-12-21 20:55:03 +00:00
|
|
|
|
|
|
|
def _update_default_device_global(val):
|
|
|
|
lib.jax_jit.global_state().default_device = val
|
|
|
|
|
2022-06-02 10:33:53 -07:00
|
|
|
|
2021-12-21 20:55:03 +00:00
|
|
|
def _update_default_device_thread_local(val):
|
|
|
|
lib.jax_jit.thread_local_state().default_device = val
|
|
|
|
|
2022-06-02 10:33:53 -07:00
|
|
|
|
2021-12-21 20:55:03 +00:00
|
|
|
def _validate_default_device(val):
|
|
|
|
if val is not None and not isinstance(val, xla_client.Device):
|
2022-06-02 10:33:53 -07:00
|
|
|
# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
|
|
|
|
# all JAX backends use a single C++ device interface.
|
|
|
|
if 'Device' in str(type(val)):
|
2022-10-13 17:06:22 +02:00
|
|
|
logger.info(
|
2022-06-02 10:33:53 -07:00
|
|
|
'Allowing non-`xla_client.Device` default device: %s, type: %s',
|
|
|
|
repr(val), type(val))
|
|
|
|
return
|
|
|
|
raise ValueError('jax.default_device must be passed a Device object (e.g. '
|
2023-10-23 15:11:15 +01:00
|
|
|
f"`jax.devices('cpu')[0]`), got: {val!r}")
|
2021-12-21 20:55:03 +00:00
|
|
|
|
2022-06-02 10:33:53 -07:00
|
|
|
|
2021-12-21 20:55:03 +00:00
|
|
|
# TODO(skye): default_device only accepts devices for now. Make it work with
|
|
|
|
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
|
2023-10-25 12:38:19 +01:00
|
|
|
default_device = define_string_or_object_state(
|
2021-12-21 20:55:03 +00:00
|
|
|
name='jax_default_device',
|
|
|
|
default=None,
|
|
|
|
help=(
|
2022-06-02 10:33:53 -07:00
|
|
|
'Configure the default device for JAX operations. Set to a Device '
|
|
|
|
'object (e.g. ``jax.devices("cpu")[0]``) to use that Device as the '
|
2021-12-21 20:55:03 +00:00
|
|
|
'default device for JAX operations and jit\'d function calls (there is '
|
|
|
|
'no effect on multi-device computations, e.g. pmapped function calls). '
|
|
|
|
'Set to None to use the system default device. See '
|
|
|
|
':ref:`faq-data-placement` for more information on device placement.'),
|
|
|
|
update_global_hook=_update_default_device_global,
|
|
|
|
update_thread_local_hook=_update_default_device_thread_local,
|
|
|
|
validate_new_val_hook=_validate_default_device)
|
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
def _update_disable_jit_global(val):
|
|
|
|
lib.jax_jit.global_state().disable_jit = val
|
|
|
|
|
|
|
|
def _update_disable_jit_thread_local(val):
|
|
|
|
lib.jax_jit.thread_local_state().disable_jit = val
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
disable_jit = define_bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_disable_jit',
|
|
|
|
default=False,
|
|
|
|
help=('Disable JIT compilation and just call original Python.'),
|
|
|
|
update_global_hook=_update_disable_jit_global,
|
|
|
|
update_thread_local_hook=_update_disable_jit_thread_local)
|
|
|
|
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
numpy_rank_promotion = define_enum_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_numpy_rank_promotion',
|
|
|
|
enum_values=['allow', 'warn', 'raise'],
|
|
|
|
default='allow',
|
|
|
|
help=('Control NumPy-style automatic rank promotion broadcasting '
|
2021-04-21 06:36:08 -07:00
|
|
|
'("allow", "warn", or "raise").'),
|
|
|
|
update_global_hook=lambda val: \
|
2022-05-31 12:57:02 -07:00
|
|
|
_update_global_jit_state(numpy_rank_promotion=val),
|
2021-04-21 06:36:08 -07:00
|
|
|
update_thread_local_hook=lambda val: \
|
|
|
|
update_thread_local_jit_state(numpy_rank_promotion=val))
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
default_matmul_precision = define_enum_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_default_matmul_precision',
|
|
|
|
enum_values=['bfloat16', 'tensorfloat32', 'float32'],
|
|
|
|
default=None,
|
|
|
|
help=('Control the default matmul and conv precision for 32bit inputs.\n\n'
|
|
|
|
|
|
|
|
'Some platforms, like TPU, offer configurable precision levels for '
|
|
|
|
'matrix multiplication and convolution computations, trading off '
|
|
|
|
'accuracy for speed. The precision can be controlled for each '
|
|
|
|
'operation; for example, see the :func:`jax.lax.conv_general_dilated` '
|
|
|
|
'and :func:`jax.lax.dot` docstrings. But it can be useful to control '
|
|
|
|
'the default behavior obtained when an operation is not given a '
|
|
|
|
'specific precision.\n\n'
|
|
|
|
|
|
|
|
'This option can be used to control the default precision '
|
|
|
|
'level for computations involved in matrix multiplication and '
|
|
|
|
'convolution on 32bit inputs. The levels roughly describe the '
|
|
|
|
"precision at which scalar products are computed. The 'bfloat16' "
|
|
|
|
"option is the fastest and least precise; 'float32' is similar to "
|
2021-04-21 06:36:08 -07:00
|
|
|
"full float32 precision; 'tensorfloat32' is intermediate.\n\n"),
|
|
|
|
update_global_hook=lambda val: \
|
2022-05-31 12:57:02 -07:00
|
|
|
_update_global_jit_state(default_matmul_precision=val),
|
2021-04-21 06:36:08 -07:00
|
|
|
update_thread_local_hook=lambda val: \
|
|
|
|
update_thread_local_jit_state(default_matmul_precision=val))
|
2021-06-02 15:22:50 -04:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
traceback_filtering = define_enum_state(
|
2021-06-02 15:22:50 -04:00
|
|
|
name = 'jax_traceback_filtering',
|
2023-08-03 10:20:29 -07:00
|
|
|
enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames",
|
|
|
|
"auto"],
|
2021-06-02 15:22:50 -04:00
|
|
|
default="auto",
|
|
|
|
help="Controls how JAX filters internal frames out of tracebacks.\n\n"
|
|
|
|
"Valid values are:\n"
|
|
|
|
" * \"off\": disables traceback filtering.\n"
|
2023-08-03 10:20:29 -07:00
|
|
|
" * \"auto\": use \"tracebackhide\" if running under a sufficiently"
|
|
|
|
" new IPython, or \"remove_frames\" otherwise.\n"
|
|
|
|
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to"
|
2021-06-02 15:22:50 -04:00
|
|
|
" hidden stack frames, which some traceback printers support.\n"
|
2023-08-03 10:20:29 -07:00
|
|
|
" * \"remove_frames\": removes hidden frames from tracebacks, and adds"
|
|
|
|
" the unfiltered traceback as a __cause__ of the exception.\n"
|
|
|
|
" * \"quiet_remove_frames\": removes hidden frames from tracebacks, and adds"
|
|
|
|
" a brief message (to the __cause__ of the exception) describing that this has"
|
|
|
|
" happened.\n")
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2022-05-26 22:22:20 -07:00
|
|
|
# This flag is for internal use.
|
|
|
|
# TODO(tianjianlu): Removes once we always enable cusparse lowering.
|
2022-12-09 15:41:06 -08:00
|
|
|
# TODO(b/262050896): Set to true after bug is fixed
|
2023-10-25 12:38:19 +01:00
|
|
|
bcoo_cusparse_lowering = define_bool_state(
|
2022-01-31 11:01:58 -08:00
|
|
|
name='jax_bcoo_cusparse_lowering',
|
2022-12-09 15:41:06 -08:00
|
|
|
default=False,
|
2022-01-31 11:01:58 -08:00
|
|
|
help=('Enables lowering BCOO ops to cuSparse.'))
|
2022-01-20 22:58:09 -08:00
|
|
|
|
|
|
|
# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging
|
|
|
|
# if the intended backend can handle lowering the result
|
2023-10-25 12:38:19 +01:00
|
|
|
dynamic_shapes = define_bool_state(
|
2022-01-20 22:58:09 -08:00
|
|
|
name='jax_dynamic_shapes',
|
2021-11-16 11:17:42 +02:00
|
|
|
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
|
2022-01-20 22:58:09 -08:00
|
|
|
help=('Enables experimental features for staging out computations with '
|
2022-03-30 17:52:55 -07:00
|
|
|
'dynamic shapes.'),
|
|
|
|
update_global_hook=lambda val: \
|
2022-05-31 12:57:02 -07:00
|
|
|
_update_global_jit_state(dynamic_shapes=val),
|
2022-03-30 17:52:55 -07:00
|
|
|
update_thread_local_hook=lambda val: \
|
|
|
|
update_thread_local_jit_state(dynamic_shapes=val))
|
2022-02-13 22:40:26 -08:00
|
|
|
|
2023-02-07 14:32:32 -08:00
|
|
|
# This flag is temporary during rollout of the remat barrier.
|
|
|
|
# TODO(parkers): Remove if there are no complaints.
|
2023-10-25 12:38:19 +01:00
|
|
|
remat_opt_barrier = define_bool_state(
|
2023-02-07 14:32:32 -08:00
|
|
|
name='jax_remat_opt_barrier',
|
|
|
|
default=(lib.version >= (0, 3, 6)),
|
|
|
|
help=('Enables using optimization-barrier op for lowering remat.'))
|
|
|
|
|
2022-08-11 10:49:56 -07:00
|
|
|
# TODO(sharadmv,mattjj): set default to True, then remove
|
2023-10-25 12:38:19 +01:00
|
|
|
eager_pmap = define_bool_state(
|
2022-08-11 10:49:56 -07:00
|
|
|
name='jax_eager_pmap',
|
2022-08-17 12:29:14 -07:00
|
|
|
default=True,
|
2022-08-11 10:49:56 -07:00
|
|
|
upgrade=True,
|
|
|
|
help='Enable eager-mode pmap when jax_disable_jit is activated.')
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
xla_runtime_errors = define_bool_state(
|
2022-09-23 15:10:41 +01:00
|
|
|
name='jax_experimental_unsafe_xla_runtime_errors',
|
2022-09-23 14:30:49 +01:00
|
|
|
default=False,
|
|
|
|
help=('Enable XLA runtime errors for jax.experimental.checkify.checks '
|
2022-09-23 15:10:41 +01:00
|
|
|
'on CPU and GPU. These errors are async, might get lost and are not '
|
|
|
|
'very readable. But, they crash the computation and enable you '
|
|
|
|
'to write jittable checks without needing to checkify. Does not '
|
|
|
|
'work under pmap/pjit.')
|
2022-09-23 14:30:49 +01:00
|
|
|
)
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
jax_xla_profile_version = define_int_state(
|
2023-08-08 14:28:35 -07:00
|
|
|
name='jax_xla_profile_version',
|
|
|
|
default=0,
|
2023-10-27 12:52:24 -07:00
|
|
|
help=(
|
|
|
|
'Optional profile version for XLA compilation. This is meaningful '
|
|
|
|
'only when XLA is configured to support the remote compilation '
|
|
|
|
'profile feature.'),
|
|
|
|
update_global_hook=lambda val: _update_global_jit_state(
|
|
|
|
xla_profile_version=val),
|
|
|
|
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
|
|
|
xla_profile_version=val),
|
2023-08-08 14:28:35 -07:00
|
|
|
)
|
|
|
|
|
2022-04-11 14:59:04 +00:00
|
|
|
@contextlib.contextmanager
|
|
|
|
def explicit_device_put_scope() -> Iterator[None]:
|
|
|
|
"""Indicates that the current context is an explicit device_put*() call."""
|
|
|
|
state = transfer_guard_lib.thread_local_state()
|
|
|
|
prev = state.explicit_device_put
|
|
|
|
state.explicit_device_put = True
|
|
|
|
try:
|
2022-02-14 13:11:26 -08:00
|
|
|
yield
|
2022-04-11 14:59:04 +00:00
|
|
|
finally:
|
|
|
|
state.explicit_device_put = prev
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def explicit_device_get_scope() -> Iterator[None]:
|
|
|
|
"""Indicates that the current context is an explicit device_get() call."""
|
|
|
|
state = transfer_guard_lib.thread_local_state()
|
|
|
|
prev = state.explicit_device_get
|
|
|
|
state.explicit_device_get = True
|
|
|
|
try:
|
2022-02-14 13:11:26 -08:00
|
|
|
yield
|
2022-04-11 14:59:04 +00:00
|
|
|
finally:
|
|
|
|
state.explicit_device_get = prev
|
|
|
|
|
|
|
|
def _update_transfer_guard(state, key, val):
|
|
|
|
"""Applies the transfer guard level within transfer_guard_lib."""
|
|
|
|
if val is None:
|
|
|
|
setattr(state, key, None)
|
|
|
|
elif val == 'allow':
|
|
|
|
setattr(state, key, transfer_guard_lib.TransferGuardLevel.ALLOW)
|
|
|
|
elif val == 'log':
|
|
|
|
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG)
|
|
|
|
elif val == 'disallow':
|
|
|
|
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW)
|
|
|
|
elif val == 'log_explicit':
|
|
|
|
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG_EXPLICIT)
|
|
|
|
elif val == 'disallow_explicit':
|
|
|
|
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
|
|
|
|
else:
|
|
|
|
assert False, f'Invalid transfer guard level {val}'
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
transfer_guard_host_to_device = define_enum_state(
|
2022-04-11 14:59:04 +00:00
|
|
|
name='jax_transfer_guard_host_to_device',
|
|
|
|
enum_values=[
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
],
|
|
|
|
# The default is applied by transfer_guard_lib. Use None here to avoid
|
|
|
|
# accidentally overriding --jax_transfer_guard.
|
|
|
|
default=None,
|
|
|
|
help=('Select the transfer guard level for host-to-device transfers. '
|
|
|
|
'Default is "allow".'),
|
|
|
|
update_global_hook=lambda val: _update_transfer_guard(
|
|
|
|
transfer_guard_lib.global_state(), 'host_to_device', val),
|
|
|
|
update_thread_local_hook=lambda val: _update_transfer_guard(
|
|
|
|
transfer_guard_lib.thread_local_state(), 'host_to_device', val))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
transfer_guard_device_to_device = define_enum_state(
|
2022-04-11 14:59:04 +00:00
|
|
|
name='jax_transfer_guard_device_to_device',
|
|
|
|
enum_values=[
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
],
|
|
|
|
# The default is applied by transfer_guard_lib. Use None here to avoid
|
|
|
|
# accidentally overriding --jax_transfer_guard.
|
|
|
|
default=None,
|
|
|
|
help=('Select the transfer guard level for device-to-device transfers. '
|
|
|
|
'Default is "allow".'),
|
|
|
|
update_global_hook=lambda val: _update_transfer_guard(
|
|
|
|
transfer_guard_lib.global_state(), 'device_to_device', val),
|
|
|
|
update_thread_local_hook=lambda val: _update_transfer_guard(
|
|
|
|
transfer_guard_lib.thread_local_state(), 'device_to_device', val))
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
transfer_guard_device_to_host = define_enum_state(
|
2022-04-11 14:59:04 +00:00
|
|
|
name='jax_transfer_guard_device_to_host',
|
|
|
|
enum_values=[
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
],
|
|
|
|
# The default is applied by transfer_guard_lib. Use None here to avoid
|
|
|
|
# accidentally overriding --jax_transfer_guard.
|
|
|
|
default=None,
|
|
|
|
help=('Select the transfer guard level for device-to-host transfers. '
|
|
|
|
'Default is "allow".'),
|
|
|
|
update_global_hook=lambda val: _update_transfer_guard(
|
|
|
|
transfer_guard_lib.global_state(), 'device_to_host', val),
|
|
|
|
update_thread_local_hook=lambda val: _update_transfer_guard(
|
|
|
|
transfer_guard_lib.thread_local_state(), 'device_to_host', val))
|
|
|
|
|
|
|
|
def _update_all_transfer_guard_global(val):
|
|
|
|
for name in ('jax_transfer_guard_host_to_device',
|
|
|
|
'jax_transfer_guard_device_to_device',
|
|
|
|
'jax_transfer_guard_device_to_host'):
|
|
|
|
config.update(name, val)
|
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
_transfer_guard = define_enum_state(
|
2022-04-11 14:59:04 +00:00
|
|
|
name='jax_transfer_guard',
|
|
|
|
enum_values=[
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
],
|
|
|
|
# The default is applied by transfer_guard_lib. Use None here to avoid
|
|
|
|
# accidentally overriding --jax_transfer_guard_*.
|
|
|
|
default=None,
|
|
|
|
help=('Select the transfer guard level for all transfers. This option is '
|
2022-02-14 13:11:26 -08:00
|
|
|
'set-only; the transfer guard level for a specific direction should '
|
|
|
|
'be read using the per-transfer direction option. '
|
|
|
|
'Default is "allow".'),
|
2022-04-11 14:59:04 +00:00
|
|
|
update_global_hook=_update_all_transfer_guard_global)
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def transfer_guard(new_val: str) -> Iterator[None]:
|
2022-06-09 17:56:03 +00:00
|
|
|
"""A contextmanager to control the transfer guard level for all transfers.
|
|
|
|
|
|
|
|
For more information, see
|
|
|
|
https://jax.readthedocs.io/en/latest/transfer_guard.html
|
|
|
|
|
|
|
|
Args:
|
|
|
|
new_val: The new thread-local transfer guard level for all transfers.
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
None.
|
|
|
|
"""
|
2022-04-11 14:59:04 +00:00
|
|
|
with contextlib.ExitStack() as stack:
|
|
|
|
stack.enter_context(transfer_guard_host_to_device(new_val))
|
|
|
|
stack.enter_context(transfer_guard_device_to_device(new_val))
|
|
|
|
stack.enter_context(transfer_guard_device_to_host(new_val))
|
|
|
|
stack.enter_context(_transfer_guard(new_val))
|
|
|
|
yield
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
|
|
|
|
|
2023-12-08 12:09:04 +00:00
|
|
|
def _update_debug_log_modules(module_names_str: str | None):
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
logging_config.disable_all_debug_logging()
|
|
|
|
if not module_names_str:
|
|
|
|
return
|
|
|
|
module_names = module_names_str.split(',')
|
|
|
|
for module_name in module_names:
|
|
|
|
logging_config.enable_debug_logging(module_name)
|
|
|
|
|
|
|
|
# Don't define a context manager since this isn't threadsafe.
|
2023-10-25 12:38:19 +01:00
|
|
|
define_string_state(
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
name='jax_debug_log_modules',
|
|
|
|
default='',
|
|
|
|
help=('Comma-separated list of module names (e.g. "jax" or '
|
|
|
|
'"jax._src.xla_bridge,jax._src.dispatch") to enable debug logging '
|
|
|
|
'for.'),
|
|
|
|
update_global_hook=_update_debug_log_modules)
|