mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make JAX_CPU_COLLECTIVES_IMPLEMENTATION
and JAX_NUM_CPU_DEVICES
env vars
Before, these values could only be specified via jax.config or flags. This PR makes them proper configs, so they also work as env vars.
This commit is contained in:
parent
809e1133c8
commit
2aa810fe60
@ -22,6 +22,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
|
||||
details.
|
||||
|
||||
* Changes
|
||||
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
|
||||
env vars. Before they could only be specified via jax.config or flags.
|
||||
|
||||
## jax 0.5.0 (Jan 17, 2025)
|
||||
|
||||
As of this release, JAX now uses
|
||||
|
@ -1716,3 +1716,21 @@ memory_fitting_effort = float_state(
|
||||
default=0.0,
|
||||
help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].'
|
||||
)
|
||||
|
||||
cpu_collectives_implementation = optional_enum_state(
|
||||
name='jax_cpu_collectives_implementation',
|
||||
enum_values=["gloo", "mpi", "megascale"],
|
||||
default=None,
|
||||
help=(
|
||||
"Cross-process collective implementation used on CPU. Must be one of "
|
||||
'("gloo", "mpi")'),
|
||||
)
|
||||
|
||||
num_cpu_devices = int_state(
|
||||
name="jax_num_cpu_devices",
|
||||
default=-1,
|
||||
help=(
|
||||
"Number of CPU devices to use. If not provided, the value of "
|
||||
"the XLA flag --xla_force_host_platform_device_count is used."
|
||||
" Must be set before JAX is initialized."),
|
||||
)
|
||||
|
@ -518,7 +518,7 @@ def request_cpu_devices(nr_devices: int):
|
||||
invoked. Test cases that require a specific number of devices should skip
|
||||
themselves if that number is not met.
|
||||
"""
|
||||
if xla_bridge.NUM_CPU_DEVICES.value < nr_devices:
|
||||
if config.num_cpu_devices.value < nr_devices:
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
config.update("jax_num_cpu_devices", nr_devices)
|
||||
|
||||
|
@ -104,17 +104,6 @@ _CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag(
|
||||
help="Deprecated, please use jax_cpu_collectives_implementation instead.",
|
||||
)
|
||||
|
||||
CPU_COLLECTIVES_IMPLEMENTATIONS = ["none", "gloo", "mpi"]
|
||||
CPU_COLLECTIVES_IMPLEMENTATION = config.enum_flag(
|
||||
name="jax_cpu_collectives_implementation",
|
||||
default="none",
|
||||
enum_values=CPU_COLLECTIVES_IMPLEMENTATIONS,
|
||||
help=(
|
||||
"Cross-process collective implementation used on CPU. Must be one of"
|
||||
f" {CPU_COLLECTIVES_IMPLEMENTATIONS}"
|
||||
),
|
||||
)
|
||||
|
||||
_CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
|
||||
name="jax_cpu_enable_async_dispatch",
|
||||
default=True,
|
||||
@ -122,14 +111,6 @@ _CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
|
||||
"inline without async dispatch.",
|
||||
)
|
||||
|
||||
NUM_CPU_DEVICES = config.int_flag(
|
||||
name="jax_num_cpu_devices",
|
||||
default=-1,
|
||||
help="Number of CPU devices to use. If not provided, the value of "
|
||||
"the XLA flag --xla_force_host_platform_device_count is used."
|
||||
" Must be set before JAX is initialized.",
|
||||
)
|
||||
|
||||
|
||||
# Warn the user if they call fork(), because it's not going to go well for them.
|
||||
def _at_fork():
|
||||
@ -255,7 +236,7 @@ def make_cpu_client(
|
||||
The created CPU client.
|
||||
"""
|
||||
if collectives is None:
|
||||
collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value
|
||||
collectives_impl = config.cpu_collectives_implementation.value
|
||||
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
|
||||
collectives_impl = 'gloo'
|
||||
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
|
||||
@ -271,12 +252,13 @@ def make_cpu_client(
|
||||
collectives = xla_client._xla.make_mpi_collectives()
|
||||
collectives.Init()
|
||||
atexit.register(collectives.Finalize)
|
||||
elif collectives_impl != 'none':
|
||||
raise RuntimeError(f"Unknown collectives implementation "
|
||||
f"{collectives_impl}. Available implementations are "
|
||||
f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.")
|
||||
elif collectives_impl == 'megascale':
|
||||
raise ValueError('JAX_CPU_COLLECTIVES_IMPLEMENTATION must "gloo" or "mpi"')
|
||||
else:
|
||||
# Already validated by config module
|
||||
assert collectives_impl is None
|
||||
|
||||
num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None
|
||||
num_devices = config.num_cpu_devices.value if config.num_cpu_devices.value >= 0 else None
|
||||
return xla_client.make_cpu_client(
|
||||
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
|
||||
distributed_client=distributed.global_state.client,
|
||||
|
Loading…
x
Reference in New Issue
Block a user