* Make all arguments to distributed.initialize equal to None.

* On Cloud TPUs, figure out the coordinator address automatically.

PiperOrigin-RevId: 449261786
This commit is contained in:
Yash Katariya 2022-05-17 10:53:17 -07:00 committed by jax authors
parent 969572658a
commit 548a6bf58b
2 changed files with 99 additions and 44 deletions

View File

@ -14,6 +14,7 @@
import os
running_in_cloud_tpu_vm = False
def cloud_tpu_init():
"""Automatically sets Cloud TPU topology and other env vars.
@ -31,6 +32,7 @@ def cloud_tpu_init():
This will not set any env vars if a single topology-related env var is already
set.
"""
global running_in_cloud_tpu_vm
try:
# pylint: disable=import-outside-toplevel
# pytype: disable=import-error
@ -42,6 +44,8 @@ def cloud_tpu_init():
# TPU environment. Exit early if we're not running on Cloud TPU.
return
running_in_cloud_tpu_vm = True
libtpu.configure_library_path()
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
@ -57,35 +61,6 @@ def cloud_tpu_init():
]):
return
# pylint: disable=import-outside-toplevel
# pytype: disable=import-error
import requests
import time
# pytype: enable=import-error
# pylint: enable=import-outside-toplevel
# Based on https://github.com/tensorflow/tensorflow/pull/40317
gce_metadata_endpoint = 'http://' + os.environ.get(
'GCE_METADATA_IP', 'metadata.google.internal')
def get_metadata(key):
retry_count = 0
retrySeconds = 0.500
api_resp = None
while retry_count < 6:
api_resp = requests.get(
f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}',
headers={'Metadata-Flavor': 'Google'})
if api_resp.status_code == 200:
break
retry_count += 1
time.sleep(retrySeconds)
if api_resp is None:
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
return api_resp.text
worker_id = get_metadata('agent-worker-number')
accelerator_type = get_metadata('accelerator-type')
@ -112,3 +87,28 @@ def cloud_tpu_init():
os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1'
os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[
accelerator_type]
def get_metadata(key):
import requests # pytype: disable=import-error
import time # pytype: disable=import-error
# Based on https://github.com/tensorflow/tensorflow/pull/40317
gce_metadata_endpoint = 'http://' + os.environ.get(
'GCE_METADATA_IP', 'metadata.google.internal')
retry_count = 0
retrySeconds = 0.500
api_resp = None
while retry_count < 6:
api_resp = requests.get(
f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}',
headers={'Metadata-Flavor': 'Google'})
if api_resp.status_code == 200:
break
retry_count += 1
time.sleep(retrySeconds)
if api_resp is None:
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
return api_resp.text

View File

@ -12,15 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import functools
from typing import Optional
from absl import logging
from jax._src import cloud_tpu_init
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
_service = None
def initialize(coordinator_address: str, num_processes: int, process_id: int):
jax_service = None
distributed_client = None
def initialize(coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None):
"""Initialize distributed system for topology discovery.
Currently, calling ``initialize`` sets up the multi-host GPU backend, and
@ -30,8 +39,16 @@ def initialize(coordinator_address: str, num_processes: int, process_id: int):
coordinator_address: IP address and port of the coordinator. The choice of
port does not matter, so long as the port is available on the coordinator
and all processes agree on the port.
num_processes: Number of processes.
process_id: Id of the current process.
Can be None only for TPU platform. If coordinator_address is None on TPU,
then it will be auto detected.
num_processes: Number of processes. Can be None only for TPU platform and
if None will be determined from the TPU slice metadata.
process_id: Id of the current process. Can be None only for TPU platform and
if None will default to the current TPU worker id determined via the TPU
slice metadata.
Raises:
RuntimeError: If `distributed.initialize` is called more than once.
Example:
@ -47,21 +64,59 @@ def initialize(coordinator_address: str, num_processes: int, process_id: int):
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP
"""
coordinator_address = os.environ.get('JAX_COORDINATOR_ADDRESS',
None) or coordinator_address
if cloud_tpu_init.running_in_cloud_tpu_vm:
worker_endpoints = cloud_tpu_init.get_metadata(
'worker-network-endpoints').split(',')
if coordinator_address is None:
coordinator_address = worker_endpoints[0].split(':')[2] + ':8476'
if num_processes is None:
num_processes = xla_bridge.process_count()
if process_id is None:
process_id = int(cloud_tpu_init.get_metadata('agent-worker-number'))
if num_processes != len(worker_endpoints):
raise RuntimeError('Number of workers does not equal the number of '
'processes. Auto detecting process_id is not possible.'
'Please pass process_id manually.')
if coordinator_address is None:
raise ValueError('coordinator_address should be defined.')
if num_processes is None:
raise ValueError('Number of processes must be defined.')
if process_id is None:
raise ValueError('The process id of the current process must be defined.')
if process_id == 0:
global _service
assert _service is None, 'initialize should be called once only'
global jax_service
if jax_service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
logging.info('Starting JAX distributed service on %s', coordinator_address)
_service = xla_extension.get_distributed_runtime_service(coordinator_address,
num_processes)
jax_service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes)
client = xla_extension.get_distributed_runtime_client(coordinator_address,
process_id)
global distributed_client
if distributed_client is not None:
raise RuntimeError('distributed.initialize should only be called once.')
distributed_client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
client.connect()
distributed_client.connect()
factory = functools.partial(xla_client.make_gpu_client, client, process_id,
platform_name='cuda')
factory = functools.partial(
xla_client.make_gpu_client,
distributed_client,
process_id,
platform_name='cuda')
xla_bridge.register_backend_factory('cuda', factory, priority=300)
factory = functools.partial(xla_client.make_gpu_client, client, process_id,
platform_name='rocm')
factory = functools.partial(
xla_client.make_gpu_client,
distributed_client,
process_id,
platform_name='rocm')
xla_bridge.register_backend_factory('rocm', factory, priority=300)