mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11390 from hawkinsp:distributed_init
PiperOrigin-RevId: 459518348
This commit is contained in:
commit
fb7e39b13e
@ -26,6 +26,7 @@ from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
|
||||
class State:
|
||||
process_id: int = 0
|
||||
service: Optional[Any] = None
|
||||
client: Optional[Any] = None
|
||||
preemption_sync_manager: Optional[Any] = None
|
||||
@ -34,8 +35,8 @@ class State:
|
||||
coordinator_address: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
process_id: Optional[int] = None):
|
||||
coordinator_address = os.environ.get('JAX_COORDINATOR_ADDRESS',
|
||||
None) or coordinator_address
|
||||
coordinator_address = (coordinator_address or
|
||||
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
|
||||
|
||||
if cloud_tpu_init.running_in_cloud_tpu_vm:
|
||||
worker_endpoints = cloud_tpu_init.get_metadata(
|
||||
@ -59,6 +60,8 @@ class State:
|
||||
if process_id is None:
|
||||
raise ValueError('The process id of the current process must be defined.')
|
||||
|
||||
self.process_id = process_id
|
||||
|
||||
if process_id == 0:
|
||||
if self.service is not None:
|
||||
raise RuntimeError('distributed.initialize should only be called once.')
|
||||
@ -149,26 +152,6 @@ def initialize(coordinator_address: Optional[str] = None,
|
||||
"""
|
||||
global_state.initialize(coordinator_address, num_processes, process_id)
|
||||
atexit.register(shutdown)
|
||||
if xla_client._version >= 65:
|
||||
factory = functools.partial(
|
||||
xla_client.make_gpu_client,
|
||||
global_state.client,
|
||||
process_id,
|
||||
platform_name='cuda')
|
||||
xla_bridge.register_backend_factory('cuda', factory, priority=300)
|
||||
factory = functools.partial(
|
||||
xla_client.make_gpu_client,
|
||||
global_state.client,
|
||||
process_id,
|
||||
platform_name='rocm')
|
||||
xla_bridge.register_backend_factory('rocm', factory, priority=300)
|
||||
else:
|
||||
factory = functools.partial(
|
||||
xla_client.make_gpu_client,
|
||||
global_state.client,
|
||||
process_id)
|
||||
xla_bridge.register_backend_factory('gpu', factory, priority=300)
|
||||
|
||||
|
||||
|
||||
def shutdown():
|
||||
|
@ -32,6 +32,7 @@ logging._warn_preinit_stderr = 0
|
||||
|
||||
import jax._src.lib
|
||||
from jax._src.config import flags, bool_env, int_env
|
||||
from jax._src import distributed
|
||||
from jax._src.lib import tpu_driver_client
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import util, traceback_util
|
||||
@ -210,17 +211,23 @@ register_backend_factory('cpu',
|
||||
register_backend_factory('tpu_driver', _make_tpu_driver_client,
|
||||
priority=100)
|
||||
|
||||
|
||||
def make_gpu_client(platform_name=None):
|
||||
return xla_client.make_gpu_client(
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
platform_name=platform_name)
|
||||
|
||||
if hasattr(xla_client, "make_gpu_client"):
|
||||
if xla_client._version >= 65:
|
||||
register_backend_factory(
|
||||
'cuda', partial(xla_client.make_gpu_client, platform_name='cuda'),
|
||||
'cuda', partial(make_gpu_client, platform_name='cuda'),
|
||||
priority=200)
|
||||
register_backend_factory(
|
||||
'rocm', partial(xla_client.make_gpu_client, platform_name='rocm'),
|
||||
'rocm', partial(make_gpu_client, platform_name='rocm'),
|
||||
priority=200)
|
||||
else:
|
||||
register_backend_factory('gpu', xla_client.make_gpu_client,
|
||||
priority=200)
|
||||
register_backend_factory('gpu', make_gpu_client, priority=200)
|
||||
|
||||
if hasattr(xla_client, "make_tpu_client"):
|
||||
register_backend_factory(
|
||||
|
Loading…
x
Reference in New Issue
Block a user