mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add jax.distributed.initialize for multi-host GPU.
This commit is contained in:
parent
821fcaa750
commit
0be30fbf96
@ -14,6 +14,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.2.24...main).
|
||||
|
||||
* New features:
|
||||
* (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend.
|
||||
* Breaking changes
|
||||
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
|
||||
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`
|
||||
|
@ -120,6 +120,7 @@ from .version import __version__ as __version__
|
||||
# jax and rely on the names imported above.
|
||||
from . import abstract_arrays as abstract_arrays
|
||||
from . import api_util as api_util
|
||||
from . import distributed as distributed
|
||||
from . import dtypes as dtypes
|
||||
from . import errors as errors
|
||||
from . import image as image
|
||||
|
59
jax/_src/distributed.py
Normal file
59
jax/_src/distributed.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
|
||||
from absl import logging
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
|
||||
_service = None
|
||||
def initialize(coordinator_address: str, num_processes: int, process_id: int):
|
||||
"""Initialize distributed system for topology discovery.
|
||||
|
||||
Currently, calling ``initialize`` sets up the multi-host GPU backend, and
|
||||
is not required for CPU or TPU backends.
|
||||
|
||||
Args:
|
||||
coordinator_address: IP address of the coordinator.
|
||||
num_processes: Number of processes.
|
||||
process_id: Id of the current processe.
|
||||
|
||||
Example:
|
||||
|
||||
Suppose there are two GPU hosts, and host 0 is the designated coordinator
|
||||
with address '10.0.0.1:1234', to initialize the GPU cluster, run the
|
||||
following commands before anything else.
|
||||
|
||||
On host 0
|
||||
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 0) # doctest: +SKIP
|
||||
|
||||
On host 1
|
||||
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP
|
||||
"""
|
||||
if process_id == 0:
|
||||
global _service
|
||||
assert _service is None, 'initialize should be called once only'
|
||||
logging.info('Starting JAX distributed service on %s', coordinator_address)
|
||||
_service = xla_extension.get_distributed_runtime_service(coordinator_address,
|
||||
num_processes)
|
||||
|
||||
client = xla_extension.get_distributed_runtime_client(coordinator_address,
|
||||
process_id)
|
||||
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
|
||||
client.connect()
|
||||
|
||||
factory = functools.partial(xla_client.make_gpu_client, client, process_id)
|
||||
xla_bridge.register_backend_factory('gpu', factory, priority=300)
|
@ -170,8 +170,15 @@ def tpu_client_timer_callback(timer_secs: float):
|
||||
# example, there could be multiple backends that provide the same kind of
|
||||
# device.
|
||||
_backend_factories = {}
|
||||
_default_backend = None
|
||||
_backends : Dict[str, Any] = {}
|
||||
_backends_errors : Dict[str, str] = {}
|
||||
_backend_lock = threading.Lock()
|
||||
|
||||
def register_backend_factory(name, factory, *, priority=0):
|
||||
with _backend_lock:
|
||||
if name in _backends:
|
||||
raise RuntimeError(f"Backend {name} already initialized")
|
||||
_backend_factories[name] = (factory, priority)
|
||||
|
||||
|
||||
@ -187,11 +194,6 @@ register_backend_factory('gpu', xla_client.make_gpu_client,
|
||||
register_backend_factory(
|
||||
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
|
||||
|
||||
_default_backend = None
|
||||
_backends : Dict[str, Any] = {}
|
||||
_backends_errors : Dict[str, str] = {}
|
||||
_backend_lock = threading.Lock()
|
||||
|
||||
|
||||
def backends():
|
||||
global _backends
|
||||
|
16
jax/distributed.py
Normal file
16
jax/distributed.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.distributed import initialize
|
Loading…
x
Reference in New Issue
Block a user