add experimental jax.no_tracing context manager

This commit is contained in:
Matthew Johnson 2024-08-23 21:21:55 +00:00
parent c6c701e6a7
commit 670a648b7b
4 changed files with 23 additions and 0 deletions

View File

@ -56,6 +56,7 @@ from jax._src.config import (
debug_nans as debug_nans,
debug_infs as debug_infs,
log_compiles as log_compiles,
no_tracing as no_tracing,
explain_cache_misses as explain_cache_misses,
default_device as default_device,
default_matmul_precision as default_matmul_precision,

View File

@ -1501,6 +1501,11 @@ eager_pmap = bool_state(
upgrade=True,
help='Enable eager-mode pmap when jax_disable_jit is activated.')
no_tracing = bool_state(
name='jax_no_tracing',
default=False,
help='Disallow tracing for JIT compilation.')
disable_vmap_shmap_error = bool_state(
name='jax_disable_vmap_shmap_error',
default=False,

View File

@ -353,6 +353,9 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
@api_boundary
def cache_miss(*args, **kwargs):
if config.no_tracing.value:
raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
"`jit`, but 'no_tracing' is set")
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
fun, jit_info, *args, **kwargs)
executable = _read_most_recent_pjit_call_executable(jaxpr)

View File

@ -1463,6 +1463,20 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertAllClose(f(np.nan), np.nan)
self.assertAllClose(jit(f)(np.nan), np.nan)
def test_no_tracing(self):
@jax.jit
def f(x):
return x
x = jnp.arange(3)
y = jnp.arange(4)
_ = f(x) # no crash
with self.assertRaisesRegex(RuntimeError, 'no_tracing'):
with jax.no_tracing():
_ = f(y) # crash!
class APITest(jtu.JaxTestCase):