43 Commits

Author SHA1 Message Date
Zac Cranko
5db78e7ae0 add distributed.is_initialized 2025-02-18 16:47:19 -08:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Yash Katariya
e1b497078e Rename jtu.create_global_mesh to jtu.create_mesh and use jax.make_mesh inside jtu.create_mesh to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Kyle Gerard Felker
ffc9292365 Squashed commit of the following:
commit 79b8cbf0cb47e32743e0970bc1abeb6a673866a8
Author: Corey Adams <corey.adams@anl.gov>
Date:   Mon Jul 1 14:14:15 2024 -0500

    Fix mypy issues; change variable name to more universally known name

commit 10edc866f568908e536e5c7bd6b59b4e5351781e
Author: Corey Adams <corey.adams@anl.gov>
Date:   Thu Jun 27 13:25:32 2024 -0500

    Change copyright year to the year this was authored

commit f7086cb44cc98d58a96ae804dcd1787bc31470f7
Author: Corey Adams <corey.adams@anl.gov>
Date:   Thu Jun 27 13:15:32 2024 -0500

    Update build file to include mpi4py cluster.

commit 6235eb311b9fca2bd81fe1c49456d164b7332753
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:11:48 2024 -0500

    Update distributed.py

    Clean up documentation slightly.

commit ef3a2e220945b2158cf20edeb1e04bbbf8f290ff
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:09:37 2024 -0500

    Update mpi4py_cluster.py

    Further clean up unneeded comments.

commit 6cc07a9a52fc202ecc65c04c513096391c27d02d
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:08:38 2024 -0500

    Update mpi4py_cluster.py

    Remove unneeded commented code.

commit 6701bd1a9d645a0e08d95df1692f43946f0a5eb8
Merge: 5a91ac342 98b87540a
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:07:25 2024 -0500

    Merge branch 'google:main' into main

commit 5a91ac34248afa6f65af3cae66df7d0d122c1d26
Merge: 301bbc67f 6c51234f9
Author: Corey adams <coreyjadams@gmail.com>
Date:   Tue May 28 22:14:08 2024 -0500

    Merge branch 'google:main' into main

commit 301bbc67f938bc30c543cf300cec8a9c75f3eef8
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue May 28 11:34:51 2024 -0500

    Add test to verify mpi4py based distributed initialization

commit 19e66949a36bb0edb4cd66b0f170f42b326928ec
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue May 28 11:14:40 2024 -0500

    Unify variable naming and fix function argument ordering

commit 72fe093042519e48d9c26b7ede3b266c7a850be6
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue May 28 10:56:25 2024 -0500

    Remove unmerged code

commit 3a96e738a3cdf9b6ed194cb764fa5640a37f6b95
Merge: e4fd97e19 ff3db9b3a
Author: Corey adams <coreyjadams@gmail.com>
Date:   Tue May 28 10:51:41 2024 -0500

    Merge branch 'google:main' into main

commit e4fd97e197211921fb6911054592041015af94ef
Merge: a69729900 72a81e58e
Author: Corey adams <coreyjadams@gmail.com>
Date:   Mon May 13 16:01:35 2024 -0500

    Merge branch 'google:main' into main

commit a6972990070d5d2f405d5ede9f82d35c7e6d157a
Merge: 85bcf42bd 1e48adc69
Author: Corey adams <coreyjadams@gmail.com>
Date:   Mon May 13 14:21:32 2024 -0500

    Merge branch 'google:main' into main

commit 85bcf42bdd36ad88a3d287c357cd12fde74c7fc0
Merge: af1a4f0a1 06cd05d1d
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue Apr 16 09:09:31 2024 -0500

    Merge branch 'main' of https://github.com/google/jax

commit af1a4f0a12008780e9507d1bdd91e9d11ec35916
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue Apr 16 08:58:33 2024 -0500

    update documentation and elaborate on spec_detect_method variable

commit 01f4709d5ecd4af675f4fb23d02d6a69b927adac
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue Apr 16 08:45:38 2024 -0500

    Address feedback and comments on PR 20174; fix typo in documentation.

commit 4f22d86e7358c29ed588267a7d91fe55fb94f143
Merge: 900a0372f 71ec6e33c
Author: Corey adams <coreyjadams@gmail.com>
Date:   Mon Mar 11 11:51:30 2024 -0500

    Merge branch 'google:main' into main

commit 900a0372f6147d3c9ab53c95b6a4262e5cfe4457
Author: Corey Adams <corey.adams@anl.gov>
Date:   Mon Mar 11 11:50:48 2024 -0500

    Auto-detect of mpi4py-based configuration is now strictly opt-in.

commit 1992969da6164e456492fe0f9cd4287f6d8f03cf
Author: Corey Adams <corey.adams@anl.gov>
Date:   Thu Mar 7 12:27:43 2024 -0600

    Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available
