Georg Stefan Schmid
7bdb2bf998
[jax.distributed] Enable grpc channel compression
2024-11-05 16:47:29 +00:00
Jaroslav Sevcik
fafebd254a
Pass the heartbeat timeouts in parameter
2024-10-10 09:58:58 -07:00
Jaroslav Sevcik
53e781a1ee
Options to control heartbeat monitor timeouts.
2024-10-10 04:57:03 -07:00
jax authors
4957ab9a5e
Clean up JAX backend for all backends to avoid dangling PyClient references.
...
PiperOrigin-RevId: 673102539
2024-09-10 14:19:00 -07:00
Brian Wieder
ee31e95ecd
Register shutdown code at import to hopefully get registered before any other atexit callbacks.
...
`atexit` callbacks are called in a LIFO order, meaning that since Jax currently registers its callback at runtime rather than import time, it gets called before any `atexit` callbacks registered at import time.
PiperOrigin-RevId: 662164776
2024-08-12 11:29:08 -07:00
Georg Stefan Schmid
f9bc4c643b
[jax.distributed] Allow setting local device ids via env var
2024-08-09 10:23:17 +00:00
Kyle Gerard Felker
ffc9292365
Squashed commit of the following:
...
commit 79b8cbf0cb47e32743e0970bc1abeb6a673866a8
Author: Corey Adams <corey.adams@anl.gov>
Date: Mon Jul 1 14:14:15 2024 -0500
Fix mypy issues; change variable name to more universally known name
commit 10edc866f568908e536e5c7bd6b59b4e5351781e
Author: Corey Adams <corey.adams@anl.gov>
Date: Thu Jun 27 13:25:32 2024 -0500
Change copyright year to the year this was authored
commit f7086cb44cc98d58a96ae804dcd1787bc31470f7
Author: Corey Adams <corey.adams@anl.gov>
Date: Thu Jun 27 13:15:32 2024 -0500
Update build file to include mpi4py cluster.
commit 6235eb311b9fca2bd81fe1c49456d164b7332753
Author: Corey adams <coreyjadams@gmail.com>
Date: Thu Jun 27 12:11:48 2024 -0500
Update distributed.py
Clean up documentation slightly.
commit ef3a2e220945b2158cf20edeb1e04bbbf8f290ff
Author: Corey adams <coreyjadams@gmail.com>
Date: Thu Jun 27 12:09:37 2024 -0500
Update mpi4py_cluster.py
Further clean up unneeded comments.
commit 6cc07a9a52fc202ecc65c04c513096391c27d02d
Author: Corey adams <coreyjadams@gmail.com>
Date: Thu Jun 27 12:08:38 2024 -0500
Update mpi4py_cluster.py
Remove unneeded commented code.
commit 6701bd1a9d645a0e08d95df1692f43946f0a5eb8
Merge: 5a91ac342 98b87540a
Author: Corey adams <coreyjadams@gmail.com>
Date: Thu Jun 27 12:07:25 2024 -0500
Merge branch 'google:main' into main
commit 5a91ac34248afa6f65af3cae66df7d0d122c1d26
Merge: 301bbc67f 6c51234f9
Author: Corey adams <coreyjadams@gmail.com>
Date: Tue May 28 22:14:08 2024 -0500
Merge branch 'google:main' into main
commit 301bbc67f938bc30c543cf300cec8a9c75f3eef8
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue May 28 11:34:51 2024 -0500
Add test to verify mpi4py based distributed initialization
commit 19e66949a36bb0edb4cd66b0f170f42b326928ec
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue May 28 11:14:40 2024 -0500
Unify variable naming and fix function argument ordering
commit 72fe093042519e48d9c26b7ede3b266c7a850be6
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue May 28 10:56:25 2024 -0500
Remove unmerged code
commit 3a96e738a3cdf9b6ed194cb764fa5640a37f6b95
Merge: e4fd97e19 ff3db9b3a
Author: Corey adams <coreyjadams@gmail.com>
Date: Tue May 28 10:51:41 2024 -0500
Merge branch 'google:main' into main
commit e4fd97e197211921fb6911054592041015af94ef
Merge: a69729900 72a81e58e
Author: Corey adams <coreyjadams@gmail.com>
Date: Mon May 13 16:01:35 2024 -0500
Merge branch 'google:main' into main
commit a6972990070d5d2f405d5ede9f82d35c7e6d157a
Merge: 85bcf42bd 1e48adc69
Author: Corey adams <coreyjadams@gmail.com>
Date: Mon May 13 14:21:32 2024 -0500
Merge branch 'google:main' into main
commit 85bcf42bdd36ad88a3d287c357cd12fde74c7fc0
Merge: af1a4f0a1 06cd05d1d
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue Apr 16 09:09:31 2024 -0500
Merge branch 'main' of https://github.com/google/jax
commit af1a4f0a12008780e9507d1bdd91e9d11ec35916
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue Apr 16 08:58:33 2024 -0500
update documentation and elaborate on spec_detect_method variable
commit 01f4709d5ecd4af675f4fb23d02d6a69b927adac
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue Apr 16 08:45:38 2024 -0500
Address feedback and comments on PR 20174; fix typo in documentation.
commit 4f22d86e7358c29ed588267a7d91fe55fb94f143
Merge: 900a0372f 71ec6e33c
Author: Corey adams <coreyjadams@gmail.com>
Date: Mon Mar 11 11:51:30 2024 -0500
Merge branch 'google:main' into main
commit 900a0372f6147d3c9ab53c95b6a4262e5cfe4457
Author: Corey Adams <corey.adams@anl.gov>
Date: Mon Mar 11 11:50:48 2024 -0500
Auto-detect of mpi4py-based configuration is now strictly opt-in.
commit 1992969da6164e456492fe0f9cd4287f6d8f03cf
Author: Corey Adams <corey.adams@anl.gov>
Date: Thu Mar 7 12:27:43 2024 -0600
Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available
2024-07-02 13:18:05 -05:00
Dan Foreman-Mackey
6d35b109fd
Rename "Example" to "Examples" in docstrings.
...
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
Dateng Lin
20379c636d
Fixed the logging due to a recent change.
...
PiperOrigin-RevId: 639392396
2024-06-01 14:28:55 -07:00
Sergei Lebedev
f5617d7323
Removed noop # type: ignore comments
...
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Frederic Bastien
f578d78f2b
Update doc with the other error that can be thrown.
2024-05-14 08:35:01 -07:00
rajasekharporeddy
aaddba0c20
Fix doc Typos
2024-04-22 10:32:51 +05:30
Olli Lupton
2dd1b3d6c8
jax.distributed.initialize: specify bind address.
...
By default, the coordinator process listens on all interfaces.
2024-04-03 17:13:27 +02:00
jax authors
c5869feb92
Add option to set coordinator lookup timeout for TPU clusters
...
PiperOrigin-RevId: 617383458
2024-03-19 20:55:45 -07:00
Sergei Lebedev
f936613b06
Upgrade remaining sources to Python 3.9
...
This PR is a follow up to #18881 .
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Peter Hawkins
30a0136813
Increase minimum jaxlib version to 0.4.19.
...
0.4.19 has xla_extension version 207 and mlir_api_version 54.
PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Peter Hawkins
b85ea68fba
Move test for backend initialization into jax.distributed.initialize() wrapper.
...
This allows us to skip the check for tests.
PiperOrigin-RevId: 580168674
2023-11-07 07:12:40 -08:00
Peter Hawkins
eeafff5891
Raise an exception if jax.distributed.initialize() is called after backends have been initialized.
...
Fixes https://github.com/google/jax/issues/18237
PiperOrigin-RevId: 579936065
2023-11-06 13:12:26 -08:00
Sergei Lebedev
cbcaac2756
MAINT Migrate remaining internal/test modules to use state objects
...
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008 .
2023-10-12 17:32:15 +01:00
Sergei Lebedev
65d3058944
Migrate a subset of internal modules to use state objects
...
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00
Jake VanderPlas
4a5bd9e046
Fix typos across the package
2023-09-22 14:54:31 -07:00
Peter Hawkins
b11e245772
Remove stale reference to coordination_service flag.
...
Fixes https://github.com/google/jax/issues/17288
PiperOrigin-RevId: 560103075
2023-08-25 08:47:53 -07:00
Peter Hawkins
c879f65aa6
[JAX] Remove the non-coordination service distributed service implementation from JAX.
...
The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code.
PiperOrigin-RevId: 554608165
2023-08-07 15:17:25 -07:00
Yash Katariya
2fa0bb0d32
Add initialization_timeout
as a parameter to allow users to increase/decreases the init_timeout parameter.
...
PiperOrigin-RevId: 554545535
2023-08-07 11:49:41 -07:00
Peter Hawkins
319ab98980
Apply pyupgrade --py39-plus.
...
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Jieying Luo
cb3b7ec93a
[PJRT PLUGIN] Add num_processes to distributed.global_state.
...
The number of processes is needed for multi-process GPU when plugin is used.
PiperOrigin-RevId: 535696950
2023-05-26 13:14:40 -07:00
Colin Gaffney
6ace6667dd
Set coordinator address to allow it to later be used to initialize OCDBT coordinator server. Allow user to pass ts.Context when serializing or deserializing.
...
PiperOrigin-RevId: 520064049
2023-03-28 10:25:07 -07:00
Peter Hawkins
88c2898e36
Use pytype_strict_library() in Bazel build rules.
...
PiperOrigin-RevId: 519757928
2023-03-27 10:16:08 -07:00
Nicolas Castet
b86030d86f
Add Open MPI automatic distributed initialization
2023-01-11 17:08:09 -06:00
Leopold Cambier
d1edad6a68
Typos: suited -> suitable, node -> host
2023-01-10 11:01:54 -08:00
Leopold Cambier
7e395c9bbe
DOC: add note about localhost & friends in jax.distributed.initialize
2023-01-10 09:17:31 -08:00
Nicholas Junge
efd61b73f6
Migrate JAX internals to builtin Python logging
...
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):
- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.
Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:
```py
import logging
logger = logging.getLogger(__name__)
logger.debug(...)
logger.info(...)
```
The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.
The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
jax authors
bc08381da3
Merge pull request #12152 from nvcastet:add_slurm_orchestrator_support
...
PiperOrigin-RevId: 476179963
2022-09-22 13:18:25 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Nicolas Castet
412a5379c1
Add generic interface for auto initialization of distributed JAX service
...
* Also add slurm cluster support
2022-09-22 14:15:38 -05:00
Peter Hawkins
6c59d72c75
Bump the minimum jaxlib version to 0.3.15.
2022-09-08 16:43:46 -04:00
jax authors
0df81ffc16
Merge pull request #11946 from nvcastet:fix_distributed_timeout
...
PiperOrigin-RevId: 469721289
2022-08-24 07:28:44 -07:00
jax authors
a73a6a8de0
Shut down preemption sync manager if enabled.
...
PiperOrigin-RevId: 469497215
2022-08-23 10:38:10 -07:00
Nicolas Castet
475d4d1497
Increase jax.distributed timeout to 5 min
2022-08-16 14:12:20 -05:00
Peter Hawkins
a2c21958a5
Document multiprocess GPU support.
...
Fixes #2731
2022-08-09 11:31:05 -04:00
Peter Hawkins
0b4b0ba072
Update minimum jaxlib version to 0.3.14.
2022-07-08 00:36:02 +00:00
Peter Hawkins
bdbdecd458
Refactor distributed GPU device initialization.
...
Avoid reregistering backend factories; instead simply have the usual
factory function support distributed GPU.
2022-07-07 00:45:19 +00:00
Haoyu Zhang
3fc24ceb35
Save an on-demand checkpoint when any worker receives a preemption signal.
...
PiperOrigin-RevId: 458525108
2022-07-01 12:45:30 -07:00
Peter Hawkins
3e699ddec0
Unbreak jax.distributed initialization.
...
A recent change broke jax.distributed initialization, which was unsurprising because those APIs were not tested. In particular, we need to only initialize the service from the first process.
Fix it and add some tests that use the distributed service from multiple threads within a unit test. Move the state of jax.distributed into an object so it can be instantiated multiple times from a test case in parallel rather than being process-global.
[XLA:Python] Add gil release guards around distributed system init/shutdown. This allows testing using multiple threads.
PiperOrigin-RevId: 453480351
2022-06-07 11:10:17 -07:00
jax authors
24ad82685c
Fix reference to jax_coordination_service flag.
...
PiperOrigin-RevId: 453224722
2022-06-06 10:08:31 -07:00
jax authors
6c89e90808
Allow JAX OSS users to switch between experimental coordination service and default PjRT distributed runtime service via a flag.
...
PiperOrigin-RevId: 452625054
2022-06-02 14:41:40 -07:00
Peter Hawkins
0e072bcf49
Fix distributed system initialization.
...
There was a version compatibility problem that meant that the distributed system did not correctly initialize when mixing the current released versions of jax and jaxlib. We don't have tests for multiple PGUs in our CI, so this was not caught.
2022-05-25 07:31:19 -07:00
Yash Katariya
0574eb2141
Use JAX's distributed system for fully asynchronous checkpointing.
...
PiperOrigin-RevId: 449380175
2022-05-17 20:17:57 -07:00
Yash Katariya
548a6bf58b
* Make all arguments to distributed.initialize equal to None.
...
* On Cloud TPUs, figure out the coordinator address automatically.
PiperOrigin-RevId: 449261786
2022-05-17 10:53:54 -07:00
Peter Hawkins
931bf3674b
[JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms.
...
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.
[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.
PiperOrigin-RevId: 446737518
2022-05-05 09:33:06 -07:00