mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Use no_tracing config in _create_pjit_jaxpr
to so that AOT path can also error if we re-trace.
PiperOrigin-RevId: 683392069
This commit is contained in:
parent
76d5938062
commit
a9e9f97f00
@ -1273,6 +1273,9 @@ def _create_pjit_jaxpr(
|
||||
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
|
||||
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
del ignored_inline # just for explain_cache_miss
|
||||
if config.no_tracing.value:
|
||||
raise RuntimeError(f"re-tracing function {fun.f} for `jit`, but "
|
||||
"'no_tracing' is set")
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
|
Loading…
x
Reference in New Issue
Block a user