mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add a warning if the user calls os.fork().
Fixes https://github.com/google/jax/issues/18852
This commit is contained in:
parent
1559d6495e
commit
ec89e5e4c5
@ -106,6 +106,16 @@ _CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool(
|
||||
)
|
||||
|
||||
|
||||
# Warn the user if they call fork(), because it's not going to go well for them.
|
||||
def _at_fork():
|
||||
warnings.warn(
|
||||
"os.fork() was called. os.fork() is incompatible with multithreaded code, "
|
||||
"and JAX is multithreaded, so this will likely lead to a deadlock.",
|
||||
RuntimeWarning, stacklevel=2)
|
||||
|
||||
os.register_at_fork(before=_at_fork)
|
||||
|
||||
|
||||
# Backends
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user