Default JAX_CPU_COLLECTIVES_IMPLEMENTATION to 'gloo'.

This enables CPU collectives by default, making multi-process CPU
communication work without extra configuration.

PiperOrigin-RevId: 724076284
This commit is contained in:
Skye Wanderman-Milne 2025-02-06 14:29:52 -08:00 committed by jax authors
parent 4b86ff22e9
commit f07243a73a
2 changed files with 10 additions and 3 deletions

View File

@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
env vars. Before they could only be specified via jax.config or flags.
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` now defaults to `'gloo'`, meaning
multi-process CPU communication works out-of-the-box.
* The `jax[tpu]` TPU extra no longer depends on the `libtpu-nightly` package.
This package may safely be removed if it is present on your machine; JAX now
uses `libtpu` instead.

View File

@ -62,6 +62,8 @@ XlaBackend = xla_client.Client
MIN_COMPUTE_CAPABILITY = 52
_DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo'
# TODO(phawkins): Remove jax_xla_backend.
_XLA_BACKEND = config.string_flag(
'jax_xla_backend', '',
@ -235,7 +237,9 @@ def make_cpu_client(
Returns:
The created CPU client.
"""
if collectives is None:
# TODO(skyewm): use distributed.is_initialized() after
# https://github.com/jax-ml/jax/pull/26172 goes in.
if collectives is None and distributed.global_state.client is not None:
collectives_impl = config.cpu_collectives_implementation.value
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
collectives_impl = 'gloo'
@ -244,6 +248,9 @@ def make_cpu_client(
'"jax_cpu_collectives_implementation", "gloo")` instead.',
DeprecationWarning,
)
if collectives_impl is None:
collectives_impl = _DEFAULT_CPU_COLLECTIVES_IMPL
if collectives_impl == 'gloo':
collectives = xla_client._xla.make_gloo_tcp_collectives(
distributed_client=distributed.global_state.client,
@ -252,8 +259,6 @@ def make_cpu_client(
collectives = xla_client._xla.make_mpi_collectives()
collectives.Init()
atexit.register(collectives.Finalize)
elif collectives_impl == 'megascale':
raise ValueError('JAX_CPU_COLLECTIVES_IMPLEMENTATION must "gloo" or "mpi"')
else:
# Already validated by config module
assert collectives_impl is None