mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
557 lines
21 KiB
Python
557 lines
21 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Utilities for synchronizing and communication across multiple hosts."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from functools import partial, lru_cache
|
|
import zlib
|
|
|
|
from typing import Any
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax.tree_util import tree_flatten, tree_unflatten
|
|
from jax._src import core
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src import array
|
|
from jax._src import sharding_impls
|
|
from jax._src.interpreters import pxla
|
|
from jax.interpreters import xla
|
|
from jax._src import pjit as pjit_lib
|
|
from jax.sharding import PartitionSpec as P
|
|
from jax._src import distributed
|
|
from jax._src.util import safe_zip
|
|
from jax._src import xla_bridge
|
|
from jax._src.lib import xla_client
|
|
import numpy as np
|
|
|
|
|
|
def _psum(x: Any) -> Any:
|
|
return jax.tree.map(partial(jnp.sum, axis=0), x)
|
|
|
|
|
|
def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any:
|
|
"""Broadcast data from a source host (host 0 by default) to all other hosts.
|
|
|
|
Args:
|
|
in_tree: pytree of arrays - each array *must* have the same shape across the
|
|
hosts.
|
|
is_source: optional bool denoting whether the caller is the source. Only
|
|
'source host' will contribute the data for the broadcast. If None, then
|
|
host 0 is used.
|
|
|
|
Returns:
|
|
A pytree matching in_tree where the leaves now all contain the data from the
|
|
first host.
|
|
"""
|
|
if jax.process_count() == 1:
|
|
return jax.tree.map(np.asarray, in_tree)
|
|
|
|
if is_source is None:
|
|
is_source = jax.process_index() == 0
|
|
|
|
devices: np.ndarray = np.array(
|
|
jax.devices()).reshape(jax.process_count(), jax.local_device_count())
|
|
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
|
|
pspec = P('processes')
|
|
|
|
def pre_jit(x):
|
|
if is_source:
|
|
inp = x
|
|
else:
|
|
inp = np.zeros_like(x)
|
|
inp = np.expand_dims(inp, axis=0)
|
|
return host_local_array_to_global_array(inp, global_mesh, pspec)
|
|
|
|
def post_jit(x):
|
|
return jax.device_get(x.addressable_data(0))
|
|
|
|
in_tree = jax.tree.map(pre_jit, in_tree)
|
|
out_tree = jax.jit(_psum, out_shardings=jax.sharding.NamedSharding(
|
|
global_mesh, P()))(in_tree)
|
|
return jax.tree.map(post_jit, out_tree)
|
|
|
|
|
|
def sync_global_devices(name: str):
|
|
"""Creates a barrier across all hosts/devices."""
|
|
h = np.uint32(zlib.crc32(name.encode()))
|
|
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
|
|
|
|
|
|
# Identity function is at the top level so that `process_allgather` doesn't
|
|
# recompile on every invocation.
|
|
def _identity_fn(x):
|
|
return x
|
|
|
|
|
|
def _handle_array_process_allgather(inp, tiled):
|
|
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
|
|
if isinstance(inp.sharding, sharding_impls.NamedSharding):
|
|
reps = inp.sharding.with_spec(P())
|
|
else:
|
|
reps = sharding_impls.GSPMDSharding.get_replicated(
|
|
inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind)
|
|
out = jax.jit(_identity_fn, out_shardings=reps)(inp)
|
|
else:
|
|
# All inputs here will be fully addressable.
|
|
if jax.process_count() == 1:
|
|
out = np.asarray(inp)
|
|
return np.expand_dims(out, axis=0) if not tiled else out
|
|
|
|
devices = np.array(jax.devices()).reshape(jax.process_count(),
|
|
jax.local_device_count())
|
|
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
|
|
pspec = P('processes')
|
|
s = jax.sharding.NamedSharding(global_mesh, pspec)
|
|
|
|
host_np_arr = np.asarray(inp)
|
|
if host_np_arr.ndim == 0 or not tiled:
|
|
host_np_arr = np.expand_dims(host_np_arr, axis=0)
|
|
|
|
aval = core.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
|
|
global_aval = pxla.mesh_local_to_global(
|
|
global_mesh, pxla.get_array_mapping(pspec), aval)
|
|
|
|
bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()]
|
|
global_arr = array.make_array_from_single_device_arrays(
|
|
global_aval.shape, s, bufs)
|
|
out = jax.jit(_identity_fn,
|
|
out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr)
|
|
|
|
return np.asarray(out.addressable_data(0))
|
|
|
|
|
|
def process_allgather(in_tree: Any, tiled: bool = False) -> Any:
|
|
"""Gather data from across processes.
|
|
|
|
Args:
|
|
in_tree: pytree of arrays - each array _must_ have the same shape across the
|
|
hosts.
|
|
tiled: Whether to stack or concat the output. Defaults to False i.e. stack
|
|
into a new positional axis at index 0.
|
|
|
|
Returns:
|
|
Pytrees of numpy arrays.
|
|
* If the input is a non-fully addressable jax.Array, then the data is
|
|
fully replicated.
|
|
* If the input is numpy array or fully addressable jax.Array, then the
|
|
output shape is dependent on the `tiled` argument.
|
|
If its False, then the output will be stacked else concatenated.
|
|
* If the input is a scalar, then the output will be stacked.
|
|
"""
|
|
|
|
def _pjit(inp):
|
|
return _handle_array_process_allgather(inp, tiled)
|
|
return jax.tree.map(_pjit, in_tree)
|
|
|
|
|
|
def assert_equal(in_tree, fail_message: str = ''):
|
|
"""Verifies that all the hosts have the same tree of values."""
|
|
expected = broadcast_one_to_all(in_tree)
|
|
if not jax.tree_util.tree_all(
|
|
jax.tree_util.tree_map(lambda *x: np.all(np.equal(*x)), in_tree, expected)):
|
|
raise AssertionError(
|
|
f'{fail_message} Expected: {expected}; got: {in_tree}.')
|
|
|
|
|
|
def reached_preemption_sync_point(step_id: int) -> bool:
|
|
"""Determine whether all hosts have reached a preemption sync step.
|
|
|
|
When any host receives a preemption notice, the notice is propagated to all
|
|
hosts and triggers a synchronization protocol in the background. The
|
|
synchronization protocol calculates the maximum step ids from all hosts, and
|
|
uses the next step id (i.e., max + 1) as the safe step to save a checkpoint.
|
|
All hosts should continue training more steps until this method returns True,
|
|
indicating that the `step_id` is equal to the safe step and the hosts should
|
|
start saving a checkpoint.
|
|
|
|
To use this API, all hosts must start training from the same step and call it
|
|
at every training step. Example usage:
|
|
|
|
```
|
|
def should_save(step_id: int) -> bool:
|
|
|
|
# Should save an on-demand checkpoint for preemption
|
|
if multihost_utils.reached_preemption_sync_point(step_id):
|
|
return True
|
|
|
|
# Should save a regular checkpoint
|
|
return step_id - last_saved_checkpoint_step >= save_interval_steps
|
|
```
|
|
|
|
Preemption notice is provided by the cluster scheduler to notify the
|
|
application in advance before it gets evicted. By default, we use SIGTERM as
|
|
the signal for preemption notice.
|
|
|
|
TODO(b/230630494): Add instructions for customized preemption notice.
|
|
|
|
Returns:
|
|
A boolean indicating whether all hosts have reached a synchronization step
|
|
after some hosts are preempted.
|
|
|
|
Raises:
|
|
RuntimeError: if preemption sync manager has not been inititialized.
|
|
"""
|
|
if distributed.global_state.client is None:
|
|
return False
|
|
sync_manager = distributed.global_state.preemption_sync_manager
|
|
if sync_manager is None:
|
|
raise RuntimeError("Preemption sync manager has not been initialized.")
|
|
return sync_manager.reached_sync_point(step_id)
|
|
|
|
|
|
@lru_cache
|
|
def _flatten_pspecs(name, in_tree, pspecs_thunk):
|
|
return pjit_lib.flatten_axis_resources(
|
|
name, in_tree, pspecs_thunk(), tupled_args=True)
|
|
|
|
@lru_cache
|
|
def _local_to_global_aval(local_aval, mesh, pspec):
|
|
return pxla.mesh_local_to_global(mesh, pxla.get_array_mapping(pspec),
|
|
local_aval)
|
|
|
|
@lru_cache
|
|
def _global_to_local_aval(global_aval, mesh, pspec):
|
|
return pxla.mesh_global_to_local(mesh, pxla.get_array_mapping(pspec),
|
|
global_aval)
|
|
|
|
|
|
def host_local_array_to_global_array_impl(
|
|
arr: Any, *, global_mesh: jax.sharding.Mesh, pspec: Any):
|
|
if pspec is None:
|
|
raise ValueError(
|
|
'`None` is not a valid input to the pspecs argument. Please use '
|
|
'jax.sharding.PartitionSpec() if you wanted to replicate your input.')
|
|
# If the Array is not fully addressable i.e. not host local, return it.
|
|
if isinstance(arr, array.ArrayImpl) and not arr.is_fully_addressable:
|
|
return arr
|
|
if isinstance(arr, array.ArrayImpl) and isinstance(
|
|
arr.sharding, jax.sharding.PmapSharding):
|
|
arr = np.array(arr)
|
|
|
|
local_sharding = jax.sharding.NamedSharding(global_mesh.local_mesh, pspec)
|
|
|
|
# If the input is a concrete jax.Array and the input array sharding
|
|
# matches the `local_sharding`, then there's no need to reshard and create
|
|
# copies.
|
|
if (isinstance(arr, array.ArrayImpl) and
|
|
arr.sharding.is_equivalent_to(local_sharding, arr.ndim)):
|
|
arrays = [x.data for x in arr.addressable_shards]
|
|
else:
|
|
arr = xla.canonicalize_dtype(arr)
|
|
arrays = [
|
|
arr[index]
|
|
for d, index in local_sharding.devices_indices_map(arr.shape).items()]
|
|
|
|
global_aval = _local_to_global_aval(
|
|
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
|
|
|
|
return pxla.batched_device_put(
|
|
global_aval, jax.sharding.NamedSharding(global_mesh, pspec),
|
|
arrays, list(global_mesh.local_mesh.devices.flat))
|
|
|
|
|
|
def host_local_array_to_global_array(
|
|
local_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
|
|
r"""Converts a host local value to a globally sharded jax.Array.
|
|
|
|
This function takes host-local data (which might be different
|
|
across hosts), and populates a global array with this data, where each
|
|
device on each host, get the appropriate slice of the data according to
|
|
sharding defined by the global_mesh/pspects.
|
|
|
|
For example:
|
|
|
|
>>> global_mesh = jax.sharding.Mesh(jax.devices(), 'x')
|
|
>>> pspecs = jax.sharding.PartitionSpec('x')
|
|
>>> host_id = jax.process_index()
|
|
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # NB: assumes jax.local_device_count() divides 4. # doctest: +SKIP
|
|
|
|
The resulting array will have the shape (4 * num_processes) and will
|
|
have distributed value of: (0, 1, 2, 3, 0, 2, 4, 6, 0, 3, 6, 9, ... ),
|
|
where each slice np.arange(4) * host_id will be partitioned across the
|
|
corresponding host's devices.
|
|
|
|
Similarly:
|
|
|
|
>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()), ['host', 'dev'])
|
|
>>> pspecs = jax.sharding.PartitionSpec('host')
|
|
>>> host_id = jax.process_index()
|
|
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # doctest: +SKIP
|
|
|
|
will create the same distributed value (0, 1, 2, 3, 0, 2, 4, 6, ...),
|
|
however each slice np.arange(4) * i will be *replicated* across corresponding
|
|
host devices.
|
|
|
|
On the other hand, if pspecs = PartitionSpec(), which means
|
|
replication across all axes, then this snippet:
|
|
|
|
>>> pspecs = jax.sharding.PartitionSpec()
|
|
>>> arr = host_local_array_to_global_array(np.arange(4), mesh, pspecs) # doctest: +SKIP
|
|
|
|
will have the shape (4,) and the value (0, 1, 2, 3) will be replicated
|
|
across all hosts and devices.
|
|
|
|
It is an undefined behavior to have not identical local_inputs with pspec
|
|
indicating data replication.
|
|
|
|
You can use this function to transition to jax.Array. Using jax.Array with
|
|
pjit has the same semantics of using GDA with pjit i.e. all jax.Array
|
|
inputs to pjit should be globally shaped.
|
|
|
|
If you are currently passing host local values to pjit, you can use this
|
|
function to convert your host local values to global Arrays and then pass that
|
|
to pjit.
|
|
|
|
|
|
Example usage.
|
|
|
|
>>> from jax.experimental import multihost_utils # doctest: +SKIP
|
|
>>>
|
|
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
|
|
>>>
|
|
>>> with mesh: # doctest: +SKIP
|
|
>>> global_out = pjitted_fun(global_inputs) # doctest: +SKIP
|
|
>>>
|
|
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
|
|
|
|
Please note ths function requires global mesh to be a continuous mesh, meaning
|
|
that devices that belong to each host should form a subcube in this mesh.
|
|
To move local data to global array with non-continuous mesh use
|
|
jax.make_array_from_callback or jax.make_array_from_single_device_arrays
|
|
instead.
|
|
|
|
Args:
|
|
local_inputs: A Pytree of host local values.
|
|
global_mesh: A jax.sharding.Mesh object. The mesh must be a contiguous mesh,
|
|
that is all hosts' devices must form a subcube in this mesh.
|
|
pspecs: A Pytree of jax.sharding.PartitionSpec's.
|
|
|
|
Returns:
|
|
A pytree of global arrays.
|
|
"""
|
|
flat_inps, in_tree = tree_flatten(local_inputs)
|
|
in_pspecs = _flatten_pspecs('input pspecs', in_tree,
|
|
pjit_lib.hashable_pytree(pspecs))
|
|
out_flat = [
|
|
host_local_array_to_global_array_p.bind(inp, global_mesh=global_mesh,
|
|
pspec=in_spec)
|
|
for inp, in_spec in safe_zip(flat_inps, in_pspecs)
|
|
]
|
|
return tree_unflatten(in_tree, out_flat)
|
|
|
|
host_local_array_to_global_array_p = core.Primitive('host_local_array_to_global_array')
|
|
host_local_array_to_global_array_p.def_impl(host_local_array_to_global_array_impl)
|
|
|
|
def ltg_abstract_eval(arr, *, global_mesh, pspec):
|
|
return _local_to_global_aval(
|
|
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
|
|
host_local_array_to_global_array_p.def_abstract_eval(ltg_abstract_eval)
|
|
|
|
ad.deflinear2(host_local_array_to_global_array_p,
|
|
lambda ct, _, **params: (
|
|
host_local_array_to_global_array_p.bind(ct, **params),))
|
|
|
|
def ltg_batcher(insert_axis, axis_data, vals_in, dims_in, global_mesh, pspec):
|
|
x, = vals_in
|
|
d, = dims_in
|
|
new_parts = None if axis_data.spmd_name is None else axis_data.spmd_name
|
|
new_pspec = list(pspec)
|
|
new_pspec.insert(d, new_parts)
|
|
new_pspec = P(*new_pspec)
|
|
y = host_local_array_to_global_array_p.bind(
|
|
x, global_mesh=global_mesh, pspec=new_pspec)
|
|
return y, d
|
|
batching.fancy_primitive_batchers[host_local_array_to_global_array_p] = partial(
|
|
ltg_batcher, False)
|
|
|
|
def _ltg_lowering(ctx, x, *, global_mesh, pspec):
|
|
return [x]
|
|
mlir.register_lowering(host_local_array_to_global_array_p, _ltg_lowering)
|
|
|
|
|
|
def global_array_to_host_local_array_impl(
|
|
arr: Any, *, global_mesh: jax.sharding.Mesh, pspec: Any):
|
|
if pspec is None:
|
|
raise ValueError(
|
|
'`None` is not a valid input to the pspecs argument. Please use '
|
|
'jax.sharding.PartitionSpec() if you wanted to replicate your input.')
|
|
# If the Array is already fully addressable i.e. host local, return it.
|
|
if isinstance(arr, array.ArrayImpl) and arr.is_fully_addressable:
|
|
return arr
|
|
|
|
global_sharding = jax.sharding.NamedSharding(global_mesh, pspec)
|
|
local_sharding = jax.sharding.NamedSharding(global_mesh.local_mesh, pspec)
|
|
local_aval = _global_to_local_aval(
|
|
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
|
|
|
|
if isinstance(arr, array.ArrayImpl):
|
|
if arr.sharding.is_equivalent_to(global_sharding, arr.ndim):
|
|
arrays = arr._arrays
|
|
else:
|
|
resharded_array = jax.device_put(arr, global_sharding)
|
|
arrays = resharded_array._arrays
|
|
return array.ArrayImpl(local_aval, local_sharding, arrays, committed=True)
|
|
else:
|
|
# numpy array can show up here during AD.
|
|
arr = xla.canonicalize_dtype(arr)
|
|
arrays = [
|
|
arr[index]
|
|
for d, index in local_sharding.devices_indices_map(arr.shape).items()]
|
|
return pxla.batched_device_put(
|
|
local_aval, local_sharding, arrays,
|
|
list(global_mesh.local_mesh.devices.flat))
|
|
|
|
|
|
def global_array_to_host_local_array(
|
|
global_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
|
|
r"""Converts a global `jax.Array` to a host local `jax.Array`.
|
|
|
|
You can use this function to transition to `jax.Array`. Using `jax.Array` with
|
|
pjit has the same semantics of using GDA with pjit i.e. all `jax.Array`
|
|
inputs to pjit should be globally shaped and the output from pjit will also
|
|
be globally shaped jax.Array's
|
|
|
|
You can use this function to convert the globally shaped `jax.Array` output
|
|
from pjit to host local values again so that the transition to jax.Array can
|
|
be a mechanical change.
|
|
|
|
Example usage:
|
|
|
|
>>> from jax.experimental import multihost_utils # doctest: +SKIP
|
|
>>>
|
|
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
|
|
>>>
|
|
>>> with mesh: # doctest: +SKIP
|
|
... global_out = pjitted_fun(global_inputs) # doctest: +SKIP
|
|
>>>
|
|
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
|
|
|
|
Args:
|
|
global_inputs: A Pytree of global jax.Array's.
|
|
global_mesh: A :class:`jax.sharding.Mesh` object. The mesh must be contiguous
|
|
meaning all local devices of the host must form a subcube.
|
|
pspecs: A Pytree of :class:`jax.sharding.PartitionSpec` objects.
|
|
|
|
Returns:
|
|
A Pytree of host local arrays.
|
|
"""
|
|
flat_inps, out_tree = tree_flatten(global_inputs)
|
|
out_pspecs = _flatten_pspecs('output pspecs', out_tree,
|
|
pjit_lib.hashable_pytree(pspecs))
|
|
out_flat = [
|
|
global_array_to_host_local_array_p.bind(inp, global_mesh=global_mesh,
|
|
pspec=o)
|
|
for inp, o in safe_zip(flat_inps, out_pspecs)
|
|
]
|
|
return tree_unflatten(out_tree, out_flat)
|
|
|
|
global_array_to_host_local_array_p = core.Primitive('global_array_to_host_local_array')
|
|
global_array_to_host_local_array_p.def_impl(global_array_to_host_local_array_impl)
|
|
|
|
def gtl_abstract_eval(arr, *, global_mesh, pspec):
|
|
return _global_to_local_aval(
|
|
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
|
|
global_array_to_host_local_array_p.def_abstract_eval(gtl_abstract_eval)
|
|
|
|
ad.deflinear2(global_array_to_host_local_array_p,
|
|
lambda ct, _, **params: (
|
|
global_array_to_host_local_array_p.bind(ct, **params),))
|
|
batching.defvectorized(global_array_to_host_local_array_p)
|
|
|
|
def _gtl_lowering(ctx, x, *, global_mesh, pspec):
|
|
return [x]
|
|
mlir.register_lowering(global_array_to_host_local_array_p, _gtl_lowering)
|
|
|
|
|
|
def live_devices(devices: list[xla_client.Device]) -> list[xla_client.Device]:
|
|
"""Returns the subset of the provided devices that are live and healthy.
|
|
|
|
This API is under active development and is not stable.
|
|
|
|
`live_devices` is a low-level fault tolerance primitive that can be used to
|
|
implement fault tolerant multi-process JAX programs.
|
|
|
|
Barrier Semantics
|
|
|
|
It's important that every process agrees on which devices are live to avoid
|
|
the processes' behavior from diverging. For example, imagine a set of
|
|
processes trying to run an AllGather, but they all disagree on which devices
|
|
should be participating in the AllGather. This is buggy.
|
|
|
|
To ensure that every process agrees on the set of live devices, the
|
|
`live_devices` function has barrier-like semantics. Consider an invocation
|
|
`live_devices(devices)` where `devices` includes devices across a set of
|
|
processes P. The invocation acts as a barrier, waiting for every process in P
|
|
to call `live_devices(devices)`. Afterwards, `live_devices` returns the same
|
|
set of live devices `A` to all the processes in P. This ensures that every
|
|
process agrees on the set of live devices.
|
|
|
|
`live_devices` does not actually act as a barrier for *every* process in P
|
|
because some processes in P might have failed. Instead, the `live_devices`
|
|
function waits only for the processes with a device in the returned set of
|
|
live devices A.
|
|
|
|
An Example
|
|
|
|
Imagine we have four processes, each with two devices:
|
|
|
|
Process A: Devices 1 and 2
|
|
Process B: Devices 3 and 4
|
|
Process C: Devices 5 and 6
|
|
Process D: Devices 7 and 8
|
|
|
|
Further imagine that process D fails and that every process calls
|
|
`live_devices(jax.devices())`. The invocation returns devices 1, 2, 3, 4, 5,
|
|
and 6. Because these devices are hosted by processes A, B, and C, the call to
|
|
`live_devices` acts as a barrier across processes A, B, and C. Process D,
|
|
which failed, is ignored.
|
|
|
|
Args:
|
|
devices: A list of devices. The provided devices must include at least one
|
|
local device.
|
|
|
|
Returns:
|
|
The subset of the provided devices that are live and healthy.
|
|
|
|
Raises:
|
|
RuntimeError: If the distributed runtime was not initialized.
|
|
ValueError: If no local devices are provided.
|
|
"""
|
|
client = distributed.global_state.client
|
|
if client is None:
|
|
raise RuntimeError('Distributed JAX not initialized.')
|
|
|
|
if not devices:
|
|
# TODO(mwhittaker): Make devices optional. If it's not provided, use
|
|
# jax.devices() as a default.
|
|
raise ValueError('No devices provided.')
|
|
|
|
process_ids = {d.process_index for d in devices}
|
|
if xla_bridge.process_index() not in process_ids:
|
|
# A process can only participate in an live_devices call if it hosts some
|
|
# of the provided devices.
|
|
raise ValueError('Provided devices do not have any local devices.')
|
|
|
|
if len(process_ids) == 1:
|
|
# If the provided devices are hosted by a single process (this one), then we
|
|
# don't have to perform any distributed computation. We know our local
|
|
# devices are all live.
|
|
return devices
|
|
|
|
live_process_ids = client.get_live_nodes(list(process_ids))
|
|
return [d for d in devices if d.process_index in live_process_ids]
|