unify configuration state handling

This commit is contained in:
Matthew Johnson 2021-03-19 13:49:38 -07:00
parent 22a2be30a2
commit fd7b286ec9
30 changed files with 263 additions and 191 deletions

View File

@ -30,7 +30,8 @@ except Exception as exc:
del _cloud_tpu_init
# flake8: noqa: F401
from .config import config
from .config import (config, enable_checks, check_tracer_leaks, checking_leaks,
debug_nans, debug_infs, log_compiles)
from .api import (
ad, # TODO(phawkins): update users to avoid this.
argnums_partial, # TODO(phawkins): update Haiku to not use this.

View File

@ -1113,7 +1113,7 @@ def _cond_typecheck(*avals, branches, linear):
f'called with operands of type {_avals_short(op_avals)}')
def cond_bind(*args, branches, linear):
if not core.skip_checks:
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
_cond_typecheck(*avals, branches=branches, linear=linear)
for jaxpr in branches:
@ -1876,7 +1876,7 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
f'called with sequence of type\n{_avals_short(x_avals)}')
def scan_bind(*args, **params):
if not core.skip_checks:
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
_scan_typecheck(True, *avals, **params)
core.check_jaxpr(params['jaxpr'].jaxpr)

View File

@ -21,7 +21,6 @@ from typing import Any, Callable
import numpy as np
import jax
from jax.config import config
partial = functools.partial
@ -192,7 +191,7 @@ def cache(max_size=4096):
@functools.wraps(f)
def wrapper(*args, **kwargs):
if jax.core.debug_state.check_leaks:
if config.jax_check_tracer_leaks:
return f(*args, **kwargs)
else:
return cached(bool(config.x64_enabled), *args, **kwargs)

View File

