Make JAX_CPU_COLLECTIVES_IMPLEMENTATION and JAX_NUM_CPU_DEVICES env vars

Before, these values could only be specified via jax.config or
flags. This PR makes them proper configs, so they also work as env
vars.
This commit is contained in:
Skye Wanderman-Milne 2025-01-28 17:00:06 -08:00
parent 809e1133c8
commit 2aa810fe60
4 changed files with 30 additions and 26 deletions

View File

@ -22,6 +22,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
details.
* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
env vars. Before they could only be specified via jax.config or flags.
## jax 0.5.0 (Jan 17, 2025)
As of this release, JAX now uses

View File

@ -1716,3 +1716,21 @@ memory_fitting_effort = float_state(
default=0.0,
help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].'
)
cpu_collectives_implementation = optional_enum_state(
name='jax_cpu_collectives_implementation',
enum_values=["gloo", "mpi", "megascale"],
default=None,
help=(
"Cross-process collective implementation used on CPU. Must be one of "
'("gloo", "mpi")'),
)
num_cpu_devices = int_state(
name="jax_num_cpu_devices",
default=-1,
help=(
"Number of CPU devices to use. If not provided, the value of "
"the XLA flag --xla_force_host_platform_device_count is used."
" Must be set before JAX is initialized."),
)

View File

@ -518,7 +518,7 @@ def request_cpu_devices(nr_devices: int):
invoked. Test cases that require a specific number of devices should skip
themselves if that number is not met.
"""
if xla_bridge.NUM_CPU_DEVICES.value < nr_devices:
if config.num_cpu_devices.value < nr_devices:
xla_bridge.get_backend.cache_clear()
config.update("jax_num_cpu_devices", nr_devices)

View File

@ -104,17 +104,6 @@ _CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag(
help="Deprecated, please use jax_cpu_collectives_implementation instead.",
)
CPU_COLLECTIVES_IMPLEMENTATIONS = ["none", "gloo", "mpi"]
CPU_COLLECTIVES_IMPLEMENTATION = config.enum_flag(
name="jax_cpu_collectives_implementation",
default="none",
enum_values=CPU_COLLECTIVES_IMPLEMENTATIONS,
help=(
"Cross-process collective implementation used on CPU. Must be one of"
f" {CPU_COLLECTIVES_IMPLEMENTATIONS}"
),
)
_CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
name="jax_cpu_enable_async_dispatch",
default=True,
@ -122,14 +111,6 @@ _CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
"inline without async dispatch.",
)
NUM_CPU_DEVICES = config.int_flag(
name="jax_num_cpu_devices",
default=-1,
help="Number of CPU devices to use. If not provided, the value of "
"the XLA flag --xla_force_host_platform_device_count is used."
" Must be set before JAX is initialized.",
)
# Warn the user if they call fork(), because it's not going to go well for them.
def _at_fork():
@ -255,7 +236,7 @@ def make_cpu_client(
The created CPU client.
"""
if collectives is None:
collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value
collectives_impl = config.cpu_collectives_implementation.value
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
collectives_impl = 'gloo'
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
@ -271,12 +252,13 @@ def make_cpu_client(
collectives = xla_client._xla.make_mpi_collectives()
collectives.Init()
atexit.register(collectives.Finalize)
elif collectives_impl != 'none':
raise RuntimeError(f"Unknown collectives implementation "
f"{collectives_impl}. Available implementations are "
f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.")
elif collectives_impl == 'megascale':
raise ValueError('JAX_CPU_COLLECTIVES_IMPLEMENTATION must "gloo" or "mpi"')
else:
# Already validated by config module
assert collectives_impl is None
num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None
num_devices = config.num_cpu_devices.value if config.num_cpu_devices.value >= 0 else None
return xla_client.make_cpu_client(
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
distributed_client=distributed.global_state.client,