mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
add experimental jax.no_tracing context manager
This commit is contained in:
parent
c6c701e6a7
commit
670a648b7b
@ -56,6 +56,7 @@ from jax._src.config import (
|
||||
debug_nans as debug_nans,
|
||||
debug_infs as debug_infs,
|
||||
log_compiles as log_compiles,
|
||||
no_tracing as no_tracing,
|
||||
explain_cache_misses as explain_cache_misses,
|
||||
default_device as default_device,
|
||||
default_matmul_precision as default_matmul_precision,
|
||||
|
@ -1501,6 +1501,11 @@ eager_pmap = bool_state(
|
||||
upgrade=True,
|
||||
help='Enable eager-mode pmap when jax_disable_jit is activated.')
|
||||
|
||||
no_tracing = bool_state(
|
||||
name='jax_no_tracing',
|
||||
default=False,
|
||||
help='Disallow tracing for JIT compilation.')
|
||||
|
||||
disable_vmap_shmap_error = bool_state(
|
||||
name='jax_disable_vmap_shmap_error',
|
||||
default=False,
|
||||
|
@ -353,6 +353,9 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
|
||||
@api_boundary
|
||||
def cache_miss(*args, **kwargs):
|
||||
if config.no_tracing.value:
|
||||
raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
|
||||
"`jit`, but 'no_tracing' is set")
|
||||
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
|
||||
fun, jit_info, *args, **kwargs)
|
||||
executable = _read_most_recent_pjit_call_executable(jaxpr)
|
||||
|
@ -1463,6 +1463,20 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(f(np.nan), np.nan)
|
||||
self.assertAllClose(jit(f)(np.nan), np.nan)
|
||||
|
||||
def test_no_tracing(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
x = jnp.arange(3)
|
||||
y = jnp.arange(4)
|
||||
|
||||
_ = f(x) # no crash
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'no_tracing'):
|
||||
with jax.no_tracing():
|
||||
_ = f(y) # crash!
|
||||
|
||||
|
||||
class APITest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user