A recent change broke jax.distributed initialization, which was unsurprising because those APIs were not tested. In particular, we need to only initialize the service from the first process.
Fix it and add some tests that use the distributed service from multiple threads within a unit test. Move the state of jax.distributed into an object so it can be instantiated multiple times from a test case in parallel rather than being process-global.
[XLA:Python] Add gil release guards around distributed system init/shutdown. This allows testing using multiple threads.
PiperOrigin-RevId: 453480351