Merge pull request #11390 from hawkinsp:distributed_init

PiperOrigin-RevId: 459518348
This commit is contained in:
jax authors 2022-07-07 08:23:26 -07:00
commit fb7e39b13e
2 changed files with 16 additions and 26 deletions

View File

@ -26,6 +26,7 @@ from jax._src.lib import xla_client
from jax._src.lib import xla_extension
class State:
process_id: int = 0
service: Optional[Any] = None
client: Optional[Any] = None
preemption_sync_manager: Optional[Any] = None
@ -34,8 +35,8 @@ class State:
coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None):
coordinator_address = os.environ.get('JAX_COORDINATOR_ADDRESS',
None) or coordinator_address
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if cloud_tpu_init.running_in_cloud_tpu_vm:
worker_endpoints = cloud_tpu_init.get_metadata(
@ -59,6 +60,8 @@ class State:
if process_id is None:
raise ValueError('The process id of the current process must be defined.')
self.process_id = process_id
if process_id == 0:
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
@ -149,26 +152,6 @@ def initialize(coordinator_address: Optional[str] = None,
"""
global_state.initialize(coordinator_address, num_processes, process_id)
atexit.register(shutdown)
if xla_client._version >= 65:
factory = functools.partial(
xla_client.make_gpu_client,
global_state.client,
process_id,
platform_name='cuda')
xla_bridge.register_backend_factory('cuda', factory, priority=300)
factory = functools.partial(
xla_client.make_gpu_client,
global_state.client,
process_id,
platform_name='rocm')
xla_bridge.register_backend_factory('rocm', factory, priority=300)
else:
factory = functools.partial(
xla_client.make_gpu_client,
global_state.client,
process_id)
xla_bridge.register_backend_factory('gpu', factory, priority=300)
def shutdown():

View File

@ -32,6 +32,7 @@ logging._warn_preinit_stderr = 0
import jax._src.lib
from jax._src.config import flags, bool_env, int_env
from jax._src import distributed
from jax._src.lib import tpu_driver_client
from jax._src.lib import xla_client
from jax._src import util, traceback_util
@ -210,17 +211,23 @@ register_backend_factory('cpu',
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
def make_gpu_client(platform_name=None):
return xla_client.make_gpu_client(
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
platform_name=platform_name)
if hasattr(xla_client, "make_gpu_client"):
if xla_client._version >= 65:
register_backend_factory(
'cuda', partial(xla_client.make_gpu_client, platform_name='cuda'),
'cuda', partial(make_gpu_client, platform_name='cuda'),
priority=200)
register_backend_factory(
'rocm', partial(xla_client.make_gpu_client, platform_name='rocm'),
'rocm', partial(make_gpu_client, platform_name='rocm'),
priority=200)
else:
register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200)
register_backend_factory('gpu', make_gpu_client, priority=200)
if hasattr(xla_client, "make_tpu_client"):
register_backend_factory(