6 Commits

Author SHA1 Message Date
Zac Cranko
5db78e7ae0 add distributed.is_initialized 2025-02-18 16:47:19 -08:00
Jake VanderPlas
5cc689976f Use PEP484-style exports in several submodules 2024-08-14 08:59:56 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
3e699ddec0 Unbreak jax.distributed initialization.
A recent change broke jax.distributed initialization, which was unsurprising because those APIs were not tested. In particular, we need to only initialize the service from the first process.

Fix it and add some tests that use the distributed service from multiple threads within a unit test. Move the state of jax.distributed into an object so it can be instantiated multiple times from a test case in parallel rather than being process-global.

[XLA:Python] Add gil release guards around distributed system init/shutdown. This allows testing using multiple threads.

PiperOrigin-RevId: 453480351
2022-06-07 11:10:17 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Qiao Zhang
0be30fbf96 Add jax.distributed.initialize for multi-host GPU. 2021-10-26 14:37:54 -07:00