27 Commits

Author SHA1 Message Date
Jieying Luo
cb3b7ec93a [PJRT PLUGIN] Add num_processes to distributed.global_state.
The number of processes is needed for multi-process GPU when plugin is used.

PiperOrigin-RevId: 535696950
2023-05-26 13:14:40 -07:00
Colin Gaffney
6ace6667dd Set coordinator address to allow it to later be used to initialize OCDBT coordinator server. Allow user to pass ts.Context when serializing or deserializing.
PiperOrigin-RevId: 520064049
2023-03-28 10:25:07 -07:00
Peter Hawkins
88c2898e36 Use pytype_strict_library() in Bazel build rules.
PiperOrigin-RevId: 519757928
2023-03-27 10:16:08 -07:00
Nicolas Castet
b86030d86f Add Open MPI automatic distributed initialization 2023-01-11 17:08:09 -06:00
Leopold Cambier
d1edad6a68 Typos: suited -> suitable, node -> host 2023-01-10 11:01:54 -08:00
Leopold Cambier
7e395c9bbe DOC: add note about localhost & friends in jax.distributed.initialize 2023-01-10 09:17:31 -08:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
jax authors
bc08381da3 Merge pull request #12152 from nvcastet:add_slurm_orchestrator_support
PiperOrigin-RevId: 476179963
2022-09-22 13:18:25 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Nicolas Castet
412a5379c1 Add generic interface for auto initialization of distributed JAX service
* Also add slurm cluster support
2022-09-22 14:15:38 -05:00
Peter Hawkins
6c59d72c75 Bump the minimum jaxlib version to 0.3.15. 2022-09-08 16:43:46 -04:00
jax authors
0df81ffc16 Merge pull request #11946 from nvcastet:fix_distributed_timeout
PiperOrigin-RevId: 469721289
2022-08-24 07:28:44 -07:00
jax authors
a73a6a8de0 Shut down preemption sync manager if enabled.
PiperOrigin-RevId: 469497215
2022-08-23 10:38:10 -07:00
Nicolas Castet
475d4d1497 Increase jax.distributed timeout to 5 min 2022-08-16 14:12:20 -05:00
Peter Hawkins
a2c21958a5 Document multiprocess GPU support.
Fixes #2731
2022-08-09 11:31:05 -04:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Peter Hawkins
bdbdecd458 Refactor distributed GPU device initialization.
Avoid reregistering backend factories; instead simply have the usual
factory function support distributed GPU.
2022-07-07 00:45:19 +00:00
Haoyu Zhang
3fc24ceb35 Save an on-demand checkpoint when any worker receives a preemption signal.
PiperOrigin-RevId: 458525108
2022-07-01 12:45:30 -07:00
Peter Hawkins
3e699ddec0 Unbreak jax.distributed initialization.
A recent change broke jax.distributed initialization, which was unsurprising because those APIs were not tested. In particular, we need to only initialize the service from the first process.

Fix it and add some tests that use the distributed service from multiple threads within a unit test. Move the state of jax.distributed into an object so it can be instantiated multiple times from a test case in parallel rather than being process-global.

[XLA:Python] Add gil release guards around distributed system init/shutdown. This allows testing using multiple threads.

PiperOrigin-RevId: 453480351
2022-06-07 11:10:17 -07:00
jax authors
24ad82685c Fix reference to jax_coordination_service flag.
PiperOrigin-RevId: 453224722
2022-06-06 10:08:31 -07:00
jax authors
6c89e90808 Allow JAX OSS users to switch between experimental coordination service and default PjRT distributed runtime service via a flag.
PiperOrigin-RevId: 452625054
2022-06-02 14:41:40 -07:00
Peter Hawkins
0e072bcf49 Fix distributed system initialization.
There was a version compatibility problem that meant that the distributed system did not correctly initialize when mixing the current released versions of jax and jaxlib. We don't have tests for multiple PGUs in our CI, so this was not caught.
2022-05-25 07:31:19 -07:00
Yash Katariya
0574eb2141 Use JAX's distributed system for fully asynchronous checkpointing.
PiperOrigin-RevId: 449380175
2022-05-17 20:17:57 -07:00
Yash Katariya
548a6bf58b * Make all arguments to distributed.initialize equal to None.
* On Cloud TPUs, figure out the coordinator address automatically.

PiperOrigin-RevId: 449261786
2022-05-17 10:53:54 -07:00
Peter Hawkins
931bf3674b [JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms.
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.

[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.

PiperOrigin-RevId: 446737518
2022-05-05 09:33:06 -07:00
Peter Hawkins
3e5ecfe363 Add jax.distributed and jax.dlpack to the docs.
Reorder the doc modules into something closer to alphabetical order.

Add missing functions from jax.scipy.linalg and jax.scipy.signal to the docs.
2022-02-17 16:10:07 -05:00
Qiao Zhang
0be30fbf96 Add jax.distributed.initialize for multi-host GPU. 2021-10-26 14:37:54 -07:00