2024-07-02 13:18:05 -05:00
Jake VanderPlas
e6e4acb7c3 tests: set configs with jtu.with_config rather than manually 2024-06-05 13:34:32 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
79c21ffd94 multiprocess_test: print error message on failure 2023-11-22 10:06:19 -08:00
Sergei Lebedev
f9087ab0c6 MAINT Drop underscore from the name of externally-referenced state objects 2023-10-13 21:30:13 +01:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Yash Katariya
efc8300d02 Remove the gda flag from multiprocess_gpu_test.py
PiperOrigin-RevId: 523116057
2023-04-10 07:53:37 -07:00
Peter Hawkins
b4402185db Move PartitionSpec into its own file (jax/_src/partition_spec.py).
No functional changes intended.

A subsequent change will move ParsedPartitionSpec and array mapping utilities here also.

PiperOrigin-RevId: 522393166
2023-04-06 11:43:25 -07:00
Yash Katariya
e21aee18a8 Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
PiperOrigin-RevId: 519781715
2023-03-27 11:33:11 -07:00
Yash Katariya
88584290aa Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests
PiperOrigin-RevId: 516881635
2023-03-15 11:34:57 -07:00
Yash Katariya
52a7701dda Replace usage of {in|out}_axis_resources with {in|out}_shardings
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Yash Katariya
418c2f9d2a Rename in_axis_resources and out_axis_resources with in_shardings and out_shardings. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.

PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Jake VanderPlas
43e57db77a Begin deprecation of public jax.ShapedArray 2023-01-30 11:27:58 -08:00
Skye Wanderman-Milne
c0577f70f9 Migrate pytestmark usage to new @jtu.pytest_mark_if_available decorator.
See discussion in https://github.com/google/jax/pull/13977. Marking
entire modules is magical and verbose, plus less precise than marking
individual classes or tests.

I wasn't super careful on which classes to mark, and erred on the side
of marking too many classes (in line with the previous behavior). It's
possible some test classes don't actually benefit from multiple
accelerators.
2023-01-12 22:44:39 +00:00
Nicolas Castet
b86030d86f Add Open MPI automatic distributed initialization 2023-01-11 17:08:09 -06:00
Yash Katariya
5afebba285 Remove _global_avals from infer_params because everything is global in pjit after jax.Array was enabled.
PiperOrigin-RevId: 500012042
2023-01-06 00:08:16 -08:00
Jake VanderPlas
b0e03fb747 Remove whitespace to fix flake8 2022-11-07 09:10:05 -08:00
Rahul Batra
e84a7e25b2 [ROCm]: Enable/update multiprocess gpu tests for ROCm 2022-10-27 16:51:37 +00:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Yash Katariya
607ce88d19 jax.Array is a unified type that will subsume JAX's DeviceArray, ShardedDeviceArray and GlobalDeviceArray.
This change replaces uses of `local_shards` and `local_data` with `addressable_shards` and `addressable_data` which are compatible with both `GDA` and `jax.Array`.

PiperOrigin-RevId: 481229606
2022-10-14 14:09:01 -07:00
Yash Katariya
7b49a3f51d Run tests in multiprocess_gpu_test only if the backend is GPU.
PiperOrigin-RevId: 477750739
2022-09-29 09:54:32 -07:00
jax authors
aafc70d293 Merge pull request #12556 from hawkinsp:rocm
PiperOrigin-RevId: 477440001
2022-09-28 06:50:19 -07:00
Peter Hawkins
f7bafb3d4c Disable multiprocess_gpu_test that fails on ROCm. 2022-09-28 13:40:57 +00:00
Peter Hawkins
eabb91e53f Fix test failure in GPU CI if NCCL_DEBUG is enabled.
If NCCL_DEBUG is enabled, NCCL prints extra status information. Make
test accept this.
2022-09-28 13:06:04 +00:00
Peter Hawkins
f4bc663c31 Wrap multiprocess test popen() uses in a context manager.
Ensures that resources from popen() are cleaned up.
2022-09-26 13:48:56 +00:00
Sudhakar
4dd0d85139 add multihost pjit tests 2022-09-23 12:11:56 -07:00
jax authors
bc08381da3 Merge pull request #12152 from nvcastet:add_slurm_orchestrator_support
PiperOrigin-RevId: 476179963
2022-09-22 13:18:25 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Nicolas Castet
412a5379c1 Add generic interface for auto initialization of distributed JAX service
* Also add slurm cluster support
2022-09-22 14:15:38 -05:00
Jake VanderPlas
2936c8a2c7 multiprocess_gpu_test: kill open subprocesses to avoid warning 2022-09-19 12:31:10 -07:00
Sudhakar
5f1858f533 Add pytest marker inside the test only if pytest is present in the env 2022-09-06 11:45:59 -07:00
Sudhakar
a571db18db Enable one gpu per process in multinode GPU CI 2022-08-29 09:00:19 -07:00
Peter Hawkins
b9d7e05eda [JAX] Handle non-contiguous GPU IDs in NCCL collectives.
Fixes https://github.com/google/jax/issues/12119

PiperOrigin-RevId: 470335156
2022-08-26 14:33:08 -07:00
Sudhakar
4b1a2eaaec combine gpu tests 2022-08-25 15:27:07 -07:00