2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2021-04-19 08:52:48 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-07-27 12:15:16 -07:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
from collections.abc import Callable, Iterator, Sequence
|
2021-04-19 08:52:48 -07:00
|
|
|
import contextlib
|
2025-02-14 14:45:25 -08:00
|
|
|
import enum
|
2021-04-19 08:52:48 -07:00
|
|
|
import functools
|
2021-08-06 07:05:43 -07:00
|
|
|
import itertools
|
2022-10-13 17:06:22 +02:00
|
|
|
import logging
|
2021-04-19 08:52:48 -07:00
|
|
|
import os
|
|
|
|
import sys
|
2024-12-09 07:34:26 -08:00
|
|
|
from typing import Any, Generic, NoReturn, Optional, Protocol, TypeVar, cast
|
2022-06-02 10:33:53 -07:00
|
|
|
|
2024-10-17 12:22:39 -07:00
|
|
|
from jax._src.lib import guard_lib
|
2021-09-23 06:33:25 -07:00
|
|
|
from jax._src.lib import jax_jit
|
2021-12-21 20:55:03 +00:00
|
|
|
from jax._src.lib import xla_client
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
from jax._src import logging_config
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
config_ext = xla_client._xla.config
|
2024-11-05 08:31:12 -08:00
|
|
|
|
2022-10-13 17:06:22 +02:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2023-07-27 12:15:16 -07:00
|
|
|
_T = TypeVar('_T')
|
|
|
|
|
2022-10-13 17:06:22 +02:00
|
|
|
|
2025-02-14 14:45:25 -08:00
|
|
|
class EffortLevel(enum.Enum):
|
|
|
|
"""Effort level enum, mirroring the XLA effort options."""
|
|
|
|
|
|
|
|
UNKNOWN = 0
|
|
|
|
O0 = 9
|
|
|
|
O1 = 19
|
|
|
|
O2 = 29
|
|
|
|
O3 = 39
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _missing_(cls, value: object) -> EffortLevel | None:
|
|
|
|
return _effort_from_string.get(value)
|
|
|
|
|
|
|
|
|
|
|
|
_effort_from_string: dict[Any, EffortLevel] = {
|
|
|
|
'UNKNOWN': EffortLevel.UNKNOWN,
|
|
|
|
'O0': EffortLevel.O0,
|
|
|
|
'O1': EffortLevel.O1,
|
|
|
|
'O2': EffortLevel.O2,
|
|
|
|
'O3': EffortLevel.O3,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
def bool_env(varname: str, default: bool) -> bool:
|
|
|
|
"""Read an environment variable and interpret it as a boolean.
|
|
|
|
|
|
|
|
True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
|
|
|
|
false values are 'n', 'no', 'f', 'false', 'off', and '0'.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
varname: the name of the variable
|
|
|
|
default: the default boolean value
|
|
|
|
Raises: ValueError if the environment variable is anything else.
|
|
|
|
"""
|
|
|
|
val = os.getenv(varname, str(default))
|
|
|
|
val = val.lower()
|
|
|
|
if val in ('y', 'yes', 't', 'true', 'on', '1'):
|
|
|
|
return True
|
|
|
|
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
|
|
|
|
return False
|
|
|
|
else:
|
2022-05-12 19:13:00 +01:00
|
|
|
raise ValueError(f"invalid truth value {val!r} for environment {varname!r}")
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
def int_env(varname: str, default: int) -> int:
|
|
|
|
"""Read an environment variable and interpret it as an integer."""
|
2021-10-04 17:54:18 -07:00
|
|
|
return int(os.getenv(varname, str(default)))
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
class ValueHolder(Protocol[_T]):
|
|
|
|
"""A holder for a configuration value.
|
2022-03-25 18:37:15 -07:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
There are two kinds of value holders: ``Flag``, which is assigned exactly
|
|
|
|
once and never modified after; and ``State``, which can be changed locally
|
|
|
|
within a thread via a context manager.
|
|
|
|
"""
|
|
|
|
|
|
|
|
value: _T
|
|
|
|
|
|
|
|
def _set(self, value: _T) -> None: ...
|
2022-03-25 18:37:15 -07:00
|
|
|
|
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
class Config:
|
|
|
|
_HAS_DYNAMIC_ATTRIBUTES = True
|
|
|
|
|
|
|
|
def __init__(self):
|
2024-04-15 10:35:50 +01:00
|
|
|
self._value_holders: dict[str, ValueHolder] = {}
|
2021-04-19 08:52:48 -07:00
|
|
|
self.meta = {}
|
|
|
|
self.use_absl = False
|
|
|
|
self._contextmanager_flags = set()
|
|
|
|
|
|
|
|
def update(self, name, val):
|
2024-02-12 17:59:41 -08:00
|
|
|
if name not in self._value_holders:
|
2023-10-25 12:38:19 +01:00
|
|
|
raise AttributeError(f"Unrecognized config option: {name}")
|
2024-02-12 17:59:41 -08:00
|
|
|
self._value_holders[name]._set(val)
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
def read(self, name):
|
|
|
|
if name in self._contextmanager_flags:
|
|
|
|
raise AttributeError(
|
|
|
|
"For flags with a corresponding contextmanager, read their value "
|
|
|
|
f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.")
|
|
|
|
return self._read(name)
|
|
|
|
|
|
|
|
def _read(self, name):
|
2022-12-02 11:36:51 -08:00
|
|
|
try:
|
2024-02-12 17:59:41 -08:00
|
|
|
return self._value_holders[name].value
|
2022-12-02 11:36:51 -08:00
|
|
|
except KeyError:
|
|
|
|
raise AttributeError(f"Unrecognized config option: {name}")
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2024-02-12 17:59:41 -08:00
|
|
|
@property
|
|
|
|
def values(self):
|
|
|
|
return {name: holder.value for name, holder in self._value_holders.items()}
|
|
|
|
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
def add_option(self, name, holder, opt_type, meta_args, meta_kwargs):
|
2024-02-12 17:59:41 -08:00
|
|
|
if name in self._value_holders:
|
2022-05-12 19:13:00 +01:00
|
|
|
raise Exception(f"Config option {name} already defined")
|
2024-02-12 17:59:41 -08:00
|
|
|
self._value_holders[name] = holder
|
2021-04-19 08:52:48 -07:00
|
|
|
self.meta[name] = (opt_type, meta_args, meta_kwargs)
|
|
|
|
|
|
|
|
def config_with_absl(self):
|
2024-01-16 02:25:56 -08:00
|
|
|
"""Registers absl flags for the JAX configs.
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
E.g., for each JAX config defined using bool_state(), this method
|
2024-01-16 02:25:56 -08:00
|
|
|
registers an absl boolean flag, with the same name.
|
|
|
|
|
|
|
|
This is the recommended method to call if you use `app.run(main)` and you
|
2024-06-21 11:28:35 -04:00
|
|
|
need JAX flags.
|
|
|
|
|
|
|
|
Examples:
|
2024-01-16 02:25:56 -08:00
|
|
|
|
|
|
|
```python
|
|
|
|
from absl import app
|
|
|
|
import jax
|
|
|
|
...
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
jax.config.config_with_absl()
|
|
|
|
app.run(main)
|
|
|
|
```
|
|
|
|
|
|
|
|
"""
|
2023-03-09 05:26:58 -08:00
|
|
|
import absl.flags as absl_FLAGS # noqa: F401 # pytype: disable=import-error
|
2023-02-27 10:39:11 -08:00
|
|
|
from absl import app, flags as absl_flags # pytype: disable=import-error
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
self.use_absl = True
|
|
|
|
self.absl_flags = absl_flags
|
|
|
|
absl_defs = { bool: absl_flags.DEFINE_bool,
|
|
|
|
int: absl_flags.DEFINE_integer,
|
2022-10-28 23:53:30 +00:00
|
|
|
float: absl_flags.DEFINE_float,
|
2021-04-19 08:52:48 -07:00
|
|
|
str: absl_flags.DEFINE_string,
|
|
|
|
'enum': absl_flags.DEFINE_enum }
|
|
|
|
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
for name, (flag_type, meta_args, meta_kwargs) in self.meta.items():
|
2024-02-12 17:59:41 -08:00
|
|
|
holder = self._value_holders[name]
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
absl_defs[flag_type](name, holder.value, *meta_args, **meta_kwargs)
|
2021-04-19 08:52:48 -07:00
|
|
|
app.call_after_init(lambda: self.complete_absl_config(absl_flags))
|
|
|
|
|
|
|
|
def complete_absl_config(self, absl_flags):
|
2024-01-16 02:25:56 -08:00
|
|
|
# NOTE: avoid calling from outside this module. Instead, use
|
|
|
|
# `config_with_absl()`, and (in rare cases) `parse_flags_with_absl()`.
|
2024-02-12 17:59:41 -08:00
|
|
|
for name, holder in self._value_holders.items():
|
2023-07-31 13:57:43 -07:00
|
|
|
try:
|
|
|
|
flag = absl_flags.FLAGS[name]
|
|
|
|
except KeyError:
|
|
|
|
# This can happen if a new flag was added after config_with_absl() was
|
|
|
|
# called, but before complete_absl_config was run. We could in principle
|
|
|
|
# add code to DEFINE_... to register any newly added flags with ABSL
|
|
|
|
# if config_with_absl() has already been called, but arguably the user
|
|
|
|
# should have called config_with_absl() later.
|
|
|
|
continue
|
2022-12-02 11:36:51 -08:00
|
|
|
if flag.present:
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
holder._set(flag.value)
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
def parse_flags_with_absl(self):
|
2024-01-16 02:25:56 -08:00
|
|
|
"""Parses command-line args that start with --jax.
|
|
|
|
|
|
|
|
This method should be used only by advanced users. Most users should use
|
|
|
|
:meth:`config_with_absl` instead.
|
|
|
|
|
|
|
|
This method has serious limitations: e.g., although it parses only the
|
|
|
|
--jax* command-line args, it runs the validators of all registered absl
|
|
|
|
flags, even non-JAX ones that have not been set yet; as such, for the
|
|
|
|
non-JAX flags, the validators run on the default flag values, not on the
|
|
|
|
values indicated by the command-line args.
|
|
|
|
"""
|
2021-04-19 08:52:48 -07:00
|
|
|
global already_configured_with_absl
|
|
|
|
if not already_configured_with_absl:
|
2021-08-06 07:05:43 -07:00
|
|
|
# Extract just the --jax... flags (before the first --) from argv. In some
|
|
|
|
# environments (e.g. ipython/colab) argv might be a mess of things
|
|
|
|
# parseable by absl and other junk.
|
|
|
|
jax_argv = itertools.takewhile(lambda a: a != '--', sys.argv)
|
|
|
|
jax_argv = ['', *(a for a in jax_argv if a.startswith('--jax'))]
|
|
|
|
|
2023-03-09 05:26:58 -08:00
|
|
|
import absl.flags # pytype: disable=import-error
|
2021-04-19 08:52:48 -07:00
|
|
|
self.config_with_absl()
|
2021-08-06 07:05:43 -07:00
|
|
|
absl.flags.FLAGS(jax_argv, known_only=True)
|
2021-04-19 08:52:48 -07:00
|
|
|
self.complete_absl_config(absl.flags)
|
|
|
|
already_configured_with_absl = True
|
|
|
|
|
2021-04-21 06:36:08 -07:00
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
def trace_context():
|
|
|
|
"""Returns a tuple of configuration values that affect tracing.
|
2021-04-21 06:36:08 -07:00
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
These values are included in the cache key for linear_util.cache.
|
2021-04-21 06:36:08 -07:00
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
Values included in this set should also most likely be included in
|
|
|
|
the C++ JIT state, which is handled separately.
|
|
|
|
"""
|
|
|
|
return (axis_env_state.value, mesh_context_manager.value,
|
|
|
|
xla_metadata_context_manager.value,
|
|
|
|
abstract_mesh_context_manager.value,
|
|
|
|
compute_on_context_manager.value, enable_x64.value,
|
|
|
|
numpy_rank_promotion.value, default_matmul_precision.value,
|
|
|
|
dynamic_shapes.value,
|
|
|
|
eager_constant_folding.value,
|
|
|
|
numpy_dtype_promotion.value,
|
|
|
|
default_device.value, random_seed_offset.value,
|
|
|
|
threefry_partitionable.value,
|
|
|
|
threefry_gpu_kernel_lowering.value,
|
|
|
|
use_direct_linearize.value,
|
|
|
|
softmax_custom_jvp.value,
|
|
|
|
disable_jit.value,
|
|
|
|
debug_key_reuse.value,
|
|
|
|
jax_xla_profile_version.value,
|
2025-04-08 13:46:14 -07:00
|
|
|
_check_rep.value,
|
2024-12-09 07:34:26 -08:00
|
|
|
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
|
|
|
hlo_source_file_canonicalization_regex.value,
|
|
|
|
pgle_profiling_runs.value,
|
|
|
|
enable_pgle.value,
|
2025-03-02 19:41:52 -08:00
|
|
|
use_shardy_partitioner.value,
|
2025-03-21 10:52:34 -07:00
|
|
|
use_high_dynamic_range_gumbel.value,
|
|
|
|
error_checking_behavior_nan.value,
|
|
|
|
error_checking_behavior_divide.value,
|
|
|
|
error_checking_behavior_oob.value)
|
2023-10-25 12:38:19 +01:00
|
|
|
|
|
|
|
config = Config()
|
|
|
|
|
|
|
|
_read = config._read
|
|
|
|
update = config.update
|
|
|
|
parse_flags_with_absl = config.parse_flags_with_absl
|
|
|
|
|
|
|
|
|
2022-04-12 15:05:53 -07:00
|
|
|
class NoDefault: pass
|
|
|
|
no_default = NoDefault()
|
|
|
|
|
2025-01-08 14:08:33 -08:00
|
|
|
config_states = {}
|
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
class State(config_ext.Config[_T]):
|
2023-10-04 09:44:12 -07:00
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
__slots__ = (
|
|
|
|
'_name', '_update_thread_local_hook', '_update_global_hook',
|
|
|
|
'_validator', '_default_context_manager_value', '__doc__', '__name__',
|
|
|
|
)
|
2024-04-15 10:35:50 +01:00
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
default: _T,
|
|
|
|
help,
|
|
|
|
update_global_hook: Callable[[_T], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[_T | None], None] | None = None,
|
|
|
|
validator: Callable[[Any], None] | None = None,
|
|
|
|
extra_description: str = '',
|
|
|
|
default_context_manager_value: Any = no_default,
|
|
|
|
include_in_jit_key: bool = False,
|
|
|
|
):
|
|
|
|
super().__init__(default, include_in_jit_key)
|
|
|
|
self._name = name
|
|
|
|
self.__name__ = name[4:] if name.startswith('jax_') else name
|
|
|
|
self.__doc__ = (f"Context manager for `{name}` config option"
|
|
|
|
f"{extra_description}.\n\n{help}")
|
|
|
|
self._update_global_hook = update_global_hook
|
|
|
|
self._update_thread_local_hook = update_thread_local_hook
|
|
|
|
self._validator = validator
|
|
|
|
self._default_context_manager_value = default_context_manager_value
|
|
|
|
if self._validator:
|
|
|
|
self._validator(default)
|
|
|
|
if self._update_global_hook:
|
|
|
|
self._update_global_hook(default)
|
2025-01-08 14:08:33 -08:00
|
|
|
config_states[name] = self
|
2024-11-04 21:04:48 -08:00
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
def __bool__(self) -> NoReturn:
|
|
|
|
raise TypeError(
|
|
|
|
"bool() not supported for instances of type '{0}' "
|
|
|
|
"(did you mean to use '{0}.value' instead?)".format(
|
|
|
|
type(self).__name__))
|
|
|
|
|
|
|
|
def _set(self, value: _T) -> None:
|
|
|
|
if self._validator:
|
|
|
|
self._validator(value)
|
|
|
|
self.set_global(value)
|
|
|
|
if self._update_global_hook:
|
|
|
|
self._update_global_hook(value)
|
|
|
|
|
|
|
|
def __call__(self, new_val: Any = no_default):
|
2025-02-25 10:29:39 -08:00
|
|
|
return StateContextManager(self, new_val)
|
2024-11-05 08:31:12 -08:00
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
def _add_hooks(self, update_global_hook, update_thread_local_hook):
|
|
|
|
"""Private method that adds hooks to an existing context-manager.
|
2024-11-05 08:31:12 -08:00
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
Used to avoid cyclic import dependencies."""
|
|
|
|
self._update_thread_local_hook = update_thread_local_hook
|
|
|
|
self._update_global_hook = update_global_hook
|
|
|
|
update_global_hook(self.get_global())
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
|
2025-02-25 10:29:39 -08:00
|
|
|
class StateContextManager(contextlib.ContextDecorator):
|
|
|
|
__slots__ = ['state', 'new_val', 'prev']
|
|
|
|
|
|
|
|
def __init__(self, state, new_val):
|
|
|
|
self.state = state
|
|
|
|
self.new_val = new_val
|
|
|
|
|
|
|
|
if new_val is no_default:
|
|
|
|
if state._default_context_manager_value is not no_default:
|
|
|
|
new_val = state._default_context_manager_value # default_context_manager_value provided to constructor
|
|
|
|
else:
|
|
|
|
# no default_value provided to constructor and no value provided as an
|
|
|
|
# argument, so we raise an error
|
|
|
|
raise TypeError(f"Context manager for {state.__name__} config option "
|
|
|
|
"requires an argument representing the new value for "
|
|
|
|
"the config option.")
|
|
|
|
if state._validator:
|
|
|
|
state._validator(new_val)
|
|
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
self.prev = self.state.swap_local(self.new_val)
|
|
|
|
if self.state._update_thread_local_hook:
|
|
|
|
self.state._update_thread_local_hook(self.new_val)
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
self.state.set_local(self.prev)
|
|
|
|
if self.state._update_thread_local_hook:
|
|
|
|
if self.prev is config_ext.unset:
|
|
|
|
self.state._update_thread_local_hook(None)
|
|
|
|
else:
|
|
|
|
self.state._update_thread_local_hook(cast(Optional[Any], self.prev))
|
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
UPGRADE_BOOL_HELP = (
|
|
|
|
" This will be enabled by default in future versions of JAX, at which "
|
|
|
|
"point all uses of the flag will be considered deprecated (following "
|
|
|
|
"the `API compatibility policy "
|
2025-04-08 08:32:59 -07:00
|
|
|
"<https://docs.jax.dev/en/latest/api_compatibility.html>`_).")
|
2024-04-15 10:35:50 +01:00
|
|
|
|
|
|
|
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
|
|
|
|
|
|
|
|
|
|
|
|
def bool_state(
|
2023-10-25 12:38:19 +01:00
|
|
|
name: str,
|
|
|
|
default: bool,
|
|
|
|
help: str,
|
|
|
|
*,
|
2023-12-08 12:09:04 +00:00
|
|
|
update_global_hook: Callable[[bool], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[bool | None], None] | None = None,
|
2023-10-25 12:38:19 +01:00
|
|
|
upgrade: bool = False,
|
|
|
|
extra_description: str = '',
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key: bool = False,
|
2024-04-15 10:35:50 +01:00
|
|
|
) -> State[bool]:
|
2023-10-25 12:38:19 +01:00
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
This function is a convenience wrapper. It defines a flag, environment
|
|
|
|
variable, and corresponding thread-local state, which can be managed via the
|
|
|
|
contextmanager it returns.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
The thread-local state value can be read via the ``config.<option_name>``
|
|
|
|
attribute, where ``config`` is the singleton ``Config`` instance.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
default: boolean, a default value for the option.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
update_global_hook: a optional callback that is called with the updated
|
|
|
|
value of the global state when it is altered or set initially.
|
|
|
|
update_thread_local_hook: a optional callback that is called with the
|
|
|
|
updated value of the thread-local state when it is altered or set
|
|
|
|
initially.
|
|
|
|
upgrade: optional indicator that this flag controls a canonical feature
|
|
|
|
upgrade, so that it is `True` for the incoming functionality, `False`
|
|
|
|
for the outgoing functionality to be deprecated.
|
|
|
|
extra_description: string, optional: extra information to add to the
|
|
|
|
summary description.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
|
2024-06-21 11:28:35 -04:00
|
|
|
Examples:
|
2023-10-25 12:38:19 +01:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
ENABLE_FOO = config.bool_state(
|
2023-10-25 12:38:19 +01:00
|
|
|
name='jax_enable_foo',
|
|
|
|
default=False,
|
|
|
|
help='Enable foo.')
|
|
|
|
|
|
|
|
# Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo
|
|
|
|
# command-line flag can be used to control the process-level value of
|
|
|
|
# the configuration option, in addition to using e.g.
|
|
|
|
# ``config.update("jax_enable_foo", True)`` directly. We can also use a
|
|
|
|
# context manager:
|
|
|
|
|
|
|
|
with enable_foo(True):
|
|
|
|
...
|
|
|
|
|
|
|
|
The value of the thread-local state or flag can be accessed via
|
|
|
|
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
|
|
|
|
an error.
|
|
|
|
"""
|
2024-02-12 06:43:23 -08:00
|
|
|
if not isinstance(default, bool):
|
2024-04-09 10:19:57 +05:30
|
|
|
raise TypeError(f"Default value must be of type bool, got {default} "
|
|
|
|
f"of type {getattr(type(default), '__name__', type(default))}")
|
2024-02-20 05:56:31 -08:00
|
|
|
default = bool_env(name.upper(), default)
|
2023-10-25 12:38:19 +01:00
|
|
|
name = name.lower()
|
|
|
|
if upgrade:
|
|
|
|
help += ' ' + UPGRADE_BOOL_HELP
|
|
|
|
extra_description += UPGRADE_BOOL_EXTRA_DESC
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
s = State[bool](
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
name, default, help, update_global_hook=update_global_hook,
|
|
|
|
update_thread_local_hook=update_thread_local_hook,
|
2024-11-05 08:31:12 -08:00
|
|
|
extra_description=extra_description, default_context_manager_value=True,
|
|
|
|
include_in_jit_key=include_in_jit_key)
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
config.add_option(name, s, bool, meta_args=[], meta_kwargs={"help": help})
|
2023-10-25 12:38:19 +01:00
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def enum_state(
|
2023-10-25 12:38:19 +01:00
|
|
|
name: str,
|
2024-04-15 10:35:50 +01:00
|
|
|
enum_values: Sequence[str],
|
2024-02-12 06:43:23 -08:00
|
|
|
default: str,
|
2023-10-25 12:38:19 +01:00
|
|
|
help: str,
|
|
|
|
*,
|
2023-12-08 12:09:04 +00:00
|
|
|
update_global_hook: Callable[[str], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key: bool = False,
|
2024-04-15 10:35:50 +01:00
|
|
|
) -> State[str]:
|
2023-10-25 12:38:19 +01:00
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
See docstring for ``bool_state``.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2024-02-12 06:43:23 -08:00
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
enum_values: list of strings representing the possible values for the
|
|
|
|
option.
|
|
|
|
default: string, default value.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
|
|
|
if not isinstance(default, str):
|
2024-04-09 10:19:57 +05:30
|
|
|
raise TypeError(f"Default value must be of type str, got {default} "
|
|
|
|
f"of type {getattr(type(default), '__name__', type(default))}")
|
2024-02-12 06:43:23 -08:00
|
|
|
name = name.lower()
|
|
|
|
default = os.getenv(name.upper(), default)
|
|
|
|
if default not in enum_values:
|
|
|
|
raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}")
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
def validator(new_val):
|
2024-02-12 06:43:23 -08:00
|
|
|
if type(new_val) is not str or new_val not in enum_values:
|
|
|
|
raise ValueError(f"new enum value must be in {enum_values}, "
|
|
|
|
f"got {new_val} of type {type(new_val)}.")
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
s = State[str](
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
name,
|
|
|
|
default,
|
|
|
|
help,
|
|
|
|
update_global_hook=update_global_hook,
|
|
|
|
update_thread_local_hook=update_thread_local_hook,
|
|
|
|
validator=validator,
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=include_in_jit_key,
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
)
|
|
|
|
config.add_option(
|
|
|
|
name, s, 'enum',
|
|
|
|
meta_args=[],
|
|
|
|
meta_kwargs={"enum_values": enum_values, "help": help}
|
|
|
|
)
|
2024-02-12 06:43:23 -08:00
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def optional_enum_state(
|
2024-02-12 06:43:23 -08:00
|
|
|
name: str,
|
2024-04-15 10:35:50 +01:00
|
|
|
enum_values: Sequence[str],
|
2024-02-12 06:43:23 -08:00
|
|
|
default: str | None,
|
|
|
|
help: str,
|
|
|
|
*,
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
update_global_hook: Callable[[str | None], None] | None = None,
|
2024-02-12 06:43:23 -08:00
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key: bool = False,
|
2024-04-15 10:35:50 +01:00
|
|
|
) -> State[str | None]:
|
2024-02-12 06:43:23 -08:00
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
See docstring for ``bool_state``.
|
2024-02-12 06:43:23 -08:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
enum_values: list of strings representing the possible values for the
|
|
|
|
option.
|
|
|
|
default: optional string, default value.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
2024-02-12 06:43:23 -08:00
|
|
|
if default is not None and not isinstance(default, str):
|
2024-04-09 10:19:57 +05:30
|
|
|
raise TypeError(f"Default value must be of type str or None, got {default} "
|
|
|
|
f"of type {getattr(type(default), '__name__', type(default))}")
|
2023-10-25 12:38:19 +01:00
|
|
|
name = name.lower()
|
|
|
|
default = os.getenv(name.upper(), default)
|
|
|
|
if default is not None and default not in enum_values:
|
|
|
|
raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}")
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
|
|
|
def validate(new_val):
|
|
|
|
if (new_val is not None and
|
|
|
|
(type(new_val) is not str or new_val not in enum_values)):
|
|
|
|
raise ValueError(f"new enum value must be None or in {enum_values}, "
|
|
|
|
f"got {new_val} of type {type(new_val)}.")
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
s = State['str | None'](
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
name, default, help, update_global_hook, update_thread_local_hook,
|
2024-11-05 08:31:12 -08:00
|
|
|
validate, include_in_jit_key=include_in_jit_key,
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
)
|
|
|
|
config.add_option(
|
|
|
|
name, s, 'enum',
|
|
|
|
meta_args=[],
|
|
|
|
meta_kwargs={"enum_values": enum_values, "help": help}
|
2024-02-12 06:43:23 -08:00
|
|
|
)
|
2023-10-25 12:38:19 +01:00
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def int_state(
|
2023-10-25 12:38:19 +01:00
|
|
|
name: str,
|
2024-02-12 06:43:23 -08:00
|
|
|
default: int,
|
2023-10-25 12:38:19 +01:00
|
|
|
help: str,
|
|
|
|
*,
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
update_global_hook: Callable[[int], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[int | None], None] | None = None,
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key: bool = False,
|
2025-03-13 14:45:16 -07:00
|
|
|
validator: Callable[[Any], None] | None = None,
|
2024-04-15 10:35:50 +01:00
|
|
|
) -> State[int]:
|
2023-10-25 12:38:19 +01:00
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
See docstring for ``bool_state``.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
default: optional int, default value.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
2024-02-12 06:43:23 -08:00
|
|
|
if not isinstance(default, int):
|
2024-04-09 10:19:57 +05:30
|
|
|
raise TypeError(f"Default value must be of type int, got {default} "
|
|
|
|
f"of type {getattr(type(default), '__name__', type(default))}")
|
2023-10-25 12:38:19 +01:00
|
|
|
name = name.lower()
|
2024-02-12 06:43:23 -08:00
|
|
|
default_env = os.getenv(name.upper())
|
2023-10-25 12:38:19 +01:00
|
|
|
if default_env is not None:
|
|
|
|
try:
|
|
|
|
default = int(default_env)
|
|
|
|
except ValueError:
|
|
|
|
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
|
|
|
def validate(new_val):
|
|
|
|
if new_val is not None and not isinstance(new_val, int):
|
|
|
|
raise ValueError(f'new int config value must be None or of type int, '
|
|
|
|
f'got {new_val} of type {type(new_val)}')
|
2025-03-13 14:45:16 -07:00
|
|
|
if new_val is not None and validator is not None:
|
|
|
|
validator(new_val)
|
2023-10-25 12:38:19 +01:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
s = State[int](name, default, help, update_global_hook,
|
2024-11-05 08:31:12 -08:00
|
|
|
update_thread_local_hook, validate,
|
|
|
|
include_in_jit_key=include_in_jit_key)
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help})
|
2023-10-25 12:38:19 +01:00
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def float_state(
|
2023-10-25 12:38:19 +01:00
|
|
|
name: str,
|
2024-02-12 06:43:23 -08:00
|
|
|
default: float,
|
2023-10-25 12:38:19 +01:00
|
|
|
help: str,
|
|
|
|
*,
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
update_global_hook: Callable[[float], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[float | None], None] | None = None,
|
2024-04-15 10:35:50 +01:00
|
|
|
) -> State[float]:
|
2023-10-25 12:38:19 +01:00
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
See docstring for ``bool_state``.
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
2024-02-12 06:43:23 -08:00
|
|
|
default: default value.
|
2023-10-25 12:38:19 +01:00
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
2024-02-12 06:43:23 -08:00
|
|
|
if not isinstance(default, float):
|
2024-04-09 10:19:57 +05:30
|
|
|
raise TypeError(f"Default value must be of type float, got {default} "
|
|
|
|
f"of type {getattr(type(default), '__name__', type(default))}")
|
2023-10-25 12:38:19 +01:00
|
|
|
name = name.lower()
|
2024-02-12 06:43:23 -08:00
|
|
|
default_env = os.getenv(name.upper())
|
2023-10-25 12:38:19 +01:00
|
|
|
if default_env is not None:
|
|
|
|
try:
|
|
|
|
default = float(default_env)
|
|
|
|
except ValueError:
|
|
|
|
raise ValueError(f"Invalid value \"{default_env}\" for JAX flag {name}")
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
|
|
|
def validate(new_val):
|
|
|
|
if new_val is not None and not isinstance(new_val, (float, int)):
|
|
|
|
raise ValueError(
|
|
|
|
f'new float config value must be None or of type float, '
|
|
|
|
f'got {new_val} of type {type(new_val)}')
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
s = State[float](name, default, help, update_global_hook,
|
|
|
|
update_thread_local_hook, validate)
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
config.add_option(name, s, float, meta_args=[], meta_kwargs={"help": help})
|
2023-10-25 12:38:19 +01:00
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def string_state(
|
2023-10-25 12:38:19 +01:00
|
|
|
name: str,
|
2024-02-12 06:43:23 -08:00
|
|
|
default: str,
|
2023-10-25 12:38:19 +01:00
|
|
|
help: str,
|
|
|
|
*,
|
2023-12-08 12:09:04 +00:00
|
|
|
update_global_hook: Callable[[str], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
2024-04-15 10:35:50 +01:00
|
|
|
) -> State[str]:
|
2023-10-25 12:38:19 +01:00
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
See docstring for ``bool_state``.
|
2023-10-25 12:38:19 +01:00
|
|
|
|
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
default: string, a default value for the option.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
update_global_hook: an optional callback that is called with the updated
|
|
|
|
value of the global state when it is altered or set initially.
|
|
|
|
update_thread_local_hook: an optional callback that is called with the
|
|
|
|
updated value of the thread-local state when it is altered or set
|
|
|
|
initially.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
2024-02-12 06:43:23 -08:00
|
|
|
if not isinstance(default, str):
|
2024-04-09 10:19:57 +05:30
|
|
|
raise TypeError(f"Default value must be of type str, got {default} "
|
|
|
|
f"of type {getattr(type(default), '__name__', type(default))}")
|
2023-10-25 12:38:19 +01:00
|
|
|
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
def validator(new_val):
|
2024-02-12 06:43:23 -08:00
|
|
|
if not isinstance(new_val, str):
|
2024-04-08 13:08:24 +05:30
|
|
|
raise TypeError('new string config value must be of type str,'
|
2023-10-25 12:38:19 +01:00
|
|
|
f' got {new_val} of type {type(new_val)}.')
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
return string_or_object_state(
|
2023-10-25 12:38:19 +01:00
|
|
|
name, default, help,
|
|
|
|
update_global_hook=update_global_hook,
|
|
|
|
update_thread_local_hook=update_thread_local_hook,
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
validator=validator)
|
2023-10-25 12:38:19 +01:00
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def optional_string_state(
|
2024-02-12 06:43:23 -08:00
|
|
|
name: str,
|
|
|
|
default: str | None,
|
|
|
|
help: str,
|
|
|
|
*,
|
|
|
|
update_global_hook: Callable[[str], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
2024-04-15 10:35:50 +01:00
|
|
|
) -> State[str | None]:
|
2024-02-12 06:43:23 -08:00
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
See docstring for ``bool_state``.
|
2024-02-12 06:43:23 -08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
default: optional string, a default value for the option.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
update_global_hook: an optional callback that is called with the updated
|
|
|
|
value of the global state when it is altered or set initially.
|
|
|
|
update_thread_local_hook: an optional callback that is called with the
|
|
|
|
updated value of the thread-local state when it is altered or set
|
|
|
|
initially.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
|
|
|
if default is not None and not isinstance(default, str):
|
2024-04-09 10:19:57 +05:30
|
|
|
raise TypeError(f"Default value must be of type str or None, got {default} "
|
|
|
|
f"of type {getattr(type(default), '__name__', type(default))}")
|
2024-02-12 06:43:23 -08:00
|
|
|
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
def validator(new_val):
|
2024-02-12 06:43:23 -08:00
|
|
|
if new_val is not None and not isinstance(new_val, str):
|
|
|
|
raise ValueError('new string config value must be None or of type str,'
|
|
|
|
f' got {new_val} of type {type(new_val)}.')
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
return string_or_object_state(
|
2024-02-12 06:43:23 -08:00
|
|
|
name, default, help,
|
|
|
|
update_global_hook=update_global_hook,
|
|
|
|
update_thread_local_hook=update_thread_local_hook,
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
validator=validator)
|
2024-02-12 06:43:23 -08:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def string_or_object_state(
|
2023-10-25 12:38:19 +01:00
|
|
|
name: str,
|
|
|
|
default: Any,
|
|
|
|
help: str,
|
|
|
|
*,
|
2023-12-08 12:09:04 +00:00
|
|
|
update_global_hook: Callable[[Any], None] | None = None,
|
|
|
|
update_thread_local_hook: Callable[[Any], None] | None = None,
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
validator: Callable[[Any], None] | None = None,
|
2024-04-15 10:35:50 +01:00
|
|
|
) -> State[Any]:
|
2023-10-25 12:38:19 +01:00
|
|
|
"""Set up thread-local state and return a contextmanager for managing it.
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
Similar to ``string_state``, except the context manager will accept
|
2024-06-04 17:39:13 +04:00
|
|
|
any object, not just a string. Any value passed via command line flag or
|
2023-10-25 12:38:19 +01:00
|
|
|
environment variable will be treated as a string.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: string, converted to lowercase to define the name of the config
|
|
|
|
option (and absl flag). It is converted to uppercase to define the
|
|
|
|
corresponding shell environment variable.
|
|
|
|
default: string, a default value for the option.
|
|
|
|
help: string, used to populate the flag help information as well as the
|
|
|
|
docstring of the returned context manager.
|
|
|
|
update_global_hook: an optional callback that is called with the updated
|
|
|
|
value of the global state when it is altered or set initially.
|
|
|
|
update_thread_local_hook: an optional callback that is called with the
|
|
|
|
updated value of the thread-local state when it is altered or set
|
|
|
|
initially.
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
validator: an optional callback that is called with the new
|
2023-10-25 12:38:19 +01:00
|
|
|
value on any update, and should raise an error if the new value is
|
|
|
|
invalid.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A contextmanager to control the thread-local state value.
|
|
|
|
"""
|
|
|
|
name = name.lower()
|
|
|
|
default = os.getenv(name.upper(), default)
|
|
|
|
config._contextmanager_flags.add(name)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
s = State[Any](
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
name, default, help, update_global_hook, update_thread_local_hook,
|
|
|
|
validator)
|
2023-10-25 12:38:19 +01:00
|
|
|
setattr(Config, name, property(lambda _: s.value))
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
config.add_option(name, s, str, meta_args=[], meta_kwargs={"help": help})
|
2023-10-25 12:38:19 +01:00
|
|
|
return s
|
2023-10-12 13:15:22 +01:00
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
class Flag(Generic[_T]):
|
|
|
|
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
__slots__ = ("_name", "value", "_update_hook")
|
|
|
|
|
|
|
|
_name: str
|
|
|
|
value: _T
|
|
|
|
_update_hook: Callable[[Any], None] | None
|
|
|
|
|
|
|
|
def __init__(self, name: str, default: _T,
|
|
|
|
update_hook: Callable[[Any], None] | None = None):
|
2023-10-25 12:38:19 +01:00
|
|
|
self._name = name
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
self._update_hook = update_hook
|
|
|
|
self._set(default)
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2023-10-25 12:38:19 +01:00
|
|
|
def __bool__(self) -> NoReturn:
|
|
|
|
raise TypeError(
|
|
|
|
"bool() not supported for instances of type '{0}' "
|
|
|
|
"(did you mean to use '{0}.value' instead?)".format(
|
|
|
|
type(self).__name__))
|
2023-07-27 12:15:16 -07:00
|
|
|
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
def _set(self, value: _T) -> None:
|
|
|
|
self.value = value
|
|
|
|
if self._update_hook is not None:
|
|
|
|
self._update_hook(value)
|
2023-07-27 12:15:16 -07:00
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def bool_flag(name, default, *args, **kwargs) -> Flag[bool]:
|
2023-10-25 12:38:19 +01:00
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
2024-04-15 10:35:50 +01:00
|
|
|
holder = Flag(name, default, update_hook)
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
config.add_option(name, holder, bool, args, kwargs)
|
|
|
|
return holder
|
2023-10-25 12:38:19 +01:00
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def int_flag(name, default, *args, **kwargs) -> Flag[int]:
|
2023-10-25 12:38:19 +01:00
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
2024-04-15 10:35:50 +01:00
|
|
|
holder = Flag(name, default, update_hook)
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
config.add_option(name, holder, int, args, kwargs)
|
|
|
|
return holder
|
2023-10-25 12:38:19 +01:00
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def float_flag(name, default, *args, **kwargs) -> Flag[float]:
|
2023-10-25 12:38:19 +01:00
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
2024-04-15 10:35:50 +01:00
|
|
|
holder = Flag(name, default, update_hook)
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
config.add_option(name, holder, float, args, kwargs)
|
|
|
|
return holder
|
2023-10-25 12:38:19 +01:00
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def string_flag(name, default, *args, **kwargs) -> Flag[str]:
|
2023-10-25 12:38:19 +01:00
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
2024-04-15 10:35:50 +01:00
|
|
|
holder = Flag(name, default, update_hook)
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
config.add_option(name, holder, str, args, kwargs)
|
|
|
|
return holder
|
2023-10-25 12:38:19 +01:00
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
|
2023-10-25 12:38:19 +01:00
|
|
|
update_hook = kwargs.pop("update_hook", None)
|
2024-04-15 10:35:50 +01:00
|
|
|
holder = Flag(name, default, update_hook)
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
config.add_option(name, holder, 'enum', args, kwargs)
|
|
|
|
return holder
|
2023-07-27 12:15:16 -07:00
|
|
|
|
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
already_configured_with_absl = False
|
|
|
|
|
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
trace_state = config_ext.Config(None, include_in_jit_key=True)
|
|
|
|
axis_env_state = config_ext.Config((), include_in_jit_key=True)
|
|
|
|
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
|
2025-01-28 11:04:05 -08:00
|
|
|
abstract_mesh_context_manager = config_ext.Config(None, include_in_jit_key=True)
|
|
|
|
device_context = config_ext.Config(None, include_in_jit_key=True)
|
2025-01-26 09:24:01 -08:00
|
|
|
compute_on_context_manager = config_ext.Config(None, include_in_jit_key=True)
|
2025-01-26 12:08:12 -08:00
|
|
|
xla_metadata_context_manager = config_ext.Config(None, include_in_jit_key=True)
|
2024-11-05 08:31:12 -08:00
|
|
|
|
2021-04-21 06:36:08 -07:00
|
|
|
|
2022-01-13 12:28:25 +02:00
|
|
|
# TODO(b/214340779): remove flag when XLA:CPU is improved.
|
2024-04-15 10:35:50 +01:00
|
|
|
jax2tf_associative_scan_reductions = bool_state(
|
2022-01-13 12:28:25 +02:00
|
|
|
name='jax2tf_associative_scan_reductions',
|
|
|
|
default=False,
|
|
|
|
help=(
|
|
|
|
'JAX has two separate lowering rules for the cumulative reduction '
|
|
|
|
'primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses '
|
|
|
|
'a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. '
|
|
|
|
'The latter has a slow implementation on CPUs and GPUs. '
|
|
|
|
'By default, jax2tf uses the TPU lowering. Set this flag to True to '
|
|
|
|
'use the associative scan lowering usage, and only if it makes a difference '
|
|
|
|
'for your application. '
|
|
|
|
'See the jax2tf README.md for more details.'
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
jax2tf_default_native_serialization = bool_state(
|
2023-03-15 10:30:52 -07:00
|
|
|
name='jax2tf_default_native_serialization',
|
2023-08-07 01:03:15 -07:00
|
|
|
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True),
|
2023-03-15 10:30:52 -07:00
|
|
|
help=(
|
|
|
|
'Sets the default value of the native_serialization parameter to '
|
2023-03-28 10:17:19 -07:00
|
|
|
'jax2tf.convert. Prefer using the parameter instead of the flag, '
|
2024-07-16 02:04:59 -07:00
|
|
|
'the flag may be removed in the future. '
|
|
|
|
'Starting with JAX 0.4.31 non-native serialization is deprecated.'
|
2023-03-15 10:30:52 -07:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
jax_serialization_version = int_state(
|
2023-07-16 09:26:27 -07:00
|
|
|
name='jax_serialization_version',
|
2024-06-12 19:24:30 +02:00
|
|
|
default=int_env('JAX_SERIALIZATION_VERSION', 0), # We use 0 to detect default.
|
|
|
|
help=(
|
|
|
|
'DEPRECATED: use jax_export_calling_convention_version.'
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
jax_export_calling_convention_version = int_state(
|
2024-06-12 19:24:30 +02:00
|
|
|
name='jax_export_calling_convention_version',
|
|
|
|
# Note: bump the default calling convention version at least one month after
|
2023-07-16 09:26:27 -07:00
|
|
|
# we update XlaCallModule to support the new version, so that serialized
|
|
|
|
# modules are forward compatible with deployed versions of XlaCallModule.
|
2024-02-01 21:55:31 -08:00
|
|
|
# Version 9 of XlaCallModule is supported since October 27th, 2023.
|
2024-06-12 19:24:30 +02:00
|
|
|
default=int_env('JAX_EXPORT_CALLING_CONVENTION_VERSION', 9),
|
2023-07-16 09:26:27 -07:00
|
|
|
help=(
|
2024-06-12 19:24:30 +02:00
|
|
|
'The calling convention version number to use for exporting. This must be '
|
2023-07-16 09:26:27 -07:00
|
|
|
'within the range of versions supported by the tf.XlaCallModule '
|
|
|
|
'used in your deployment environment. '
|
2025-04-08 08:32:59 -07:00
|
|
|
'See https://docs.jax.dev/en/latest/export/shape_poly.html#calling-convention-versions.'
|
2023-07-16 09:26:27 -07:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2024-06-19 13:00:51 +03:00
|
|
|
export_ignore_forward_compatibility = bool_state(
|
|
|
|
name='jax_export_ignore_forward_compatibility',
|
|
|
|
default=bool_env('JAX_EXPORT_IGNORE_FORWARD_COMPATIBILIY', False),
|
|
|
|
help=(
|
|
|
|
'Whether to ignore the forward compatibility lowering rules. '
|
2025-04-08 08:32:59 -07:00
|
|
|
'See https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls.'
|
2024-06-19 13:00:51 +03:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
jax_platforms = optional_string_state(
|
2022-06-17 16:38:56 +00:00
|
|
|
name='jax_platforms',
|
2022-12-21 09:05:36 -08:00
|
|
|
default=None,
|
2022-06-17 16:38:56 +00:00
|
|
|
help=(
|
|
|
|
'Comma-separated list of platform names specifying which platforms jax '
|
|
|
|
'should initialize. If any of the platforms in this list are not successfully '
|
|
|
|
'initialized, an exception will be raised and the program will be aborted. '
|
|
|
|
'The first platform in the list will be the default platform. '
|
|
|
|
'For example, config.jax_platforms=cpu,tpu means that CPU and TPU backends '
|
|
|
|
'will be initialized, and the CPU backend will be used unless otherwise '
|
|
|
|
'specified. If TPU initialization fails, it will raise an exception. '
|
|
|
|
'By default, jax will try to initialize all available '
|
|
|
|
'platforms and will default to GPU or TPU if available, and fallback to CPU '
|
|
|
|
'otherwise.'
|
|
|
|
))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
jax_pjrt_client_create_options = optional_string_state(
|
2024-05-28 13:42:18 -07:00
|
|
|
name='jax_pjrt_client_create_options',
|
|
|
|
default=None,
|
|
|
|
help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings '
|
|
|
|
'provided to a device platform pjrt client as extra arguments.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
enable_checks = bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_enable_checks',
|
|
|
|
default=False,
|
|
|
|
help='Turn on invariant checking for JAX internals. Makes things slower.')
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
debug_key_reuse = bool_state(
|
2024-03-21 10:47:16 -07:00
|
|
|
name='jax_debug_key_reuse',
|
2023-12-11 12:03:48 -08:00
|
|
|
default=False,
|
2024-02-29 15:30:19 -08:00
|
|
|
help=('Turn on experimental key reuse checking. With this configuration enabled,'
|
|
|
|
' typed PRNG keys (i.e. keys created with jax.random.key()) will have their'
|
|
|
|
' usage tracked, and incorrect reuse of a previously-used key will lead to'
|
|
|
|
' an error. Currently enabling this leads to a small Python overhead on'
|
|
|
|
' every call to a JIT-compiled function with keys as inputs or outputs.'))
|
2023-12-11 12:03:48 -08:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
check_tracer_leaks = bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_check_tracer_leaks',
|
|
|
|
default=False,
|
|
|
|
help=('Turn on checking for leaked tracers as soon as a trace completes. '
|
|
|
|
'Enabling leak checking may have performance impacts: some caching '
|
2021-11-08 09:21:18 -08:00
|
|
|
'is disabled, and other overheads may be added. Additionally, be aware '
|
|
|
|
'that some Python debuggers can cause false positives, so it is recommended '
|
|
|
|
'to disable any debuggers while leak checking is enabled.'))
|
2021-04-19 08:52:48 -07:00
|
|
|
checking_leaks = functools.partial(check_tracer_leaks, True)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
debug_nans = bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_debug_nans',
|
|
|
|
default=False,
|
|
|
|
help=('Add nan checks to every operation. When a nan is detected on the '
|
|
|
|
'output of a jit-compiled computation, call into the un-compiled '
|
|
|
|
'version in an attempt to more precisely identify the operation '
|
|
|
|
'which produced the nan.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
debug_infs = bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_debug_infs',
|
|
|
|
default=False,
|
|
|
|
help=('Add inf checks to every operation. When an inf is detected on the '
|
|
|
|
'output of a jit-compiled computation, call into the un-compiled '
|
|
|
|
'version in an attempt to more precisely identify the operation '
|
|
|
|
'which produced the inf.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
log_compiles = bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_log_compiles',
|
|
|
|
default=False,
|
2024-04-19 13:58:06 -07:00
|
|
|
help=('Log a message each time `jit` or `pmap` compiles an XLA '
|
2022-10-13 17:06:22 +02:00
|
|
|
'computation. Logging is performed with `logging`. When this '
|
2021-04-19 08:52:48 -07:00
|
|
|
'option is set, the log level is WARNING; otherwise the level is '
|
|
|
|
'DEBUG.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
explain_cache_misses = bool_state(
|
2023-06-09 14:43:42 -07:00
|
|
|
name='jax_explain_cache_misses',
|
|
|
|
default=False,
|
|
|
|
help=('Each time there is a miss on one of the main caches (e.g. the '
|
2025-04-03 10:25:02 +01:00
|
|
|
'tracing cache), log an explanation. Logging is performed with '
|
2023-06-09 14:43:42 -07:00
|
|
|
'`logging`. When this option is set, the log level is WARNING; '
|
|
|
|
'otherwise the level is DEBUG.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
log_checkpoint_residuals = bool_state(
|
2023-03-22 13:37:40 -07:00
|
|
|
name='jax_log_checkpoint_residuals',
|
|
|
|
default=False,
|
|
|
|
help=('Log a message every time jax.checkpoint (aka jax.remat) is '
|
|
|
|
'partially evaluated (e.g. for autodiff), printing what residuals '
|
|
|
|
'are saved.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
pmap_shmap_merge = bool_state(
|
2023-04-20 21:22:16 -07:00
|
|
|
name='jax_pmap_shmap_merge',
|
|
|
|
default=False,
|
|
|
|
upgrade=True,
|
|
|
|
help='If True, pmap and shard_map API will be merged.')
|
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
distributed_debug = bool_state(
|
2021-07-30 12:37:21 -07:00
|
|
|
name='jax_distributed_debug',
|
Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.
This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.
Example output:
```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
process_index: 0
device_count: 8
local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
devices: None
abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
dtype=object), ('x',))
abstract args: []
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
[TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
dtype=object), ('x', 'y'))
abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-19 12:45:17 -07:00
|
|
|
default=False,
|
|
|
|
help=('Enable logging useful for debugging multi-process distributed '
|
2022-10-13 17:06:22 +02:00
|
|
|
'computations. Logging is performed with `logging` at WARNING '
|
Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.
This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.
Example output:
```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
process_index: 0
device_count: 8
local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
devices: None
abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
dtype=object), ('x',))
abstract args: []
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
[TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
dtype=object), ('x', 'y'))
abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-19 12:45:17 -07:00
|
|
|
'level.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
random_seed_offset = int_state(
|
2023-12-12 18:31:07 -08:00
|
|
|
name='jax_random_seed_offset',
|
|
|
|
default=0,
|
|
|
|
help=('Offset to all random seeds (e.g. argument to jax.random.key()).'),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True,
|
2023-12-12 18:31:07 -08:00
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
legacy_prng_key = enum_state(
|
2023-08-22 15:08:51 -07:00
|
|
|
name='jax_legacy_prng_key',
|
|
|
|
enum_values=['allow', 'warn', 'error'],
|
|
|
|
default='allow',
|
|
|
|
help=('Specify the behavior when raw PRNG keys are passed to '
|
|
|
|
'jax.random APIs.')
|
|
|
|
)
|
2022-03-25 18:37:15 -07:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
enable_custom_prng = bool_state(
|
2021-08-15 08:09:30 -07:00
|
|
|
name='jax_enable_custom_prng',
|
|
|
|
default=False,
|
2022-03-28 17:17:33 -07:00
|
|
|
upgrade=True,
|
2021-08-15 08:09:30 -07:00
|
|
|
help=('Enables an internal upgrade that allows one to define custom '
|
2022-03-25 18:37:15 -07:00
|
|
|
'pseudo-random number generator implementations.'))
|
2021-08-15 08:09:30 -07:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
default_prng_impl = enum_state(
|
2021-10-07 19:15:43 -07:00
|
|
|
name='jax_default_prng_impl',
|
|
|
|
enum_values=['threefry2x32', 'rbg', 'unsafe_rbg'],
|
|
|
|
default='threefry2x32',
|
|
|
|
help=('Select the default PRNG implementation, used when one is not '
|
|
|
|
'explicitly provided at seeding time.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
threefry_partitionable = bool_state(
|
2022-10-25 08:13:55 -07:00
|
|
|
name='jax_threefry_partitionable',
|
2025-01-13 22:45:41 -08:00
|
|
|
default=True,
|
2022-10-25 08:13:55 -07:00
|
|
|
upgrade=True,
|
|
|
|
help=('Enables internal threefry PRNG implementation changes that '
|
2023-07-18 06:15:24 -07:00
|
|
|
'render it automatically partitionable in some cases. Without this '
|
|
|
|
'flag, using the standard jax.random pseudo-random number generation '
|
|
|
|
'may result in extraneous communication and/or redundant distributed '
|
2022-10-25 08:13:55 -07:00
|
|
|
'computation. With this flag, the communication overheads disappear '
|
2022-11-22 13:58:59 -08:00
|
|
|
'in some cases.'),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True)
|
2022-10-25 08:13:55 -07:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
threefry_gpu_kernel_lowering = bool_state(
|
2024-05-01 10:32:36 -07:00
|
|
|
name='jax_threefry_gpu_kernel_lowering',
|
|
|
|
default=False,
|
|
|
|
help=('On GPU, lower threefry PRNG operations to a kernel implementation. '
|
|
|
|
'This makes compile times faster at a potential runtime memory '
|
|
|
|
'cost.'),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True)
|
2024-05-01 10:32:36 -07:00
|
|
|
|
2024-11-04 13:33:19 -05:00
|
|
|
use_direct_linearize = bool_state(
|
|
|
|
name='jax_use_direct_linearize',
|
|
|
|
default=False,
|
|
|
|
help=('Use direct linearization instead JVP followed by partial eval'),
|
|
|
|
include_in_jit_key=True)
|
|
|
|
|
2025-04-08 13:46:14 -07:00
|
|
|
# TODO make it so people don't use this, this is internal...
|
|
|
|
_check_rep = bool_state(
|
|
|
|
name='check_rep',
|
|
|
|
default=False,
|
|
|
|
help='internal implementation detail of shard_map, DO NOT USE',
|
|
|
|
include_in_jit_key=True)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
softmax_custom_jvp = bool_state(
|
2023-04-19 18:11:35 -07:00
|
|
|
name='jax_softmax_custom_jvp',
|
2023-05-23 11:56:50 -07:00
|
|
|
default=False,
|
2023-04-19 18:11:35 -07:00
|
|
|
upgrade=True,
|
|
|
|
help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should '
|
2023-05-23 11:56:50 -07:00
|
|
|
'improve memory usage and stability. Set True to use new '
|
2024-09-20 07:51:48 -07:00
|
|
|
'behavior. See https://github.com/jax-ml/jax/pull/15677'),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True)
|
2023-04-19 18:11:35 -07:00
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
enable_custom_vjp_by_custom_transpose = bool_state(
|
2022-03-25 18:37:15 -07:00
|
|
|
name='jax_enable_custom_vjp_by_custom_transpose',
|
|
|
|
default=False,
|
2022-03-28 17:17:33 -07:00
|
|
|
upgrade=True,
|
2022-03-25 18:37:15 -07:00
|
|
|
help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
|
|
|
|
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
raise_persistent_cache_errors = bool_state(
|
2022-09-27 20:59:08 +00:00
|
|
|
name='jax_raise_persistent_cache_errors',
|
|
|
|
default=False,
|
|
|
|
help=('If true, exceptions raised when reading or writing to the '
|
|
|
|
'persistent compilation cache will be allowed through, halting '
|
|
|
|
'program execution if not manually caught. If false, exceptions are '
|
|
|
|
'caught and raised as warnings, allowing program execution to '
|
|
|
|
'continue. Defaults to false so cache bugs or intermittent issues '
|
|
|
|
'are non-fatal.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
persistent_cache_min_compile_time_secs = float_state(
|
2022-10-28 23:53:30 +00:00
|
|
|
name='jax_persistent_cache_min_compile_time_secs',
|
2024-02-12 06:43:23 -08:00
|
|
|
default=1.,
|
2022-10-28 23:53:30 +00:00
|
|
|
help=('The minimum compile time of a computation to be written to the '
|
|
|
|
'persistent compilation cache. This threshold can be raised to '
|
|
|
|
'decrease the number of entries written to the cache.'))
|
2022-09-27 20:59:08 +00:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
persistent_cache_min_entry_size_bytes = int_state(
|
2024-01-04 15:16:25 -08:00
|
|
|
name='jax_persistent_cache_min_entry_size_bytes',
|
|
|
|
default=0,
|
|
|
|
help=('The minimum size (in bytes) of an entry that will be cached in the '
|
|
|
|
'persistent compilation cache: '
|
|
|
|
'* -1: disable the size restriction and prevent overrides. '
|
|
|
|
'* Leave at default (0) to allow for overrides. The override will '
|
|
|
|
' typically ensure that the minimum size is optimal for the '
|
|
|
|
' filesystem being used for the cache. '
|
|
|
|
'* > 0: the actual minimum size desired; no overrides.'))
|
|
|
|
|
2024-07-29 16:13:01 -07:00
|
|
|
# TODO: Change default to all
|
|
|
|
persistent_cache_enable_xla_caches = optional_string_state(
|
|
|
|
name='jax_persistent_cache_enable_xla_caches',
|
|
|
|
default='xla_gpu_per_fusion_autotune_cache_dir',
|
|
|
|
help=('When the persistent cache is enabled, additional XLA caching will '
|
|
|
|
'also be enabled automatically. This option can be used to configure'
|
|
|
|
'which XLA caching methods will be enabled.'),
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
compilation_cache_include_metadata_in_key = bool_state(
|
2023-04-19 13:26:24 -07:00
|
|
|
name='jax_compilation_cache_include_metadata_in_key',
|
|
|
|
default=False,
|
|
|
|
help=(
|
|
|
|
'Include metadata, such as file names and line numbers, in the'
|
|
|
|
' compilation cache key. If false, the cache will still get hits even'
|
|
|
|
' if functions or files are moved, etc. However, it means that'
|
|
|
|
' executables loaded from the cache may have stale metadata, which'
|
|
|
|
' may show up in, e.g., profiles.'
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
hlo_source_file_canonicalization_regex = optional_string_state(
|
2021-07-30 12:37:21 -07:00
|
|
|
name='jax_hlo_source_file_canonicalization_regex',
|
|
|
|
default=None,
|
|
|
|
help=('Used to canonicalize the source_path metadata of HLO instructions '
|
|
|
|
'by removing the given regex. If set, re.sub() is called on each '
|
|
|
|
'source_file with the given regex, and all matches are removed. '
|
|
|
|
'This can be used to avoid spurious cache misses when using the '
|
|
|
|
'persistent compilation cache, which includes HLO metadata in the '
|
|
|
|
'cache key.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
include_full_tracebacks_in_locations = bool_state(
|
2023-05-22 07:40:52 -07:00
|
|
|
name='jax_include_full_tracebacks_in_locations',
|
2023-12-17 21:55:34 -08:00
|
|
|
default=True,
|
2023-05-22 07:40:52 -07:00
|
|
|
help=(
|
2024-01-03 23:26:22 -08:00
|
|
|
'Include Python tracebacks in MLIR locations in IR emitted by JAX.'
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
traceback_in_locations_limit = int_state(
|
2024-01-03 23:26:22 -08:00
|
|
|
name='jax_traceback_in_locations_limit',
|
|
|
|
default=10,
|
|
|
|
help=(
|
|
|
|
'Limit the number of frames at the Python traceback frames included in '
|
|
|
|
'MLIR locations. If set to the negative value, traceback will not be '
|
|
|
|
'limited.'
|
2023-05-22 07:40:52 -07:00
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
share_binary_between_hosts = bool_state(
|
2024-01-11 23:37:22 -08:00
|
|
|
name='jax_share_binary_between_hosts',
|
|
|
|
default=False,
|
|
|
|
help=(
|
|
|
|
'If set to True, the compiled module will be shared between hosts '
|
|
|
|
'directly.'
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
share_binary_between_hosts_timeout_ms = int_state(
|
2024-01-11 23:37:22 -08:00
|
|
|
name='jax_share_binary_between_hosts_timeout_ms',
|
2024-02-06 01:27:21 -08:00
|
|
|
default=20 * 60 * 1000,
|
2024-01-11 23:37:22 -08:00
|
|
|
help='Timeout for the compiled module share.',
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
enable_pgle = bool_state(
|
2024-05-29 01:49:06 -07:00
|
|
|
name='jax_enable_pgle',
|
|
|
|
default=False,
|
|
|
|
help=(
|
|
|
|
'If set to True and the property jax_pgle_profiling_runs is set to '
|
|
|
|
'greater than 0, the modules will be recompiled after running specified '
|
|
|
|
'number times with collected data provided to the profile guided latency '
|
|
|
|
'estimator.'
|
|
|
|
),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True,
|
2024-05-29 01:49:06 -07:00
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
pgle_profiling_runs = int_state(
|
2024-05-29 01:49:06 -07:00
|
|
|
name='jax_pgle_profiling_runs',
|
|
|
|
default=3,
|
|
|
|
help=(
|
|
|
|
'Amount of times module should be profiled before recompilation when '
|
|
|
|
'PGLE is used.'
|
|
|
|
),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True,
|
2024-05-29 01:49:06 -07:00
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
pgle_aggregation_percentile = int_state(
|
2024-05-29 01:49:06 -07:00
|
|
|
name='jax_pgle_aggregation_percentile',
|
|
|
|
default=90,
|
|
|
|
help='Percentile used to aggregate performance data between devices when '
|
|
|
|
'PGLE is used.',
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
enable_compilation_cache = bool_state(
|
2023-11-27 14:52:22 -08:00
|
|
|
name='jax_enable_compilation_cache',
|
|
|
|
default=True,
|
|
|
|
help=('If set to False, the compilation cache will be disabled regardless '
|
2024-01-12 22:44:03 -08:00
|
|
|
'of whether set_cache_dir() was called. If set to True, the '
|
2023-11-27 14:52:22 -08:00
|
|
|
'path could be set to a default value or via a call to '
|
2024-01-12 22:44:03 -08:00
|
|
|
'set_cache_dir().'),
|
2023-11-27 14:52:22 -08:00
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
compilation_cache_dir = optional_string_state(
|
2023-11-27 14:52:22 -08:00
|
|
|
name='jax_compilation_cache_dir',
|
|
|
|
default=None,
|
|
|
|
help=('Path for the cache. '
|
|
|
|
'Precedence: '
|
2024-01-12 22:44:03 -08:00
|
|
|
'1. A call to compilation_cache.set_cache_dir(). '
|
2023-11-27 14:52:22 -08:00
|
|
|
'2. The value of this flag set in the command line or by default.'),
|
|
|
|
)
|
|
|
|
|
2024-11-14 10:38:53 +00:00
|
|
|
compilation_cache_expect_pgle = bool_state(
|
|
|
|
name='jax_compilation_cache_expect_pgle',
|
|
|
|
default=False,
|
|
|
|
help=('If set to True, compilation cache entries that were compiled with '
|
|
|
|
'profile data (i.e. PGLE was enabled and the requisite number of '
|
|
|
|
'executions were profiled) will be preferentially loaded, even if '
|
|
|
|
'PGLE is not currently enabled. A warning will be printed when no '
|
|
|
|
'preferred cache entry is found.')
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
compilation_cache_max_size = int_state(
|
2024-05-30 17:59:05 +04:00
|
|
|
name='jax_compilation_cache_max_size',
|
|
|
|
default=-1,
|
|
|
|
help=('The maximum size (in bytes) allowed for the persistent compilation '
|
|
|
|
'cache. When set, the least recently accessed cache entry(s) '
|
|
|
|
'will be deleted once the total cache directory size '
|
|
|
|
'exceeds the specified limit. '
|
|
|
|
'Caching will be disabled if this value is set to 0. A '
|
|
|
|
'special value of -1 indicates no limit, allowing the cache '
|
|
|
|
'size to grow indefinitely.'),
|
|
|
|
)
|
|
|
|
|
2024-09-10 10:02:05 -07:00
|
|
|
remove_custom_partitioning_ptr_from_cache_key = bool_state(
|
|
|
|
name='jax_remove_custom_partitioning_ptr_from_cache_key',
|
|
|
|
default=False,
|
|
|
|
help=('If set to True, remove the custom partitioning pointer '
|
|
|
|
'present in the precompiled stableHLO before hashing '
|
|
|
|
'during cache key computation. This is a potentially '
|
|
|
|
'unsafe flag to set and only users who are sure of '
|
|
|
|
'what they are trying to achieve should set it.'),
|
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
default_dtype_bits = enum_state(
|
2021-12-07 08:23:37 -08:00
|
|
|
name='jax_default_dtype_bits',
|
|
|
|
enum_values=['32', '64'],
|
|
|
|
default='64',
|
|
|
|
help=('Specify bit width of default dtypes, either 32-bit or 64-bit. '
|
|
|
|
'This is a temporary flag that will be used during the process '
|
|
|
|
'of deprecating the ``jax_enable_x64`` flag.'))
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
numpy_dtype_promotion = enum_state(
|
2022-05-26 10:56:09 -07:00
|
|
|
name='jax_numpy_dtype_promotion',
|
|
|
|
enum_values=['standard', 'strict'],
|
|
|
|
default='standard',
|
|
|
|
help=('Specify the rules used for implicit type promotion in operations '
|
|
|
|
'between arrays. Options are "standard" or "strict"; in strict-mode, '
|
|
|
|
'binary operations between arrays of differing strongly-specified '
|
|
|
|
'dtypes will result in an error.'),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True)
|
2022-05-26 10:56:09 -07:00
|
|
|
|
2024-09-19 10:41:58 -07:00
|
|
|
disallow_mesh_context_manager = bool_state(
|
|
|
|
name='jax_disallow_mesh_context_manager',
|
|
|
|
default=False,
|
|
|
|
help=(
|
|
|
|
'If set to True, trying to use a mesh as a context manager will'
|
|
|
|
' result in a RuntimeError.'
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2025-03-21 10:52:34 -07:00
|
|
|
# TODO(ayx): Move these 3 flags out of config once we have a user-level
|
|
|
|
# extension mechanism for adding contexts to which the jit cache is sensitive.
|
|
|
|
error_checking_behavior_nan = enum_state(
|
|
|
|
name='jax_error_checking_behavior_nan',
|
|
|
|
enum_values=['ignore', 'raise'],
|
|
|
|
default='ignore',
|
|
|
|
help=(
|
|
|
|
'Specify the behavior when a NaN is encountered. Options are "ignore"'
|
|
|
|
' or "raise".'
|
|
|
|
),
|
|
|
|
include_in_jit_key=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
error_checking_behavior_divide = enum_state(
|
|
|
|
name='jax_error_checking_behavior_divide',
|
|
|
|
enum_values=['ignore', 'raise'],
|
|
|
|
default='ignore',
|
|
|
|
help=(
|
|
|
|
'Specify the behavior when a divide by zero is encountered. Options are'
|
|
|
|
' "ignore" or "raise".'
|
|
|
|
),
|
|
|
|
include_in_jit_key=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
error_checking_behavior_oob = enum_state(
|
|
|
|
name='jax_error_checking_behavior_oob',
|
|
|
|
enum_values=['ignore', 'raise'],
|
|
|
|
default='ignore',
|
|
|
|
help=(
|
|
|
|
'Specify the behavior when an out of bounds access is encountered.'
|
|
|
|
' Options are "ignore" or "raise".'
|
|
|
|
),
|
|
|
|
include_in_jit_key=True,
|
|
|
|
)
|
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
def _update_x64_global(val):
|
2024-12-09 07:34:26 -08:00
|
|
|
jax_jit.global_state().enable_x64 = val
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
def _update_x64_thread_local(val):
|
2024-12-09 07:34:26 -08:00
|
|
|
jax_jit.thread_local_state().enable_x64 = val
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
enable_x64 = bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_enable_x64',
|
|
|
|
default=False,
|
|
|
|
help='Enable 64-bit types to be used',
|
|
|
|
update_global_hook=_update_x64_global,
|
|
|
|
update_thread_local_hook=_update_x64_thread_local)
|
|
|
|
|
|
|
|
# TODO(phawkins): remove after fixing users of FLAGS.x64_enabled.
|
2022-06-02 10:33:53 -07:00
|
|
|
config._contextmanager_flags.remove('jax_enable_x64')
|
2021-04-19 08:52:48 -07:00
|
|
|
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
setattr(Config, "x64_enabled", property(lambda _: enable_x64.value))
|
2021-12-21 20:55:03 +00:00
|
|
|
|
|
|
|
def _update_default_device_global(val):
|
2024-12-09 07:34:26 -08:00
|
|
|
jax_jit.global_state().default_device = val
|
2021-12-21 20:55:03 +00:00
|
|
|
|
2022-06-02 10:33:53 -07:00
|
|
|
|
2021-12-21 20:55:03 +00:00
|
|
|
def _update_default_device_thread_local(val):
|
2024-12-09 07:34:26 -08:00
|
|
|
jax_jit.thread_local_state().default_device = val
|
2021-12-21 20:55:03 +00:00
|
|
|
|
2022-06-02 10:33:53 -07:00
|
|
|
|
2021-12-21 20:55:03 +00:00
|
|
|
def _validate_default_device(val):
|
2024-11-07 00:24:32 +00:00
|
|
|
if (val is not None and
|
|
|
|
not isinstance(val, xla_client.Device) and
|
|
|
|
val not in ['cpu', 'gpu', 'tpu']):
|
2022-06-02 10:33:53 -07:00
|
|
|
# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
|
|
|
|
# all JAX backends use a single C++ device interface.
|
|
|
|
if 'Device' in str(type(val)):
|
2022-10-13 17:06:22 +02:00
|
|
|
logger.info(
|
2022-06-02 10:33:53 -07:00
|
|
|
'Allowing non-`xla_client.Device` default device: %s, type: %s',
|
|
|
|
repr(val), type(val))
|
|
|
|
return
|
2024-11-07 00:24:32 +00:00
|
|
|
raise ValueError('jax.default_device must be passed either a Device object (e.g. '
|
|
|
|
f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'"
|
|
|
|
f", got: {val!r}")
|
2021-12-21 20:55:03 +00:00
|
|
|
|
2022-06-02 10:33:53 -07:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
default_device = string_or_object_state(
|
2021-12-21 20:55:03 +00:00
|
|
|
name='jax_default_device',
|
|
|
|
default=None,
|
|
|
|
help=(
|
2022-06-02 10:33:53 -07:00
|
|
|
'Configure the default device for JAX operations. Set to a Device '
|
|
|
|
'object (e.g. ``jax.devices("cpu")[0]``) to use that Device as the '
|
2021-12-21 20:55:03 +00:00
|
|
|
'default device for JAX operations and jit\'d function calls (there is '
|
|
|
|
'no effect on multi-device computations, e.g. pmapped function calls). '
|
|
|
|
'Set to None to use the system default device. See '
|
|
|
|
':ref:`faq-data-placement` for more information on device placement.'),
|
|
|
|
update_global_hook=_update_default_device_global,
|
|
|
|
update_thread_local_hook=_update_default_device_thread_local,
|
Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
2024-02-12 13:03:58 -08:00
|
|
|
validator=_validate_default_device)
|
2021-12-21 20:55:03 +00:00
|
|
|
|
2021-04-19 08:52:48 -07:00
|
|
|
def _update_disable_jit_global(val):
|
2024-12-09 07:34:26 -08:00
|
|
|
jax_jit.global_state().disable_jit = val
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
def _update_disable_jit_thread_local(val):
|
2024-12-09 07:34:26 -08:00
|
|
|
jax_jit.thread_local_state().disable_jit = val
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
disable_jit = bool_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_disable_jit',
|
|
|
|
default=False,
|
|
|
|
help=('Disable JIT compilation and just call original Python.'),
|
|
|
|
update_global_hook=_update_disable_jit_global,
|
|
|
|
update_thread_local_hook=_update_disable_jit_thread_local)
|
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
numpy_rank_promotion = enum_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_numpy_rank_promotion',
|
|
|
|
enum_values=['allow', 'warn', 'raise'],
|
|
|
|
default='allow',
|
|
|
|
help=('Control NumPy-style automatic rank promotion broadcasting '
|
2021-04-21 06:36:08 -07:00
|
|
|
'("allow", "warn", or "raise").'),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True)
|
2021-04-19 08:52:48 -07:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
default_matmul_precision = optional_enum_state(
|
2021-04-19 08:52:48 -07:00
|
|
|
name='jax_default_matmul_precision',
|
2024-10-24 13:44:18 -04:00
|
|
|
enum_values=[
|
|
|
|
# Legacy precision API values
|
|
|
|
'default', 'high', 'highest', 'bfloat16', 'tensorfloat32', 'float32',
|
|
|
|
# Dot algorithm presets
|
|
|
|
'ANY_F8_ANY_F8_F32', 'ANY_F8_ANY_F8_F32_FAST_ACCUM', 'ANY_F8_ANY_F8_ANY',
|
|
|
|
'ANY_F8_ANY_F8_ANY_FAST_ACCUM', 'F16_F16_F16', 'F16_F16_F32',
|
|
|
|
'BF16_BF16_BF16', 'BF16_BF16_F32', 'BF16_BF16_F32_X3',
|
2025-03-14 08:57:34 -07:00
|
|
|
'BF16_BF16_F32_X6', 'BF16_BF16_F32_X9', 'TF32_TF32_F32',
|
|
|
|
'TF32_TF32_F32_X3', 'F32_F32_F32', 'F64_F64_F64',
|
2024-10-24 13:44:18 -04:00
|
|
|
],
|
2021-04-19 08:52:48 -07:00
|
|
|
default=None,
|
|
|
|
help=('Control the default matmul and conv precision for 32bit inputs.\n\n'
|
|
|
|
|
|
|
|
'Some platforms, like TPU, offer configurable precision levels for '
|
|
|
|
'matrix multiplication and convolution computations, trading off '
|
|
|
|
'accuracy for speed. The precision can be controlled for each '
|
|
|
|
'operation; for example, see the :func:`jax.lax.conv_general_dilated` '
|
|
|
|
'and :func:`jax.lax.dot` docstrings. But it can be useful to control '
|
|
|
|
'the default behavior obtained when an operation is not given a '
|
|
|
|
'specific precision.\n\n'
|
|
|
|
|
|
|
|
'This option can be used to control the default precision '
|
|
|
|
'level for computations involved in matrix multiplication and '
|
|
|
|
'convolution on 32bit inputs. The levels roughly describe the '
|
|
|
|
"precision at which scalar products are computed. The 'bfloat16' "
|
|
|
|
"option is the fastest and least precise; 'float32' is similar to "
|
2024-10-24 13:44:18 -04:00
|
|
|
"full float32 precision; 'tensorfloat32' is intermediate.\n\n"
|
|
|
|
|
|
|
|
'This parameter can also be used to specify an accumulation '
|
|
|
|
'"algorithm" for functions that perform matrix multiplications, like '
|
|
|
|
':func:`jax.lax.dot`. To specify an algorithm, set this option to '
|
|
|
|
'the name of a :class:`~jax.lax.DotAlgorithmPreset`.\n\n'),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True)
|
2021-06-02 15:22:50 -04:00
|
|
|
|
2024-09-05 18:22:15 -07:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
traceback_filtering = enum_state(
|
2021-06-02 15:22:50 -04:00
|
|
|
name = 'jax_traceback_filtering',
|
2023-08-03 10:20:29 -07:00
|
|
|
enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames",
|
|
|
|
"auto"],
|
2021-06-02 15:22:50 -04:00
|
|
|
default="auto",
|
2025-04-08 14:56:59 -07:00
|
|
|
help="Controls how JAX filters internal frames out of tracebacks. Valid values are:\n"
|
|
|
|
"- ``off``: disables traceback filtering.\n"
|
|
|
|
"- ``auto``: use ``tracebackhide`` if running under a sufficiently "
|
|
|
|
"new IPython, or ``remove_frames`` otherwise.\n"
|
|
|
|
"- ``tracebackhide``: adds ``__tracebackhide__`` annotations to "
|
|
|
|
"hidden stack frames, which some traceback printers support.\n"
|
|
|
|
"- ``remove_frames``: removes hidden frames from tracebacks, and adds "
|
|
|
|
"the unfiltered traceback as a ``__cause__`` of the exception.\n"
|
|
|
|
"- ``quiet_remove_frames``: removes hidden frames from tracebacks, and adds "
|
|
|
|
"a brief message (to the ``__cause__`` of the exception) describing that this has "
|
|
|
|
"happened.\n\n")
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2022-05-26 22:22:20 -07:00
|
|
|
# This flag is for internal use.
|
|
|
|
# TODO(tianjianlu): Removes once we always enable cusparse lowering.
|
2022-12-09 15:41:06 -08:00
|
|
|
# TODO(b/262050896): Set to true after bug is fixed
|
2024-04-15 10:35:50 +01:00
|
|
|
bcoo_cusparse_lowering = bool_state(
|
2022-01-31 11:01:58 -08:00
|
|
|
name='jax_bcoo_cusparse_lowering',
|
2022-12-09 15:41:06 -08:00
|
|
|
default=False,
|
2022-01-31 11:01:58 -08:00
|
|
|
help=('Enables lowering BCOO ops to cuSparse.'))
|
2022-01-20 22:58:09 -08:00
|
|
|
|
|
|
|
# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging
|
|
|
|
# if the intended backend can handle lowering the result
|
2024-04-15 10:35:50 +01:00
|
|
|
dynamic_shapes = bool_state(
|
2022-01-20 22:58:09 -08:00
|
|
|
name='jax_dynamic_shapes',
|
2021-11-16 11:17:42 +02:00
|
|
|
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
|
2022-01-20 22:58:09 -08:00
|
|
|
help=('Enables experimental features for staging out computations with '
|
2022-03-30 17:52:55 -07:00
|
|
|
'dynamic shapes.'),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True)
|
2022-02-13 22:40:26 -08:00
|
|
|
|
2024-10-29 11:03:49 -07:00
|
|
|
# This is for stackless backward compat with e.g. equinox
|
|
|
|
eager_constant_folding = bool_state(
|
|
|
|
name='eager_constant_folding',
|
|
|
|
default=False,
|
|
|
|
help=('Attempt constant folding during staging.'),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True)
|
2024-10-29 11:03:49 -07:00
|
|
|
|
2024-09-18 14:53:45 -07:00
|
|
|
enable_remat_opt_pass = bool_state(
|
2024-10-02 16:18:03 -04:00
|
|
|
name='jax_compiler_enable_remat_pass',
|
|
|
|
default=True,
|
2024-09-18 14:53:45 -07:00
|
|
|
help=('Config to enable / disable the rematerialization HLO pass. '
|
|
|
|
'Useful to allow XLA to automatically trade off memory and '
|
|
|
|
'compute when encountering OOM errors. However, you are '
|
|
|
|
'likely to get better results manually with jax.checkpoint'))
|
|
|
|
|
2024-08-23 21:21:55 +00:00
|
|
|
no_tracing = bool_state(
|
|
|
|
name='jax_no_tracing',
|
|
|
|
default=False,
|
|
|
|
help='Disallow tracing for JIT compilation.')
|
|
|
|
|
2024-07-31 16:25:30 +00:00
|
|
|
disable_vmap_shmap_error = bool_state(
|
|
|
|
name='jax_disable_vmap_shmap_error',
|
|
|
|
default=False,
|
|
|
|
upgrade=False,
|
|
|
|
help='Temporary workaround to disable an error check in vmap-of-shmap.')
|
|
|
|
|
2024-04-04 18:21:10 -07:00
|
|
|
# TODO(mattjj): remove once we land mutable array plumbing, or face great shame
|
2024-04-15 10:35:50 +01:00
|
|
|
custom_vjp_disable_shape_check = bool_state(
|
2024-04-04 18:21:10 -07:00
|
|
|
name='jax_custom_vjp_disable_shape_check',
|
|
|
|
default=False,
|
|
|
|
upgrade=True,
|
|
|
|
help='Disable the check from #19009 to enable some custom_vjp hacks.')
|
|
|
|
|
2024-10-02 17:18:52 +00:00
|
|
|
mutable_array_checks = bool_state(
|
|
|
|
name='jax_mutable_array_checks',
|
|
|
|
default=False,
|
|
|
|
upgrade=True,
|
|
|
|
help='Enable error checks for mutable arrays that rule out aliasing.')
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
xla_runtime_errors = bool_state(
|
2022-09-23 15:10:41 +01:00
|
|
|
name='jax_experimental_unsafe_xla_runtime_errors',
|
2022-09-23 14:30:49 +01:00
|
|
|
default=False,
|
|
|
|
help=('Enable XLA runtime errors for jax.experimental.checkify.checks '
|
2022-09-23 15:10:41 +01:00
|
|
|
'on CPU and GPU. These errors are async, might get lost and are not '
|
|
|
|
'very readable. But, they crash the computation and enable you '
|
|
|
|
'to write jittable checks without needing to checkify. Does not '
|
|
|
|
'work under pmap/pjit.')
|
2022-09-23 14:30:49 +01:00
|
|
|
)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
jax_xla_profile_version = int_state(
|
2023-08-08 14:28:35 -07:00
|
|
|
name='jax_xla_profile_version',
|
|
|
|
default=0,
|
2023-10-27 12:52:24 -07:00
|
|
|
help=(
|
|
|
|
'Optional profile version for XLA compilation. This is meaningful '
|
|
|
|
'only when XLA is configured to support the remote compilation '
|
|
|
|
'profile feature.'),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True,
|
2023-08-08 14:28:35 -07:00
|
|
|
)
|
|
|
|
|
2022-04-11 14:59:04 +00:00
|
|
|
@contextlib.contextmanager
|
|
|
|
def explicit_device_put_scope() -> Iterator[None]:
|
|
|
|
"""Indicates that the current context is an explicit device_put*() call."""
|
2024-10-17 12:22:39 -07:00
|
|
|
state = guard_lib.thread_local_state()
|
2022-04-11 14:59:04 +00:00
|
|
|
prev = state.explicit_device_put
|
|
|
|
state.explicit_device_put = True
|
|
|
|
try:
|
2022-02-14 13:11:26 -08:00
|
|
|
yield
|
2022-04-11 14:59:04 +00:00
|
|
|
finally:
|
|
|
|
state.explicit_device_put = prev
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def explicit_device_get_scope() -> Iterator[None]:
|
|
|
|
"""Indicates that the current context is an explicit device_get() call."""
|
2024-10-17 12:22:39 -07:00
|
|
|
state = guard_lib.thread_local_state()
|
2022-04-11 14:59:04 +00:00
|
|
|
prev = state.explicit_device_get
|
|
|
|
state.explicit_device_get = True
|
|
|
|
try:
|
2022-02-14 13:11:26 -08:00
|
|
|
yield
|
2022-04-11 14:59:04 +00:00
|
|
|
finally:
|
|
|
|
state.explicit_device_get = prev
|
|
|
|
|
|
|
|
def _update_transfer_guard(state, key, val):
|
2024-10-17 12:22:39 -07:00
|
|
|
"""Applies the transfer guard level within guard_lib."""
|
2022-04-11 14:59:04 +00:00
|
|
|
if val is None:
|
|
|
|
setattr(state, key, None)
|
|
|
|
elif val == 'allow':
|
2024-10-17 12:22:39 -07:00
|
|
|
setattr(state, key, guard_lib.TransferGuardLevel.ALLOW)
|
2022-04-11 14:59:04 +00:00
|
|
|
elif val == 'log':
|
2024-10-17 12:22:39 -07:00
|
|
|
setattr(state, key, guard_lib.TransferGuardLevel.LOG)
|
2022-04-11 14:59:04 +00:00
|
|
|
elif val == 'disallow':
|
2024-10-17 12:22:39 -07:00
|
|
|
setattr(state, key, guard_lib.TransferGuardLevel.DISALLOW)
|
2022-04-11 14:59:04 +00:00
|
|
|
elif val == 'log_explicit':
|
2024-10-17 12:22:39 -07:00
|
|
|
setattr(state, key, guard_lib.TransferGuardLevel.LOG_EXPLICIT)
|
2022-04-11 14:59:04 +00:00
|
|
|
elif val == 'disallow_explicit':
|
2024-10-17 12:22:39 -07:00
|
|
|
setattr(state, key, guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
|
2022-04-11 14:59:04 +00:00
|
|
|
else:
|
|
|
|
assert False, f'Invalid transfer guard level {val}'
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
transfer_guard_host_to_device = optional_enum_state(
|
2022-04-11 14:59:04 +00:00
|
|
|
name='jax_transfer_guard_host_to_device',
|
|
|
|
enum_values=[
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
],
|
2024-10-17 12:22:39 -07:00
|
|
|
# The default is applied by guard_lib. Use None here to avoid accidentally
|
|
|
|
# overriding --jax_transfer_guard.
|
2022-04-11 14:59:04 +00:00
|
|
|
default=None,
|
|
|
|
help=('Select the transfer guard level for host-to-device transfers. '
|
|
|
|
'Default is "allow".'),
|
|
|
|
update_global_hook=lambda val: _update_transfer_guard(
|
2024-10-17 12:22:39 -07:00
|
|
|
guard_lib.global_state(), 'host_to_device', val),
|
2022-04-11 14:59:04 +00:00
|
|
|
update_thread_local_hook=lambda val: _update_transfer_guard(
|
2024-10-17 12:22:39 -07:00
|
|
|
guard_lib.thread_local_state(), 'host_to_device', val))
|
2022-04-11 14:59:04 +00:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
transfer_guard_device_to_device = optional_enum_state(
|
2022-04-11 14:59:04 +00:00
|
|
|
name='jax_transfer_guard_device_to_device',
|
|
|
|
enum_values=[
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
],
|
2024-10-17 12:22:39 -07:00
|
|
|
# The default is applied by guard_lib. Use None here to avoid accidentally
|
|
|
|
# overriding --jax_transfer_guard.
|
2022-04-11 14:59:04 +00:00
|
|
|
default=None,
|
|
|
|
help=('Select the transfer guard level for device-to-device transfers. '
|
|
|
|
'Default is "allow".'),
|
|
|
|
update_global_hook=lambda val: _update_transfer_guard(
|
2024-10-17 12:22:39 -07:00
|
|
|
guard_lib.global_state(), 'device_to_device', val),
|
2022-04-11 14:59:04 +00:00
|
|
|
update_thread_local_hook=lambda val: _update_transfer_guard(
|
2024-10-17 12:22:39 -07:00
|
|
|
guard_lib.thread_local_state(), 'device_to_device', val))
|
2022-04-11 14:59:04 +00:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
transfer_guard_device_to_host = optional_enum_state(
|
2022-04-11 14:59:04 +00:00
|
|
|
name='jax_transfer_guard_device_to_host',
|
|
|
|
enum_values=[
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
],
|
2024-10-17 12:22:39 -07:00
|
|
|
# The default is applied by guard_lib. Use None here to avoid
|
2022-04-11 14:59:04 +00:00
|
|
|
# accidentally overriding --jax_transfer_guard.
|
|
|
|
default=None,
|
|
|
|
help=('Select the transfer guard level for device-to-host transfers. '
|
|
|
|
'Default is "allow".'),
|
|
|
|
update_global_hook=lambda val: _update_transfer_guard(
|
2024-10-17 12:22:39 -07:00
|
|
|
guard_lib.global_state(), 'device_to_host', val
|
|
|
|
),
|
2022-04-11 14:59:04 +00:00
|
|
|
update_thread_local_hook=lambda val: _update_transfer_guard(
|
2024-10-17 12:22:39 -07:00
|
|
|
guard_lib.thread_local_state(), 'device_to_host', val))
|
2022-04-11 14:59:04 +00:00
|
|
|
|
|
|
|
def _update_all_transfer_guard_global(val):
|
|
|
|
for name in ('jax_transfer_guard_host_to_device',
|
|
|
|
'jax_transfer_guard_device_to_device',
|
|
|
|
'jax_transfer_guard_device_to_host'):
|
|
|
|
config.update(name, val)
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
_transfer_guard = optional_enum_state(
|
2022-04-11 14:59:04 +00:00
|
|
|
name='jax_transfer_guard',
|
|
|
|
enum_values=[
|
|
|
|
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
|
|
|
|
],
|
2024-10-17 12:22:39 -07:00
|
|
|
# The default is applied by guard_lib. Use None here to avoid accidentally
|
|
|
|
# overriding --jax_transfer_guard_*.
|
2022-04-11 14:59:04 +00:00
|
|
|
default=None,
|
|
|
|
help=('Select the transfer guard level for all transfers. This option is '
|
2022-02-14 13:11:26 -08:00
|
|
|
'set-only; the transfer guard level for a specific direction should '
|
|
|
|
'be read using the per-transfer direction option. '
|
|
|
|
'Default is "allow".'),
|
2022-04-11 14:59:04 +00:00
|
|
|
update_global_hook=_update_all_transfer_guard_global)
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def transfer_guard(new_val: str) -> Iterator[None]:
|
2022-06-09 17:56:03 +00:00
|
|
|
"""A contextmanager to control the transfer guard level for all transfers.
|
|
|
|
|
|
|
|
For more information, see
|
2025-04-08 08:32:59 -07:00
|
|
|
https://docs.jax.dev/en/latest/transfer_guard.html
|
2022-06-09 17:56:03 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
new_val: The new thread-local transfer guard level for all transfers.
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
None.
|
|
|
|
"""
|
2022-04-11 14:59:04 +00:00
|
|
|
with contextlib.ExitStack() as stack:
|
|
|
|
stack.enter_context(transfer_guard_host_to_device(new_val))
|
|
|
|
stack.enter_context(transfer_guard_device_to_device(new_val))
|
|
|
|
stack.enter_context(transfer_guard_device_to_host(new_val))
|
|
|
|
stack.enter_context(_transfer_guard(new_val))
|
|
|
|
yield
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
|
|
|
|
|
2024-10-24 08:45:09 -07:00
|
|
|
def _update_garbage_collection_guard(state, key, val):
|
|
|
|
"""Applies the transfer guard level within guard_lib."""
|
|
|
|
if val is None:
|
|
|
|
setattr(state, key, None)
|
|
|
|
elif val == 'allow':
|
|
|
|
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.ALLOW)
|
|
|
|
elif val == 'log':
|
|
|
|
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.LOG)
|
|
|
|
elif val == 'fatal':
|
|
|
|
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.FATAL)
|
|
|
|
else:
|
|
|
|
assert False, f'Invalid garbage collection guard level {val}'
|
2024-10-17 12:22:39 -07:00
|
|
|
|
2024-10-24 08:45:09 -07:00
|
|
|
array_garbage_collection_guard = optional_enum_state(
|
|
|
|
name='jax_array_garbage_collection_guard',
|
|
|
|
enum_values=['allow', 'log', 'fatal'],
|
|
|
|
# The default is applied by guard_lib.
|
|
|
|
default=None,
|
|
|
|
help=(
|
2025-04-08 14:56:59 -07:00
|
|
|
'Select garbage collection guard level for ``jax.Array`` objects.\n\n'
|
|
|
|
'This option can be used to control what happens when a ``jax.Array``'
|
|
|
|
' object is garbage collected. It is desirable for ``jax.Array``'
|
|
|
|
' objects to be freed by Python reference counting rather than garbage'
|
2024-10-24 08:45:09 -07:00
|
|
|
' collection in order to avoid device memory being held by the arrays'
|
2025-04-08 14:56:59 -07:00
|
|
|
' until garbage collection occurs.\n\n'
|
|
|
|
'Valid values are:\n\n'
|
|
|
|
'* ``allow``: do not log garbage collection of ``jax.Array`` objects.\n'
|
|
|
|
'* ``log``: log an error when a ``jax.Array`` is garbage collected.\n'
|
|
|
|
'* ``fatal``: fatal error if a ``jax.Array`` is garbage collected.\n\n'
|
|
|
|
'Default is ``allow``. Note that not all cycles may be detected.'
|
2024-10-24 08:45:09 -07:00
|
|
|
),
|
|
|
|
update_global_hook=lambda val: _update_garbage_collection_guard(
|
|
|
|
guard_lib.global_state(), 'garbage_collect_array', val
|
|
|
|
),
|
|
|
|
update_thread_local_hook=lambda val: _update_garbage_collection_guard(
|
|
|
|
guard_lib.thread_local_state(), 'garbage_collect_array', val
|
|
|
|
),
|
|
|
|
)
|
2024-10-17 12:22:39 -07:00
|
|
|
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
# Don't define a context manager since this isn't threadsafe.
|
2024-04-15 10:35:50 +01:00
|
|
|
string_state(
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
name='jax_debug_log_modules',
|
|
|
|
default='',
|
|
|
|
help=('Comma-separated list of module names (e.g. "jax" or '
|
|
|
|
'"jax._src.xla_bridge,jax._src.dispatch") to enable debug logging '
|
|
|
|
'for.'),
|
2024-09-05 18:22:15 -07:00
|
|
|
update_global_hook=logging_config.update_debug_log_modules)
|
|
|
|
|
|
|
|
# Don't define a context manager since this isn't threadsafe.
|
|
|
|
optional_enum_state(
|
|
|
|
name='jax_logging_level',
|
|
|
|
enum_values=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
|
|
|
default=logging.getLevelName(logging.getLogger("jax").level),
|
|
|
|
help=('Set the corresponding logging level on all jax loggers. Only string'
|
|
|
|
' values from ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR",'
|
|
|
|
' "CRITICAL"] are accepted. If None, the logging level will not be'
|
|
|
|
' set. Includes C++ logging.'),
|
|
|
|
update_global_hook=lambda logging_level: \
|
|
|
|
logging_config.update_logging_level_global(logging_level=logging_level)
|
|
|
|
)
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
pmap_no_rank_reduction = bool_state(
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
name='jax_pmap_no_rank_reduction',
|
2024-09-11 15:41:05 -07:00
|
|
|
default=True,
|
|
|
|
help='If True, pmap shards have a the same rank as their enclosing array.',
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
)
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
|
|
|
|
use_shardy_partitioner = bool_state(
|
|
|
|
name='jax_use_shardy_partitioner',
|
|
|
|
default=False,
|
|
|
|
upgrade=True,
|
|
|
|
help=(
|
|
|
|
'Whether to lower to Shardy. Shardy is a new open sourced propagation '
|
|
|
|
'framework for MLIR. Currently Shardy is experimental in JAX. See '
|
|
|
|
'www.github.com/openxla/shardy'
|
|
|
|
),
|
2024-11-05 08:31:12 -08:00
|
|
|
include_in_jit_key=True,
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
)
|
2024-11-18 08:11:04 -08:00
|
|
|
|
|
|
|
gpu_use_magma = enum_state(
|
|
|
|
name='jax_use_magma',
|
|
|
|
enum_values=['off', 'on', 'auto'],
|
|
|
|
default='auto',
|
|
|
|
help=(
|
|
|
|
'Enable experimental support for MAGMA-backed lax.linalg.eig on GPU. '
|
|
|
|
'See the documentation for lax.linalg.eig for more details about how '
|
|
|
|
'to use this feature.'
|
|
|
|
),
|
|
|
|
)
|
2024-11-26 13:57:47 +00:00
|
|
|
|
|
|
|
exec_time_optimization_effort = float_state(
|
|
|
|
name='jax_exec_time_optimization_effort',
|
|
|
|
default=0.0,
|
|
|
|
help='Effort for minimizing execution time (higher means more effort), valid range [-1.0, 1.0].'
|
|
|
|
)
|
|
|
|
|
|
|
|
memory_fitting_effort = float_state(
|
|
|
|
name='jax_memory_fitting_effort',
|
|
|
|
default=0.0,
|
|
|
|
help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].'
|
|
|
|
)
|
2025-01-28 17:00:06 -08:00
|
|
|
|
2025-02-14 14:45:25 -08:00
|
|
|
optimization_level = enum_state(
|
|
|
|
name='jax_optimization_level',
|
|
|
|
enum_values=[
|
|
|
|
'UNKNOWN',
|
|
|
|
'O0',
|
|
|
|
'O1',
|
|
|
|
'O2',
|
|
|
|
'O3',
|
|
|
|
],
|
|
|
|
default='UNKNOWN',
|
|
|
|
help='The degree to which the compiler should optimize for execution time',
|
|
|
|
include_in_jit_key=True
|
|
|
|
)
|
|
|
|
|
|
|
|
memory_fitting_level = enum_state(
|
|
|
|
name='jax_memory_fitting_level',
|
|
|
|
enum_values=[
|
|
|
|
'UNKNOWN',
|
|
|
|
'O0',
|
|
|
|
'O1',
|
|
|
|
'O2',
|
|
|
|
'O3',
|
|
|
|
],
|
2025-04-18 15:08:24 -07:00
|
|
|
default='O2',
|
2025-02-14 14:45:25 -08:00
|
|
|
help=(
|
|
|
|
'The degree to which the compiler should attempt to make the program'
|
|
|
|
' fit in memory'
|
|
|
|
),
|
2025-04-18 15:08:24 -07:00
|
|
|
include_in_jit_key=True,
|
2025-02-14 14:45:25 -08:00
|
|
|
)
|
|
|
|
|
2025-04-18 09:28:13 -07:00
|
|
|
DEFAULT_CPU_COLLECTIVES_IMPL = "gloo"
|
|
|
|
|
2025-01-28 17:00:06 -08:00
|
|
|
cpu_collectives_implementation = optional_enum_state(
|
|
|
|
name='jax_cpu_collectives_implementation',
|
|
|
|
enum_values=["gloo", "mpi", "megascale"],
|
2025-04-18 09:28:13 -07:00
|
|
|
default=DEFAULT_CPU_COLLECTIVES_IMPL,
|
2025-01-28 17:00:06 -08:00
|
|
|
help=(
|
|
|
|
"Cross-process collective implementation used on CPU. Must be one of "
|
|
|
|
'("gloo", "mpi")'),
|
|
|
|
)
|
|
|
|
|
2025-02-26 18:16:45 -08:00
|
|
|
enable_empty_arrays = bool_state(
|
|
|
|
name='jax_enable_empty_arrays',
|
|
|
|
default=False,
|
|
|
|
help=(
|
|
|
|
"Enable the creation of an Array from an empty list of single-device "
|
|
|
|
"arrays. This is to support MPMD/pipeline parallelism in McJAX (WIP)."
|
|
|
|
)
|
|
|
|
)
|
2025-03-02 19:41:52 -08:00
|
|
|
|
|
|
|
use_high_dynamic_range_gumbel = bool_state(
|
|
|
|
name='jax_high_dynamic_range_gumbel',
|
|
|
|
default=False,
|
|
|
|
help='If True, gumble noise draws two samples to cover low probability '
|
|
|
|
'events with more precision.',
|
|
|
|
)
|