mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Default JAX_CPU_COLLECTIVES_IMPLEMENTATION to 'gloo'.
This enables CPU collectives by default, making multi-process CPU communication work without extra configuration. PiperOrigin-RevId: 724076284
This commit is contained in:
parent
4b86ff22e9
commit
f07243a73a
@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* Changes
|
||||
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
|
||||
env vars. Before they could only be specified via jax.config or flags.
|
||||
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` now defaults to `'gloo'`, meaning
|
||||
multi-process CPU communication works out-of-the-box.
|
||||
* The `jax[tpu]` TPU extra no longer depends on the `libtpu-nightly` package.
|
||||
This package may safely be removed if it is present on your machine; JAX now
|
||||
uses `libtpu` instead.
|
||||
|
@ -62,6 +62,8 @@ XlaBackend = xla_client.Client
|
||||
|
||||
MIN_COMPUTE_CAPABILITY = 52
|
||||
|
||||
_DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo'
|
||||
|
||||
# TODO(phawkins): Remove jax_xla_backend.
|
||||
_XLA_BACKEND = config.string_flag(
|
||||
'jax_xla_backend', '',
|
||||
@ -235,7 +237,9 @@ def make_cpu_client(
|
||||
Returns:
|
||||
The created CPU client.
|
||||
"""
|
||||
if collectives is None:
|
||||
# TODO(skyewm): use distributed.is_initialized() after
|
||||
# https://github.com/jax-ml/jax/pull/26172 goes in.
|
||||
if collectives is None and distributed.global_state.client is not None:
|
||||
collectives_impl = config.cpu_collectives_implementation.value
|
||||
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
|
||||
collectives_impl = 'gloo'
|
||||
@ -244,6 +248,9 @@ def make_cpu_client(
|
||||
'"jax_cpu_collectives_implementation", "gloo")` instead.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
if collectives_impl is None:
|
||||
collectives_impl = _DEFAULT_CPU_COLLECTIVES_IMPL
|
||||
|
||||
if collectives_impl == 'gloo':
|
||||
collectives = xla_client._xla.make_gloo_tcp_collectives(
|
||||
distributed_client=distributed.global_state.client,
|
||||
@ -252,8 +259,6 @@ def make_cpu_client(
|
||||
collectives = xla_client._xla.make_mpi_collectives()
|
||||
collectives.Init()
|
||||
atexit.register(collectives.Finalize)
|
||||
elif collectives_impl == 'megascale':
|
||||
raise ValueError('JAX_CPU_COLLECTIVES_IMPLEMENTATION must "gloo" or "mpi"')
|
||||
else:
|
||||
# Already validated by config module
|
||||
assert collectives_impl is None
|
||||
|
Loading…
x
Reference in New Issue
Block a user