diff --git a/CHANGELOG.md b/CHANGELOG.md index f97d953d7..ed1456dcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ Remember to align the itemized text with the first line of an item within a list * Changes * The minimum jaxlib version is now 0.4.19. * Released wheels are built now with clang instead of gcc. + * Enforce that the device backend has not been initialized prior to calling `jax.distributed.initialize()`. + * Automate arguments to `jax.distributed.initialize()` in cloud TPU environments. * Deprecations * The previously-deprecated `sym_pos` argument has been removed from diff --git a/jax/_src/clusters/__init__.py b/jax/_src/clusters/__init__.py index c2afe9a7f..d933af613 100644 --- a/jax/_src/clusters/__init__.py +++ b/jax/_src/clusters/__init__.py @@ -22,4 +22,6 @@ from .cluster import ClusterEnv # available one from the list will be picked. from .ompi_cluster import OmpiCluster from .slurm_cluster import SlurmCluster -from .cloud_tpu_cluster import TpuCluster +from .cloud_tpu_cluster import GkeTpuCluster +from .cloud_tpu_cluster import MultisliceGceTpuCluster +from .cloud_tpu_cluster import SingleSliceGceTpuCluster diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index bd666dd2e..e8628c23b 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -13,8 +13,10 @@ # limitations under the License. import os +import re +import socket +import time from typing import Optional -from jax._src import xla_bridge from jax._src import clusters from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm @@ -43,32 +45,129 @@ def get_metadata(key): raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries") return api_resp.text +def get_tpu_env_value(key): + def get_tpu_env_value_from_metadata(key): + tpu_env_data = get_metadata('tpu-env') + key_value_pairs = tpu_env_data.split('\n') + for key_value_pair in key_value_pairs: + # Typical line is MEGASCALE_NUM_SLICES: '2' + if ':' in key_value_pair: + row_key, value = re.split(':', key_value_pair, 1) + row_key = row_key.strip() + if row_key == key: + return value.strip().strip("'") + return None -class TpuCluster(clusters.ClusterEnv): + value = os.environ.get(key, None) + return value if value is not None else get_tpu_env_value_from_metadata(key) + +def is_gce_env(): + worker_number_string = get_metadata('agent-worker-number') + try: + worker_number = int(worker_number_string) + return True + except: + return False + +def is_multislice_gce_env(): + return is_gce_env() and get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None + +def is_gke_env(): + return os.environ.get("TPU_WORKER_HOSTNAMES", None) is not None + +def get_gce_worker_endpoints() -> str: + return get_metadata('worker-network-endpoints').split(',') + +class SingleSliceGceTpuCluster(clusters.ClusterEnv): @classmethod def is_env_present(cls) -> bool: - return running_in_cloud_tpu_vm + return running_in_cloud_tpu_vm and is_gce_env() and not is_multislice_gce_env() @classmethod def get_coordinator_address(cls) -> str: - return cls._get_worker_endpoints()[0].split(':')[2] + ':8476' + return get_gce_worker_endpoints()[0].split(':')[2] + ':8476' @classmethod def get_process_count(cls) -> int: - return xla_bridge.process_count() + return len(get_gce_worker_endpoints()) @classmethod def get_process_id(cls) -> int: - if cls.get_process_count() != len(cls._get_worker_endpoints()): - raise RuntimeError('Number of workers does not equal the number of ' - 'processes. Auto detecting process_id is not possible.' - 'Please pass process_id to jax.distributed.initialize() manually.') return int(get_metadata('agent-worker-number')) @classmethod def get_local_process_id(cls) -> Optional[int]: return None +class MultisliceGceTpuCluster(clusters.ClusterEnv): + @classmethod + def is_env_present(cls) -> bool: + return running_in_cloud_tpu_vm and is_multislice_gce_env() + + @classmethod + def get_coordinator_address(cls) -> str: + coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') + coordinator_address = coordinator_address.split(':')[0] + + # The coordinator may not be up before the other hosts try to + # communicate with it. We check for its existence with retries. + coordinator_found = False + lookup_attempt = 1 + max_coordinator_lookups = 50 + while not coordinator_found and lookup_attempt <= max_coordinator_lookups: + try: + ip_address = socket.gethostbyname(coordinator_address) + coordinator_found = True + except socket.gaierror: + print(f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying...") + lookup_attempt += 1 + time.sleep(5) + + if not coordinator_found: + raise RuntimeError(f"Failed to recognize coordinator address {coordinator_address}") + + # Use a different port for the jax coordinator than the MXLA coordinator, + # which is set to 8080 in multislice GCE. + coordinator_address = coordinator_address + ':8476' + return coordinator_address + + @classmethod + def get_process_count(cls) -> int: + processes_per_slice = cls._get_process_count_per_slice() + num_slices = int(get_tpu_env_value('MEGASCALE_NUM_SLICES')) + return processes_per_slice * num_slices + + @classmethod + def get_process_id(cls) -> int: + process_id_in_slice = cls._get_process_id_in_slice() + slice_id = int(get_tpu_env_value('MEGASCALE_SLICE_ID')) + processes_per_slice = cls._get_process_count_per_slice() + return process_id_in_slice + slice_id * processes_per_slice + + @classmethod + def get_local_process_id(cls) -> Optional[int]: + return None + @staticmethod - def _get_worker_endpoints() -> str: - return get_metadata('worker-network-endpoints').split(',') + def _get_process_count_per_slice() -> int: + return len(get_gce_worker_endpoints()) + + @staticmethod + def _get_process_id_in_slice() -> int: + return int(get_metadata('agent-worker-number')) + +class GkeTpuCluster(MultisliceGceTpuCluster): + # This class handles both single and multislice GKE as the environment + # variables are set the same in both cases. + @classmethod + def is_env_present(cls) -> bool: + return running_in_cloud_tpu_vm and is_gke_env() + + @staticmethod + def _get_process_count_per_slice() -> int: + tpu_worker_hostnames = str(os.environ.get('TPU_WORKER_HOSTNAMES', None)) + return len(tpu_worker_hostnames.split(',')) + + @staticmethod + def _get_process_id_in_slice() -> int: + return int(str(os.environ.get('TPU_WORKER_ID')))