Add jax.distributed.initialize for multi-host GPU.

This commit is contained in:
Qiao Zhang 2021-10-25 15:34:57 -07:00
parent 821fcaa750
commit 0be30fbf96
5 changed files with 85 additions and 5 deletions

View File

@ -14,6 +14,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.24...main).
* New features:
* (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend.
* Breaking changes
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`

View File

@ -120,6 +120,7 @@ from .version import __version__ as __version__
# jax and rely on the names imported above.
from . import abstract_arrays as abstract_arrays
from . import api_util as api_util
from . import distributed as distributed
from . import dtypes as dtypes
from . import errors as errors
from . import image as image

59
jax/_src/distributed.py Normal file
View File

@ -0,0 +1,59 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from absl import logging
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
_service = None
def initialize(coordinator_address: str, num_processes: int, process_id: int):
"""Initialize distributed system for topology discovery.
Currently, calling ``initialize`` sets up the multi-host GPU backend, and
is not required for CPU or TPU backends.
Args:
coordinator_address: IP address of the coordinator.
num_processes: Number of processes.
process_id: Id of the current processe.
Example:
Suppose there are two GPU hosts, and host 0 is the designated coordinator
with address '10.0.0.1:1234', to initialize the GPU cluster, run the
following commands before anything else.
On host 0
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 0) # doctest: +SKIP
On host 1
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP
"""
if process_id == 0:
global _service
assert _service is None, 'initialize should be called once only'
logging.info('Starting JAX distributed service on %s', coordinator_address)
_service = xla_extension.get_distributed_runtime_service(coordinator_address,
num_processes)
client = xla_extension.get_distributed_runtime_client(coordinator_address,
process_id)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
client.connect()
factory = functools.partial(xla_client.make_gpu_client, client, process_id)
xla_bridge.register_backend_factory('gpu', factory, priority=300)

View File

@ -170,8 +170,15 @@ def tpu_client_timer_callback(timer_secs: float):
# example, there could be multiple backends that provide the same kind of
# device.
_backend_factories = {}
_default_backend = None
_backends : Dict[str, Any] = {}
_backends_errors : Dict[str, str] = {}
_backend_lock = threading.Lock()
def register_backend_factory(name, factory, *, priority=0):
with _backend_lock:
if name in _backends:
raise RuntimeError(f"Backend {name} already initialized")
_backend_factories[name] = (factory, priority)
@ -187,11 +194,6 @@ register_backend_factory('gpu', xla_client.make_gpu_client,
register_backend_factory(
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
_default_backend = None
_backends : Dict[str, Any] = {}
_backends_errors : Dict[str, str] = {}
_backend_lock = threading.Lock()
def backends():
global _backends

16
jax/distributed.py Normal file
View File

@ -0,0 +1,16 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
from jax._src.distributed import initialize