[jax.distributed] Allow setting local device ids via env var

This commit is contained in:
Georg Stefan Schmid 2024-07-10 10:15:35 +00:00
parent a8b425cac5
commit f9bc4c643b

View File

@ -45,10 +45,12 @@ class State:
initialization_timeout: int = 300,
coordinator_bind_address: str | None = None):
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
os.environ.get('JAX_COORDINATOR_ADDRESS'))
if isinstance(local_device_ids, int):
local_device_ids = [local_device_ids]
if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')):
local_device_ids = list(map(int, env_ids.split(",")))
(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(