mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 11:46:08 +00:00
make core.trace_state resetting be thread-local
This commit is contained in:
parent
b78b7a0309
commit
6d4987cc04
@ -544,11 +544,10 @@ trace_state = TraceState()
|
||||
|
||||
def reset_trace_state() -> bool:
|
||||
"Reset the global trace state and return True if it was already clean."
|
||||
global trace_state
|
||||
if (trace_state.substack != [Sublevel(0)] or
|
||||
trace_state.trace_stack.downward or
|
||||
trace_state.trace_stack.upward):
|
||||
trace_state = TraceState()
|
||||
trace_state.__init__()
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
Loading…
x
Reference in New Issue
Block a user