mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
unify configuration state handling
This commit is contained in:
parent
22a2be30a2
commit
fd7b286ec9
@ -30,7 +30,8 @@ except Exception as exc:
|
||||
del _cloud_tpu_init
|
||||
|
||||
# flake8: noqa: F401
|
||||
from .config import config
|
||||
from .config import (config, enable_checks, check_tracer_leaks, checking_leaks,
|
||||
debug_nans, debug_infs, log_compiles)
|
||||
from .api import (
|
||||
ad, # TODO(phawkins): update users to avoid this.
|
||||
argnums_partial, # TODO(phawkins): update Haiku to not use this.
|
||||
|
@ -1113,7 +1113,7 @@ def _cond_typecheck(*avals, branches, linear):
|
||||
f'called with operands of type {_avals_short(op_avals)}')
|
||||
|
||||
def cond_bind(*args, branches, linear):
|
||||
if not core.skip_checks:
|
||||
if config.jax_enable_checks:
|
||||
avals = _map(core.get_aval, args)
|
||||
_cond_typecheck(*avals, branches=branches, linear=linear)
|
||||
for jaxpr in branches:
|
||||
@ -1876,7 +1876,7 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
|
||||
f'called with sequence of type\n{_avals_short(x_avals)}')
|
||||
|
||||
def scan_bind(*args, **params):
|
||||
if not core.skip_checks:
|
||||
if config.jax_enable_checks:
|
||||
avals = _map(core.get_aval, args)
|
||||
_scan_typecheck(True, *avals, **params)
|
||||
core.check_jaxpr(params['jaxpr'].jaxpr)
|
||||
|
@ -21,7 +21,6 @@ from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
|
||||
partial = functools.partial
|
||||
@ -192,7 +191,7 @@ def cache(max_size=4096):
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
if jax.core.debug_state.check_leaks:
|
||||
if config.jax_check_tracer_leaks:
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
return cached(bool(config.x64_enabled), *args, **kwargs)
|
||||
|
10
jax/api.py
10
jax/api.py
@ -41,7 +41,7 @@ from . import lib
|
||||
from . import linear_util as lu
|
||||
from . import ad_util
|
||||
from . import dtypes
|
||||
from .core import eval_jaxpr, checking_leaks
|
||||
from .core import eval_jaxpr
|
||||
from .api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
|
||||
flatten_fun_nokwargs2, argnums_partial,
|
||||
argnums_partial_except, flatten_axes, donation_vector,
|
||||
@ -362,7 +362,7 @@ def _cpp_jit(
|
||||
context = (getattr(core.thread_local_state.trace_state.trace_stack,
|
||||
"dynamic", None), config.x64_enabled)
|
||||
# TODO(jblespiau): Move this to C++.
|
||||
if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled():
|
||||
if (config.jax_debug_nans or config.jax_debug_infs) and not _jit_is_disabled():
|
||||
device_arrays = cpp_jitted_f(context, *args, **kwargs)
|
||||
try:
|
||||
xla.check_special(xla.xla_call_p, [
|
||||
@ -372,7 +372,7 @@ def _cpp_jit(
|
||||
])
|
||||
return device_arrays
|
||||
except FloatingPointError:
|
||||
assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case
|
||||
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
|
||||
print("Invalid nan value encountered in the output of a C++-jit "
|
||||
"function. Calling the de-optimized version.")
|
||||
return cache_miss(*args, **kwargs)[0] # probably won't return
|
||||
@ -389,7 +389,7 @@ def _cpp_jit(
|
||||
@api_boundary
|
||||
def f_jitted(*args, **kwargs):
|
||||
# TODO(jblespiau): Move this to C++.
|
||||
if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled():
|
||||
if (config.jax_debug_nans or config.jax_debug_infs) and not _jit_is_disabled():
|
||||
device_arrays = cpp_jitted_f(*args, **kwargs)
|
||||
try:
|
||||
xla.check_special(xla.xla_call_p, [
|
||||
@ -399,7 +399,7 @@ def _cpp_jit(
|
||||
])
|
||||
return device_arrays
|
||||
except FloatingPointError:
|
||||
assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case
|
||||
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
|
||||
print("Invalid nan value encountered in the output of a C++-jit "
|
||||
"function. Calling the de-optimized version.")
|
||||
return cache_miss(*args, **kwargs)[0] # probably won't return
|
||||
|
178
jax/config.py
178
jax/config.py
@ -12,8 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from jax import lib
|
||||
|
||||
def bool_env(varname: str, default: bool) -> bool:
|
||||
@ -42,11 +46,16 @@ def int_env(varname: str, default: int) -> int:
|
||||
|
||||
|
||||
class Config:
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
|
||||
def __init__(self):
|
||||
self.values = {}
|
||||
self.meta = {}
|
||||
self.FLAGS = NameSpace(self.read)
|
||||
self.use_absl = False
|
||||
self._contextmanager_flags = set()
|
||||
|
||||
# TODO(mattjj): delete these when only omnistaging is available
|
||||
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
|
||||
self._omnistaging_disablers = []
|
||||
|
||||
@ -65,6 +74,13 @@ class Config:
|
||||
lib.jax_jit.global_state().enable_x64 = val
|
||||
|
||||
def read(self, name):
|
||||
if name in self._contextmanager_flags:
|
||||
raise AttributeError(
|
||||
"For flags with a corresponding contextmanager, read their value "
|
||||
f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.")
|
||||
return self._read(name)
|
||||
|
||||
def _read(self, name):
|
||||
if self.use_absl:
|
||||
return getattr(self.absl_flags.FLAGS, name)
|
||||
else:
|
||||
@ -143,14 +159,82 @@ class Config:
|
||||
disabler()
|
||||
self.omnistaging_enabled = False
|
||||
|
||||
@property
|
||||
def x64_enabled(self):
|
||||
return lib.jax_jit.get_enable_x64()
|
||||
# # TODO(jakevdp, mattjj): unify this with `define_bool_state` stuff below
|
||||
# @property
|
||||
# def x64_enabled(self):
|
||||
# return lib.jax_jit.get_enable_x64()
|
||||
|
||||
# TODO(jakevdp): make this public when thread-local x64 is fully implemented.
|
||||
def _set_x64_enabled(self, state):
|
||||
lib.jax_jit.thread_local_state().enable_x64 = bool(state)
|
||||
# def _set_x64_enabled(self, state):
|
||||
# lib.jax_jit.thread_local_state().enable_x64 = bool(state)
|
||||
|
||||
def define_bool_state(self, name: str, default: bool, help: str):
|
||||
"""Set up thread-local state and return a contextmanager for managing it.
|
||||
|
||||
This function is a convenience wrapper. It defines a flag and corresponding
|
||||
thread-local state, which can be managed via the contextmanager it returns.
|
||||
|
||||
The thread-local state value can be read via the ``config.<option_name>``
|
||||
attribute, where ``config`` is the singleton ``Config`` instance.
|
||||
|
||||
Args:
|
||||
name: string, converted to lowercase to define the name of the config
|
||||
option (and absl flag). It is converted to uppercase to define the
|
||||
corresponding shell environment variable.
|
||||
default: boolean, a default value for the option.
|
||||
help: string, used to populate the flag help information as well as the
|
||||
docstring of the returned context manager.
|
||||
|
||||
Returns:
|
||||
A contextmanager to control the thread-local state value.
|
||||
|
||||
Example:
|
||||
|
||||
enable_foo = config.define_bool_state(
|
||||
name='jax_enable_foo',
|
||||
default=False,
|
||||
help='Enable foo.')
|
||||
|
||||
# Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo
|
||||
# command-line flag can be used to control the process-level value of
|
||||
# the configuration option, in addition to using e.g.
|
||||
# ``config.update("jax_enable_foo", True)`` directly. We can also use a
|
||||
# context manager:
|
||||
|
||||
with enable_foo(True):
|
||||
...
|
||||
|
||||
The value of the thread-local state or flag can be accessed via
|
||||
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
|
||||
an error.
|
||||
"""
|
||||
name = name.lower()
|
||||
self.DEFINE_bool(name, bool_env(name.upper(), default), help)
|
||||
self._contextmanager_flags.add(name)
|
||||
|
||||
def get_state(self):
|
||||
val = getattr(_thread_local_state, name, unset)
|
||||
return val if val is not unset else self._read(name)
|
||||
setattr(Config, name, property(get_state))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_state(new_val: bool):
|
||||
prev_val = getattr(_thread_local_state, name, unset)
|
||||
setattr(_thread_local_state, name, new_val)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prev_val is unset:
|
||||
delattr(_thread_local_state, name)
|
||||
else:
|
||||
setattr(_thread_local_state, name, prev_val)
|
||||
set_state.__name__ = name[4:] if name.startswith('jax_') else name
|
||||
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
|
||||
return set_state
|
||||
|
||||
_thread_local_state = threading.local()
|
||||
|
||||
class Unset: pass
|
||||
unset = Unset()
|
||||
|
||||
class NameSpace(object):
|
||||
def __init__(self, getter):
|
||||
@ -166,11 +250,6 @@ FLAGS = flags.FLAGS
|
||||
|
||||
already_configured_with_absl = False
|
||||
|
||||
flags.DEFINE_bool(
|
||||
'jax_enable_checks',
|
||||
bool_env('JAX_ENABLE_CHECKS', False),
|
||||
help='Turn on invariant checking (core.skip_checks = False)'
|
||||
)
|
||||
|
||||
flags.DEFINE_bool(
|
||||
'jax_omnistaging',
|
||||
@ -184,14 +263,6 @@ flags.DEFINE_integer(
|
||||
help='Set the number of stack frames in JAX tracer error messages.'
|
||||
)
|
||||
|
||||
flags.DEFINE_bool(
|
||||
'jax_check_tracer_leaks',
|
||||
bool_env('JAX_CHECK_TRACER_LEAKS', False),
|
||||
help=('Turn on checking for leaked tracers as soon as a trace completes. '
|
||||
'Enabling leak checking may have performance impacts: some caching '
|
||||
'is disabled, and other overheads may be added.'),
|
||||
)
|
||||
|
||||
flags.DEFINE_bool(
|
||||
'jax_host_callback_inline',
|
||||
bool_env('JAX_HOST_CALLBACK_INLINE', False),
|
||||
@ -206,3 +277,72 @@ flags.DEFINE_integer(
|
||||
'until the Python callback consume more outfeeds.'),
|
||||
lower_bound=int(16 * 1e6)
|
||||
)
|
||||
|
||||
|
||||
enable_checks = config.define_bool_state(
|
||||
name='jax_enable_checks',
|
||||
default=False,
|
||||
help='Turn on invariant checking for JAX internals. Makes things slower.')
|
||||
|
||||
check_tracer_leaks = config.define_bool_state(
|
||||
name='jax_check_tracer_leaks',
|
||||
default=False,
|
||||
help=('Turn on checking for leaked tracers as soon as a trace completes. '
|
||||
'Enabling leak checking may have performance impacts: some caching '
|
||||
'is disabled, and other overheads may be added.'))
|
||||
checking_leaks = functools.partial(check_tracer_leaks, True)
|
||||
|
||||
debug_nans = config.define_bool_state(
|
||||
name='jax_debug_nans',
|
||||
default=False,
|
||||
help=('Add nan checks to every operation. When a nan is detected on the '
|
||||
'output of a jit-compiled computation, call into the un-compiled '
|
||||
'version in an attempt to more precisely identify the operation '
|
||||
'which produced the nan.'))
|
||||
|
||||
debug_infs = config.define_bool_state(
|
||||
name='jax_debug_infs',
|
||||
default=False,
|
||||
help=('Add inf checks to every operation. When an inf is detected on the '
|
||||
'output of a jit-compiled computation, call into the un-compiled '
|
||||
'version in an attempt to more precisely identify the operation '
|
||||
'which produced the inf.'))
|
||||
|
||||
log_compiles = config.define_bool_state(
|
||||
name='jax_log_compiles',
|
||||
default=False,
|
||||
help=('Log a message each time every time `jit` or `pmap` compiles an XLA '
|
||||
'computation. Logging is performed with `absl.logging`. When this '
|
||||
'option is set, the log level is WARNING; otherwise the level is '
|
||||
'DEBUG.'))
|
||||
|
||||
# Because jax_enable_x64 is managed by C++ code, we don't reuse the
|
||||
# config.define_bool_state mechanism, though conceptually it is the same.
|
||||
config.DEFINE_bool('jax_enable_x64', bool_env('JAX_ENABLE_X64', False),
|
||||
help='Enable 64-bit types to be used')
|
||||
lib.jax_jit.global_state().enable_x64 = bool_env('JAX_ENABLE_X64', False)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_x64(new_val: bool = True):
|
||||
"""Experimental context manager to temporarily enable X64 mode.
|
||||
|
||||
Usage::
|
||||
|
||||
>>> import jax.numpy as jnp
|
||||
>>> with enable_x64(True):
|
||||
... print(jnp.arange(10.0).dtype)
|
||||
...
|
||||
float64
|
||||
"""
|
||||
prev_val = config.jax_enable_x64
|
||||
lib.jax_jit.thread_local_state().enable_x64 = bool(new_val)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lib.jax_jit.thread_local_state().enable_x64 = prev_val
|
||||
Config.jax_enable_x64 = property(lambda self: lib.jax_jit.get_enable_x64())
|
||||
# config._contextmanager_flags.add('jax_enable_x64') # TODO(mattjj): remove footgun
|
||||
|
||||
# The `x64_enabled` property doesn't fit the naming scheme, but we use it for
|
||||
# backward compatibility.
|
||||
Config.x64_enabled = Config.jax_enable_x64
|
||||
|
45
jax/core.py
45
jax/core.py
@ -45,33 +45,6 @@ from ._src.pprint_util import pp, vcat, PrettyPrint
|
||||
from ._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
# TODO(mattjj): move this into debug_state
|
||||
skip_checks = not FLAGS.jax_enable_checks
|
||||
|
||||
@contextmanager
|
||||
def skipping_checks():
|
||||
"""Context manager for temporarily disabling internal checks."""
|
||||
global skip_checks
|
||||
old_value, skip_checks = skip_checks, True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
skip_checks = old_value
|
||||
|
||||
@contextmanager
|
||||
def checking_leaks():
|
||||
"""Context manager for temporarily enabling tracer leak checks."""
|
||||
old_value, debug_state.check_leaks = debug_state.check_leaks, True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
debug_state.check_leaks = old_value
|
||||
|
||||
class DebugState(threading.local):
|
||||
def __init__(self):
|
||||
self.check_leaks = FLAGS.jax_check_tracer_leaks
|
||||
debug_state = DebugState()
|
||||
|
||||
zip = safe_zip
|
||||
map = safe_map
|
||||
|
||||
@ -279,8 +252,8 @@ class Primitive:
|
||||
|
||||
|
||||
def bind(self, *args, **params):
|
||||
assert skip_checks or all(isinstance(arg, Tracer)
|
||||
or valid_jaxtype(arg) for arg in args), args
|
||||
assert (not config.jax_enable_checks or
|
||||
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
|
||||
top_trace = find_top_trace(args)
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
out = top_trace.process_primitive(self, tracers, params)
|
||||
@ -569,7 +542,7 @@ class Tracer:
|
||||
|
||||
def __getattr__(self, name):
|
||||
# if the aval property raises an AttributeError, gets caught here
|
||||
assert skip_checks or name != "aval"
|
||||
assert not config.jax_enable_checks or name != "aval"
|
||||
|
||||
try:
|
||||
attr = getattr(self.aval, name)
|
||||
@ -783,7 +756,7 @@ def new_main(trace_type: Type[Trace],
|
||||
if lib._xla_extension_version >= 11:
|
||||
jit_tls.extra_jit_context = extra_jit_context(stack)
|
||||
|
||||
if debug_state.check_leaks:
|
||||
if config.jax_check_tracer_leaks:
|
||||
t = ref(main)
|
||||
del main
|
||||
if t() is not None:
|
||||
@ -807,7 +780,7 @@ def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
|
||||
if lib._xla_extension_version >= 11:
|
||||
jit_tls.extra_jit_context = extra_jit_context(stack)
|
||||
|
||||
if debug_state.check_leaks:
|
||||
if config.jax_check_tracer_leaks:
|
||||
t = ref(main)
|
||||
del main
|
||||
if t() is not None:
|
||||
@ -827,7 +800,7 @@ def new_sublevel() -> Generator[None, None, None]:
|
||||
finally:
|
||||
thread_local_state.trace_state.substack.pop()
|
||||
|
||||
if debug_state.check_leaks:
|
||||
if config.jax_check_tracer_leaks:
|
||||
t = ref(sublevel)
|
||||
del sublevel
|
||||
if t() is not None:
|
||||
@ -899,7 +872,7 @@ class AbstractUnit(AbstractValue):
|
||||
# _num_buffers = 0
|
||||
def at_least_vspace(self): return self
|
||||
def join(self, other):
|
||||
if not skip_checks:
|
||||
if config.jax_enable_checks:
|
||||
assert other is abstract_unit, other
|
||||
return self
|
||||
def _eq(self, self_traced, other): return get_aval(other) is self
|
||||
@ -1932,7 +1905,7 @@ def omnistaging_disabler() -> None:
|
||||
finally:
|
||||
thread_local_state.trace_state.trace_stack.pop(bottom)
|
||||
|
||||
if debug_state.check_leaks:
|
||||
if config.jax_check_tracer_leaks:
|
||||
t = ref(main)
|
||||
del main
|
||||
if t() is not None:
|
||||
@ -1949,7 +1922,7 @@ def omnistaging_disabler() -> None:
|
||||
yield # dummy implementation for forward compatibility
|
||||
|
||||
def bind(self, *args, **kwargs):
|
||||
assert skip_checks or all(isinstance(arg, Tracer)
|
||||
assert not config.jax_enable_checks or all(isinstance(arg, Tracer)
|
||||
or valid_jaxtype(arg) for arg in args), args
|
||||
top_trace = find_top_trace(args)
|
||||
if top_trace is None:
|
||||
|
@ -914,7 +914,7 @@ def closure_convert(fun, *example_args):
|
||||
"""
|
||||
flat_args, in_tree = tree_flatten(example_args)
|
||||
in_avals = tuple(map(abstractify, flat_args))
|
||||
if core.debug_state.check_leaks:
|
||||
if config.jax_check_tracer_leaks:
|
||||
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
|
||||
else:
|
||||
return _closure_convert_for_avals(fun, in_tree, in_avals)
|
||||
|
@ -20,27 +20,19 @@
|
||||
# so we need our own implementation that deviates from NumPy in places.
|
||||
|
||||
|
||||
from distutils.util import strtobool
|
||||
import functools
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._src import util
|
||||
from .config import flags, config
|
||||
from . import lib
|
||||
from .lib import xla_client
|
||||
|
||||
from ._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_bool('jax_enable_x64',
|
||||
strtobool(os.getenv('JAX_ENABLE_X64', 'False')),
|
||||
'Enable 64-bit types to be used.')
|
||||
lib.jax_jit.global_state().enable_x64 = strtobool(
|
||||
os.getenv('JAX_ENABLE_X64', 'False'))
|
||||
|
||||
# bfloat16 support
|
||||
bfloat16: type = xla_client.bfloat16
|
||||
|
@ -70,9 +70,9 @@ def _is_tfval(v: TfVal) -> bool:
|
||||
if isinstance(v, (tf.Tensor, tf.Variable)):
|
||||
return True
|
||||
try:
|
||||
# Note: this conversion is overkill and just intended as a type check; this code
|
||||
# is in principle only run if core.skip_checks is False.
|
||||
# TODO: it is not true that this code is run only without skip_checks
|
||||
# Note: this conversion is overkill and just intended as a type check; this
|
||||
# code is in principle only run if config.jax_enable_checks is True.
|
||||
# TODO: it is not true that this code is run only with jax_enable_checks.
|
||||
_safe_convert_to_tensor(v)
|
||||
return True
|
||||
except ValueError:
|
||||
@ -353,7 +353,7 @@ def _tfval_shape_dtype(val: TfVal) -> Tuple[Sequence[Optional[int]], DType]:
|
||||
# May be partially known
|
||||
return tuple(val.shape), to_jax_dtype(val.dtype)
|
||||
else: # Must be a numeric value
|
||||
assert core.skip_checks or _is_tfval(val), f"Non TfVal: {val}"
|
||||
assert not config.jax_enable_checks or _is_tfval(val), f"Non TfVal: {val}"
|
||||
raw_aval = xla.abstractify(val)
|
||||
return raw_aval.shape, raw_aval.dtype # type: ignore[attr-defined]
|
||||
|
||||
@ -605,7 +605,7 @@ class TensorFlowTracer(core.Tracer):
|
||||
val = tf.cast(val, dtype=aval_dtype)
|
||||
val_dtype = aval_dtype
|
||||
|
||||
if not core.skip_checks:
|
||||
if config.jax_enable_checks:
|
||||
assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}"
|
||||
for aval_dim, val_dim in util.safe_zip(self._aval.shape, val_shape): # type: ignore[attr-defined]
|
||||
if val_dim is None:
|
||||
@ -703,7 +703,7 @@ class TensorFlowTrace(core.Trace):
|
||||
|
||||
# Check that the impl rule returned a value of expected shape and dtype
|
||||
# TODO: adapt this to match polymorphic shapes
|
||||
if not core.skip_checks:
|
||||
if config.jax_enable_checks:
|
||||
if primitive.multiple_results:
|
||||
for o, expected_aval in zip(out, out_aval): # type: ignore
|
||||
assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), (
|
||||
@ -1530,7 +1530,7 @@ def _common_reduce_window(operand, init_val, reducer, window_dimensions,
|
||||
reducer_fn = tf.function(reducer, autograph=False).get_concrete_function(o_spec, o_spec)
|
||||
|
||||
if not isinstance(init_val, tf.Tensor):
|
||||
assert core.skip_checks or _is_tfval(init_val), f"Non TfVal: {init_val}"
|
||||
assert not config.jax_enable_checks or _is_tfval(init_val), f"Non TfVal: {init_val}"
|
||||
init_val = tf.constant(init_val, operand.dtype)
|
||||
out = tfxla.reduce_window(operand, init_val,
|
||||
reducer_fn, window_dimensions,
|
||||
|
@ -135,7 +135,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
dtype=dtype)
|
||||
for dtype in [np.int64, np.float64]))
|
||||
def test_converts_64bit(self, dtype=np.int64, with_function=False):
|
||||
if not config.FLAGS.jax_enable_x64:
|
||||
if not config.jax_enable_x64:
|
||||
self.skipTest("requires x64 mode")
|
||||
big_const = np.full((5,), 2 ** 33, dtype=dtype)
|
||||
self.ConvertAndCompare(jnp.sin, big_const)
|
||||
|
@ -17,31 +17,14 @@
|
||||
**Experimental: please give feedback, and expect changes.**
|
||||
"""
|
||||
|
||||
# This file provides
|
||||
# 1. a jax.experimental API endpoint;
|
||||
# 2. the `disable_x64` wrapper.
|
||||
# TODO(jakevdp): remove this file, and consider removing `disable_x64` for
|
||||
# uniformity
|
||||
|
||||
from contextlib import contextmanager
|
||||
from jax import config
|
||||
|
||||
@contextmanager
|
||||
def enable_x64():
|
||||
"""Experimental context manager to temporarily enable X64 mode.
|
||||
|
||||
Usage::
|
||||
|
||||
>>> import jax.numpy as jnp
|
||||
>>> with enable_x64():
|
||||
... print(jnp.arange(10.0).dtype)
|
||||
...
|
||||
float64
|
||||
|
||||
See Also
|
||||
--------
|
||||
jax.experimental.disable_x64 : temporarily disable X64 mode.
|
||||
"""
|
||||
_x64_state = config.x64_enabled
|
||||
config._set_x64_enabled(True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
config._set_x64_enabled(_x64_state)
|
||||
from jax.config import enable_x64
|
||||
|
||||
@contextmanager
|
||||
def disable_x64():
|
||||
@ -59,9 +42,5 @@ def disable_x64():
|
||||
--------
|
||||
jax.experimental.enable_x64 : temporarily enable X64 mode.
|
||||
"""
|
||||
_x64_state = config.x64_enabled
|
||||
config._set_x64_enabled(False)
|
||||
try:
|
||||
with enable_x64(False):
|
||||
yield
|
||||
finally:
|
||||
config._set_x64_enabled(_x64_state)
|
||||
|
@ -174,7 +174,7 @@ def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
|
||||
# assert v.aval == ct.aval, (prim, v.aval, ct.aval)
|
||||
return
|
||||
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
|
||||
if not core.skip_checks:
|
||||
if config.jax_enable_checks:
|
||||
ct_aval = core.get_aval(ct_env[v])
|
||||
joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type()
|
||||
assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval)
|
||||
@ -389,7 +389,7 @@ class JVPTracer(Tracer):
|
||||
__slots__ = ['primal', 'tangent']
|
||||
|
||||
def __init__(self, trace, primal, tangent):
|
||||
if not core.skip_checks:
|
||||
if config.jax_enable_checks:
|
||||
_primal_tangent_shapes_match(primal, tangent)
|
||||
self._trace = trace
|
||||
self.primal = primal
|
||||
|
@ -108,7 +108,7 @@ class BatchTracer(Tracer):
|
||||
__slots__ = ['val', 'batch_dim']
|
||||
|
||||
def __init__(self, trace, val, batch_dim: Optional[int]):
|
||||
assert core.skip_checks or type(batch_dim) in (int, NotMapped) # type: ignore
|
||||
assert not config.jax_enable_checks or type(batch_dim) in (int, NotMapped) # type: ignore
|
||||
self._trace = trace
|
||||
self.val = val
|
||||
self.batch_dim = batch_dim
|
||||
|
@ -49,7 +49,7 @@ class PartialVal(tuple):
|
||||
"""
|
||||
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
|
||||
pv, const = xs
|
||||
if not core.skip_checks:
|
||||
if config.jax_enable_checks:
|
||||
# type checks
|
||||
assert isinstance(pv, (AbstractValue, type(None))), xs
|
||||
assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs
|
||||
@ -648,25 +648,25 @@ def tracers_to_jaxpr(
|
||||
const_vars, const_vals = unzip2(consts.items())
|
||||
# The env_vars are pre-pended to the invars
|
||||
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
|
||||
core.skip_checks or core.check_jaxpr(jaxpr)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
||||
return jaxpr, const_vals, env_vals
|
||||
|
||||
@cache()
|
||||
def convert_constvars_jaxpr(jaxpr: Jaxpr):
|
||||
"""Moves the constvars to the start of invars."""
|
||||
core.skip_checks or core.check_jaxpr(jaxpr)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
||||
lifted_jaxpr = Jaxpr(constvars=(),
|
||||
invars=jaxpr.constvars + jaxpr.invars,
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
||||
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
||||
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
|
||||
return lifted_jaxpr
|
||||
|
||||
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int):
|
||||
core.skip_checks or core.check_jaxpr(jaxpr)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
||||
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
|
||||
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
|
||||
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
||||
core.skip_checks or core.check_jaxpr(converted_jaxpr)
|
||||
config.jax_enable_checks and core.check_jaxpr(converted_jaxpr)
|
||||
return converted_jaxpr
|
||||
|
||||
|
||||
|
@ -486,7 +486,7 @@ class ShardedDeviceArray(xla.DeviceArray): # type: ignore
|
||||
self.indices = indices
|
||||
self._npy_value = None
|
||||
self._one_replica_buffer_indices = None
|
||||
if not core.skip_checks:
|
||||
if config.jax_enable_checks:
|
||||
assert type(aval) is ShapedArray
|
||||
|
||||
@property
|
||||
@ -792,7 +792,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
|
||||
f"{jaxpr_replicas} and nested_partitions={num_partitions}")
|
||||
|
||||
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority,
|
||||
f"Compiling {fun.__name__} ({id(fun)}) for {num_global_shards} "
|
||||
f"devices with args {avals}. (num_replicas={num_global_replicas}"
|
||||
@ -1387,7 +1387,7 @@ def mesh_callable(fun: lu.WrappedFun,
|
||||
global_axis_sizes = mesh.shape
|
||||
local_axis_sizes = local_mesh.shape
|
||||
|
||||
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority,
|
||||
f"Compiling {fun.__name__} ({id(fun)}) for {tuple(global_axis_sizes.items())} "
|
||||
f"mesh with args {local_in_untiled_avals}. Argument mapping: {in_axes}.")
|
||||
|
@ -144,7 +144,7 @@ def _sharded_callable(
|
||||
for out, parts, lparts
|
||||
in safe_zip(global_out_avals, out_parts, local_out_parts)]
|
||||
|
||||
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority,
|
||||
f"Compiling {fun.__name__} for {nparts} devices with "
|
||||
f"args {global_abstract_args}.")
|
||||
|
@ -23,7 +23,7 @@ from warnings import warn
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
from ..config import flags, bool_env, config
|
||||
from ..config import config
|
||||
from .. import core
|
||||
from .. import ad_util
|
||||
from .. import dtypes
|
||||
@ -58,17 +58,6 @@ XlaShape = Any # xla_client.Shape
|
||||
XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder
|
||||
XlaExecutable = Any # xla_extension.LocalExecutable
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_bool('jax_debug_nans',
|
||||
bool_env('JAX_DEBUG_NANS', False),
|
||||
'Add nan checks to every operation.')
|
||||
flags.DEFINE_bool('jax_debug_infs',
|
||||
bool_env('JAX_DEBUG_INFS', False),
|
||||
'Add inf checks to every operation.')
|
||||
flags.DEFINE_bool('jax_log_compiles',
|
||||
bool_env('JAX_LOG_COMPILES', False),
|
||||
'Print a message each time a `jit` computation is compiled.')
|
||||
|
||||
# This flag is set on exit; no logging should be attempted
|
||||
_on_exit = False
|
||||
|
||||
@ -244,7 +233,7 @@ def apply_primitive(prim, *args, **params):
|
||||
|
||||
def _partition_outputs(avals, outs):
|
||||
nouts = [aval._num_buffers for aval in avals]
|
||||
if not core.skip_checks:
|
||||
if config.jax_enable_checks:
|
||||
assert sum(nouts) == len(outs), f"Internal error: sum(nouts)={sum(nouts)} should equal len(outs)={len(outs)}."
|
||||
outs = iter(outs)
|
||||
return [[next(outs) for _ in range(nout)] for nout in nouts]
|
||||
@ -372,7 +361,7 @@ def _execute_replicated_primitive(prim, compiled, result_handler, *args):
|
||||
return result_handler(*out_bufs)
|
||||
|
||||
def needs_check_special():
|
||||
return FLAGS.jax_debug_infs or FLAGS.jax_debug_nans
|
||||
return config.jax_debug_infs or config.jax_debug_nans
|
||||
|
||||
def check_special(name, bufs):
|
||||
if needs_check_special():
|
||||
@ -382,9 +371,9 @@ def check_special(name, bufs):
|
||||
def _check_special(name, xla_shape, buf):
|
||||
assert not xla_shape.is_tuple()
|
||||
if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
|
||||
if FLAGS.jax_debug_nans and np.any(np.isnan(buf.to_py())):
|
||||
if config.jax_debug_nans and np.any(np.isnan(buf.to_py())):
|
||||
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
|
||||
if FLAGS.jax_debug_infs and np.any(np.isinf(buf.to_py())):
|
||||
if config.jax_debug_infs and np.any(np.isinf(buf.to_py())):
|
||||
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
||||
|
||||
### compiling jaxprs
|
||||
@ -590,13 +579,13 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_inv
|
||||
try:
|
||||
return compiled_fun(*args)
|
||||
except FloatingPointError:
|
||||
assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case
|
||||
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
|
||||
print("Invalid value encountered in the output of a jit function. "
|
||||
"Calling the de-optimized version.")
|
||||
# We want to run the wrapped function again (after _xla_callable already ran
|
||||
# it), but linear_util.WrappedFun instances are meant to be run only once.
|
||||
# In addition to re-executing the Python code, which is usually undesirable
|
||||
# but which FLAGS.jax_debug_nans is meant to opt into, we'll be re-executing
|
||||
# but which config.jax_debug_nans is meant to opt into, we'll be re-executing
|
||||
# any linear_util.py-style side effects, i.e. re-populating Stores created
|
||||
# by any transformation_with_aux's applied to fun. Since this is
|
||||
# intentional here, to avoid "Store occupied" errors we reset the stores to
|
||||
@ -688,7 +677,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
return partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers)
|
||||
|
||||
if not _on_exit:
|
||||
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority, "Compiling %s (%s) for args %s.",
|
||||
fun.__name__, id(fun), abstract_args)
|
||||
|
||||
@ -1096,7 +1085,7 @@ class _DeviceArray(DeviceArray): # type: ignore
|
||||
self._device = device
|
||||
|
||||
self._npy_value = None
|
||||
if not core.skip_checks:
|
||||
if config.jax_enable_checks:
|
||||
assert type(aval) is ShapedArray
|
||||
npy_value = self._value
|
||||
assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape
|
||||
|
@ -248,7 +248,7 @@ def cache(call: Callable):
|
||||
|
||||
def memoized_fun(fun: WrappedFun, *args):
|
||||
cache = fun_caches.setdefault(fun.f, {})
|
||||
if core.debug_state.check_leaks:
|
||||
if config.jax_check_tracer_leaks:
|
||||
key = (_copy_main_traces(fun.transforms), fun.params, args, config.x64_enabled)
|
||||
else:
|
||||
key = (fun.transforms, fun.params, args, config.x64_enabled)
|
||||
|
@ -445,7 +445,7 @@ def skip_on_flag(flag_name, skip_value):
|
||||
def skip(test_method): # pylint: disable=missing-docstring
|
||||
@functools.wraps(test_method)
|
||||
def test_method_wrapper(self, *args, **kwargs):
|
||||
flag_value = getattr(FLAGS, flag_name)
|
||||
flag_value = config._read(flag_name)
|
||||
if flag_value == skip_value:
|
||||
test_name = getattr(test_method, '__name__', '[unknown test]')
|
||||
raise unittest.SkipTest(
|
||||
@ -819,7 +819,7 @@ class JaxTestCase(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(JaxTestCase, self).setUp()
|
||||
core.skip_checks = False
|
||||
config.update('jax_enable_checks', True)
|
||||
# We use the adler32 hash for two reasons.
|
||||
# a) it is deterministic run to run, unlike hash() which is randomized.
|
||||
# b) it returns values in int32 range, which RandomState requires.
|
||||
|
3
mypy.ini
3
mypy.ini
@ -1,5 +1,6 @@
|
||||
[mypy]
|
||||
show_error_codes=True
|
||||
show_error_codes = True
|
||||
disable_error_code = attr-defined
|
||||
|
||||
[mypy-absl.*]
|
||||
ignore_missing_imports = True
|
||||
|
@ -2217,7 +2217,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with core.checking_leaks():
|
||||
with jax.checking_leaks():
|
||||
lst = []
|
||||
|
||||
@jit
|
||||
@ -2232,7 +2232,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with core.checking_leaks():
|
||||
with jax.checking_leaks():
|
||||
lst = []
|
||||
|
||||
@api.pmap
|
||||
@ -2247,7 +2247,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with core.checking_leaks():
|
||||
with jax.checking_leaks():
|
||||
lst = []
|
||||
|
||||
def f(x):
|
||||
@ -2261,7 +2261,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with core.checking_leaks():
|
||||
with jax.checking_leaks():
|
||||
@jit
|
||||
def f(x):
|
||||
return x
|
||||
@ -2279,7 +2279,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with core.checking_leaks():
|
||||
with jax.checking_leaks():
|
||||
lst = []
|
||||
|
||||
to_scan = lambda c, x: (lst.append(c) or jnp.sin(c), None)
|
||||
@ -2291,7 +2291,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with core.checking_leaks():
|
||||
with jax.checking_leaks():
|
||||
to_scan = lambda c, x: (jnp.sin(c), None)
|
||||
lax.scan(to_scan, 1., np.arange(3.)) # doesn't crash
|
||||
|
||||
@ -2299,7 +2299,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with core.checking_leaks():
|
||||
with jax.checking_leaks():
|
||||
to_scan = lambda c, x: (c, None)
|
||||
|
||||
def f(x):
|
||||
@ -2310,7 +2310,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with core.checking_leaks():
|
||||
with jax.checking_leaks():
|
||||
to_scan = lambda c, _: (1., None)
|
||||
|
||||
@api.vmap
|
||||
@ -2322,7 +2322,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with core.checking_leaks():
|
||||
with jax.checking_leaks():
|
||||
to_scan = lambda c, _: (c, None)
|
||||
|
||||
@api.vmap
|
||||
@ -2334,7 +2334,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with core.checking_leaks():
|
||||
with jax.checking_leaks():
|
||||
@jit
|
||||
def f(x):
|
||||
lst = []
|
||||
@ -4957,29 +4957,29 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
|
||||
api.grad(lambda x, y: f(x, y)[0])(1., 2.) # doesn't crash
|
||||
|
||||
def test_custom_transforms_vjp_nones(self):
|
||||
core.skip_checks = True # Fails with checks
|
||||
# issue raised by jsnoek@ and jumper@
|
||||
@jax.custom_transforms
|
||||
def solve(a, b):
|
||||
return jnp.dot(jnp.linalg.inv(a), b)
|
||||
# print(solve(a, b))
|
||||
with jax.enable_checks(False): # fails with checks
|
||||
# issue raised by jsnoek@ and jumper@
|
||||
@jax.custom_transforms
|
||||
def solve(a, b):
|
||||
return jnp.dot(jnp.linalg.inv(a), b)
|
||||
# print(solve(a, b))
|
||||
|
||||
def solve_vjp(a, b):
|
||||
x = solve(a, b)
|
||||
def vjp(x_tangent):
|
||||
dx = jnp.dot(solve(a, x_tangent), x.T)
|
||||
out = (dx, b * 0.)
|
||||
return out
|
||||
return x, vjp
|
||||
jax.defvjp_all(solve, solve_vjp)
|
||||
gf = grad(lambda a,b: jnp.sum(solve(a, b)))
|
||||
def solve_vjp(a, b):
|
||||
x = solve(a, b)
|
||||
def vjp(x_tangent):
|
||||
dx = jnp.dot(solve(a, x_tangent), x.T)
|
||||
out = (dx, b * 0.)
|
||||
return out
|
||||
return x, vjp
|
||||
jax.defvjp_all(solve, solve_vjp)
|
||||
gf = grad(lambda a,b: jnp.sum(solve(a, b)))
|
||||
|
||||
n = 3
|
||||
a_in = jnp.linspace(0, 1, n)[:, None]
|
||||
a = jnp.dot(a_in, a_in.T) + jnp.eye(n) * 0.1
|
||||
real_x = np.random.RandomState(0).randn(n)
|
||||
b = jnp.dot(a + jnp.eye(a.shape[0]), real_x)
|
||||
print(gf(a, b)) # doesn't crash
|
||||
n = 3
|
||||
a_in = jnp.linspace(0, 1, n)[:, None]
|
||||
a = jnp.dot(a_in, a_in.T) + jnp.eye(n) * 0.1
|
||||
real_x = np.random.RandomState(0).randn(n)
|
||||
b = jnp.dot(a + jnp.eye(a.shape[0]), real_x)
|
||||
print(gf(a, b)) # doesn't crash
|
||||
|
||||
|
||||
class BufferDonationTest(jtu.BufferDonationTestCase):
|
||||
|
@ -30,7 +30,7 @@ config.parse_flags_with_absl()
|
||||
class DebugNaNsTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = config.read("jax_debug_nans")
|
||||
self.cfg = config._read("jax_debug_nans")
|
||||
config.update("jax_debug_nans", True)
|
||||
|
||||
def tearDown(self):
|
||||
@ -144,7 +144,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
class DebugInfsTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = config.read("jax_debug_infs")
|
||||
self.cfg = config._read("jax_debug_infs")
|
||||
config.update("jax_debug_infs", True)
|
||||
|
||||
def tearDown(self):
|
||||
|
@ -148,7 +148,7 @@ class DJaxADTests(jtu.JaxTestCase):
|
||||
y = sin(x)
|
||||
return reduce_sum(y, axes=(0,))
|
||||
x = bbarray((5,), jnp.arange(2.))
|
||||
with jax.core.skipping_checks(): # TODO implement dxla_call abs eval rule
|
||||
with jax.enable_checks(False): # TODO implement dxla_call abs eval rule
|
||||
z, f_lin = jax.linearize(f, x)
|
||||
z_dot = f_lin(ones_like(x))
|
||||
|
||||
|
@ -25,7 +25,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import api
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
@ -992,7 +991,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
expected = np.array(0.0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
with core.skipping_checks():
|
||||
with jax.enable_checks(False):
|
||||
with self.assertRaises(TypeError):
|
||||
lax.stop_gradient(lambda x: x)
|
||||
|
||||
|
@ -1709,7 +1709,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(carry_out[1], carry_init, check_dtypes=False)
|
||||
self.assertAllClose(carry_out[0], jnp.array([2., 2., 2.]), check_dtypes = False)
|
||||
|
||||
# TODO(mattjj, dougalm): fix this test when skip_checks is False
|
||||
def testIssue757(self):
|
||||
# code from https://github.com/google/jax/issues/757
|
||||
def fn(a):
|
||||
|
@ -208,7 +208,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
))
|
||||
def test_bicgstab_against_scipy(
|
||||
self, shape, dtype, preconditioner):
|
||||
if not config.FLAGS.jax_enable_x64:
|
||||
if not config.jax_enable_x64:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
@ -2190,7 +2190,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
# api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.)
|
||||
|
||||
def test_primitive_jaxtype_error(self):
|
||||
with core.skipping_checks():
|
||||
with jax.enable_checks(False):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Argument .* of type .* is not a valid JAX type"):
|
||||
lax.add(1, 'hi')
|
||||
|
@ -117,12 +117,12 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
|
||||
def testEluMemory(self):
|
||||
# see https://github.com/google/jax/pull/1640
|
||||
with core.skipping_checks(): # With checks we materialize the array
|
||||
with jax.enable_checks(False): # With checks we materialize the array
|
||||
jax.make_jaxpr(lambda: nn.elu(jnp.ones((10 ** 12,)))) # don't oom
|
||||
|
||||
def testHardTanhMemory(self):
|
||||
# see https://github.com/google/jax/pull/1640
|
||||
with core.skipping_checks(): # With checks we materialize the array
|
||||
with jax.enable_checks(False): # With checks we materialize the array
|
||||
jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
|
||||
|
||||
def testOneHot(self):
|
||||
|
@ -914,7 +914,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
raise SkipTest("after deleting lazy constants, requires omnistaging")
|
||||
def f(x):
|
||||
return random.normal(random.PRNGKey(x), (int(1e12),))
|
||||
with core.skipping_checks(): # check_jaxpr will materialize array
|
||||
with jax.enable_checks(False): # check_jaxpr will materialize array
|
||||
api.eval_shape(f, 0) # doesn't error
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
@ -73,7 +73,7 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
func = _maybe_jit(jit, lambda: jnp.arange(10.0))
|
||||
func()
|
||||
|
||||
expected_dtype = "float64" if config.read("jax_enable_x64") else "float32"
|
||||
expected_dtype = "float64" if config._read("jax_enable_x64") else "float32"
|
||||
self.assertEqual(func().dtype, expected_dtype)
|
||||
|
||||
with enable_x64():
|
||||
|
Loading…
x
Reference in New Issue
Block a user