mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
* Make all arguments to distributed.initialize equal to None.
* On Cloud TPUs, figure out the coordinator address automatically. PiperOrigin-RevId: 449261786
This commit is contained in:
parent
969572658a
commit
548a6bf58b
@ -14,6 +14,7 @@
|
||||
|
||||
import os
|
||||
|
||||
running_in_cloud_tpu_vm = False
|
||||
|
||||
def cloud_tpu_init():
|
||||
"""Automatically sets Cloud TPU topology and other env vars.
|
||||
@ -31,6 +32,7 @@ def cloud_tpu_init():
|
||||
This will not set any env vars if a single topology-related env var is already
|
||||
set.
|
||||
"""
|
||||
global running_in_cloud_tpu_vm
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pytype: disable=import-error
|
||||
@ -42,6 +44,8 @@ def cloud_tpu_init():
|
||||
# TPU environment. Exit early if we're not running on Cloud TPU.
|
||||
return
|
||||
|
||||
running_in_cloud_tpu_vm = True
|
||||
|
||||
libtpu.configure_library_path()
|
||||
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
|
||||
|
||||
@ -57,35 +61,6 @@ def cloud_tpu_init():
|
||||
]):
|
||||
return
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pytype: disable=import-error
|
||||
import requests
|
||||
import time
|
||||
# pytype: enable=import-error
|
||||
# pylint: enable=import-outside-toplevel
|
||||
|
||||
# Based on https://github.com/tensorflow/tensorflow/pull/40317
|
||||
gce_metadata_endpoint = 'http://' + os.environ.get(
|
||||
'GCE_METADATA_IP', 'metadata.google.internal')
|
||||
|
||||
def get_metadata(key):
|
||||
retry_count = 0
|
||||
retrySeconds = 0.500
|
||||
api_resp = None
|
||||
|
||||
while retry_count < 6:
|
||||
api_resp = requests.get(
|
||||
f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}',
|
||||
headers={'Metadata-Flavor': 'Google'})
|
||||
if api_resp.status_code == 200:
|
||||
break
|
||||
retry_count += 1
|
||||
time.sleep(retrySeconds)
|
||||
|
||||
if api_resp is None:
|
||||
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
|
||||
return api_resp.text
|
||||
|
||||
worker_id = get_metadata('agent-worker-number')
|
||||
accelerator_type = get_metadata('accelerator-type')
|
||||
|
||||
@ -112,3 +87,28 @@ def cloud_tpu_init():
|
||||
os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1'
|
||||
os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[
|
||||
accelerator_type]
|
||||
|
||||
|
||||
def get_metadata(key):
|
||||
import requests # pytype: disable=import-error
|
||||
import time # pytype: disable=import-error
|
||||
# Based on https://github.com/tensorflow/tensorflow/pull/40317
|
||||
gce_metadata_endpoint = 'http://' + os.environ.get(
|
||||
'GCE_METADATA_IP', 'metadata.google.internal')
|
||||
|
||||
retry_count = 0
|
||||
retrySeconds = 0.500
|
||||
api_resp = None
|
||||
|
||||
while retry_count < 6:
|
||||
api_resp = requests.get(
|
||||
f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}',
|
||||
headers={'Metadata-Flavor': 'Google'})
|
||||
if api_resp.status_code == 200:
|
||||
break
|
||||
retry_count += 1
|
||||
time.sleep(retrySeconds)
|
||||
|
||||
if api_resp is None:
|
||||
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
|
||||
return api_resp.text
|
||||
|
@ -12,15 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import functools
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from absl import logging
|
||||
from jax._src import cloud_tpu_init
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
|
||||
_service = None
|
||||
def initialize(coordinator_address: str, num_processes: int, process_id: int):
|
||||
jax_service = None
|
||||
distributed_client = None
|
||||
|
||||
|
||||
def initialize(coordinator_address: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
process_id: Optional[int] = None):
|
||||
"""Initialize distributed system for topology discovery.
|
||||
|
||||
Currently, calling ``initialize`` sets up the multi-host GPU backend, and
|
||||
@ -30,8 +39,16 @@ def initialize(coordinator_address: str, num_processes: int, process_id: int):
|
||||
coordinator_address: IP address and port of the coordinator. The choice of
|
||||
port does not matter, so long as the port is available on the coordinator
|
||||
and all processes agree on the port.
|
||||
num_processes: Number of processes.
|
||||
process_id: Id of the current process.
|
||||
Can be None only for TPU platform. If coordinator_address is None on TPU,
|
||||
then it will be auto detected.
|
||||
num_processes: Number of processes. Can be None only for TPU platform and
|
||||
if None will be determined from the TPU slice metadata.
|
||||
process_id: Id of the current process. Can be None only for TPU platform and
|
||||
if None will default to the current TPU worker id determined via the TPU
|
||||
slice metadata.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `distributed.initialize` is called more than once.
|
||||
|
||||
Example:
|
||||
|
||||
@ -47,21 +64,59 @@ def initialize(coordinator_address: str, num_processes: int, process_id: int):
|
||||
|
||||
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP
|
||||
"""
|
||||
|
||||
coordinator_address = os.environ.get('JAX_COORDINATOR_ADDRESS',
|
||||
None) or coordinator_address
|
||||
|
||||
if cloud_tpu_init.running_in_cloud_tpu_vm:
|
||||
worker_endpoints = cloud_tpu_init.get_metadata(
|
||||
'worker-network-endpoints').split(',')
|
||||
if coordinator_address is None:
|
||||
coordinator_address = worker_endpoints[0].split(':')[2] + ':8476'
|
||||
if num_processes is None:
|
||||
num_processes = xla_bridge.process_count()
|
||||
if process_id is None:
|
||||
process_id = int(cloud_tpu_init.get_metadata('agent-worker-number'))
|
||||
|
||||
if num_processes != len(worker_endpoints):
|
||||
raise RuntimeError('Number of workers does not equal the number of '
|
||||
'processes. Auto detecting process_id is not possible.'
|
||||
'Please pass process_id manually.')
|
||||
|
||||
if coordinator_address is None:
|
||||
raise ValueError('coordinator_address should be defined.')
|
||||
if num_processes is None:
|
||||
raise ValueError('Number of processes must be defined.')
|
||||
if process_id is None:
|
||||
raise ValueError('The process id of the current process must be defined.')
|
||||
|
||||
if process_id == 0:
|
||||
global _service
|
||||
assert _service is None, 'initialize should be called once only'
|
||||
global jax_service
|
||||
if jax_service is not None:
|
||||
raise RuntimeError('distributed.initialize should only be called once.')
|
||||
|
||||
logging.info('Starting JAX distributed service on %s', coordinator_address)
|
||||
_service = xla_extension.get_distributed_runtime_service(coordinator_address,
|
||||
num_processes)
|
||||
jax_service = xla_extension.get_distributed_runtime_service(
|
||||
coordinator_address, num_processes)
|
||||
|
||||
client = xla_extension.get_distributed_runtime_client(coordinator_address,
|
||||
process_id)
|
||||
global distributed_client
|
||||
if distributed_client is not None:
|
||||
raise RuntimeError('distributed.initialize should only be called once.')
|
||||
|
||||
distributed_client = xla_extension.get_distributed_runtime_client(
|
||||
coordinator_address, process_id)
|
||||
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
|
||||
client.connect()
|
||||
distributed_client.connect()
|
||||
|
||||
factory = functools.partial(xla_client.make_gpu_client, client, process_id,
|
||||
platform_name='cuda')
|
||||
factory = functools.partial(
|
||||
xla_client.make_gpu_client,
|
||||
distributed_client,
|
||||
process_id,
|
||||
platform_name='cuda')
|
||||
xla_bridge.register_backend_factory('cuda', factory, priority=300)
|
||||
factory = functools.partial(xla_client.make_gpu_client, client, process_id,
|
||||
platform_name='rocm')
|
||||
factory = functools.partial(
|
||||
xla_client.make_gpu_client,
|
||||
distributed_client,
|
||||
process_id,
|
||||
platform_name='rocm')
|
||||
xla_bridge.register_backend_factory('rocm', factory, priority=300)
|
||||
|
Loading…
x
Reference in New Issue
Block a user