Merge pull request #6203 from hawkinsp:x64

PiperOrigin-RevId: 364871527
This commit is contained in:
jax authors 2021-03-24 12:58:20 -07:00
commit 7b4c2e3c3a
2 changed files with 72 additions and 37 deletions

View File

@ -19,6 +19,7 @@ import sys
import threading
from jax import lib
from typing import Callable, Optional
def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
@ -58,6 +59,7 @@ class Config:
# TODO(mattjj): delete these when only omnistaging is available
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
self._omnistaging_disablers = []
self._update_hooks = {}
def update(self, name, val):
if self.use_absl:
@ -68,10 +70,12 @@ class Config:
raise Exception("Unrecognized config option: {}".format(name))
self.values[name] = val
hook = self._update_hooks.get(name, None)
if hook:
hook(val)
if name == "jax_disable_jit":
lib.jax_jit.global_state().disable_jit = val
elif name == "jax_enable_x64":
lib.jax_jit.global_state().enable_x64 = val
def read(self, name):
if name in self._contextmanager_flags:
@ -87,27 +91,36 @@ class Config:
self.check_exists(name)
return self.values[name]
def add_option(self, name, default, opt_type, meta_args, meta_kwargs):
def add_option(self, name, default, opt_type, meta_args, meta_kwargs,
update_hook=None):
if name in self.values:
raise Exception("Config option {} already defined".format(name))
self.values[name] = default
self.meta[name] = (opt_type, meta_args, meta_kwargs)
if update_hook:
self._update_hooks[name] = update_hook
update_hook(default)
def check_exists(self, name):
if name not in self.values:
raise AttributeError("Unrecognized config option: {}".format(name))
def DEFINE_bool(self, name, default, *args, **kwargs):
self.add_option(name, default, bool, args, kwargs)
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, bool, args, kwargs, update_hook=update_hook)
def DEFINE_integer(self, name, default, *args, **kwargs):
self.add_option(name, default, int, args, kwargs)
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
def DEFINE_string(self, name, default, *args, **kwargs):
self.add_option(name, default, str, args, kwargs)
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
def DEFINE_enum(self, name, default, *args, **kwargs):
self.add_option(name, default, 'enum', args, kwargs)
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, 'enum', args, kwargs,
update_hook=update_hook)
def config_with_absl(self):
# Run this before calling `app.run(main)` etc
@ -162,7 +175,10 @@ class Config:
disabler()
self.omnistaging_enabled = False
def define_bool_state(self, name: str, default: bool, help: str):
def define_bool_state(
self, name: str, default: bool, help: str, *,
update_global_hook: Optional[Callable[[bool], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None):
"""Set up thread-local state and return a contextmanager for managing it.
This function is a convenience wrapper. It defines a flag and corresponding
@ -178,6 +194,11 @@ class Config:
default: boolean, a default value for the option.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
update_global_hook: a optional callback that is called with the updated
value of the global state when it is altered or set initially.
update_thread_local_hook: a optional callback that is called with the
updated value of the thread-local state when it is altered or set
initially.
Returns:
A contextmanager to control the thread-local state value.
@ -203,7 +224,8 @@ class Config:
an error.
"""
name = name.lower()
self.DEFINE_bool(name, bool_env(name.upper(), default), help)
self.DEFINE_bool(name, bool_env(name.upper(), default), help,
update_hook=update_global_hook)
self._contextmanager_flags.add(name)
def get_state(self):
@ -215,13 +237,20 @@ class Config:
def set_state(new_val: bool):
prev_val = getattr(_thread_local_state, name, unset)
setattr(_thread_local_state, name, new_val)
if update_thread_local_hook:
update_thread_local_hook(new_val)
try:
yield
finally:
if prev_val is unset:
delattr(_thread_local_state, name)
if update_thread_local_hook:
update_thread_local_hook(None)
else:
setattr(_thread_local_state, name, prev_val)
if update_thread_local_hook:
update_thread_local_hook(prev_val)
set_state.__name__ = name[4:] if name.startswith('jax_') else name
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
return set_state
@ -311,33 +340,20 @@ log_compiles = config.define_bool_state(
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))
# Because jax_enable_x64 is managed by C++ code, we don't reuse the
# config.define_bool_state mechanism, though conceptually it is the same.
config.DEFINE_bool('jax_enable_x64', bool_env('JAX_ENABLE_X64', False),
help='Enable 64-bit types to be used')
lib.jax_jit.global_state().enable_x64 = bool_env('JAX_ENABLE_X64', False)
def _update_x64_global(val):
lib.jax_jit.global_state().enable_x64 = val
@contextlib.contextmanager
def enable_x64(new_val: bool = True):
"""Experimental context manager to temporarily enable X64 mode.
def _update_x64_thread_local(val):
lib.jax_jit.thread_local_state().enable_x64 = val
Usage::
enable_x64 = config.define_bool_state(
name='jax_enable_x64',
default=False,
help='Enable 64-bit types to be used',
update_global_hook=_update_x64_global,
update_thread_local_hook=_update_x64_thread_local)
>>> import jax.numpy as jnp
>>> with enable_x64(True):
... print(jnp.arange(10.0).dtype)
...
float64
"""
prev_val = config.jax_enable_x64
lib.jax_jit.thread_local_state().enable_x64 = bool(new_val)
try:
yield
finally:
lib.jax_jit.thread_local_state().enable_x64 = prev_val
Config.jax_enable_x64 = property(lambda self: lib.jax_jit.get_enable_x64())
# config._contextmanager_flags.add('jax_enable_x64') # TODO(mattjj): remove footgun
# TODO(phawkins): remove after fixing users of FLAGS.x64_enabled.
config._contextmanager_flags.remove("jax_enable_x64")
# The `x64_enabled` property doesn't fit the naming scheme, but we use it for
# backward compatibility.
Config.x64_enabled = Config.jax_enable_x64
Config.x64_enabled = Config.jax_enable_x64 # type: ignore

View File

@ -24,7 +24,26 @@
# uniformity
from contextlib import contextmanager
from jax.config import enable_x64
from ..config import enable_x64 as _jax_enable_x64
@contextmanager
def enable_x64(new_val: bool = True):
"""Experimental context manager to temporarily enable X64 mode.
Usage::
>>> import jax.numpy as jnp
>>> with enable_x64():
... print(jnp.arange(10.0).dtype)
...
float64
See Also
--------
jax.experimental.enable_x64 : temporarily enable X64 mode.
"""
with _jax_enable_x64(new_val):
yield
@contextmanager
def disable_x64():
@ -42,5 +61,5 @@ def disable_x64():
--------
jax.experimental.enable_x64 : temporarily enable X64 mode.
"""
with enable_x64(False):
with _jax_enable_x64(False):
yield