@ -41,7 +41,7 @@ from . import lib
from . import linear_util as lu
from . import ad_util
from . import dtypes
from .core import eval_jaxpr, checking_leaks
from .core import eval_jaxpr
from .api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
flatten_fun_nokwargs2, argnums_partial,
argnums_partial_except, flatten_axes, donation_vector,
@ -362,7 +362,7 @@ def _cpp_jit(
context = (getattr(core.thread_local_state.trace_state.trace_stack,
"dynamic", None), config.x64_enabled)
# TODO(jblespiau): Move this to C++.
if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled():
if (config.jax_debug_nans or config.jax_debug_infs) and not _jit_is_disabled():
device_arrays = cpp_jitted_f(context, *args, **kwargs)
try:
xla.check_special(xla.xla_call_p, [
@ -372,7 +372,7 @@ def _cpp_jit(
])
return device_arrays
except FloatingPointError:
assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
print("Invalid nan value encountered in the output of a C++-jit "
"function. Calling the de-optimized version.")
return cache_miss(*args, **kwargs)[0] # probably won't return
@ -389,7 +389,7 @@ def _cpp_jit(
@api_boundary
def f_jitted(*args, **kwargs):
# TODO(jblespiau): Move this to C++.
if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled():
if (config.jax_debug_nans or config.jax_debug_infs) and not _jit_is_disabled():
device_arrays = cpp_jitted_f(*args, **kwargs)
try:
xla.check_special(xla.xla_call_p, [
@ -399,7 +399,7 @@ def _cpp_jit(
])
return device_arrays
except FloatingPointError:
assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
print("Invalid nan value encountered in the output of a C++-jit "
"function. Calling the de-optimized version.")
return cache_miss(*args, **kwargs)[0] # probably won't return

View File

@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import os
import sys
import threading
from jax import lib
def bool_env(varname: str, default: bool) -> bool:
@ -42,11 +46,16 @@ def int_env(varname: str, default: int) -> int:
class Config:
_HAS_DYNAMIC_ATTRIBUTES = True
def __init__(self):
self.values = {}
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False
self._contextmanager_flags = set()
# TODO(mattjj): delete these when only omnistaging is available
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
self._omnistaging_disablers = []
@ -65,6 +74,13 @@ class Config:
lib.jax_jit.global_state().enable_x64 = val
def read(self, name):
if name in self._contextmanager_flags:
raise AttributeError(
"For flags with a corresponding contextmanager, read their value "
f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.")
return self._read(name)
def _read(self, name):
if self.use_absl:
return getattr(self.absl_flags.FLAGS, name)
else:
@ -143,14 +159,82 @@ class Config:
disabler()
self.omnistaging_enabled = False
@property
def x64_enabled(self):
return lib.jax_jit.get_enable_x64()
# # TODO(jakevdp, mattjj): unify this with `define_bool_state` stuff below
# @property
# def x64_enabled(self):
# return lib.jax_jit.get_enable_x64()
# TODO(jakevdp): make this public when thread-local x64 is fully implemented.
def _set_x64_enabled(self, state):
lib.jax_jit.thread_local_state().enable_x64 = bool(state)
# def _set_x64_enabled(self, state):
# lib.jax_jit.thread_local_state().enable_x64 = bool(state)
def define_bool_state(self, name: str, default: bool, help: str):
"""Set up thread-local state and return a contextmanager for managing it.
This function is a convenience wrapper. It defines a flag and corresponding
thread-local state, which can be managed via the contextmanager it returns.
The thread-local state value can be read via the ``config.<option_name>``
attribute, where ``config`` is the singleton ``Config`` instance.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
default: boolean, a default value for the option.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
Example:
enable_foo = config.define_bool_state(
name='jax_enable_foo',
default=False,
help='Enable foo.')
# Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo
# command-line flag can be used to control the process-level value of
# the configuration option, in addition to using e.g.
# ``config.update("jax_enable_foo", True)`` directly. We can also use a
# context manager:
with enable_foo(True):
...
The value of the thread-local state or flag can be accessed via
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
an error.
"""
name = name.lower()
self.DEFINE_bool(name, bool_env(name.upper(), default), help)
self._contextmanager_flags.add(name)
def get_state(self):
val = getattr(_thread_local_state, name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))
@contextlib.contextmanager
def set_state(new_val: bool):
prev_val = getattr(_thread_local_state, name, unset)
setattr(_thread_local_state, name, new_val)
try:
yield
finally:
if prev_val is unset:
delattr(_thread_local_state, name)
else:
setattr(_thread_local_state, name, prev_val)
set_state.__name__ = name[4:] if name.startswith('jax_') else name
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
return set_state
_thread_local_state = threading.local()
class Unset: pass
unset = Unset()
class NameSpace(object):
def __init__(self, getter):
@ -166,11 +250,6 @@ FLAGS = flags.FLAGS
already_configured_with_absl = False
flags.DEFINE_bool(
'jax_enable_checks',
bool_env('JAX_ENABLE_CHECKS', False),
help='Turn on invariant checking (core.skip_checks = False)'
)
flags.DEFINE_bool(
'jax_omnistaging',
@ -184,14 +263,6 @@ flags.DEFINE_integer(
help='Set the number of stack frames in JAX tracer error messages.'
)
flags.DEFINE_bool(
'jax_check_tracer_leaks',
bool_env('JAX_CHECK_TRACER_LEAKS', False),
help=('Turn on checking for leaked tracers as soon as a trace completes. '
'Enabling leak checking may have performance impacts: some caching '
'is disabled, and other overheads may be added.'),
)
flags.DEFINE_bool(
'jax_host_callback_inline',
bool_env('JAX_HOST_CALLBACK_INLINE', False),
@ -206,3 +277,72 @@ flags.DEFINE_integer(
'until the Python callback consume more outfeeds.'),
lower_bound=int(16 * 1e6)
)
enable_checks = config.define_bool_state(
name='jax_enable_checks',
default=False,
help='Turn on invariant checking for JAX internals. Makes things slower.')
check_tracer_leaks = config.define_bool_state(
name='jax_check_tracer_leaks',
default=False,
help=('Turn on checking for leaked tracers as soon as a trace completes. '
'Enabling leak checking may have performance impacts: some caching '
'is disabled, and other overheads may be added.'))
checking_leaks = functools.partial(check_tracer_leaks, True)
debug_nans = config.define_bool_state(
name='jax_debug_nans',
default=False,
help=('Add nan checks to every operation. When a nan is detected on the '
'output of a jit-compiled computation, call into the un-compiled '
'version in an attempt to more precisely identify the operation '
'which produced the nan.'))
debug_infs = config.define_bool_state(
name='jax_debug_infs',
default=False,
help=('Add inf checks to every operation. When an inf is detected on the '
'output of a jit-compiled computation, call into the un-compiled '
'version in an attempt to more precisely identify the operation '
'which produced the inf.'))
log_compiles = config.define_bool_state(
name='jax_log_compiles',
default=False,
help=('Log a message each time every time `jit` or `pmap` compiles an XLA '
'computation. Logging is performed with `absl.logging`. When this '
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))
# Because jax_enable_x64 is managed by C++ code, we don't reuse the
# config.define_bool_state mechanism, though conceptually it is the same.
config.DEFINE_bool('jax_enable_x64', bool_env('JAX_ENABLE_X64', False),
help='Enable 64-bit types to be used')
lib.jax_jit.global_state().enable_x64 = bool_env('JAX_ENABLE_X64', False)
@contextlib.contextmanager
def enable_x64(new_val: bool = True):
"""Experimental context manager to temporarily enable X64 mode.
Usage::
>>> import jax.numpy as jnp
>>> with enable_x64(True):
... print(jnp.arange(10.0).dtype)
...
float64
"""
prev_val = config.jax_enable_x64
lib.jax_jit.thread_local_state().enable_x64 = bool(new_val)
try:
yield
finally:
lib.jax_jit.thread_local_state().enable_x64 = prev_val
Config.jax_enable_x64 = property(lambda self: lib.jax_jit.get_enable_x64())
# config._contextmanager_flags.add('jax_enable_x64') # TODO(mattjj): remove footgun
# The `x64_enabled` property doesn't fit the naming scheme, but we use it for
# backward compatibility.
Config.x64_enabled = Config.jax_enable_x64

View File

@ -45,33 +45,6 @@ from ._src.pprint_util import pp, vcat, PrettyPrint
from ._src import traceback_util
traceback_util.register_exclusion(__file__)
# TODO(mattjj): move this into debug_state
skip_checks = not FLAGS.jax_enable_checks
@contextmanager
def skipping_checks():
"""Context manager for temporarily disabling internal checks."""
global skip_checks
old_value, skip_checks = skip_checks, True
try:
yield
finally:
skip_checks = old_value
@contextmanager
def checking_leaks():
"""Context manager for temporarily enabling tracer leak checks."""
old_value, debug_state.check_leaks = debug_state.check_leaks, True
try:
yield
finally:
debug_state.check_leaks = old_value
class DebugState(threading.local):
def __init__(self):
self.check_leaks = FLAGS.jax_check_tracer_leaks
debug_state = DebugState()
zip = safe_zip
map = safe_map
@ -279,8 +252,8 @@ class Primitive:
def bind(self, *args, **params):
assert skip_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
assert (not config.jax_enable_checks or
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
top_trace = find_top_trace(args)
tracers = map(top_trace.full_raise, args)
out = top_trace.process_primitive(self, tracers, params)
@ -569,7 +542,7 @@ class Tracer:
def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
assert skip_checks or name != "aval"
assert not config.jax_enable_checks or name != "aval"
try:
attr = getattr(self.aval, name)
@ -783,7 +756,7 @@ def new_main(trace_type: Type[Trace],
if lib._xla_extension_version >= 11:
jit_tls.extra_jit_context = extra_jit_context(stack)
if debug_state.check_leaks:
if config.jax_check_tracer_leaks:
t = ref(main)
del main
if t() is not None:
@ -807,7 +780,7 @@ def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
if lib._xla_extension_version >= 11:
jit_tls.extra_jit_context = extra_jit_context(stack)
if debug_state.check_leaks:
if config.jax_check_tracer_leaks:
t = ref(main)
del main
if t() is not None:
@ -827,7 +800,7 @@ def new_sublevel() -> Generator[None, None, None]:
finally:
thread_local_state.trace_state.substack.pop()
if debug_state.check_leaks:
if config.jax_check_tracer_leaks:
t = ref(sublevel)
del sublevel
if t() is not None:
@ -899,7 +872,7 @@ class AbstractUnit(AbstractValue):
# _num_buffers = 0
def at_least_vspace(self): return self
def join(self, other):
if not skip_checks:
if config.jax_enable_checks:
assert other is abstract_unit, other
return self
def _eq(self, self_traced, other): return get_aval(other) is self
@ -1932,7 +1905,7 @@ def omnistaging_disabler() -> None:
finally:
thread_local_state.trace_state.trace_stack.pop(bottom)
if debug_state.check_leaks:
if config.jax_check_tracer_leaks:
t = ref(main)
del main
if t() is not None:
@ -1949,7 +1922,7 @@ def omnistaging_disabler() -> None:
yield # dummy implementation for forward compatibility
def bind(self, *args, **kwargs):
assert skip_checks or all(isinstance(arg, Tracer)
assert not config.jax_enable_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
top_trace = find_top_trace(args)
if top_trace is None:

View File

@ -914,7 +914,7 @@ def closure_convert(fun, *example_args):
"""
flat_args, in_tree = tree_flatten(example_args)
in_avals = tuple(map(abstractify, flat_args))
if core.debug_state.check_leaks:
if config.jax_check_tracer_leaks:
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
else:
return _closure_convert_for_avals(fun, in_tree, in_avals)

View File

@ -20,27 +20,19 @@
# so we need our own implementation that deviates from NumPy in places.
from distutils.util import strtobool
import functools
import os
from typing import Dict
import numpy as np
from ._src import util
from .config import flags, config
from . import lib
from .lib import xla_client
from ._src import traceback_util
traceback_util.register_exclusion(__file__)
FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_enable_x64',
strtobool(os.getenv('JAX_ENABLE_X64', 'False')),
'Enable 64-bit types to be used.')
lib.jax_jit.global_state().enable_x64 = strtobool(
os.getenv('JAX_ENABLE_X64', 'False'))
# bfloat16 support
bfloat16: type = xla_client.bfloat16

View File

@ -70,9 +70,9 @@ def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
return True
try:
# Note: this conversion is overkill and just intended as a type check; this code
# is in principle only run if core.skip_checks is False.
# TODO: it is not true that this code is run only without skip_checks
# Note: this conversion is overkill and just intended as a type check; this
# code is in principle only run if config.jax_enable_checks is True.
# TODO: it is not true that this code is run only with jax_enable_checks.
_safe_convert_to_tensor(v)
return True
except ValueError:
@ -353,7 +353,7 @@ def _tfval_shape_dtype(val: TfVal) -> Tuple[Sequence[Optional[int]], DType]:
# May be partially known
return tuple(val.shape), to_jax_dtype(val.dtype)
else: # Must be a numeric value
assert core.skip_checks or _is_tfval(val), f"Non TfVal: {val}"
assert not config.jax_enable_checks or _is_tfval(val), f"Non TfVal: {val}"
raw_aval = xla.abstractify(val)
return raw_aval.shape, raw_aval.dtype # type: ignore[attr-defined]
@ -605,7 +605,7 @@ class TensorFlowTracer(core.Tracer):
val = tf.cast(val, dtype=aval_dtype)
val_dtype = aval_dtype
if not core.skip_checks:
if config.jax_enable_checks:
assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}"
for aval_dim, val_dim in util.safe_zip(self._aval.shape, val_shape): # type: ignore[attr-defined]
if val_dim is None:
@ -703,7 +703,7 @@ class TensorFlowTrace(core.Trace):
# Check that the impl rule returned a value of expected shape and dtype
# TODO: adapt this to match polymorphic shapes
if not core.skip_checks:
if config.jax_enable_checks:
if primitive.multiple_results:
for o, expected_aval in zip(out, out_aval): # type: ignore
assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), (
@ -1530,7 +1530,7 @@ def _common_reduce_window(operand, init_val, reducer, window_dimensions,
reducer_fn = tf.function(reducer, autograph=False).get_concrete_function(o_spec, o_spec)
if not isinstance(init_val, tf.Tensor):
assert core.skip_checks or _is_tfval(init_val), f"Non TfVal: {init_val}"
assert not config.jax_enable_checks or _is_tfval(init_val), f"Non TfVal: {init_val}"
init_val = tf.constant(init_val, operand.dtype)
out = tfxla.reduce_window(operand, init_val,
reducer_fn, window_dimensions,

View File

@ -135,7 +135,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
dtype=dtype)
for dtype in [np.int64, np.float64]))
def test_converts_64bit(self, dtype=np.int64, with_function=False):
if not config.FLAGS.jax_enable_x64:
if not config.jax_enable_x64:
self.skipTest("requires x64 mode")
big_const = np.full((5,), 2 ** 33, dtype=dtype)
self.ConvertAndCompare(jnp.sin, big_const)

View File

@ -17,31 +17,14 @@
**Experimental: please give feedback, and expect changes.**
"""
# This file provides
# 1. a jax.experimental API endpoint;
# 2. the `disable_x64` wrapper.
# TODO(jakevdp): remove this file, and consider removing `disable_x64` for
# uniformity
from contextlib import contextmanager
from jax import config
@contextmanager
def enable_x64():
"""Experimental context manager to temporarily enable X64 mode.
Usage::
>>> import jax.numpy as jnp
>>> with enable_x64():
... print(jnp.arange(10.0).dtype)
...
float64
See Also
--------
jax.experimental.disable_x64 : temporarily disable X64 mode.
"""
_x64_state = config.x64_enabled
config._set_x64_enabled(True)
try:
yield
finally:
config._set_x64_enabled(_x64_state)
from jax.config import enable_x64
@contextmanager
def disable_x64():
@ -59,9 +42,5 @@ def disable_x64():
--------
jax.experimental.enable_x64 : temporarily enable X64 mode.
"""
_x64_state = config.x64_enabled
config._set_x64_enabled(False)
try:
with enable_x64(False):
yield
finally:
config._set_x64_enabled(_x64_state)

View File

@ -174,7 +174,7 @@ def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
# assert v.aval == ct.aval, (prim, v.aval, ct.aval)
return
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
if not core.skip_checks:
if config.jax_enable_checks:
ct_aval = core.get_aval(ct_env[v])
joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type()
assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval)
@ -389,7 +389,7 @@ class JVPTracer(Tracer):
__slots__ = ['primal', 'tangent']
def __init__(self, trace, primal, tangent):
if not core.skip_checks:
if config.jax_enable_checks:
_primal_tangent_shapes_match(primal, tangent)
self._trace = trace
self.primal = primal

View File

@ -108,7 +108,7 @@ class BatchTracer(Tracer):
__slots__ = ['val', 'batch_dim']
def __init__(self, trace, val, batch_dim: Optional[int]):
assert core.skip_checks or type(batch_dim) in (int, NotMapped) # type: ignore
assert not config.jax_enable_checks or type(batch_dim) in (int, NotMapped) # type: ignore
self._trace = trace
self.val = val
self.batch_dim = batch_dim

View File

@ -49,7 +49,7 @@ class PartialVal(tuple):
"""
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
pv, const = xs
if not core.skip_checks:
if config.jax_enable_checks:
# type checks
assert isinstance(pv, (AbstractValue, type(None))), xs
assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs
@ -648,25 +648,25 @@ def tracers_to_jaxpr(
const_vars, const_vals = unzip2(consts.items())
# The env_vars are pre-pended to the invars
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
core.skip_checks or core.check_jaxpr(jaxpr)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
return jaxpr, const_vals, env_vals
@cache()
def convert_constvars_jaxpr(jaxpr: Jaxpr):
"""Moves the constvars to the start of invars."""
core.skip_checks or core.check_jaxpr(jaxpr)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int):
core.skip_checks or core.check_jaxpr(jaxpr)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
core.skip_checks or core.check_jaxpr(converted_jaxpr)
config.jax_enable_checks and core.check_jaxpr(converted_jaxpr)
return converted_jaxpr

View File

@ -486,7 +486,7 @@ class ShardedDeviceArray(xla.DeviceArray): # type: ignore
self.indices = indices
self._npy_value = None
self._one_replica_buffer_indices = None
if not core.skip_checks:
if config.jax_enable_checks:
assert type(aval) is ShapedArray
@property
@ -792,7 +792,7 @@ def parallel_callable(fun: lu.WrappedFun,
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
f"{jaxpr_replicas} and nested_partitions={num_partitions}")
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
f"Compiling {fun.__name__} ({id(fun)}) for {num_global_shards} "
f"devices with args {avals}. (num_replicas={num_global_replicas}"
@ -1387,7 +1387,7 @@ def mesh_callable(fun: lu.WrappedFun,
global_axis_sizes = mesh.shape
local_axis_sizes = local_mesh.shape
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
f"Compiling {fun.__name__} ({id(fun)}) for {tuple(global_axis_sizes.items())} "
f"mesh with args {local_in_untiled_avals}. Argument mapping: {in_axes}.")

View File

@ -144,7 +144,7 @@ def _sharded_callable(
for out, parts, lparts
in safe_zip(global_out_avals, out_parts, local_out_parts)]
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
f"Compiling {fun.__name__} for {nparts} devices with "
f"args {global_abstract_args}.")

View File

@ -23,7 +23,7 @@ from warnings import warn
from absl import logging
import numpy as np
from ..config import flags, bool_env, config
from ..config import config
from .. import core
from .. import ad_util
from .. import dtypes
@ -58,17 +58,6 @@ XlaShape = Any # xla_client.Shape
XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder
XlaExecutable = Any # xla_extension.LocalExecutable
FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_debug_nans',
bool_env('JAX_DEBUG_NANS', False),
'Add nan checks to every operation.')
flags.DEFINE_bool('jax_debug_infs',
bool_env('JAX_DEBUG_INFS', False),
'Add inf checks to every operation.')
flags.DEFINE_bool('jax_log_compiles',
bool_env('JAX_LOG_COMPILES', False),
'Print a message each time a `jit` computation is compiled.')
# This flag is set on exit; no logging should be attempted
_on_exit = False
@ -244,7 +233,7 @@ def apply_primitive(prim, *args, **params):
def _partition_outputs(avals, outs):
nouts = [aval._num_buffers for aval in avals]
if not core.skip_checks:
if config.jax_enable_checks:
assert sum(nouts) == len(outs), f"Internal error: sum(nouts)={sum(nouts)} should equal len(outs)={len(outs)}."
outs = iter(outs)
return [[next(outs) for _ in range(nout)] for nout in nouts]
@ -372,7 +361,7 @@ def _execute_replicated_primitive(prim, compiled, result_handler, *args):
return result_handler(*out_bufs)
def needs_check_special():
return FLAGS.jax_debug_infs or FLAGS.jax_debug_nans
return config.jax_debug_infs or config.jax_debug_nans
def check_special(name, bufs):
if needs_check_special():
@ -382,9 +371,9 @@ def check_special(name, bufs):
def _check_special(name, xla_shape, buf):
assert not xla_shape.is_tuple()
if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
if FLAGS.jax_debug_nans and np.any(np.isnan(buf.to_py())):
if config.jax_debug_nans and np.any(np.isnan(buf.to_py())):
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
if FLAGS.jax_debug_infs and np.any(np.isinf(buf.to_py())):
if config.jax_debug_infs and np.any(np.isinf(buf.to_py())):
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
### compiling jaxprs
@ -590,13 +579,13 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_inv
try:
return compiled_fun(*args)
except FloatingPointError:
assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
print("Invalid value encountered in the output of a jit function. "
"Calling the de-optimized version.")
# We want to run the wrapped function again (after _xla_callable already ran
# it), but linear_util.WrappedFun instances are meant to be run only once.
# In addition to re-executing the Python code, which is usually undesirable
# but which FLAGS.jax_debug_nans is meant to opt into, we'll be re-executing
# but which config.jax_debug_nans is meant to opt into, we'll be re-executing
# any linear_util.py-style side effects, i.e. re-populating Stores created
# by any transformation_with_aux's applied to fun. Since this is
# intentional here, to avoid "Store occupied" errors we reset the stores to
@ -688,7 +677,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
return partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers)
if not _on_exit:
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority, "Compiling %s (%s) for args %s.",
fun.__name__, id(fun), abstract_args)
@ -1096,7 +1085,7 @@ class _DeviceArray(DeviceArray): # type: ignore
self._device = device
self._npy_value = None
if not core.skip_checks:
if config.jax_enable_checks:
assert type(aval) is ShapedArray
npy_value = self._value
assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape

View File

@ -248,7 +248,7 @@ def cache(call: Callable):
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if core.debug_state.check_leaks:
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args, config.x64_enabled)
else:
key = (fun.transforms, fun.params, args, config.x64_enabled)

View File

@ -445,7 +445,7 @@ def skip_on_flag(flag_name, skip_value):
def skip(test_method): # pylint: disable=missing-docstring
@functools.wraps(test_method)
def test_method_wrapper(self, *args, **kwargs):
flag_value = getattr(FLAGS, flag_name)
flag_value = config._read(flag_name)
if flag_value == skip_value:
test_name = getattr(test_method, '__name__', '[unknown test]')
raise unittest.SkipTest(
@ -819,7 +819,7 @@ class JaxTestCase(parameterized.TestCase):
def setUp(self):
super(JaxTestCase, self).setUp()
core.skip_checks = False
config.update('jax_enable_checks', True)
# We use the adler32 hash for two reasons.
# a) it is deterministic run to run, unlike hash() which is randomized.
# b) it returns values in int32 range, which RandomState requires.

View File

@ -1,5 +1,6 @@
[mypy]
show_error_codes=True
show_error_codes = True
disable_error_code = attr-defined
[mypy-absl.*]
ignore_missing_imports = True

View File

@ -2217,7 +2217,7 @@ class APITest(jtu.JaxTestCase):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with core.checking_leaks():
with jax.checking_leaks():
lst = []
@jit
@ -2232,7 +2232,7 @@ class APITest(jtu.JaxTestCase):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with core.checking_leaks():
with jax.checking_leaks():
lst = []
@api.pmap
@ -2247,7 +2247,7 @@ class APITest(jtu.JaxTestCase):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with core.checking_leaks():
with jax.checking_leaks():
lst = []
def f(x):
@ -2261,7 +2261,7 @@ class APITest(jtu.JaxTestCase):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with core.checking_leaks():
with jax.checking_leaks():
@jit
def f(x):
return x
@ -2279,7 +2279,7 @@ class APITest(jtu.JaxTestCase):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with core.checking_leaks():
with jax.checking_leaks():
lst = []
to_scan = lambda c, x: (lst.append(c) or jnp.sin(c), None)
@ -2291,7 +2291,7 @@ class APITest(jtu.JaxTestCase):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with core.checking_leaks():
with jax.checking_leaks():
to_scan = lambda c, x: (jnp.sin(c), None)
lax.scan(to_scan, 1., np.arange(3.)) # doesn't crash
@ -2299,7 +2299,7 @@ class APITest(jtu.JaxTestCase):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with core.checking_leaks():
with jax.checking_leaks():
to_scan = lambda c, x: (c, None)
def f(x):
@ -2310,7 +2310,7 @@ class APITest(jtu.JaxTestCase):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with core.checking_leaks():
with jax.checking_leaks():
to_scan = lambda c, _: (1., None)
@api.vmap
@ -2322,7 +2322,7 @@ class APITest(jtu.JaxTestCase):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with core.checking_leaks():
with jax.checking_leaks():
to_scan = lambda c, _: (c, None)
@api.vmap
@ -2334,7 +2334,7 @@ class APITest(jtu.JaxTestCase):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with core.checking_leaks():
with jax.checking_leaks():
@jit
def f(x):
lst = []
@ -4957,29 +4957,29 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
api.grad(lambda x, y: f(x, y)[0])(1., 2.) # doesn't crash
def test_custom_transforms_vjp_nones(self):
core.skip_checks = True # Fails with checks
# issue raised by jsnoek@ and jumper@
@jax.custom_transforms
def solve(a, b):
return jnp.dot(jnp.linalg.inv(a), b)
# print(solve(a, b))
with jax.enable_checks(False): # fails with checks
# issue raised by jsnoek@ and jumper@
@jax.custom_transforms
def solve(a, b):
return jnp.dot(jnp.linalg.inv(a), b)
# print(solve(a, b))
def solve_vjp(a, b):
x = solve(a, b)
def vjp(x_tangent):
dx = jnp.dot(solve(a, x_tangent), x.T)
out = (dx, b * 0.)
return out
return x, vjp
jax.defvjp_all(solve, solve_vjp)
gf = grad(lambda a,b: jnp.sum(solve(a, b)))
def solve_vjp(a, b):
x = solve(a, b)
def vjp(x_tangent):
dx = jnp.dot(solve(a, x_tangent), x.T)
out = (dx, b * 0.)
return out
return x, vjp
jax.defvjp_all(solve, solve_vjp)
gf = grad(lambda a,b: jnp.sum(solve(a, b)))
n = 3
a_in = jnp.linspace(0, 1, n)[:, None]
a = jnp.dot(a_in, a_in.T) + jnp.eye(n) * 0.1
real_x = np.random.RandomState(0).randn(n)
b = jnp.dot(a + jnp.eye(a.shape[0]), real_x)
print(gf(a, b)) # doesn't crash
n = 3
a_in = jnp.linspace(0, 1, n)[:, None]
a = jnp.dot(a_in, a_in.T) + jnp.eye(n) * 0.1
real_x = np.random.RandomState(0).randn(n)
b = jnp.dot(a + jnp.eye(a.shape[0]), real_x)
print(gf(a, b)) # doesn't crash
class BufferDonationTest(jtu.BufferDonationTestCase):

View File

@ -30,7 +30,7 @@ config.parse_flags_with_absl()
class DebugNaNsTest(jtu.JaxTestCase):
def setUp(self):
self.cfg = config.read("jax_debug_nans")
self.cfg = config._read("jax_debug_nans")
config.update("jax_debug_nans", True)
def tearDown(self):
@ -144,7 +144,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
class DebugInfsTest(jtu.JaxTestCase):
def setUp(self):
self.cfg = config.read("jax_debug_infs")
self.cfg = config._read("jax_debug_infs")
config.update("jax_debug_infs", True)
def tearDown(self):

View File

@ -148,7 +148,7 @@ class DJaxADTests(jtu.JaxTestCase):
y = sin(x)
return reduce_sum(y, axes=(0,))
x = bbarray((5,), jnp.arange(2.))
with jax.core.skipping_checks(): # TODO implement dxla_call abs eval rule
with jax.enable_checks(False): # TODO implement dxla_call abs eval rule
z, f_lin = jax.linearize(f, x)
z_dot = f_lin(ones_like(x))

View File

@ -25,7 +25,6 @@ import numpy as np
import jax
from jax import api
from jax import core
from jax import dtypes
from jax import lax
from jax import test_util as jtu
@ -992,7 +991,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
expected = np.array(0.0)
self.assertAllClose(ans, expected, check_dtypes=False)
with core.skipping_checks():
with jax.enable_checks(False):
with self.assertRaises(TypeError):
lax.stop_gradient(lambda x: x)

View File

@ -1709,7 +1709,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(carry_out[1], carry_init, check_dtypes=False)
self.assertAllClose(carry_out[0], jnp.array([2., 2., 2.]), check_dtypes = False)
# TODO(mattjj, dougalm): fix this test when skip_checks is False
def testIssue757(self):
# code from https://github.com/google/jax/issues/757
def fn(a):

View File

@ -208,7 +208,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
))
def test_bicgstab_against_scipy(
self, shape, dtype, preconditioner):
if not config.FLAGS.jax_enable_x64:
if not config.jax_enable_x64:
raise unittest.SkipTest("requires x64 mode")
rng = jtu.rand_default(self.rng())

View File

@ -2190,7 +2190,7 @@ class LaxTest(jtu.JaxTestCase):
# api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.)
def test_primitive_jaxtype_error(self):
with core.skipping_checks():
with jax.enable_checks(False):
with self.assertRaisesRegex(
TypeError, "Argument .* of type .* is not a valid JAX type"):
lax.add(1, 'hi')

View File

@ -117,12 +117,12 @@ class NNFunctionsTest(jtu.JaxTestCase):
def testEluMemory(self):
# see https://github.com/google/jax/pull/1640
with core.skipping_checks(): # With checks we materialize the array
with jax.enable_checks(False): # With checks we materialize the array
jax.make_jaxpr(lambda: nn.elu(jnp.ones((10 ** 12,)))) # don't oom
def testHardTanhMemory(self):
# see https://github.com/google/jax/pull/1640
with core.skipping_checks(): # With checks we materialize the array
with jax.enable_checks(False): # With checks we materialize the array
jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
def testOneHot(self):

View File

@ -914,7 +914,7 @@ class LaxRandomTest(jtu.JaxTestCase):
raise SkipTest("after deleting lazy constants, requires omnistaging")
def f(x):
return random.normal(random.PRNGKey(x), (int(1e12),))
with core.skipping_checks(): # check_jaxpr will materialize array
with jax.enable_checks(False): # check_jaxpr will materialize array
api.eval_shape(f, 0) # doesn't error
@parameterized.named_parameters(jtu.cases_from_list(

View File

@ -73,7 +73,7 @@ class X64ContextTests(jtu.JaxTestCase):
func = _maybe_jit(jit, lambda: jnp.arange(10.0))
func()
expected_dtype = "float64" if config.read("jax_enable_x64") else "float32"
expected_dtype = "float64" if config._read("jax_enable_x64") else "float32"
self.assertEqual(func().dtype, expected_dtype)
with enable_x64():