mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
make c++ jit sensitive to global omnistaging state
fixes #5206 Co-authored-by: Jean-Baptiste Lespiau <jblespiau@google.com>
This commit is contained in:
parent
daaeb2bf6f
commit
7a9f8f96ea
15
jax/api.py
15
jax/api.py
@ -259,7 +259,7 @@ def _cpp_jit(
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
f"got device={device} and backend={backend}.")
|
||||
|
||||
def cache_miss(*args, **kwargs):
|
||||
def cache_miss(_, *args, **kwargs):
|
||||
### This first part is basically the same code as in _python_jit.
|
||||
# An alternative would be for cache_miss to accept from C++ the arguments
|
||||
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
|
||||
@ -362,17 +362,20 @@ def _cpp_jit(
|
||||
"""
|
||||
return config.read("jax_disable_jit")
|
||||
|
||||
static_argnums_ = (0,) + tuple(i + 1 for i in static_argnums)
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
|
||||
get_jax_enable_x64, get_jax_disable_jit_flag,
|
||||
static_argnums)
|
||||
static_argnums_)
|
||||
|
||||
# TODO(mattjj): make cpp callable follow descriptor protocol for bound methods
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def f_jitted(*args, **kwargs):
|
||||
context = getattr(core.thread_local_state.trace_state.trace_stack,
|
||||
'dynamic', None)
|
||||
# TODO(jblespiau): Move this to C++.
|
||||
if FLAGS.jax_debug_nans and not _jit_is_disabled():
|
||||
device_arrays = cpp_jitted_f(*args, **kwargs)
|
||||
device_arrays = cpp_jitted_f(context, *args, **kwargs)
|
||||
try:
|
||||
xla.check_nans(xla.xla_call_p, [
|
||||
da.device_buffer
|
||||
@ -384,9 +387,11 @@ def _cpp_jit(
|
||||
assert FLAGS.jax_debug_nans # compiled_fun can only raise in this case
|
||||
print("Invalid nan value encountered in the output of a C++-jit "
|
||||
"function. Calling the de-optimized version.")
|
||||
return cache_miss(*args, **kwargs)[0] # probably won't return
|
||||
else:
|
||||
return cache_miss(context, *args, **kwargs)[0] # probably won't return
|
||||
elif _jit_is_disabled():
|
||||
return cpp_jitted_f(*args, **kwargs)
|
||||
else:
|
||||
return cpp_jitted_f(context, *args, **kwargs)
|
||||
f_jitted._cpp_jitted_f = cpp_jitted_f
|
||||
|
||||
return f_jitted
|
||||
|
@ -506,6 +506,22 @@ class CPPJitTest(jtu.JaxTestCase):
|
||||
f"explicit inner-jit backend specification cpu."):
|
||||
f(1.)
|
||||
|
||||
def test_omnistaging(self):
|
||||
# See https://github.com/google/jax/issues/5206
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
key_list = [None]
|
||||
|
||||
def init():
|
||||
key, subkey = jax.random.split(key_list[0])
|
||||
key_list[0] = key
|
||||
return jax.random.normal(subkey, ())
|
||||
|
||||
key_list[0] = np.array([2384771982, 3928867769], dtype=np.uint32)
|
||||
init()
|
||||
self.jit(init)()
|
||||
self.assertIsInstance(key_list[0], core.Tracer)
|
||||
|
||||
class PythonJitTest(CPPJitTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user