mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Automate arguments for jax.distributed.initialize for cloud TPU environments.
PiperOrigin-RevId: 586892544
This commit is contained in:
parent
a07ed22b02
commit
8ad774fb10
@ -14,6 +14,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Changes
|
||||
* The minimum jaxlib version is now 0.4.19.
|
||||
* Released wheels are built now with clang instead of gcc.
|
||||
* Enforce that the device backend has not been initialized prior to calling `jax.distributed.initialize()`.
|
||||
* Automate arguments to `jax.distributed.initialize()` in cloud TPU environments.
|
||||
|
||||
* Deprecations
|
||||
* The previously-deprecated `sym_pos` argument has been removed from
|
||||
|
@ -22,4 +22,6 @@ from .cluster import ClusterEnv
|
||||
# available one from the list will be picked.
|
||||
from .ompi_cluster import OmpiCluster
|
||||
from .slurm_cluster import SlurmCluster
|
||||
from .cloud_tpu_cluster import TpuCluster
|
||||
from .cloud_tpu_cluster import GkeTpuCluster
|
||||
from .cloud_tpu_cluster import MultisliceGceTpuCluster
|
||||
from .cloud_tpu_cluster import SingleSliceGceTpuCluster
|
||||
|
@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
from typing import Optional
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import clusters
|
||||
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
||||
|
||||
@ -43,32 +45,129 @@ def get_metadata(key):
|
||||
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
|
||||
return api_resp.text
|
||||
|
||||
def get_tpu_env_value(key):
|
||||
def get_tpu_env_value_from_metadata(key):
|
||||
tpu_env_data = get_metadata('tpu-env')
|
||||
key_value_pairs = tpu_env_data.split('\n')
|
||||
for key_value_pair in key_value_pairs:
|
||||
# Typical line is MEGASCALE_NUM_SLICES: '2'
|
||||
if ':' in key_value_pair:
|
||||
row_key, value = re.split(':', key_value_pair, 1)
|
||||
row_key = row_key.strip()
|
||||
if row_key == key:
|
||||
return value.strip().strip("'")
|
||||
return None
|
||||
|
||||
class TpuCluster(clusters.ClusterEnv):
|
||||
value = os.environ.get(key, None)
|
||||
return value if value is not None else get_tpu_env_value_from_metadata(key)
|
||||
|
||||
def is_gce_env():
|
||||
worker_number_string = get_metadata('agent-worker-number')
|
||||
try:
|
||||
worker_number = int(worker_number_string)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def is_multislice_gce_env():
|
||||
return is_gce_env() and get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None
|
||||
|
||||
def is_gke_env():
|
||||
return os.environ.get("TPU_WORKER_HOSTNAMES", None) is not None
|
||||
|
||||
def get_gce_worker_endpoints() -> str:
|
||||
return get_metadata('worker-network-endpoints').split(',')
|
||||
|
||||
class SingleSliceGceTpuCluster(clusters.ClusterEnv):
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
return running_in_cloud_tpu_vm
|
||||
return running_in_cloud_tpu_vm and is_gce_env() and not is_multislice_gce_env()
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls) -> str:
|
||||
return cls._get_worker_endpoints()[0].split(':')[2] + ':8476'
|
||||
return get_gce_worker_endpoints()[0].split(':')[2] + ':8476'
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
return xla_bridge.process_count()
|
||||
return len(get_gce_worker_endpoints())
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
if cls.get_process_count() != len(cls._get_worker_endpoints()):
|
||||
raise RuntimeError('Number of workers does not equal the number of '
|
||||
'processes. Auto detecting process_id is not possible.'
|
||||
'Please pass process_id to jax.distributed.initialize() manually.')
|
||||
return int(get_metadata('agent-worker-number'))
|
||||
|
||||
@classmethod
|
||||
def get_local_process_id(cls) -> Optional[int]:
|
||||
return None
|
||||
|
||||
class MultisliceGceTpuCluster(clusters.ClusterEnv):
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
return running_in_cloud_tpu_vm and is_multislice_gce_env()
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls) -> str:
|
||||
coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS')
|
||||
coordinator_address = coordinator_address.split(':')[0]
|
||||
|
||||
# The coordinator may not be up before the other hosts try to
|
||||
# communicate with it. We check for its existence with retries.
|
||||
coordinator_found = False
|
||||
lookup_attempt = 1
|
||||
max_coordinator_lookups = 50
|
||||
while not coordinator_found and lookup_attempt <= max_coordinator_lookups:
|
||||
try:
|
||||
ip_address = socket.gethostbyname(coordinator_address)
|
||||
coordinator_found = True
|
||||
except socket.gaierror:
|
||||
print(f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying...")
|
||||
lookup_attempt += 1
|
||||
time.sleep(5)
|
||||
|
||||
if not coordinator_found:
|
||||
raise RuntimeError(f"Failed to recognize coordinator address {coordinator_address}")
|
||||
|
||||
# Use a different port for the jax coordinator than the MXLA coordinator,
|
||||
# which is set to 8080 in multislice GCE.
|
||||
coordinator_address = coordinator_address + ':8476'
|
||||
return coordinator_address
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
processes_per_slice = cls._get_process_count_per_slice()
|
||||
num_slices = int(get_tpu_env_value('MEGASCALE_NUM_SLICES'))
|
||||
return processes_per_slice * num_slices
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
process_id_in_slice = cls._get_process_id_in_slice()
|
||||
slice_id = int(get_tpu_env_value('MEGASCALE_SLICE_ID'))
|
||||
processes_per_slice = cls._get_process_count_per_slice()
|
||||
return process_id_in_slice + slice_id * processes_per_slice
|
||||
|
||||
@classmethod
|
||||
def get_local_process_id(cls) -> Optional[int]:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_worker_endpoints() -> str:
|
||||
return get_metadata('worker-network-endpoints').split(',')
|
||||
def _get_process_count_per_slice() -> int:
|
||||
return len(get_gce_worker_endpoints())
|
||||
|
||||
@staticmethod
|
||||
def _get_process_id_in_slice() -> int:
|
||||
return int(get_metadata('agent-worker-number'))
|
||||
|
||||
class GkeTpuCluster(MultisliceGceTpuCluster):
|
||||
# This class handles both single and multislice GKE as the environment
|
||||
# variables are set the same in both cases.
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
return running_in_cloud_tpu_vm and is_gke_env()
|
||||
|
||||
@staticmethod
|
||||
def _get_process_count_per_slice() -> int:
|
||||
tpu_worker_hostnames = str(os.environ.get('TPU_WORKER_HOSTNAMES', None))
|
||||
return len(tpu_worker_hostnames.split(','))
|
||||
|
||||
@staticmethod
|
||||
def _get_process_id_in_slice() -> int:
|
||||
return int(str(os.environ.get('TPU_WORKER_ID')))
|
||||
|
Loading…
x
Reference in New Issue
Block a user