mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6203 from hawkinsp:x64
PiperOrigin-RevId: 364871527
This commit is contained in:
commit
7b4c2e3c3a
@ -19,6 +19,7 @@ import sys
|
||||
import threading
|
||||
|
||||
from jax import lib
|
||||
from typing import Callable, Optional
|
||||
|
||||
def bool_env(varname: str, default: bool) -> bool:
|
||||
"""Read an environment variable and interpret it as a boolean.
|
||||
@ -58,6 +59,7 @@ class Config:
|
||||
# TODO(mattjj): delete these when only omnistaging is available
|
||||
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
|
||||
self._omnistaging_disablers = []
|
||||
self._update_hooks = {}
|
||||
|
||||
def update(self, name, val):
|
||||
if self.use_absl:
|
||||
@ -68,10 +70,12 @@ class Config:
|
||||
raise Exception("Unrecognized config option: {}".format(name))
|
||||
self.values[name] = val
|
||||
|
||||
hook = self._update_hooks.get(name, None)
|
||||
if hook:
|
||||
hook(val)
|
||||
|
||||
if name == "jax_disable_jit":
|
||||
lib.jax_jit.global_state().disable_jit = val
|
||||
elif name == "jax_enable_x64":
|
||||
lib.jax_jit.global_state().enable_x64 = val
|
||||
|
||||
def read(self, name):
|
||||
if name in self._contextmanager_flags:
|
||||
@ -87,27 +91,36 @@ class Config:
|
||||
self.check_exists(name)
|
||||
return self.values[name]
|
||||
|
||||
def add_option(self, name, default, opt_type, meta_args, meta_kwargs):
|
||||
def add_option(self, name, default, opt_type, meta_args, meta_kwargs,
|
||||
update_hook=None):
|
||||
if name in self.values:
|
||||
raise Exception("Config option {} already defined".format(name))
|
||||
self.values[name] = default
|
||||
self.meta[name] = (opt_type, meta_args, meta_kwargs)
|
||||
if update_hook:
|
||||
self._update_hooks[name] = update_hook
|
||||
update_hook(default)
|
||||
|
||||
def check_exists(self, name):
|
||||
if name not in self.values:
|
||||
raise AttributeError("Unrecognized config option: {}".format(name))
|
||||
|
||||
def DEFINE_bool(self, name, default, *args, **kwargs):
|
||||
self.add_option(name, default, bool, args, kwargs)
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, bool, args, kwargs, update_hook=update_hook)
|
||||
|
||||
def DEFINE_integer(self, name, default, *args, **kwargs):
|
||||
self.add_option(name, default, int, args, kwargs)
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
|
||||
|
||||
def DEFINE_string(self, name, default, *args, **kwargs):
|
||||
self.add_option(name, default, str, args, kwargs)
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
|
||||
|
||||
def DEFINE_enum(self, name, default, *args, **kwargs):
|
||||
self.add_option(name, default, 'enum', args, kwargs)
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, 'enum', args, kwargs,
|
||||
update_hook=update_hook)
|
||||
|
||||
def config_with_absl(self):
|
||||
# Run this before calling `app.run(main)` etc
|
||||
@ -162,7 +175,10 @@ class Config:
|
||||
disabler()
|
||||
self.omnistaging_enabled = False
|
||||
|
||||
def define_bool_state(self, name: str, default: bool, help: str):
|
||||
def define_bool_state(
|
||||
self, name: str, default: bool, help: str, *,
|
||||
update_global_hook: Optional[Callable[[bool], None]] = None,
|
||||
update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None):
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
This function is a convenience wrapper. It defines a flag and corresponding
|
||||
@ -178,6 +194,11 @@ class Config:
|
||||
default: boolean, a default value for the option.
|
||||
help: string, used to populate the flag help information as well as the
|
||||
docstring of the returned context manager.
|
||||
update_global_hook: a optional callback that is called with the updated
|
||||
value of the global state when it is altered or set initially.
|
||||
update_thread_local_hook: a optional callback that is called with the
|
||||
updated value of the thread-local state when it is altered or set
|
||||
initially.
|
||||
|
||||
Returns:
|
||||
A contextmanager to control the thread-local state value.
|
||||
@ -203,7 +224,8 @@ class Config:
|
||||
an error.
|
||||
"""
|
||||
name = name.lower()
|
||||
self.DEFINE_bool(name, bool_env(name.upper(), default), help)
|
||||
self.DEFINE_bool(name, bool_env(name.upper(), default), help,
|
||||
update_hook=update_global_hook)
|
||||
self._contextmanager_flags.add(name)
|
||||
|
||||
def get_state(self):
|
||||
@ -215,13 +237,20 @@ class Config:
|
||||
def set_state(new_val: bool):
|
||||
prev_val = getattr(_thread_local_state, name, unset)
|
||||
setattr(_thread_local_state, name, new_val)
|
||||
if update_thread_local_hook:
|
||||
update_thread_local_hook(new_val)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prev_val is unset:
|
||||
delattr(_thread_local_state, name)
|
||||
if update_thread_local_hook:
|
||||
update_thread_local_hook(None)
|
||||
else:
|
||||
setattr(_thread_local_state, name, prev_val)
|
||||
if update_thread_local_hook:
|
||||
update_thread_local_hook(prev_val)
|
||||
|
||||
set_state.__name__ = name[4:] if name.startswith('jax_') else name
|
||||
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
|
||||
return set_state
|
||||
@ -311,33 +340,20 @@ log_compiles = config.define_bool_state(
|
||||
'option is set, the log level is WARNING; otherwise the level is '
|
||||
'DEBUG.'))
|
||||
|
||||
# Because jax_enable_x64 is managed by C++ code, we don't reuse the
|
||||
# config.define_bool_state mechanism, though conceptually it is the same.
|
||||
config.DEFINE_bool('jax_enable_x64', bool_env('JAX_ENABLE_X64', False),
|
||||
help='Enable 64-bit types to be used')
|
||||
lib.jax_jit.global_state().enable_x64 = bool_env('JAX_ENABLE_X64', False)
|
||||
def _update_x64_global(val):
|
||||
lib.jax_jit.global_state().enable_x64 = val
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_x64(new_val: bool = True):
|
||||
"""Experimental context manager to temporarily enable X64 mode.
|
||||
def _update_x64_thread_local(val):
|
||||
lib.jax_jit.thread_local_state().enable_x64 = val
|
||||
|
||||
Usage::
|
||||
enable_x64 = config.define_bool_state(
|
||||
name='jax_enable_x64',
|
||||
default=False,
|
||||
help='Enable 64-bit types to be used',
|
||||
update_global_hook=_update_x64_global,
|
||||
update_thread_local_hook=_update_x64_thread_local)
|
||||
|
||||
>>> import jax.numpy as jnp
|
||||
>>> with enable_x64(True):
|
||||
... print(jnp.arange(10.0).dtype)
|
||||
...
|
||||
float64
|
||||
"""
|
||||
prev_val = config.jax_enable_x64
|
||||
lib.jax_jit.thread_local_state().enable_x64 = bool(new_val)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lib.jax_jit.thread_local_state().enable_x64 = prev_val
|
||||
Config.jax_enable_x64 = property(lambda self: lib.jax_jit.get_enable_x64())
|
||||
# config._contextmanager_flags.add('jax_enable_x64') # TODO(mattjj): remove footgun
|
||||
# TODO(phawkins): remove after fixing users of FLAGS.x64_enabled.
|
||||
config._contextmanager_flags.remove("jax_enable_x64")
|
||||
|
||||
# The `x64_enabled` property doesn't fit the naming scheme, but we use it for
|
||||
# backward compatibility.
|
||||
Config.x64_enabled = Config.jax_enable_x64
|
||||
Config.x64_enabled = Config.jax_enable_x64 # type: ignore
|
||||
|
@ -24,7 +24,26 @@
|
||||
# uniformity
|
||||
|
||||
from contextlib import contextmanager
|
||||
from jax.config import enable_x64
|
||||
from ..config import enable_x64 as _jax_enable_x64
|
||||
|
||||
@contextmanager
|
||||
def enable_x64(new_val: bool = True):
|
||||
"""Experimental context manager to temporarily enable X64 mode.
|
||||
|
||||
Usage::
|
||||
|
||||
>>> import jax.numpy as jnp
|
||||
>>> with enable_x64():
|
||||
... print(jnp.arange(10.0).dtype)
|
||||
...
|
||||
float64
|
||||
|
||||
See Also
|
||||
--------
|
||||
jax.experimental.enable_x64 : temporarily enable X64 mode.
|
||||
"""
|
||||
with _jax_enable_x64(new_val):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def disable_x64():
|
||||
@ -42,5 +61,5 @@ def disable_x64():
|
||||
--------
|
||||
jax.experimental.enable_x64 : temporarily enable X64 mode.
|
||||
"""
|
||||
with enable_x64(False):
|
||||
with _jax_enable_x64(False):
|
||||
yield
|
||||
|
Loading…
x
Reference in New Issue
Block a user