mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax.distributed] Allow setting local device ids via env var
This commit is contained in:
parent
a8b425cac5
commit
f9bc4c643b
@ -45,10 +45,12 @@ class State:
|
||||
initialization_timeout: int = 300,
|
||||
coordinator_bind_address: str | None = None):
|
||||
coordinator_address = (coordinator_address or
|
||||
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
|
||||
os.environ.get('JAX_COORDINATOR_ADDRESS'))
|
||||
if isinstance(local_device_ids, int):
|
||||
local_device_ids = [local_device_ids]
|
||||
|
||||
if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')):
|
||||
local_device_ids = list(map(int, env_ids.split(",")))
|
||||
|
||||
(coordinator_address, num_processes, process_id, local_device_ids) = (
|
||||
clusters.ClusterEnv.auto_detect_unset_distributed_params(
|
||||
|
Loading…
x
Reference in New Issue
Block a user