Zac Cranko
5db78e7ae0
add distributed.is_initialized
2025-02-18 16:47:19 -08:00
Michael Hudgins
d4d1518c3d
Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
...
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Yash Katariya
e1b497078e
Rename jtu.create_global_mesh
to jtu.create_mesh
and use jax.make_mesh
inside jtu.create_mesh
to get maximum test coverage of the new API.
...
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
Yash Katariya
0d5dae09ff
Delete xmap
and the jax.experimental.maps
module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
...
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Kyle Gerard Felker
ffc9292365
Squashed commit of the following:
...
commit 79b8cbf0cb47e32743e0970bc1abeb6a673866a8
Author: Corey Adams <corey.adams@anl.gov>
Date: Mon Jul 1 14:14:15 2024 -0500
Fix mypy issues; change variable name to more universally known name
commit 10edc866f568908e536e5c7bd6b59b4e5351781e
Author: Corey Adams <corey.adams@anl.gov>
Date: Thu Jun 27 13:25:32 2024 -0500
Change copyright year to the year this was authored
commit f7086cb44cc98d58a96ae804dcd1787bc31470f7
Author: Corey Adams <corey.adams@anl.gov>
Date: Thu Jun 27 13:15:32 2024 -0500
Update build file to include mpi4py cluster.
commit 6235eb311b9fca2bd81fe1c49456d164b7332753
Author: Corey adams <coreyjadams@gmail.com>
Date: Thu Jun 27 12:11:48 2024 -0500
Update distributed.py
Clean up documentation slightly.
commit ef3a2e220945b2158cf20edeb1e04bbbf8f290ff
Author: Corey adams <coreyjadams@gmail.com>
Date: Thu Jun 27 12:09:37 2024 -0500
Update mpi4py_cluster.py
Further clean up unneeded comments.
commit 6cc07a9a52fc202ecc65c04c513096391c27d02d
Author: Corey adams <coreyjadams@gmail.com>
Date: Thu Jun 27 12:08:38 2024 -0500
Update mpi4py_cluster.py
Remove unneeded commented code.
commit 6701bd1a9d645a0e08d95df1692f43946f0a5eb8
Merge: 5a91ac342 98b87540a
Author: Corey adams <coreyjadams@gmail.com>
Date: Thu Jun 27 12:07:25 2024 -0500
Merge branch 'google:main' into main
commit 5a91ac34248afa6f65af3cae66df7d0d122c1d26
Merge: 301bbc67f 6c51234f9
Author: Corey adams <coreyjadams@gmail.com>
Date: Tue May 28 22:14:08 2024 -0500
Merge branch 'google:main' into main
commit 301bbc67f938bc30c543cf300cec8a9c75f3eef8
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue May 28 11:34:51 2024 -0500
Add test to verify mpi4py based distributed initialization
commit 19e66949a36bb0edb4cd66b0f170f42b326928ec
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue May 28 11:14:40 2024 -0500
Unify variable naming and fix function argument ordering
commit 72fe093042519e48d9c26b7ede3b266c7a850be6
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue May 28 10:56:25 2024 -0500
Remove unmerged code
commit 3a96e738a3cdf9b6ed194cb764fa5640a37f6b95
Merge: e4fd97e19 ff3db9b3a
Author: Corey adams <coreyjadams@gmail.com>
Date: Tue May 28 10:51:41 2024 -0500
Merge branch 'google:main' into main
commit e4fd97e197211921fb6911054592041015af94ef
Merge: a69729900 72a81e58e
Author: Corey adams <coreyjadams@gmail.com>
Date: Mon May 13 16:01:35 2024 -0500
Merge branch 'google:main' into main
commit a6972990070d5d2f405d5ede9f82d35c7e6d157a
Merge: 85bcf42bd 1e48adc69
Author: Corey adams <coreyjadams@gmail.com>
Date: Mon May 13 14:21:32 2024 -0500
Merge branch 'google:main' into main
commit 85bcf42bdd36ad88a3d287c357cd12fde74c7fc0
Merge: af1a4f0a1 06cd05d1d
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue Apr 16 09:09:31 2024 -0500
Merge branch 'main' of https://github.com/google/jax
commit af1a4f0a12008780e9507d1bdd91e9d11ec35916
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue Apr 16 08:58:33 2024 -0500
update documentation and elaborate on spec_detect_method variable
commit 01f4709d5ecd4af675f4fb23d02d6a69b927adac
Author: Corey Adams <corey.adams@anl.gov>
Date: Tue Apr 16 08:45:38 2024 -0500
Address feedback and comments on PR 20174; fix typo in documentation.
commit 4f22d86e7358c29ed588267a7d91fe55fb94f143
Merge: 900a0372f 71ec6e33c
Author: Corey adams <coreyjadams@gmail.com>
Date: Mon Mar 11 11:51:30 2024 -0500
Merge branch 'google:main' into main
commit 900a0372f6147d3c9ab53c95b6a4262e5cfe4457
Author: Corey Adams <corey.adams@anl.gov>
Date: Mon Mar 11 11:50:48 2024 -0500
Auto-detect of mpi4py-based configuration is now strictly opt-in.
commit 1992969da6164e456492fe0f9cd4287f6d8f03cf
Author: Corey Adams <corey.adams@anl.gov>
Date: Thu Mar 7 12:27:43 2024 -0600
Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available
2024-07-02 13:18:05 -05:00
Jake VanderPlas
e6e4acb7c3
tests: set configs with jtu.with_config rather than manually
2024-06-05 13:34:32 -07:00
Jake VanderPlas
f090074d86
Avoid 'from jax import config' imports
...
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
79c21ffd94
multiprocess_test: print error message on failure
2023-11-22 10:06:19 -08:00
Sergei Lebedev
f9087ab0c6
MAINT Drop underscore from the name of externally-referenced state objects
2023-10-13 21:30:13 +01:00
Sergei Lebedev
cbcaac2756
MAINT Migrate remaining internal/test modules to use state objects
...
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008 .
2023-10-12 17:32:15 +01:00
Peter Hawkins
1885c4933c
Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
...
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Peter Hawkins
2c32660a8f
Replace references to DeviceArray with Array.
...
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Jake VanderPlas
fbe4f10403
Change to simpler import for jax.config
2023-04-21 11:51:22 -07:00
Yash Katariya
efc8300d02
Remove the gda flag from multiprocess_gpu_test.py
...
PiperOrigin-RevId: 523116057
2023-04-10 07:53:37 -07:00
Peter Hawkins
b4402185db
Move PartitionSpec into its own file (jax/_src/partition_spec.py).
...
No functional changes intended.
A subsequent change will move ParsedPartitionSpec and array mapping utilities here also.
PiperOrigin-RevId: 522393166
2023-04-06 11:43:25 -07:00
Yash Katariya
e21aee18a8
Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
...
PiperOrigin-RevId: 519781715
2023-03-27 11:33:11 -07:00
Yash Katariya
88584290aa
Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests
...
PiperOrigin-RevId: 516881635
2023-03-15 11:34:57 -07:00
Yash Katariya
52a7701dda
Replace usage of {in|out}_axis_resources with {in|out}_shardings
...
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Yash Katariya
418c2f9d2a
Rename in_axis_resources
and out_axis_resources
with in_shardings
and out_shardings
. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
...
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.
PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Roy Frostig
cb8dcce2fe
migrate more internal dependencies from jax.core
to jax._src.core
...
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Jake VanderPlas
43e57db77a
Begin deprecation of public jax.ShapedArray
2023-01-30 11:27:58 -08:00
Skye Wanderman-Milne
c0577f70f9
Migrate pytestmark usage to new @jtu.pytest_mark_if_available
decorator.
...
See discussion in https://github.com/google/jax/pull/13977 . Marking
entire modules is magical and verbose, plus less precise than marking
individual classes or tests.
I wasn't super careful on which classes to mark, and erred on the side
of marking too many classes (in line with the previous behavior). It's
possible some test classes don't actually benefit from multiple
accelerators.
2023-01-12 22:44:39 +00:00
Nicolas Castet
b86030d86f
Add Open MPI automatic distributed initialization
2023-01-11 17:08:09 -06:00
Yash Katariya
5afebba285
Remove _global_avals from infer_params because everything is global in pjit after jax.Array was enabled.
...
PiperOrigin-RevId: 500012042
2023-01-06 00:08:16 -08:00
Jake VanderPlas
b0e03fb747
Remove whitespace to fix flake8
2022-11-07 09:10:05 -08:00
Rahul Batra
e84a7e25b2
[ROCm]: Enable/update multiprocess gpu tests for ROCm
2022-10-27 16:51:37 +00:00
Peter Hawkins
320d531521
Increase the minimum jaxlib version to 0.3.22.
...
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Yash Katariya
607ce88d19
jax.Array is a unified type that will subsume JAX's DeviceArray, ShardedDeviceArray and GlobalDeviceArray.
...
This change replaces uses of `local_shards` and `local_data` with `addressable_shards` and `addressable_data` which are compatible with both `GDA` and `jax.Array`.
PiperOrigin-RevId: 481229606
2022-10-14 14:09:01 -07:00
Yash Katariya
7b49a3f51d
Run tests in multiprocess_gpu_test only if the backend is GPU.
...
PiperOrigin-RevId: 477750739
2022-09-29 09:54:32 -07:00
jax authors
aafc70d293
Merge pull request #12556 from hawkinsp:rocm
...
PiperOrigin-RevId: 477440001
2022-09-28 06:50:19 -07:00
Peter Hawkins
f7bafb3d4c
Disable multiprocess_gpu_test that fails on ROCm.
2022-09-28 13:40:57 +00:00
Peter Hawkins
eabb91e53f
Fix test failure in GPU CI if NCCL_DEBUG is enabled.
...
If NCCL_DEBUG is enabled, NCCL prints extra status information. Make
test accept this.
2022-09-28 13:06:04 +00:00
Peter Hawkins
f4bc663c31
Wrap multiprocess test popen() uses in a context manager.
...
Ensures that resources from popen() are cleaned up.
2022-09-26 13:48:56 +00:00
Sudhakar
4dd0d85139
add multihost pjit tests
2022-09-23 12:11:56 -07:00
jax authors
bc08381da3
Merge pull request #12152 from nvcastet:add_slurm_orchestrator_support
...
PiperOrigin-RevId: 476179963
2022-09-22 13:18:25 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Nicolas Castet
412a5379c1
Add generic interface for auto initialization of distributed JAX service
...
* Also add slurm cluster support
2022-09-22 14:15:38 -05:00
Jake VanderPlas
2936c8a2c7
multiprocess_gpu_test: kill open subprocesses to avoid warning
2022-09-19 12:31:10 -07:00
Sudhakar
5f1858f533
Add pytest marker inside the test only if pytest is present in the env
2022-09-06 11:45:59 -07:00
Sudhakar
a571db18db
Enable one gpu per process in multinode GPU CI
2022-08-29 09:00:19 -07:00
Peter Hawkins
b9d7e05eda
[JAX] Handle non-contiguous GPU IDs in NCCL collectives.
...
Fixes https://github.com/google/jax/issues/12119
PiperOrigin-RevId: 470335156
2022-08-26 14:33:08 -07:00
Sudhakar
4b1a2eaaec
combine gpu tests
2022-08-25 15:27:07 -07:00