diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 5123fa3f6..f899b19c0 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -14,6 +14,7 @@ import os +running_in_cloud_tpu_vm = False def cloud_tpu_init(): """Automatically sets Cloud TPU topology and other env vars. @@ -31,6 +32,7 @@ def cloud_tpu_init(): This will not set any env vars if a single topology-related env var is already set. """ + global running_in_cloud_tpu_vm try: # pylint: disable=import-outside-toplevel # pytype: disable=import-error @@ -42,6 +44,8 @@ def cloud_tpu_init(): # TPU environment. Exit early if we're not running on Cloud TPU. return + running_in_cloud_tpu_vm = True + libtpu.configure_library_path() os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') @@ -57,35 +61,6 @@ def cloud_tpu_init(): ]): return - # pylint: disable=import-outside-toplevel - # pytype: disable=import-error - import requests - import time - # pytype: enable=import-error - # pylint: enable=import-outside-toplevel - - # Based on https://github.com/tensorflow/tensorflow/pull/40317 - gce_metadata_endpoint = 'http://' + os.environ.get( - 'GCE_METADATA_IP', 'metadata.google.internal') - - def get_metadata(key): - retry_count = 0 - retrySeconds = 0.500 - api_resp = None - - while retry_count < 6: - api_resp = requests.get( - f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}', - headers={'Metadata-Flavor': 'Google'}) - if api_resp.status_code == 200: - break - retry_count += 1 - time.sleep(retrySeconds) - - if api_resp is None: - raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries") - return api_resp.text - worker_id = get_metadata('agent-worker-number') accelerator_type = get_metadata('accelerator-type') @@ -112,3 +87,28 @@ def cloud_tpu_init(): os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1' os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[ accelerator_type] + + +def get_metadata(key): + import requests # pytype: disable=import-error + import time # pytype: disable=import-error + # Based on https://github.com/tensorflow/tensorflow/pull/40317 + gce_metadata_endpoint = 'http://' + os.environ.get( + 'GCE_METADATA_IP', 'metadata.google.internal') + + retry_count = 0 + retrySeconds = 0.500 + api_resp = None + + while retry_count < 6: + api_resp = requests.get( + f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}', + headers={'Metadata-Flavor': 'Google'}) + if api_resp.status_code == 200: + break + retry_count += 1 + time.sleep(retrySeconds) + + if api_resp is None: + raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries") + return api_resp.text diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index d46cc9477..eda240290 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -12,15 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import functools +from typing import Optional + from absl import logging +from jax._src import cloud_tpu_init from jax._src.lib import xla_bridge from jax._src.lib import xla_client from jax._src.lib import xla_extension -_service = None -def initialize(coordinator_address: str, num_processes: int, process_id: int): +jax_service = None +distributed_client = None + + +def initialize(coordinator_address: Optional[str] = None, + num_processes: Optional[int] = None, + process_id: Optional[int] = None): """Initialize distributed system for topology discovery. Currently, calling ``initialize`` sets up the multi-host GPU backend, and @@ -30,8 +39,16 @@ def initialize(coordinator_address: str, num_processes: int, process_id: int): coordinator_address: IP address and port of the coordinator. The choice of port does not matter, so long as the port is available on the coordinator and all processes agree on the port. - num_processes: Number of processes. - process_id: Id of the current process. + Can be None only for TPU platform. If coordinator_address is None on TPU, + then it will be auto detected. + num_processes: Number of processes. Can be None only for TPU platform and + if None will be determined from the TPU slice metadata. + process_id: Id of the current process. Can be None only for TPU platform and + if None will default to the current TPU worker id determined via the TPU + slice metadata. + + Raises: + RuntimeError: If `distributed.initialize` is called more than once. Example: @@ -47,21 +64,59 @@ def initialize(coordinator_address: str, num_processes: int, process_id: int): >>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP """ + + coordinator_address = os.environ.get('JAX_COORDINATOR_ADDRESS', + None) or coordinator_address + + if cloud_tpu_init.running_in_cloud_tpu_vm: + worker_endpoints = cloud_tpu_init.get_metadata( + 'worker-network-endpoints').split(',') + if coordinator_address is None: + coordinator_address = worker_endpoints[0].split(':')[2] + ':8476' + if num_processes is None: + num_processes = xla_bridge.process_count() + if process_id is None: + process_id = int(cloud_tpu_init.get_metadata('agent-worker-number')) + + if num_processes != len(worker_endpoints): + raise RuntimeError('Number of workers does not equal the number of ' + 'processes. Auto detecting process_id is not possible.' + 'Please pass process_id manually.') + + if coordinator_address is None: + raise ValueError('coordinator_address should be defined.') + if num_processes is None: + raise ValueError('Number of processes must be defined.') + if process_id is None: + raise ValueError('The process id of the current process must be defined.') + if process_id == 0: - global _service - assert _service is None, 'initialize should be called once only' + global jax_service + if jax_service is not None: + raise RuntimeError('distributed.initialize should only be called once.') + logging.info('Starting JAX distributed service on %s', coordinator_address) - _service = xla_extension.get_distributed_runtime_service(coordinator_address, - num_processes) + jax_service = xla_extension.get_distributed_runtime_service( + coordinator_address, num_processes) - client = xla_extension.get_distributed_runtime_client(coordinator_address, - process_id) + global distributed_client + if distributed_client is not None: + raise RuntimeError('distributed.initialize should only be called once.') + + distributed_client = xla_extension.get_distributed_runtime_client( + coordinator_address, process_id) logging.info('Connecting to JAX distributed service on %s', coordinator_address) - client.connect() + distributed_client.connect() - factory = functools.partial(xla_client.make_gpu_client, client, process_id, - platform_name='cuda') + factory = functools.partial( + xla_client.make_gpu_client, + distributed_client, + process_id, + platform_name='cuda') xla_bridge.register_backend_factory('cuda', factory, priority=300) - factory = functools.partial(xla_client.make_gpu_client, client, process_id, - platform_name='rocm') + factory = functools.partial( + xla_client.make_gpu_client, + distributed_client, + process_id, + platform_name='rocm') xla_bridge.register_backend_factory('rocm', factory, priority=300)