Add a warning if the user calls os.fork().

Fixes https://github.com/google/jax/issues/18852
This commit is contained in:
Peter Hawkins 2023-12-14 18:02:13 -05:00
parent 1559d6495e
commit ec89e5e4c5

View File

@ -106,6 +106,16 @@ _CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool(
)
# Warn the user if they call fork(), because it's not going to go well for them.
def _at_fork():
warnings.warn(
"os.fork() was called. os.fork() is incompatible with multithreaded code, "
"and JAX is multithreaded, so this will likely lead to a deadlock.",
RuntimeWarning, stacklevel=2)
os.register_at_fork(before=_at_fork)
# Backends