make c++ jit sensitive to global omnistaging state

fixes #5206

Co-authored-by: Jean-Baptiste Lespiau <jblespiau@google.com>
This commit is contained in:
Matthew Johnson 2021-01-05 08:14:16 -08:00
parent daaeb2bf6f
commit 7a9f8f96ea
2 changed files with 26 additions and 5 deletions

View File

@ -259,7 +259,7 @@ def _cpp_jit(
raise ValueError("can't specify both a device and a backend for jit, "
f"got device={device} and backend={backend}.")
def cache_miss(*args, **kwargs):
def cache_miss(_, *args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
@ -362,17 +362,20 @@ def _cpp_jit(
"""
return config.read("jax_disable_jit")
static_argnums_ = (0,) + tuple(i + 1 for i in static_argnums)
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
get_jax_enable_x64, get_jax_disable_jit_flag,
static_argnums)
static_argnums_)
# TODO(mattjj): make cpp callable follow descriptor protocol for bound methods
@wraps(fun)
@api_boundary
def f_jitted(*args, **kwargs):
context = getattr(core.thread_local_state.trace_state.trace_stack,
'dynamic', None)
# TODO(jblespiau): Move this to C++.
if FLAGS.jax_debug_nans and not _jit_is_disabled():
device_arrays = cpp_jitted_f(*args, **kwargs)
device_arrays = cpp_jitted_f(context, *args, **kwargs)
try:
xla.check_nans(xla.xla_call_p, [
da.device_buffer
@ -384,9 +387,11 @@ def _cpp_jit(
assert FLAGS.jax_debug_nans # compiled_fun can only raise in this case
print("Invalid nan value encountered in the output of a C++-jit "
"function. Calling the de-optimized version.")
return cache_miss(*args, **kwargs)[0] # probably won't return
else:
return cache_miss(context, *args, **kwargs)[0] # probably won't return
elif _jit_is_disabled():
return cpp_jitted_f(*args, **kwargs)
else:
return cpp_jitted_f(context, *args, **kwargs)
f_jitted._cpp_jitted_f = cpp_jitted_f
return f_jitted

View File

@ -506,6 +506,22 @@ class CPPJitTest(jtu.JaxTestCase):
f"explicit inner-jit backend specification cpu."):
f(1.)
def test_omnistaging(self):
# See https://github.com/google/jax/issues/5206
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
key_list = [None]
def init():
key, subkey = jax.random.split(key_list[0])
key_list[0] = key
return jax.random.normal(subkey, ())
key_list[0] = np.array([2384771982, 3928867769], dtype=np.uint32)
init()
self.jit(init)()
self.assertIsInstance(key_list[0], core.Tracer)
class PythonJitTest(CPPJitTest):