Use no_tracing config in _create_pjit_jaxpr to so that AOT path can also error if we re-trace.

PiperOrigin-RevId: 683392069
This commit is contained in:
Yash Katariya 2024-10-07 17:48:35 -07:00 committed by jax authors
parent 76d5938062
commit a9e9f97f00

View File

@ -1273,6 +1273,9 @@ def _create_pjit_jaxpr(
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
del ignored_inline # just for explain_cache_miss
if config.no_tracing.value:
raise RuntimeError(f"re-tracing function {fun.f} for `jit`, but "
"'no_tracing' is set")
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):