diff --git a/CHANGELOG.md b/CHANGELOG.md index d7a8e0b93..6869160bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more details. +* Changes + * `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as + env vars. Before they could only be specified via jax.config or flags. + ## jax 0.5.0 (Jan 17, 2025) As of this release, JAX now uses diff --git a/jax/_src/config.py b/jax/_src/config.py index c457acb0e..af7287602 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1716,3 +1716,21 @@ memory_fitting_effort = float_state( default=0.0, help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].' ) + +cpu_collectives_implementation = optional_enum_state( + name='jax_cpu_collectives_implementation', + enum_values=["gloo", "mpi", "megascale"], + default=None, + help=( + "Cross-process collective implementation used on CPU. Must be one of " + '("gloo", "mpi")'), +) + +num_cpu_devices = int_state( + name="jax_num_cpu_devices", + default=-1, + help=( + "Number of CPU devices to use. If not provided, the value of " + "the XLA flag --xla_force_host_platform_device_count is used." + " Must be set before JAX is initialized."), +) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index f34b2c8f3..3f23c4c76 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -518,7 +518,7 @@ def request_cpu_devices(nr_devices: int): invoked. Test cases that require a specific number of devices should skip themselves if that number is not met. """ - if xla_bridge.NUM_CPU_DEVICES.value < nr_devices: + if config.num_cpu_devices.value < nr_devices: xla_bridge.get_backend.cache_clear() config.update("jax_num_cpu_devices", nr_devices) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index cbbd04da4..51f00c56d 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -104,17 +104,6 @@ _CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag( help="Deprecated, please use jax_cpu_collectives_implementation instead.", ) -CPU_COLLECTIVES_IMPLEMENTATIONS = ["none", "gloo", "mpi"] -CPU_COLLECTIVES_IMPLEMENTATION = config.enum_flag( - name="jax_cpu_collectives_implementation", - default="none", - enum_values=CPU_COLLECTIVES_IMPLEMENTATIONS, - help=( - "Cross-process collective implementation used on CPU. Must be one of" - f" {CPU_COLLECTIVES_IMPLEMENTATIONS}" - ), -) - _CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag( name="jax_cpu_enable_async_dispatch", default=True, @@ -122,14 +111,6 @@ _CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag( "inline without async dispatch.", ) -NUM_CPU_DEVICES = config.int_flag( - name="jax_num_cpu_devices", - default=-1, - help="Number of CPU devices to use. If not provided, the value of " - "the XLA flag --xla_force_host_platform_device_count is used." - " Must be set before JAX is initialized.", -) - # Warn the user if they call fork(), because it's not going to go well for them. def _at_fork(): @@ -255,7 +236,7 @@ def make_cpu_client( The created CPU client. """ if collectives is None: - collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value + collectives_impl = config.cpu_collectives_implementation.value if _CPU_ENABLE_GLOO_COLLECTIVES.value: collectives_impl = 'gloo' warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is ' @@ -271,12 +252,13 @@ def make_cpu_client( collectives = xla_client._xla.make_mpi_collectives() collectives.Init() atexit.register(collectives.Finalize) - elif collectives_impl != 'none': - raise RuntimeError(f"Unknown collectives implementation " - f"{collectives_impl}. Available implementations are " - f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.") + elif collectives_impl == 'megascale': + raise ValueError('JAX_CPU_COLLECTIVES_IMPLEMENTATION must "gloo" or "mpi"') + else: + # Already validated by config module + assert collectives_impl is None - num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None + num_devices = config.num_cpu_devices.value if config.num_cpu_devices.value >= 0 else None return xla_client.make_cpu_client( asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value, distributed_client=distributed.global_state.client,