mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Reverts a913fbf2fddc5b8c1b6c85b159d0eeb1bf65d461
PiperOrigin-RevId: 693360032
This commit is contained in:
parent
478b750c29
commit
0e8acff5c6
@ -123,8 +123,8 @@ def _update_debug_special_global(_):
|
||||
jax_jit.global_state().post_hook = None
|
||||
|
||||
def _update_debug_special_thread_local(_):
|
||||
if (getattr(config._thread_local_state, "jax_debug_nans", False) or
|
||||
getattr(config._thread_local_state, "jax_debug_infs", False)):
|
||||
if (config.debug_nans.get_local() == True or
|
||||
config.debug_infs.get_local() == True):
|
||||
jax_jit.thread_local_state().post_hook = _nan_check_posthook
|
||||
else:
|
||||
jax_jit.thread_local_state().post_hook = None
|
||||
|
@ -29,8 +29,8 @@ compute_on_context = ComputeOnContext()
|
||||
@contextmanager
|
||||
def extend_compute_type(c_type: str):
|
||||
compute_on_context.stack.append(c_type)
|
||||
config.update_thread_local_jit_state(
|
||||
compute_on_context_manager=tuple(compute_on_context.stack))
|
||||
config.compute_on_context_manager.set_local(
|
||||
tuple(compute_on_context.stack))
|
||||
try:
|
||||
if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1:
|
||||
raise NotImplementedError(
|
||||
@ -39,8 +39,7 @@ def extend_compute_type(c_type: str):
|
||||
yield compute_on_context.stack[-1]
|
||||
finally:
|
||||
compute_on_context.stack.pop()
|
||||
config.update_thread_local_jit_state(
|
||||
compute_on_context_manager=tuple(compute_on_context.stack))
|
||||
config.compute_on_context_manager.set_local(tuple(compute_on_context.stack))
|
||||
|
||||
def current_compute_type() -> str | None:
|
||||
return compute_on_context.stack[-1] if compute_on_context.stack else None
|
||||
|
@ -22,14 +22,23 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast
|
||||
from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast, TYPE_CHECKING
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src.lib import guard_lib
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import logging_config
|
||||
|
||||
# TODO(phawkins): reenable pytype after xla_extension_version >= 295
|
||||
# pytype: skip-file
|
||||
|
||||
if xla_extension_version >= 295:
|
||||
config_ext = xla_client._xla.config
|
||||
else:
|
||||
config_ext = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_T = TypeVar('_T')
|
||||
@ -191,49 +200,79 @@ class Config:
|
||||
already_configured_with_absl = True
|
||||
|
||||
|
||||
def trace_context():
|
||||
"""Returns a tuple of configuration values that affect tracing.
|
||||
if xla_extension_version >= 295:
|
||||
def trace_context():
|
||||
"""Returns a tuple of configuration values that affect tracing.
|
||||
|
||||
These values are included in the cache key for linear_util.cache.
|
||||
These values are included in the cache key for linear_util.cache.
|
||||
|
||||
Values included in this set should also most likely be included in
|
||||
the C++ JIT state, which is handled separately.
|
||||
"""
|
||||
tls = jax_jit.thread_local_state()
|
||||
axis_env_state = ()
|
||||
mesh_context_manager = ()
|
||||
xla_metadata_context_manager = ()
|
||||
compute_on_context_manager = ()
|
||||
Values included in this set should also most likely be included in
|
||||
the C++ JIT state, which is handled separately.
|
||||
"""
|
||||
return (axis_env_state.value, mesh_context_manager.value, xla_metadata_context_manager.value,
|
||||
compute_on_context_manager.value, enable_x64.value,
|
||||
numpy_rank_promotion.value, default_matmul_precision.value,
|
||||
dynamic_shapes.value,
|
||||
eager_constant_folding.value,
|
||||
numpy_dtype_promotion.value,
|
||||
default_device.value, random_seed_offset.value,
|
||||
threefry_partitionable.value,
|
||||
threefry_gpu_kernel_lowering.value,
|
||||
sharding_in_types.value,
|
||||
softmax_custom_jvp.value,
|
||||
enable_memories.value,
|
||||
disable_jit.value,
|
||||
debug_key_reuse.value,
|
||||
jax_xla_profile_version.value,
|
||||
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
||||
hlo_source_file_canonicalization_regex.value,
|
||||
pgle_profiling_runs.value,
|
||||
enable_pgle.value,
|
||||
use_shardy_partitioner.value)
|
||||
else:
|
||||
def trace_context():
|
||||
"""Returns a tuple of configuration values that affect tracing.
|
||||
|
||||
context: Any = tls.extra_jit_context
|
||||
if context and context.axis_env_state is not None:
|
||||
axis_env_state = context.axis_env_state
|
||||
if context and context.mesh_context_manager:
|
||||
mesh_context_manager = context.mesh_context_manager
|
||||
if context and context.xla_metadata_context_manager:
|
||||
xla_metadata_context_manager = context.xla_metadata_context_manager
|
||||
if context and context.compute_on_context_manager:
|
||||
compute_on_context_manager = context.compute_on_context_manager
|
||||
return (axis_env_state, mesh_context_manager, xla_metadata_context_manager,
|
||||
compute_on_context_manager, enable_x64.value,
|
||||
numpy_rank_promotion.value, default_matmul_precision.value,
|
||||
dynamic_shapes.value,
|
||||
eager_constant_folding.value,
|
||||
numpy_dtype_promotion.value,
|
||||
default_device.value, random_seed_offset.value,
|
||||
threefry_partitionable.value,
|
||||
threefry_gpu_kernel_lowering.value,
|
||||
sharding_in_types.value,
|
||||
softmax_custom_jvp.value,
|
||||
enable_memories.value,
|
||||
disable_jit.value,
|
||||
debug_key_reuse.value,
|
||||
jax_xla_profile_version.value,
|
||||
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
||||
hlo_source_file_canonicalization_regex.value,
|
||||
pgle_profiling_runs.value,
|
||||
enable_pgle.value,
|
||||
use_shardy_partitioner.value)
|
||||
These values are included in the cache key for linear_util.cache.
|
||||
|
||||
Values included in this set should also most likely be included in
|
||||
the C++ JIT state, which is handled separately.
|
||||
"""
|
||||
tls = jax_jit.thread_local_state()
|
||||
axis_env_state = ()
|
||||
mesh_context_manager = ()
|
||||
xla_metadata_context_manager = ()
|
||||
compute_on_context_manager = ()
|
||||
|
||||
context: Any = tls.extra_jit_context
|
||||
if context and context.axis_env_state is not None:
|
||||
axis_env_state = context.axis_env_state
|
||||
if context and context.mesh_context_manager:
|
||||
mesh_context_manager = context.mesh_context_manager
|
||||
if context and context.xla_metadata_context_manager:
|
||||
xla_metadata_context_manager = context.xla_metadata_context_manager
|
||||
if context and context.compute_on_context_manager:
|
||||
compute_on_context_manager = context.compute_on_context_manager
|
||||
return (axis_env_state, mesh_context_manager, xla_metadata_context_manager,
|
||||
compute_on_context_manager, enable_x64.value,
|
||||
numpy_rank_promotion.value, default_matmul_precision.value,
|
||||
dynamic_shapes.value,
|
||||
eager_constant_folding.value,
|
||||
numpy_dtype_promotion.value,
|
||||
default_device.value, random_seed_offset.value,
|
||||
threefry_partitionable.value,
|
||||
threefry_gpu_kernel_lowering.value,
|
||||
sharding_in_types.value,
|
||||
softmax_custom_jvp.value,
|
||||
enable_memories.value,
|
||||
disable_jit.value,
|
||||
debug_key_reuse.value,
|
||||
jax_xla_profile_version.value,
|
||||
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
||||
hlo_source_file_canonicalization_regex.value,
|
||||
pgle_profiling_runs.value,
|
||||
enable_pgle.value,
|
||||
use_shardy_partitioner.value)
|
||||
|
||||
config = Config()
|
||||
|
||||
@ -245,94 +284,185 @@ parse_flags_with_absl = config.parse_flags_with_absl
|
||||
class NoDefault: pass
|
||||
no_default = NoDefault()
|
||||
|
||||
if xla_extension_version >= 295:
|
||||
class State(config_ext.Config[_T]):
|
||||
|
||||
class _Unset: pass
|
||||
unset = _Unset()
|
||||
__slots__ = (
|
||||
'_name', '_update_thread_local_hook', '_update_global_hook',
|
||||
'_validator', '_default_context_manager_value', '__doc__', '__name__',
|
||||
)
|
||||
|
||||
_thread_local_state = threading.local()
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
default: _T,
|
||||
help,
|
||||
update_global_hook: Callable[[_T], None] | None = None,
|
||||
update_thread_local_hook: Callable[[_T | None], None] | None = None,
|
||||
validator: Callable[[Any], None] | None = None,
|
||||
extra_description: str = '',
|
||||
default_context_manager_value: Any = no_default,
|
||||
include_in_jit_key: bool = False,
|
||||
):
|
||||
super().__init__(default, include_in_jit_key)
|
||||
self._name = name
|
||||
self.__name__ = name[4:] if name.startswith('jax_') else name
|
||||
self.__doc__ = (f"Context manager for `{name}` config option"
|
||||
f"{extra_description}.\n\n{help}")
|
||||
self._update_global_hook = update_global_hook
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._validator = validator
|
||||
self._default_context_manager_value = default_context_manager_value
|
||||
if self._validator:
|
||||
self._validator(default)
|
||||
if self._update_global_hook:
|
||||
self._update_global_hook(default)
|
||||
|
||||
class State(Generic[_T]):
|
||||
def __bool__(self) -> NoReturn:
|
||||
raise TypeError(
|
||||
"bool() not supported for instances of type '{0}' "
|
||||
"(did you mean to use '{0}.value' instead?)".format(
|
||||
type(self).__name__))
|
||||
|
||||
__slots__ = (
|
||||
'_name', '_value', '_update_thread_local_hook', '_update_global_hook',
|
||||
'_validator', '_default_context_manager_value', '__doc__', '__name__',
|
||||
)
|
||||
def _set(self, value: _T) -> None:
|
||||
if self._validator:
|
||||
self._validator(value)
|
||||
self.set_global(value)
|
||||
if self._update_global_hook:
|
||||
self._update_global_hook(value)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
default: _T,
|
||||
help,
|
||||
update_global_hook: Callable[[_T], None] | None = None,
|
||||
update_thread_local_hook: Callable[[_T | None], None] | None = None,
|
||||
validator: Callable[[Any], None] | None = None,
|
||||
extra_description: str = '',
|
||||
default_context_manager_value: Any = no_default,
|
||||
):
|
||||
self._name = name
|
||||
self.__name__ = name[4:] if name.startswith('jax_') else name
|
||||
self.__doc__ = (f"Context manager for `{name}` config option"
|
||||
f"{extra_description}.\n\n{help}")
|
||||
self._update_global_hook = update_global_hook
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._validator = validator
|
||||
self._default_context_manager_value = default_context_manager_value
|
||||
self._set(default)
|
||||
|
||||
def __bool__(self) -> NoReturn:
|
||||
raise TypeError(
|
||||
"bool() not supported for instances of type '{0}' "
|
||||
"(did you mean to use '{0}.value' instead?)".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _set(self, value: _T) -> None:
|
||||
if self._validator:
|
||||
self._validator(value)
|
||||
self._value = value
|
||||
if self._update_global_hook:
|
||||
self._update_global_hook(value)
|
||||
|
||||
@property
|
||||
def value(self) -> _T:
|
||||
val = _thread_local_state.__dict__.get(self._name, unset)
|
||||
return cast(_T, val) if val is not unset else self._value
|
||||
|
||||
@contextlib.contextmanager
|
||||
def __call__(self, new_val: Any = no_default):
|
||||
if new_val is no_default:
|
||||
if self._default_context_manager_value is not no_default:
|
||||
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
|
||||
else:
|
||||
# no default_value provided to constructor and no value provided as an
|
||||
# argument, so we raise an error
|
||||
raise TypeError(f"Context manager for {self.__name__} config option "
|
||||
"requires an argument representing the new value for "
|
||||
"the config option.")
|
||||
if self._validator:
|
||||
self._validator(new_val)
|
||||
prev_val = getattr(_thread_local_state, self._name, unset)
|
||||
setattr(_thread_local_state, self._name, new_val)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(new_val)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prev_val is unset:
|
||||
delattr(_thread_local_state, self._name)
|
||||
@contextlib.contextmanager
|
||||
def __call__(self, new_val: Any = no_default):
|
||||
if new_val is no_default:
|
||||
if self._default_context_manager_value is not no_default:
|
||||
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
|
||||
else:
|
||||
# no default_value provided to constructor and no value provided as an
|
||||
# argument, so we raise an error
|
||||
raise TypeError(f"Context manager for {self.__name__} config option "
|
||||
"requires an argument representing the new value for "
|
||||
"the config option.")
|
||||
if self._validator:
|
||||
self._validator(new_val)
|
||||
prev_val = self.swap_local(new_val)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(new_val)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.set_local(prev_val)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(None)
|
||||
else:
|
||||
setattr(_thread_local_state, self._name, prev_val)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(cast(_T, prev_val))
|
||||
if prev_val is config_ext.unset:
|
||||
self._update_thread_local_hook(None)
|
||||
else:
|
||||
self._update_thread_local_hook(cast(Optional[Any], prev_val))
|
||||
|
||||
def _add_hooks(self, update_global_hook, update_thread_local_hook):
|
||||
"""Private method that adds hooks to an existing context-manager.
|
||||
def _add_hooks(self, update_global_hook, update_thread_local_hook):
|
||||
"""Private method that adds hooks to an existing context-manager.
|
||||
|
||||
Used to avoid cyclic import dependencies."""
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._update_global_hook = update_global_hook
|
||||
update_global_hook(self._value)
|
||||
Used to avoid cyclic import dependencies."""
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._update_global_hook = update_global_hook
|
||||
update_global_hook(self.get_global())
|
||||
|
||||
else:
|
||||
class _Unset: pass
|
||||
unset = _Unset()
|
||||
|
||||
_thread_local_state = threading.local()
|
||||
|
||||
class State(Generic[_T]):
|
||||
|
||||
__slots__ = (
|
||||
'_name', '_value', '_update_thread_local_hook', '_update_global_hook',
|
||||
'_validator', '_default_context_manager_value', '__doc__', '__name__',
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
default: _T,
|
||||
help,
|
||||
update_global_hook: Callable[[_T], None] | None = None,
|
||||
update_thread_local_hook: Callable[[_T | None], None] | None = None,
|
||||
validator: Callable[[Any], None] | None = None,
|
||||
extra_description: str = '',
|
||||
default_context_manager_value: Any = no_default,
|
||||
include_in_jit_key: bool = False,
|
||||
):
|
||||
self._name = name
|
||||
self.__name__ = name[4:] if name.startswith('jax_') else name
|
||||
self.__doc__ = (f"Context manager for `{name}` config option"
|
||||
f"{extra_description}.\n\n{help}")
|
||||
if include_in_jit_key:
|
||||
assert update_global_hook is None
|
||||
assert update_thread_local_hook is None
|
||||
update_global_hook = lambda val: _update_global_jit_state(
|
||||
**{self.__name__: val})
|
||||
update_thread_local_hook = lambda val: update_thread_local_jit_state(
|
||||
**{self.__name__: val})
|
||||
self._update_global_hook = update_global_hook
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._validator = validator
|
||||
self._default_context_manager_value = default_context_manager_value
|
||||
self._set(default)
|
||||
def __bool__(self) -> NoReturn:
|
||||
raise TypeError(
|
||||
"bool() not supported for instances of type '{0}' "
|
||||
"(did you mean to use '{0}.value' instead?)".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _set(self, value: _T) -> None:
|
||||
if self._validator:
|
||||
self._validator(value)
|
||||
self._value = value
|
||||
if self._update_global_hook:
|
||||
self._update_global_hook(value)
|
||||
|
||||
@property
|
||||
def value(self) -> _T:
|
||||
val = _thread_local_state.__dict__.get(self._name, unset)
|
||||
return cast(_T, val) if val is not unset else self._value
|
||||
|
||||
def get_local(self) -> Any:
|
||||
return _thread_local_state.__dict__.get(self._name, unset)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def __call__(self, new_val: Any = no_default):
|
||||
if new_val is no_default:
|
||||
if self._default_context_manager_value is not no_default:
|
||||
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
|
||||
else:
|
||||
# no default_value provided to constructor and no value provided as an
|
||||
# argument, so we raise an error
|
||||
raise TypeError(f"Context manager for {self.__name__} config option "
|
||||
"requires an argument representing the new value for "
|
||||
"the config option.")
|
||||
if self._validator:
|
||||
self._validator(new_val)
|
||||
prev_val = getattr(_thread_local_state, self._name, unset)
|
||||
setattr(_thread_local_state, self._name, new_val)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(new_val)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prev_val is unset:
|
||||
delattr(_thread_local_state, self._name)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(None)
|
||||
else:
|
||||
setattr(_thread_local_state, self._name, prev_val)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(cast(_T, prev_val))
|
||||
|
||||
def _add_hooks(self, update_global_hook, update_thread_local_hook):
|
||||
"""Private method that adds hooks to an existing context-manager.
|
||||
|
||||
Used to avoid cyclic import dependencies."""
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._update_global_hook = update_global_hook
|
||||
update_global_hook(self._value)
|
||||
|
||||
|
||||
UPGRADE_BOOL_HELP = (
|
||||
@ -353,6 +483,7 @@ def bool_state(
|
||||
update_thread_local_hook: Callable[[bool | None], None] | None = None,
|
||||
upgrade: bool = False,
|
||||
extra_description: str = '',
|
||||
include_in_jit_key: bool = False,
|
||||
) -> State[bool]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
@ -417,7 +548,8 @@ def bool_state(
|
||||
s = State[bool](
|
||||
name, default, help, update_global_hook=update_global_hook,
|
||||
update_thread_local_hook=update_thread_local_hook,
|
||||
extra_description=extra_description, default_context_manager_value=True)
|
||||
extra_description=extra_description, default_context_manager_value=True,
|
||||
include_in_jit_key=include_in_jit_key)
|
||||
config.add_option(name, s, bool, meta_args=[], meta_kwargs={"help": help})
|
||||
setattr(Config, name, property(lambda _: s.value))
|
||||
return s
|
||||
@ -431,6 +563,7 @@ def enum_state(
|
||||
*,
|
||||
update_global_hook: Callable[[str], None] | None = None,
|
||||
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
||||
include_in_jit_key: bool = False,
|
||||
) -> State[str]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
@ -470,6 +603,7 @@ def enum_state(
|
||||
update_global_hook=update_global_hook,
|
||||
update_thread_local_hook=update_thread_local_hook,
|
||||
validator=validator,
|
||||
include_in_jit_key=include_in_jit_key,
|
||||
)
|
||||
config.add_option(
|
||||
name, s, 'enum',
|
||||
@ -488,6 +622,7 @@ def optional_enum_state(
|
||||
*,
|
||||
update_global_hook: Callable[[str | None], None] | None = None,
|
||||
update_thread_local_hook: Callable[[str | None], None] | None = None,
|
||||
include_in_jit_key: bool = False,
|
||||
) -> State[str | None]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
@ -523,7 +658,7 @@ def optional_enum_state(
|
||||
|
||||
s = State['str | None'](
|
||||
name, default, help, update_global_hook, update_thread_local_hook,
|
||||
validate
|
||||
validate, include_in_jit_key=include_in_jit_key,
|
||||
)
|
||||
config.add_option(
|
||||
name, s, 'enum',
|
||||
@ -541,6 +676,7 @@ def int_state(
|
||||
*,
|
||||
update_global_hook: Callable[[int], None] | None = None,
|
||||
update_thread_local_hook: Callable[[int | None], None] | None = None,
|
||||
include_in_jit_key: bool = False,
|
||||
) -> State[int]:
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
@ -575,7 +711,8 @@ def int_state(
|
||||
f'got {new_val} of type {type(new_val)}')
|
||||
|
||||
s = State[int](name, default, help, update_global_hook,
|
||||
update_thread_local_hook, validate)
|
||||
update_thread_local_hook, validate,
|
||||
include_in_jit_key=include_in_jit_key)
|
||||
config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help})
|
||||
setattr(Config, name, property(lambda _: s.value))
|
||||
return s
|
||||
@ -826,92 +963,119 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
|
||||
already_configured_with_absl = False
|
||||
|
||||
|
||||
# The C++ JIT maintains its own copy of several configuration items as
|
||||
# a global/thread-local state. These methods allow updates to part of the
|
||||
# state when a configuration value changes.
|
||||
class _GlobalExtraJitContext(NamedTuple):
|
||||
numpy_rank_promotion: str | None = None
|
||||
numpy_dtype_promotion: str | None = None
|
||||
default_matmul_precision: Any | None = None
|
||||
dynamic_shapes: bool = False
|
||||
eager_constant_folding: bool = False
|
||||
random_seed_offset: int = 0
|
||||
threefry_partitionable: bool = False
|
||||
threefry_gpu_kernel_lowering: bool = False
|
||||
sharding_in_types: bool = False
|
||||
softmax_custom_jvp: bool = False
|
||||
xla_profile_version: int = 0
|
||||
pgle_profiling_runs: int = 0
|
||||
enable_pgle: bool = False
|
||||
use_shardy_partitioner: bool = False
|
||||
if xla_extension_version >= 295:
|
||||
trace_state = config_ext.Config(None, include_in_jit_key=True)
|
||||
axis_env_state = config_ext.Config((), include_in_jit_key=True)
|
||||
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
else:
|
||||
# The C++ JIT maintains its own copy of several configuration items as
|
||||
# a global/thread-local state. These methods allow updates to part of the
|
||||
# state when a configuration value changes.
|
||||
class _GlobalExtraJitContext(NamedTuple):
|
||||
numpy_rank_promotion: str | None = None
|
||||
numpy_dtype_promotion: str | None = None
|
||||
default_matmul_precision: Any | None = None
|
||||
dynamic_shapes: bool = False
|
||||
eager_constant_folding: bool = False
|
||||
random_seed_offset: int = 0
|
||||
threefry_partitionable: bool = False
|
||||
threefry_gpu_kernel_lowering: bool = False
|
||||
sharding_in_types: bool = False
|
||||
softmax_custom_jvp: bool = False
|
||||
xla_profile_version: int = 0
|
||||
pgle_profiling_runs: int = 0
|
||||
enable_pgle: bool = False
|
||||
use_shardy_partitioner: bool = False
|
||||
|
||||
|
||||
def _update_global_jit_state(**kw):
|
||||
gs = jax_jit.global_state()
|
||||
context = gs.extra_jit_context or _GlobalExtraJitContext()
|
||||
gs.extra_jit_context = context._replace(**kw)
|
||||
def _update_global_jit_state(**kw):
|
||||
gs = jax_jit.global_state()
|
||||
context = gs.extra_jit_context or _GlobalExtraJitContext()
|
||||
gs.extra_jit_context = context._replace(**kw)
|
||||
|
||||
|
||||
class _ThreadLocalExtraJitContext(NamedTuple):
|
||||
"""A namedtuple containing states to add to the cache key.
|
||||
class _ThreadLocalExtraJitContext(NamedTuple):
|
||||
"""A namedtuple containing states to add to the cache key.
|
||||
|
||||
Just in time compilation (for jit, pmap, etc) behavior is configurable through
|
||||
global and thread-local options, used in the cache key.
|
||||
Just in time compilation (for jit, pmap, etc) behavior is configurable through
|
||||
global and thread-local options, used in the cache key.
|
||||
|
||||
The initialization, which uses both config.py and core.py is done using
|
||||
`_update_thread_local_jit_state` in core.py to prevent circular imports.
|
||||
"""
|
||||
trace_state: Any | None = None
|
||||
axis_env_state: Hashable = ()
|
||||
mesh_context_manager: Hashable = ()
|
||||
compute_on_context_manager: Hashable = ()
|
||||
xla_metadata_context_manager: Hashable = ()
|
||||
The initialization, which uses both config.py and core.py is done using
|
||||
`_update_thread_local_jit_state` in core.py to prevent circular imports.
|
||||
"""
|
||||
trace_state: Any | None = None
|
||||
axis_env_state: Hashable = ()
|
||||
mesh_context_manager: Hashable = ()
|
||||
compute_on_context_manager: Hashable = ()
|
||||
xla_metadata_context_manager: Hashable = ()
|
||||
|
||||
# Values set by _StateContextManager context managers.
|
||||
# CAUTION: these must be initialized to `None`! The state context manager
|
||||
# restores these to None on exit. If the object default is not `None`, the
|
||||
# context manager is not a no-op, which leads to problems with stale state
|
||||
# (e.g. spurious cache misses in tests).
|
||||
numpy_rank_promotion: str | None = None
|
||||
numpy_dtype_promotion: str | None = None
|
||||
default_matmul_precision: Any | None = None
|
||||
dynamic_shapes: bool | None = None
|
||||
eager_constant_folding : bool | None = None
|
||||
random_seed_offset: int | None = None
|
||||
threefry_partitionable: bool | None = None
|
||||
threefry_gpu_kernel_lowering: bool | None = None
|
||||
sharding_in_types: bool | None = None
|
||||
softmax_custom_jvp: bool | None = None
|
||||
xla_profile_version: int | None = None
|
||||
pgle_profiling_runs: int | None = None
|
||||
enable_pgle: bool | None = None
|
||||
use_shardy_partitioner: bool | None = None
|
||||
# Values set by _StateContextManager context managers.
|
||||
# CAUTION: these must be initialized to `None`! The state context manager
|
||||
# restores these to None on exit. If the object default is not `None`, the
|
||||
# context manager is not a no-op, which leads to problems with stale state
|
||||
# (e.g. spurious cache misses in tests).
|
||||
numpy_rank_promotion: str | None = None
|
||||
numpy_dtype_promotion: str | None = None
|
||||
default_matmul_precision: Any | None = None
|
||||
dynamic_shapes: bool | None = None
|
||||
eager_constant_folding : bool | None = None
|
||||
random_seed_offset: int | None = None
|
||||
threefry_partitionable: bool | None = None
|
||||
threefry_gpu_kernel_lowering: bool | None = None
|
||||
sharding_in_types: bool | None = None
|
||||
softmax_custom_jvp: bool | None = None
|
||||
xla_profile_version: int | None = None
|
||||
pgle_profiling_runs: int | None = None
|
||||
enable_pgle: bool | None = None
|
||||
use_shardy_partitioner: bool | None = None
|
||||
|
||||
|
||||
class _ThreadLocalStateCache(threading.local):
|
||||
""""A thread local cache for _ThreadLocalExtraJitContext
|
||||
class _ThreadLocalStateCache(threading.local):
|
||||
""""A thread local cache for _ThreadLocalExtraJitContext
|
||||
|
||||
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus
|
||||
incurring dispatch overhead for comparing this python object during jit calls.
|
||||
We want to deduplicate the objects that have the same hash/equality to also
|
||||
have the same object ID, since the equality check is much faster if the object
|
||||
IDs match.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.canonicalize = functools.lru_cache(128)(lambda x: x)
|
||||
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus
|
||||
incurring dispatch overhead for comparing this python object during jit calls.
|
||||
We want to deduplicate the objects that have the same hash/equality to also
|
||||
have the same object ID, since the equality check is much faster if the object
|
||||
IDs match.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.canonicalize = functools.lru_cache(128)(lambda x: x)
|
||||
|
||||
|
||||
_thread_local_state_cache = _ThreadLocalStateCache()
|
||||
_thread_local_state_cache = _ThreadLocalStateCache()
|
||||
|
||||
|
||||
def update_thread_local_jit_state(**kw):
|
||||
tls = jax_jit.thread_local_state()
|
||||
# After xla_client._version >= 70, the thread_local object will necessarily
|
||||
# be initialized when accessed. The following line can be removed when the
|
||||
# minimum jaxlib version is past version 70
|
||||
context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
|
||||
tmp = context._replace(**kw)
|
||||
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
|
||||
def update_thread_local_jit_state(**kw):
|
||||
tls = jax_jit.thread_local_state()
|
||||
# After xla_client._version >= 70, the thread_local object will necessarily
|
||||
# be initialized when accessed. The following line can be removed when the
|
||||
# minimum jaxlib version is past version 70
|
||||
context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
|
||||
tmp = context._replace(**kw)
|
||||
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
|
||||
|
||||
class JitConfig:
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
|
||||
def value(self):
|
||||
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
|
||||
|
||||
def get_local(self):
|
||||
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
|
||||
|
||||
def set_local(self, value):
|
||||
update_thread_local_jit_state(**{self._name: value})
|
||||
|
||||
trace_state = JitConfig('trace_state')
|
||||
axis_env_state = JitConfig('axis_env_state')
|
||||
mesh_context_manager = JitConfig('mesh_context_manager')
|
||||
compute_on_context_manager = JitConfig('compute_on_context_manager')
|
||||
xla_metadata_context_manager = JitConfig('xla_metadata_context_manager')
|
||||
|
||||
|
||||
# TODO(b/214340779): remove flag when XLA:CPU is improved.
|
||||
jax2tf_associative_scan_reductions = bool_state(
|
||||
@ -1102,10 +1266,7 @@ random_seed_offset = int_state(
|
||||
name='jax_random_seed_offset',
|
||||
default=0,
|
||||
help=('Offset to all random seeds (e.g. argument to jax.random.key()).'),
|
||||
update_global_hook=lambda val: _update_global_jit_state(
|
||||
random_seed_offset=val),
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
random_seed_offset=val)
|
||||
include_in_jit_key=True,
|
||||
)
|
||||
|
||||
legacy_prng_key = enum_state(
|
||||
@ -1140,10 +1301,7 @@ threefry_partitionable = bool_state(
|
||||
'may result in extraneous communication and/or redundant distributed '
|
||||
'computation. With this flag, the communication overheads disappear '
|
||||
'in some cases.'),
|
||||
update_global_hook=lambda val: _update_global_jit_state(
|
||||
threefry_partitionable=val),
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
threefry_partitionable=val))
|
||||
include_in_jit_key=True)
|
||||
|
||||
threefry_gpu_kernel_lowering = bool_state(
|
||||
name='jax_threefry_gpu_kernel_lowering',
|
||||
@ -1151,20 +1309,14 @@ threefry_gpu_kernel_lowering = bool_state(
|
||||
help=('On GPU, lower threefry PRNG operations to a kernel implementation. '
|
||||
'This makes compile times faster at a potential runtime memory '
|
||||
'cost.'),
|
||||
update_global_hook=lambda val: _update_global_jit_state(
|
||||
threefry_gpu_kernel_lowering=val),
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
threefry_gpu_kernel_lowering=val))
|
||||
include_in_jit_key=True)
|
||||
|
||||
sharding_in_types = bool_state(
|
||||
name='jax_sharding_in_types',
|
||||
default=False,
|
||||
help=('When True, enables forward only sharding propagation in JAX and '
|
||||
'avals have sharding on them.'),
|
||||
update_global_hook=lambda val: _update_global_jit_state(
|
||||
sharding_in_types=val),
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
sharding_in_types=val))
|
||||
include_in_jit_key=True)
|
||||
|
||||
data_dependent_tracing_fallback = bool_state(
|
||||
name='jax_data_dependent_tracing_fallback',
|
||||
@ -1179,10 +1331,7 @@ softmax_custom_jvp = bool_state(
|
||||
help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should '
|
||||
'improve memory usage and stability. Set True to use new '
|
||||
'behavior. See https://github.com/jax-ml/jax/pull/15677'),
|
||||
update_global_hook=lambda val: _update_global_jit_state(
|
||||
softmax_custom_jvp=val),
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
softmax_custom_jvp=val))
|
||||
include_in_jit_key=True)
|
||||
|
||||
|
||||
enable_custom_vjp_by_custom_transpose = bool_state(
|
||||
@ -1298,9 +1447,7 @@ enable_pgle = bool_state(
|
||||
'number times with collected data provided to the profile guided latency '
|
||||
'estimator.'
|
||||
),
|
||||
update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val),
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
enable_pgle=val),
|
||||
include_in_jit_key=True,
|
||||
)
|
||||
|
||||
pgle_profiling_runs = int_state(
|
||||
@ -1310,12 +1457,7 @@ pgle_profiling_runs = int_state(
|
||||
'Amount of times module should be profiled before recompilation when '
|
||||
'PGLE is used.'
|
||||
),
|
||||
update_global_hook=lambda val: _update_global_jit_state(
|
||||
pgle_profiling_runs=val
|
||||
),
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
pgle_profiling_runs=val
|
||||
),
|
||||
include_in_jit_key=True,
|
||||
)
|
||||
|
||||
pgle_aggregation_percentile = int_state(
|
||||
@ -1381,10 +1523,7 @@ numpy_dtype_promotion = enum_state(
|
||||
'between arrays. Options are "standard" or "strict"; in strict-mode, '
|
||||
'binary operations between arrays of differing strongly-specified '
|
||||
'dtypes will result in an error.'),
|
||||
update_global_hook=lambda val: \
|
||||
_update_global_jit_state(numpy_dtype_promotion=val),
|
||||
update_thread_local_hook=lambda val: \
|
||||
update_thread_local_jit_state(numpy_dtype_promotion=val))
|
||||
include_in_jit_key=True)
|
||||
|
||||
disallow_mesh_context_manager = bool_state(
|
||||
name='jax_disallow_mesh_context_manager',
|
||||
@ -1470,10 +1609,7 @@ numpy_rank_promotion = enum_state(
|
||||
default='allow',
|
||||
help=('Control NumPy-style automatic rank promotion broadcasting '
|
||||
'("allow", "warn", or "raise").'),
|
||||
update_global_hook=lambda val: \
|
||||
_update_global_jit_state(numpy_rank_promotion=val),
|
||||
update_thread_local_hook=lambda val: \
|
||||
update_thread_local_jit_state(numpy_rank_promotion=val))
|
||||
include_in_jit_key=True)
|
||||
|
||||
default_matmul_precision = optional_enum_state(
|
||||
name='jax_default_matmul_precision',
|
||||
@ -1509,10 +1645,7 @@ default_matmul_precision = optional_enum_state(
|
||||
'"algorithm" for functions that perform matrix multiplications, like '
|
||||
':func:`jax.lax.dot`. To specify an algorithm, set this option to '
|
||||
'the name of a :class:`~jax.lax.DotAlgorithmPreset`.\n\n'),
|
||||
update_global_hook=lambda val: \
|
||||
_update_global_jit_state(default_matmul_precision=val),
|
||||
update_thread_local_hook=lambda val: \
|
||||
update_thread_local_jit_state(default_matmul_precision=val))
|
||||
include_in_jit_key=True)
|
||||
|
||||
traceback_filtering = enum_state(
|
||||
name = 'jax_traceback_filtering',
|
||||
@ -1547,20 +1680,14 @@ dynamic_shapes = bool_state(
|
||||
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
|
||||
help=('Enables experimental features for staging out computations with '
|
||||
'dynamic shapes.'),
|
||||
update_global_hook=lambda val: \
|
||||
_update_global_jit_state(dynamic_shapes=val),
|
||||
update_thread_local_hook=lambda val: \
|
||||
update_thread_local_jit_state(dynamic_shapes=val))
|
||||
include_in_jit_key=True)
|
||||
|
||||
# This is for stackless backward compat with e.g. equinox
|
||||
eager_constant_folding = bool_state(
|
||||
name='eager_constant_folding',
|
||||
default=False,
|
||||
help=('Attempt constant folding during staging.'),
|
||||
update_global_hook=lambda val: \
|
||||
_update_global_jit_state(eager_constant_folding=val),
|
||||
update_thread_local_hook=lambda val: \
|
||||
update_thread_local_jit_state(eager_constant_folding=val))
|
||||
include_in_jit_key=True)
|
||||
|
||||
# This flag is temporary during rollout of the remat barrier.
|
||||
# TODO(parkers): Remove if there are no complaints.
|
||||
@ -1619,10 +1746,7 @@ jax_xla_profile_version = int_state(
|
||||
'Optional profile version for XLA compilation. This is meaningful '
|
||||
'only when XLA is configured to support the remote compilation '
|
||||
'profile feature.'),
|
||||
update_global_hook=lambda val: _update_global_jit_state(
|
||||
xla_profile_version=val),
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
xla_profile_version=val),
|
||||
include_in_jit_key=True,
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -1821,10 +1945,5 @@ use_shardy_partitioner = bool_state(
|
||||
'framework for MLIR. Currently Shardy is experimental in JAX. See '
|
||||
'www.github.com/openxla/shardy'
|
||||
),
|
||||
update_global_hook=lambda val: _update_global_jit_state(
|
||||
use_shardy_partitioner=val
|
||||
),
|
||||
update_thread_local_hook=lambda val: update_thread_local_jit_state(
|
||||
use_shardy_partitioner=val
|
||||
),
|
||||
include_in_jit_key=True,
|
||||
)
|
||||
|
@ -1016,18 +1016,16 @@ class TracingContext(threading.local):
|
||||
def set_trace(self, trace):
|
||||
self.trace = trace
|
||||
ts = ref(trace) if trace is not None else None
|
||||
config.update_thread_local_jit_state(trace_state=ts)
|
||||
config.trace_state.set_local(ts)
|
||||
|
||||
def set_axis_env(self, axis_env):
|
||||
self.axis_env = axis_env
|
||||
config.update_thread_local_jit_state(
|
||||
axis_env_state=self.axis_env.as_hashable_key())
|
||||
config.axis_env_state.set_local(axis_env.as_hashable_key())
|
||||
|
||||
def update_thread_local_jit_state(self):
|
||||
ts = ref(self.trace) if self.trace is not None else None
|
||||
config.update_thread_local_jit_state(
|
||||
trace_state=ts,
|
||||
axis_env_state=self.axis_env.as_hashable_key())
|
||||
config.trace_state.set_local(ts)
|
||||
config.axis_env_state.set_local(self.axis_env.as_hashable_key())
|
||||
|
||||
trace_ctx = TracingContext()
|
||||
|
||||
@ -1071,10 +1069,7 @@ def _initialize_jax_jit_thread_local_state():
|
||||
|
||||
This function does not live in `config.py`, to prevent circular imports.
|
||||
"""
|
||||
tls = jax_jit.thread_local_state()
|
||||
|
||||
if tls.extra_jit_context is None:
|
||||
trace_ctx.update_thread_local_jit_state()
|
||||
trace_ctx.update_thread_local_jit_state()
|
||||
|
||||
jax_jit.set_thread_local_state_initialization_callback(
|
||||
_initialize_jax_jit_thread_local_state)
|
||||
|
@ -224,17 +224,17 @@ class Mesh(contextlib.ContextDecorator):
|
||||
new_env = thread_resources.stack[-1].with_mesh(self)
|
||||
thread_resources.stack.append(new_env)
|
||||
thread_resources.env = new_env
|
||||
jax_config.update_thread_local_jit_state(
|
||||
mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack
|
||||
if not t.physical_mesh.empty))
|
||||
jax_config.mesh_context_manager.set_local(
|
||||
tuple(t.physical_mesh for t in thread_resources.stack
|
||||
if not t.physical_mesh.empty))
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
thread_resources.stack.pop()
|
||||
thread_resources.env = thread_resources.stack[-1]
|
||||
jax_config.update_thread_local_jit_state(
|
||||
mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack
|
||||
if not t.physical_mesh.empty))
|
||||
jax_config.mesh_context_manager.set_local(
|
||||
tuple(t.physical_mesh for t in thread_resources.stack
|
||||
if not t.physical_mesh.empty))
|
||||
return False
|
||||
|
||||
@property
|
||||
@ -410,7 +410,7 @@ class AbstractMesh:
|
||||
|
||||
@staticmethod
|
||||
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
|
||||
jax_config.update_thread_local_jit_state(mesh_context_manager=mesh)
|
||||
jax_config.mesh_context_manager.set_local(mesh)
|
||||
return
|
||||
|
||||
|
||||
|
@ -41,15 +41,13 @@ def set_xla_metadata(*args, **kwargs):
|
||||
thread_local_metadata.val,
|
||||
new_metadata,
|
||||
)
|
||||
config.update_thread_local_jit_state(
|
||||
xla_metadata_context_manager=tuple(
|
||||
(v, k) for k, v in sorted(new_metadata.items())))
|
||||
config.xla_metadata_context_manager.set_local(
|
||||
tuple((v, k) for k, v in sorted(new_metadata.items()))
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
thread_local_metadata.val = prev_metadata
|
||||
config.update_thread_local_jit_state(
|
||||
xla_metadata_context_manager=tuple(
|
||||
(v, k) for k, v in sorted(prev_metadata.items())
|
||||
)
|
||||
config.xla_metadata_context_manager.set_local(
|
||||
tuple((v, k) for k, v in sorted(prev_metadata.items()))
|
||||
)
|
||||
|
@ -2215,8 +2215,6 @@ class CppPmapTest(PythonPmapTest):
|
||||
pmaped_f(inputs)
|
||||
self.assertEqual(pmaped_f._cache_size, 1)
|
||||
|
||||
config.update_thread_local_jit_state()
|
||||
|
||||
pmaped_f(inputs)
|
||||
self.assertEqual(pmaped_f._cache_size, 1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user