[JAX] [XLA:Python] Move JAX configuration objects into C++.

A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.

There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.

PiperOrigin-RevId: 693114411
This commit is contained in:
Peter Hawkins 2024-11-04 15:38:25 -08:00 committed by jax authors
parent 38b4d00100
commit ab47d4687f
7 changed files with 396 additions and 287 deletions

View File

@ -123,8 +123,8 @@ def _update_debug_special_global(_):
jax_jit.global_state().post_hook = None
def _update_debug_special_thread_local(_):
if (getattr(config._thread_local_state, "jax_debug_nans", False) or
getattr(config._thread_local_state, "jax_debug_infs", False)):
if (config.debug_nans.get_local() == True or
config.debug_infs.get_local() == True):
jax_jit.thread_local_state().post_hook = _nan_check_posthook
else:
jax_jit.thread_local_state().post_hook = None

View File

@ -29,8 +29,8 @@ compute_on_context = ComputeOnContext()
@contextmanager
def extend_compute_type(c_type: str):
compute_on_context.stack.append(c_type)
config.update_thread_local_jit_state(
compute_on_context_manager=tuple(compute_on_context.stack))
config.compute_on_context_manager.set_local(
tuple(compute_on_context.stack))
try:
if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1:
raise NotImplementedError(
@ -39,8 +39,7 @@ def extend_compute_type(c_type: str):
yield compute_on_context.stack[-1]
finally:
compute_on_context.stack.pop()
config.update_thread_local_jit_state(
compute_on_context_manager=tuple(compute_on_context.stack))
config.compute_on_context_manager.set_local(tuple(compute_on_context.stack))
def current_compute_type() -> str | None:
return compute_on_context.stack[-1] if compute_on_context.stack else None

View File

@ -22,14 +22,23 @@ import logging
import os
import sys
import threading
from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast
from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast, TYPE_CHECKING
from jax._src import lib
from jax._src.lib import guard_lib
from jax._src.lib import jax_jit
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src import logging_config
# TODO(phawkins): reenable pytype after xla_extension_version >= 295
# pytype: skip-file
if xla_extension_version >= 295:
config_ext = xla_client._xla.config
else:
config_ext = None
logger = logging.getLogger(__name__)
_T = TypeVar('_T')
@ -191,49 +200,79 @@ class Config:
already_configured_with_absl = True
def trace_context():
"""Returns a tuple of configuration values that affect tracing.
if xla_extension_version >= 295:
def trace_context():
"""Returns a tuple of configuration values that affect tracing.
These values are included in the cache key for linear_util.cache.
These values are included in the cache key for linear_util.cache.
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately.
"""
tls = jax_jit.thread_local_state()
axis_env_state = ()
mesh_context_manager = ()
xla_metadata_context_manager = ()
compute_on_context_manager = ()
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately.
"""
return (axis_env_state.value, mesh_context_manager.value, xla_metadata_context_manager.value,
compute_on_context_manager.value, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value,
eager_constant_folding.value,
numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
sharding_in_types.value,
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
debug_key_reuse.value,
jax_xla_profile_version.value,
# Technically this affects jaxpr->stablehlo lowering, not tracing.
hlo_source_file_canonicalization_regex.value,
pgle_profiling_runs.value,
enable_pgle.value,
use_shardy_partitioner.value)
else:
def trace_context():
"""Returns a tuple of configuration values that affect tracing.
context: Any = tls.extra_jit_context
if context and context.axis_env_state is not None:
axis_env_state = context.axis_env_state
if context and context.mesh_context_manager:
mesh_context_manager = context.mesh_context_manager
if context and context.xla_metadata_context_manager:
xla_metadata_context_manager = context.xla_metadata_context_manager
if context and context.compute_on_context_manager:
compute_on_context_manager = context.compute_on_context_manager
return (axis_env_state, mesh_context_manager, xla_metadata_context_manager,
compute_on_context_manager, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value,
eager_constant_folding.value,
numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
sharding_in_types.value,
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
debug_key_reuse.value,
jax_xla_profile_version.value,
# Technically this affects jaxpr->stablehlo lowering, not tracing.
hlo_source_file_canonicalization_regex.value,
pgle_profiling_runs.value,
enable_pgle.value,
use_shardy_partitioner.value)
These values are included in the cache key for linear_util.cache.
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately.
"""
tls = jax_jit.thread_local_state()
axis_env_state = ()
mesh_context_manager = ()
xla_metadata_context_manager = ()
compute_on_context_manager = ()
context: Any = tls.extra_jit_context
if context and context.axis_env_state is not None:
axis_env_state = context.axis_env_state
if context and context.mesh_context_manager:
mesh_context_manager = context.mesh_context_manager
if context and context.xla_metadata_context_manager:
xla_metadata_context_manager = context.xla_metadata_context_manager
if context and context.compute_on_context_manager:
compute_on_context_manager = context.compute_on_context_manager
return (axis_env_state, mesh_context_manager, xla_metadata_context_manager,
compute_on_context_manager, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value,
eager_constant_folding.value,
numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
sharding_in_types.value,
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
debug_key_reuse.value,
jax_xla_profile_version.value,
# Technically this affects jaxpr->stablehlo lowering, not tracing.
hlo_source_file_canonicalization_regex.value,
pgle_profiling_runs.value,
enable_pgle.value,
use_shardy_partitioner.value)
config = Config()
@ -245,94 +284,185 @@ parse_flags_with_absl = config.parse_flags_with_absl
class NoDefault: pass
no_default = NoDefault()
if xla_extension_version >= 295:
class State(config_ext.Config[_T]):
class _Unset: pass
unset = _Unset()
__slots__ = (
'_name', '_update_thread_local_hook', '_update_global_hook',
'_validator', '_default_context_manager_value', '__doc__', '__name__',
)
_thread_local_state = threading.local()
def __init__(
self,
name: str,
default: _T,
help,
update_global_hook: Callable[[_T], None] | None = None,
update_thread_local_hook: Callable[[_T | None], None] | None = None,
validator: Callable[[Any], None] | None = None,
extra_description: str = '',
default_context_manager_value: Any = no_default,
include_in_jit_key: bool = False,
):
super().__init__(default, include_in_jit_key)
self._name = name
self.__name__ = name[4:] if name.startswith('jax_') else name
self.__doc__ = (f"Context manager for `{name}` config option"
f"{extra_description}.\n\n{help}")
self._update_global_hook = update_global_hook
self._update_thread_local_hook = update_thread_local_hook
self._validator = validator
self._default_context_manager_value = default_context_manager_value
if self._validator:
self._validator(default)
if self._update_global_hook:
self._update_global_hook(default)
class State(Generic[_T]):
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
__slots__ = (
'_name', '_value', '_update_thread_local_hook', '_update_global_hook',
'_validator', '_default_context_manager_value', '__doc__', '__name__',
)
def _set(self, value: _T) -> None:
if self._validator:
self._validator(value)
self.set_global(value)
if self._update_global_hook:
self._update_global_hook(value)
def __init__(
self,
name: str,
default: _T,
help,
update_global_hook: Callable[[_T], None] | None = None,
update_thread_local_hook: Callable[[_T | None], None] | None = None,
validator: Callable[[Any], None] | None = None,
extra_description: str = '',
default_context_manager_value: Any = no_default,
):
self._name = name
self.__name__ = name[4:] if name.startswith('jax_') else name
self.__doc__ = (f"Context manager for `{name}` config option"
f"{extra_description}.\n\n{help}")
self._update_global_hook = update_global_hook
self._update_thread_local_hook = update_thread_local_hook
self._validator = validator
self._default_context_manager_value = default_context_manager_value
self._set(default)
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
def _set(self, value: _T) -> None:
if self._validator:
self._validator(value)
self._value = value
if self._update_global_hook:
self._update_global_hook(value)
@property
def value(self) -> _T:
val = _thread_local_state.__dict__.get(self._name, unset)
return cast(_T, val) if val is not unset else self._value
@contextlib.contextmanager
def __call__(self, new_val: Any = no_default):
if new_val is no_default:
if self._default_context_manager_value is not no_default:
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
else:
# no default_value provided to constructor and no value provided as an
# argument, so we raise an error
raise TypeError(f"Context manager for {self.__name__} config option "
"requires an argument representing the new value for "
"the config option.")
if self._validator:
self._validator(new_val)
prev_val = getattr(_thread_local_state, self._name, unset)
setattr(_thread_local_state, self._name, new_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(new_val)
try:
yield
finally:
if prev_val is unset:
delattr(_thread_local_state, self._name)
@contextlib.contextmanager
def __call__(self, new_val: Any = no_default):
if new_val is no_default:
if self._default_context_manager_value is not no_default:
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
else:
# no default_value provided to constructor and no value provided as an
# argument, so we raise an error
raise TypeError(f"Context manager for {self.__name__} config option "
"requires an argument representing the new value for "
"the config option.")
if self._validator:
self._validator(new_val)
prev_val = self.swap_local(new_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(new_val)
try:
yield
finally:
self.set_local(prev_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(None)
else:
setattr(_thread_local_state, self._name, prev_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(cast(_T, prev_val))
if prev_val is config_ext.unset:
self._update_thread_local_hook(None)
else:
self._update_thread_local_hook(cast(Optional[Any], prev_val))
def _add_hooks(self, update_global_hook, update_thread_local_hook):
"""Private method that adds hooks to an existing context-manager.
def _add_hooks(self, update_global_hook, update_thread_local_hook):
"""Private method that adds hooks to an existing context-manager.
Used to avoid cyclic import dependencies."""
self._update_thread_local_hook = update_thread_local_hook
self._update_global_hook = update_global_hook
update_global_hook(self._value)
Used to avoid cyclic import dependencies."""
self._update_thread_local_hook = update_thread_local_hook
self._update_global_hook = update_global_hook
update_global_hook(self.get_global())
else:
class _Unset: pass
unset = _Unset()
_thread_local_state = threading.local()
class State(Generic[_T]):
__slots__ = (
'_name', '_value', '_update_thread_local_hook', '_update_global_hook',
'_validator', '_default_context_manager_value', '__doc__', '__name__',
)
def __init__(
self,
name: str,
default: _T,
help,
update_global_hook: Callable[[_T], None] | None = None,
update_thread_local_hook: Callable[[_T | None], None] | None = None,
validator: Callable[[Any], None] | None = None,
extra_description: str = '',
default_context_manager_value: Any = no_default,
include_in_jit_key: bool = False,
):
self._name = name
self.__name__ = name[4:] if name.startswith('jax_') else name
self.__doc__ = (f"Context manager for `{name}` config option"
f"{extra_description}.\n\n{help}")
if include_in_jit_key:
assert update_global_hook is None
assert update_thread_local_hook is None
update_global_hook = lambda val: _update_global_jit_state(
**{self.__name__: val})
update_thread_local_hook = lambda val: update_thread_local_jit_state(
**{self.__name__: val})
self._update_global_hook = update_global_hook
self._update_thread_local_hook = update_thread_local_hook
self._validator = validator
self._default_context_manager_value = default_context_manager_value
self._set(default)
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
def _set(self, value: _T) -> None:
if self._validator:
self._validator(value)
self._value = value
if self._update_global_hook:
self._update_global_hook(value)
@property
def value(self) -> _T:
val = _thread_local_state.__dict__.get(self._name, unset)
return cast(_T, val) if val is not unset else self._value
def get_local(self) -> Any:
return _thread_local_state.__dict__.get(self._name, unset)
@contextlib.contextmanager
def __call__(self, new_val: Any = no_default):
if new_val is no_default:
if self._default_context_manager_value is not no_default:
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
else:
# no default_value provided to constructor and no value provided as an
# argument, so we raise an error
raise TypeError(f"Context manager for {self.__name__} config option "
"requires an argument representing the new value for "
"the config option.")
if self._validator:
self._validator(new_val)
prev_val = getattr(_thread_local_state, self._name, unset)
setattr(_thread_local_state, self._name, new_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(new_val)
try:
yield
finally:
if prev_val is unset:
delattr(_thread_local_state, self._name)
if self._update_thread_local_hook:
self._update_thread_local_hook(None)
else:
setattr(_thread_local_state, self._name, prev_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(cast(_T, prev_val))
def _add_hooks(self, update_global_hook, update_thread_local_hook):
"""Private method that adds hooks to an existing context-manager.
Used to avoid cyclic import dependencies."""
self._update_thread_local_hook = update_thread_local_hook
self._update_global_hook = update_global_hook
update_global_hook(self._value)
UPGRADE_BOOL_HELP = (
@ -353,6 +483,7 @@ def bool_state(
update_thread_local_hook: Callable[[bool | None], None] | None = None,
upgrade: bool = False,
extra_description: str = '',
include_in_jit_key: bool = False,
) -> State[bool]:
"""Set up thread-local state and return a contextmanager for managing it.
@ -417,7 +548,8 @@ def bool_state(
s = State[bool](
name, default, help, update_global_hook=update_global_hook,
update_thread_local_hook=update_thread_local_hook,
extra_description=extra_description, default_context_manager_value=True)
extra_description=extra_description, default_context_manager_value=True,
include_in_jit_key=include_in_jit_key)
config.add_option(name, s, bool, meta_args=[], meta_kwargs={"help": help})
setattr(Config, name, property(lambda _: s.value))
return s
@ -431,6 +563,7 @@ def enum_state(
*,
update_global_hook: Callable[[str], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
include_in_jit_key: bool = False,
) -> State[str]:
"""Set up thread-local state and return a contextmanager for managing it.
@ -470,6 +603,7 @@ def enum_state(
update_global_hook=update_global_hook,
update_thread_local_hook=update_thread_local_hook,
validator=validator,
include_in_jit_key=include_in_jit_key,
)
config.add_option(
name, s, 'enum',
@ -488,6 +622,7 @@ def optional_enum_state(
*,
update_global_hook: Callable[[str | None], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
include_in_jit_key: bool = False,
) -> State[str | None]:
"""Set up thread-local state and return a contextmanager for managing it.
@ -523,7 +658,7 @@ def optional_enum_state(
s = State['str | None'](
name, default, help, update_global_hook, update_thread_local_hook,
validate
validate, include_in_jit_key=include_in_jit_key,
)
config.add_option(
name, s, 'enum',
@ -541,6 +676,7 @@ def int_state(
*,
update_global_hook: Callable[[int], None] | None = None,
update_thread_local_hook: Callable[[int | None], None] | None = None,
include_in_jit_key: bool = False,
) -> State[int]:
"""Set up thread-local state and return a contextmanager for managing it.
@ -575,7 +711,8 @@ def int_state(
f'got {new_val} of type {type(new_val)}')
s = State[int](name, default, help, update_global_hook,
update_thread_local_hook, validate)
update_thread_local_hook, validate,
include_in_jit_key=include_in_jit_key)
config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help})
setattr(Config, name, property(lambda _: s.value))
return s
@ -826,92 +963,119 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
already_configured_with_absl = False
# The C++ JIT maintains its own copy of several configuration items as
# a global/thread-local state. These methods allow updates to part of the
# state when a configuration value changes.
class _GlobalExtraJitContext(NamedTuple):
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool = False
eager_constant_folding: bool = False
random_seed_offset: int = 0
threefry_partitionable: bool = False
threefry_gpu_kernel_lowering: bool = False
sharding_in_types: bool = False
softmax_custom_jvp: bool = False
xla_profile_version: int = 0
pgle_profiling_runs: int = 0
enable_pgle: bool = False
use_shardy_partitioner: bool = False
if xla_extension_version >= 295:
trace_state = config_ext.Config(None, include_in_jit_key=True)
axis_env_state = config_ext.Config((), include_in_jit_key=True)
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
else:
# The C++ JIT maintains its own copy of several configuration items as
# a global/thread-local state. These methods allow updates to part of the
# state when a configuration value changes.
class _GlobalExtraJitContext(NamedTuple):
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool = False
eager_constant_folding: bool = False
random_seed_offset: int = 0
threefry_partitionable: bool = False
threefry_gpu_kernel_lowering: bool = False
sharding_in_types: bool = False
softmax_custom_jvp: bool = False
xla_profile_version: int = 0
pgle_profiling_runs: int = 0
enable_pgle: bool = False
use_shardy_partitioner: bool = False
def _update_global_jit_state(**kw):
gs = jax_jit.global_state()
context = gs.extra_jit_context or _GlobalExtraJitContext()
gs.extra_jit_context = context._replace(**kw)
def _update_global_jit_state(**kw):
gs = jax_jit.global_state()
context = gs.extra_jit_context or _GlobalExtraJitContext()
gs.extra_jit_context = context._replace(**kw)
class _ThreadLocalExtraJitContext(NamedTuple):
"""A namedtuple containing states to add to the cache key.
class _ThreadLocalExtraJitContext(NamedTuple):
"""A namedtuple containing states to add to the cache key.
Just in time compilation (for jit, pmap, etc) behavior is configurable through
global and thread-local options, used in the cache key.
Just in time compilation (for jit, pmap, etc) behavior is configurable through
global and thread-local options, used in the cache key.
The initialization, which uses both config.py and core.py is done using
`_update_thread_local_jit_state` in core.py to prevent circular imports.
"""
trace_state: Any | None = None
axis_env_state: Hashable = ()
mesh_context_manager: Hashable = ()
compute_on_context_manager: Hashable = ()
xla_metadata_context_manager: Hashable = ()
The initialization, which uses both config.py and core.py is done using
`_update_thread_local_jit_state` in core.py to prevent circular imports.
"""
trace_state: Any | None = None
axis_env_state: Hashable = ()
mesh_context_manager: Hashable = ()
compute_on_context_manager: Hashable = ()
xla_metadata_context_manager: Hashable = ()
# Values set by _StateContextManager context managers.
# CAUTION: these must be initialized to `None`! The state context manager
# restores these to None on exit. If the object default is not `None`, the
# context manager is not a no-op, which leads to problems with stale state
# (e.g. spurious cache misses in tests).
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool | None = None
eager_constant_folding : bool | None = None
random_seed_offset: int | None = None
threefry_partitionable: bool | None = None
threefry_gpu_kernel_lowering: bool | None = None
sharding_in_types: bool | None = None
softmax_custom_jvp: bool | None = None
xla_profile_version: int | None = None
pgle_profiling_runs: int | None = None
enable_pgle: bool | None = None
use_shardy_partitioner: bool | None = None
# Values set by _StateContextManager context managers.
# CAUTION: these must be initialized to `None`! The state context manager
# restores these to None on exit. If the object default is not `None`, the
# context manager is not a no-op, which leads to problems with stale state
# (e.g. spurious cache misses in tests).
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool | None = None
eager_constant_folding : bool | None = None
random_seed_offset: int | None = None
threefry_partitionable: bool | None = None
threefry_gpu_kernel_lowering: bool | None = None
sharding_in_types: bool | None = None
softmax_custom_jvp: bool | None = None
xla_profile_version: int | None = None
pgle_profiling_runs: int | None = None
enable_pgle: bool | None = None
use_shardy_partitioner: bool | None = None
class _ThreadLocalStateCache(threading.local):
""""A thread local cache for _ThreadLocalExtraJitContext
class _ThreadLocalStateCache(threading.local):
""""A thread local cache for _ThreadLocalExtraJitContext
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus
incurring dispatch overhead for comparing this python object during jit calls.
We want to deduplicate the objects that have the same hash/equality to also
have the same object ID, since the equality check is much faster if the object
IDs match.
"""
def __init__(self):
self.canonicalize = functools.lru_cache(128)(lambda x: x)
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus
incurring dispatch overhead for comparing this python object during jit calls.
We want to deduplicate the objects that have the same hash/equality to also
have the same object ID, since the equality check is much faster if the object
IDs match.
"""
def __init__(self):
self.canonicalize = functools.lru_cache(128)(lambda x: x)
_thread_local_state_cache = _ThreadLocalStateCache()
_thread_local_state_cache = _ThreadLocalStateCache()
def update_thread_local_jit_state(**kw):
tls = jax_jit.thread_local_state()
# After xla_client._version >= 70, the thread_local object will necessarily
# be initialized when accessed. The following line can be removed when the
# minimum jaxlib version is past version 70
context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
tmp = context._replace(**kw)
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
def update_thread_local_jit_state(**kw):
tls = jax_jit.thread_local_state()
# After xla_client._version >= 70, the thread_local object will necessarily
# be initialized when accessed. The following line can be removed when the
# minimum jaxlib version is past version 70
context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
tmp = context._replace(**kw)
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
class JitConfig:
def __init__(self, name):
self._name = name
def value(self):
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
def get_local(self):
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
def set_local(self, value):
update_thread_local_jit_state(**{self._name: value})
trace_state = JitConfig('trace_state')
axis_env_state = JitConfig('axis_env_state')
mesh_context_manager = JitConfig('mesh_context_manager')
compute_on_context_manager = JitConfig('compute_on_context_manager')
xla_metadata_context_manager = JitConfig('xla_metadata_context_manager')
# TODO(b/214340779): remove flag when XLA:CPU is improved.
jax2tf_associative_scan_reductions = bool_state(
@ -1102,10 +1266,7 @@ random_seed_offset = int_state(
name='jax_random_seed_offset',
default=0,
help=('Offset to all random seeds (e.g. argument to jax.random.key()).'),
update_global_hook=lambda val: _update_global_jit_state(
random_seed_offset=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
random_seed_offset=val)
include_in_jit_key=True,
)
legacy_prng_key = enum_state(
@ -1140,10 +1301,7 @@ threefry_partitionable = bool_state(
'may result in extraneous communication and/or redundant distributed '
'computation. With this flag, the communication overheads disappear '
'in some cases.'),
update_global_hook=lambda val: _update_global_jit_state(
threefry_partitionable=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
threefry_partitionable=val))
include_in_jit_key=True)
threefry_gpu_kernel_lowering = bool_state(
name='jax_threefry_gpu_kernel_lowering',
@ -1151,20 +1309,14 @@ threefry_gpu_kernel_lowering = bool_state(
help=('On GPU, lower threefry PRNG operations to a kernel implementation. '
'This makes compile times faster at a potential runtime memory '
'cost.'),
update_global_hook=lambda val: _update_global_jit_state(
threefry_gpu_kernel_lowering=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
threefry_gpu_kernel_lowering=val))
include_in_jit_key=True)
sharding_in_types = bool_state(
name='jax_sharding_in_types',
default=False,
help=('When True, enables forward only sharding propagation in JAX and '
'avals have sharding on them.'),
update_global_hook=lambda val: _update_global_jit_state(
sharding_in_types=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
sharding_in_types=val))
include_in_jit_key=True)
data_dependent_tracing_fallback = bool_state(
name='jax_data_dependent_tracing_fallback',
@ -1179,10 +1331,7 @@ softmax_custom_jvp = bool_state(
help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should '
'improve memory usage and stability. Set True to use new '
'behavior. See https://github.com/jax-ml/jax/pull/15677'),
update_global_hook=lambda val: _update_global_jit_state(
softmax_custom_jvp=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
softmax_custom_jvp=val))
include_in_jit_key=True)
enable_custom_vjp_by_custom_transpose = bool_state(
@ -1298,9 +1447,7 @@ enable_pgle = bool_state(
'number times with collected data provided to the profile guided latency '
'estimator.'
),
update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
enable_pgle=val),
include_in_jit_key=True,
)
pgle_profiling_runs = int_state(
@ -1310,12 +1457,7 @@ pgle_profiling_runs = int_state(
'Amount of times module should be profiled before recompilation when '
'PGLE is used.'
),
update_global_hook=lambda val: _update_global_jit_state(
pgle_profiling_runs=val
),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
pgle_profiling_runs=val
),
include_in_jit_key=True,
)
pgle_aggregation_percentile = int_state(
@ -1381,10 +1523,7 @@ numpy_dtype_promotion = enum_state(
'between arrays. Options are "standard" or "strict"; in strict-mode, '
'binary operations between arrays of differing strongly-specified '
'dtypes will result in an error.'),
update_global_hook=lambda val: \
_update_global_jit_state(numpy_dtype_promotion=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(numpy_dtype_promotion=val))
include_in_jit_key=True)
disallow_mesh_context_manager = bool_state(
name='jax_disallow_mesh_context_manager',
@ -1470,10 +1609,7 @@ numpy_rank_promotion = enum_state(
default='allow',
help=('Control NumPy-style automatic rank promotion broadcasting '
'("allow", "warn", or "raise").'),
update_global_hook=lambda val: \
_update_global_jit_state(numpy_rank_promotion=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(numpy_rank_promotion=val))
include_in_jit_key=True)
default_matmul_precision = optional_enum_state(
name='jax_default_matmul_precision',
@ -1509,10 +1645,7 @@ default_matmul_precision = optional_enum_state(
'"algorithm" for functions that perform matrix multiplications, like '
':func:`jax.lax.dot`. To specify an algorithm, set this option to '
'the name of a :class:`~jax.lax.DotAlgorithmPreset`.\n\n'),
update_global_hook=lambda val: \
_update_global_jit_state(default_matmul_precision=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(default_matmul_precision=val))
include_in_jit_key=True)
traceback_filtering = enum_state(
name = 'jax_traceback_filtering',
@ -1547,20 +1680,14 @@ dynamic_shapes = bool_state(
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
help=('Enables experimental features for staging out computations with '
'dynamic shapes.'),
update_global_hook=lambda val: \
_update_global_jit_state(dynamic_shapes=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(dynamic_shapes=val))
include_in_jit_key=True)
# This is for stackless backward compat with e.g. equinox
eager_constant_folding = bool_state(
name='eager_constant_folding',
default=False,
help=('Attempt constant folding during staging.'),
update_global_hook=lambda val: \
_update_global_jit_state(eager_constant_folding=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(eager_constant_folding=val))
include_in_jit_key=True)
# This flag is temporary during rollout of the remat barrier.
# TODO(parkers): Remove if there are no complaints.
@ -1619,10 +1746,7 @@ jax_xla_profile_version = int_state(
'Optional profile version for XLA compilation. This is meaningful '
'only when XLA is configured to support the remote compilation '
'profile feature.'),
update_global_hook=lambda val: _update_global_jit_state(
xla_profile_version=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
xla_profile_version=val),
include_in_jit_key=True,
)
@contextlib.contextmanager
@ -1821,10 +1945,5 @@ use_shardy_partitioner = bool_state(
'framework for MLIR. Currently Shardy is experimental in JAX. See '
'www.github.com/openxla/shardy'
),
update_global_hook=lambda val: _update_global_jit_state(
use_shardy_partitioner=val
),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
use_shardy_partitioner=val
),
include_in_jit_key=True,
)

View File

@ -1016,18 +1016,16 @@ class TracingContext(threading.local):
def set_trace(self, trace):
self.trace = trace
ts = ref(trace) if trace is not None else None
config.update_thread_local_jit_state(trace_state=ts)
config.trace_state.set_local(ts)
def set_axis_env(self, axis_env):
self.axis_env = axis_env
config.update_thread_local_jit_state(
axis_env_state=self.axis_env.as_hashable_key())
config.axis_env_state.set_local(axis_env.as_hashable_key())
def update_thread_local_jit_state(self):
ts = ref(self.trace) if self.trace is not None else None
config.update_thread_local_jit_state(
trace_state=ts,
axis_env_state=self.axis_env.as_hashable_key())
config.trace_state.set_local(ts)
config.axis_env_state.set_local(self.axis_env.as_hashable_key())
trace_ctx = TracingContext()
@ -1071,10 +1069,7 @@ def _initialize_jax_jit_thread_local_state():
This function does not live in `config.py`, to prevent circular imports.
"""
tls = jax_jit.thread_local_state()
if tls.extra_jit_context is None:
trace_ctx.update_thread_local_jit_state()
trace_ctx.update_thread_local_jit_state()
jax_jit.set_thread_local_state_initialization_callback(
_initialize_jax_jit_thread_local_state)

View File

@ -224,17 +224,17 @@ class Mesh(contextlib.ContextDecorator):
new_env = thread_resources.stack[-1].with_mesh(self)
thread_resources.stack.append(new_env)
thread_resources.env = new_env
jax_config.update_thread_local_jit_state(
mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack
if not t.physical_mesh.empty))
jax_config.mesh_context_manager.set_local(
tuple(t.physical_mesh for t in thread_resources.stack
if not t.physical_mesh.empty))
return self
def __exit__(self, exc_type, exc_value, traceback):
thread_resources.stack.pop()
thread_resources.env = thread_resources.stack[-1]
jax_config.update_thread_local_jit_state(
mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack
if not t.physical_mesh.empty))
jax_config.mesh_context_manager.set_local(
tuple(t.physical_mesh for t in thread_resources.stack
if not t.physical_mesh.empty))
return False
@property
@ -410,7 +410,7 @@ class AbstractMesh:
@staticmethod
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
jax_config.update_thread_local_jit_state(mesh_context_manager=mesh)
jax_config.mesh_context_manager.set_local(mesh)
return

View File

@ -41,15 +41,13 @@ def set_xla_metadata(*args, **kwargs):
thread_local_metadata.val,
new_metadata,
)
config.update_thread_local_jit_state(
xla_metadata_context_manager=tuple(
(v, k) for k, v in sorted(new_metadata.items())))
config.xla_metadata_context_manager.set_local(
tuple((v, k) for k, v in sorted(new_metadata.items()))
)
try:
yield
finally:
thread_local_metadata.val = prev_metadata
config.update_thread_local_jit_state(
xla_metadata_context_manager=tuple(
(v, k) for k, v in sorted(prev_metadata.items())
)
config.xla_metadata_context_manager.set_local(
tuple((v, k) for k, v in sorted(prev_metadata.items()))
)

View File

@ -2215,8 +2215,6 @@ class CppPmapTest(PythonPmapTest):
pmaped_f(inputs)
self.assertEqual(pmaped_f._cache_size, 1)
config.update_thread_local_jit_state()
pmaped_f(inputs)
self.assertEqual(pmaped_f._cache_size, 1)