mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Replace host_id
with process_index
terminology, take 2.
We're switching to the new terminology to avoid confusion in cases where multiple jax processes are running on a single host, and each process has a unique process_index/host_id. This keeps aliases for the old `host_id` APIs for now, but these will eventually be removed. This was originally commited in b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in 14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test failures from renaming the local_devices argument name. This change is identical except it also adds staging for the argument name change.
This commit is contained in:
parent
8310867b41
commit
9128ba0c74
@ -17,6 +17,14 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
|
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
|
||||||
be used within `jit` ({jax-issue}`6501`)
|
be used within `jit` ({jax-issue}`6501`)
|
||||||
* Breaking changes:
|
* Breaking changes:
|
||||||
|
* The following function names have changed. There are still aliases, so this
|
||||||
|
should not break existing code, but the aliases will eventually be removed
|
||||||
|
so please change your code.
|
||||||
|
* `host_id` --> {func}`~jax.process_index`
|
||||||
|
* `host_count` --> {func}`~jax.process_count`
|
||||||
|
* `host_ids` --> `range(jax.process_count())`
|
||||||
|
* Similarly, the argument to {func}`~jax.local_devices` has been renamed from
|
||||||
|
`host_id` to `process_index`.
|
||||||
* Arguments to {func}`jax.jit` other than the function are now marked as
|
* Arguments to {func}`jax.jit` other than the function are now marked as
|
||||||
keyword-only. This change is to prevent accidental breakage when arguments
|
keyword-only. This change is to prevent accidental breakage when arguments
|
||||||
are added to `jit`.
|
are added to `jit`.
|
||||||
|
10
docs/jax.rst
10
docs/jax.rst
@ -77,11 +77,10 @@ Parallelization (:code:`pmap`)
|
|||||||
pmap
|
pmap
|
||||||
devices
|
devices
|
||||||
local_devices
|
local_devices
|
||||||
host_id
|
process_index
|
||||||
host_ids
|
|
||||||
device_count
|
device_count
|
||||||
local_device_count
|
local_device_count
|
||||||
host_count
|
process_count
|
||||||
|
|
||||||
|
|
||||||
.. autofunction:: jit
|
.. autofunction:: jit
|
||||||
@ -124,8 +123,7 @@ Parallelization (:code:`pmap`)
|
|||||||
.. autofunction:: pmap
|
.. autofunction:: pmap
|
||||||
.. autofunction:: devices
|
.. autofunction:: devices
|
||||||
.. autofunction:: local_devices
|
.. autofunction:: local_devices
|
||||||
.. autofunction:: host_id
|
.. autofunction:: process_index
|
||||||
.. autofunction:: host_ids
|
|
||||||
.. autofunction:: device_count
|
.. autofunction:: device_count
|
||||||
.. autofunction:: local_device_count
|
.. autofunction:: local_device_count
|
||||||
.. autofunction:: host_count
|
.. autofunction:: process_count
|
||||||
|
@ -82,6 +82,8 @@ from ._src.api import (
|
|||||||
named_call,
|
named_call,
|
||||||
partial, # TODO(phawkins): update callers to use functools.partial.
|
partial, # TODO(phawkins): update callers to use functools.partial.
|
||||||
pmap,
|
pmap,
|
||||||
|
process_count,
|
||||||
|
process_index,
|
||||||
pxla, # TODO(phawkins): update users to avoid this.
|
pxla, # TODO(phawkins): update users to avoid this.
|
||||||
remat,
|
remat,
|
||||||
shapecheck,
|
shapecheck,
|
||||||
|
@ -62,8 +62,8 @@ from ..lib import xla_bridge as xb
|
|||||||
from ..lib import xla_client as xc
|
from ..lib import xla_client as xc
|
||||||
# Unused imports to be exported
|
# Unused imports to be exported
|
||||||
from ..lib.xla_bridge import (device_count, local_device_count, devices,
|
from ..lib.xla_bridge import (device_count, local_device_count, devices,
|
||||||
local_devices, host_id, host_ids, host_count,
|
local_devices, process_index, process_count,
|
||||||
default_backend)
|
host_id, host_ids, host_count, default_backend)
|
||||||
from ..core import ConcreteArray, ShapedArray, raise_to_shaped
|
from ..core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||||
from ..interpreters import partial_eval as pe
|
from ..interpreters import partial_eval as pe
|
||||||
from ..interpreters import xla
|
from ..interpreters import xla
|
||||||
@ -1372,20 +1372,20 @@ def pmap(
|
|||||||
:py:func:`pmap` compiles ``fun``, so while it can be combined with
|
:py:func:`pmap` compiles ``fun``, so while it can be combined with
|
||||||
:py:func:`jit`, it's usually unnecessary.
|
:py:func:`jit`, it's usually unnecessary.
|
||||||
|
|
||||||
**Multi-host platforms:** On multi-host platforms such as TPU pods,
|
**Multi-process platforms:** On multi-process platforms such as TPU pods,
|
||||||
:py:func:`pmap` is designed to be used in SPMD Python programs, where every
|
:py:func:`pmap` is designed to be used in SPMD Python programs, where every
|
||||||
host is running the same Python code such that all hosts run the same pmapped
|
process is running the same Python code such that all processes run the same
|
||||||
function in the same order. Each host should still call the pmapped function
|
pmapped function in the same order. Each process should still call the pmapped
|
||||||
with mapped axis size equal to the number of *local* devices (unless
|
function with mapped axis size equal to the number of *local* devices (unless
|
||||||
``devices`` is specified, see below), and an array of the same leading axis
|
``devices`` is specified, see below), and an array of the same leading axis
|
||||||
size will be returned as usual. However, any collective operations in ``fun``
|
size will be returned as usual. However, any collective operations in ``fun``
|
||||||
will be computed over *all* participating devices, including those on other
|
will be computed over *all* participating devices, including those on other
|
||||||
hosts, via device-to-device communication. Conceptually, this can be thought
|
processes, via device-to-device communication. Conceptually, this can be
|
||||||
of as running a pmap over a single array sharded across hosts, where each host
|
thought of as running a pmap over a single array sharded across processes,
|
||||||
"sees" only its local shard of the input and output. The SPMD model requires
|
where each process "sees" only its local shard of the input and output. The
|
||||||
that the same multi-host pmaps must be run in the same order on all devices,
|
SPMD model requires that the same multi-process pmaps must be run in the same
|
||||||
but they can be interspersed with arbitrary operations running on a single
|
order on all devices, but they can be interspersed with arbitrary operations
|
||||||
host.
|
running in a single process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fun: Function to be mapped over argument axes. Its arguments and return
|
fun: Function to be mapped over argument axes. Its arguments and return
|
||||||
@ -1519,26 +1519,26 @@ def pmap(
|
|||||||
>>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP
|
>>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP
|
||||||
1.0
|
1.0
|
||||||
|
|
||||||
On multi-host platforms, collective operations operate over all devices,
|
On multi-process platforms, collective operations operate over all devices,
|
||||||
including those on other hosts. For example, assuming the following code runs
|
including those on other processes. For example, assuming the following code
|
||||||
on two hosts with 4 XLA devices each:
|
runs on two processes with 4 XLA devices each:
|
||||||
|
|
||||||
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
|
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
|
||||||
>>> data = jnp.arange(4) if jax.host_id() == 0 else jnp.arange(4, 8)
|
>>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8)
|
||||||
>>> out = pmap(f, axis_name='i')(data) # doctest: +SKIP
|
>>> out = pmap(f, axis_name='i')(data) # doctest: +SKIP
|
||||||
>>> print(out) # doctest: +SKIP
|
>>> print(out) # doctest: +SKIP
|
||||||
[28 29 30 31] # on host 0
|
[28 29 30 31] # on process 0
|
||||||
[32 33 34 35] # on host 1
|
[32 33 34 35] # on process 1
|
||||||
|
|
||||||
Each host passes in a different length-4 array, corresponding to its 4 local
|
Each process passes in a different length-4 array, corresponding to its 4
|
||||||
devices, and the psum operates over all 8 values. Conceptually, the two
|
local devices, and the psum operates over all 8 values. Conceptually, the two
|
||||||
length-4 arrays can be thought of as a sharded length-8 array (in this example
|
length-4 arrays can be thought of as a sharded length-8 array (in this example
|
||||||
equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped axis
|
equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped
|
||||||
given name 'i'. The pmap call on each host then returns the corresponding
|
axis given name 'i'. The pmap call on each process then returns the
|
||||||
length-4 output shard.
|
corresponding length-4 output shard.
|
||||||
|
|
||||||
The ``devices`` argument can be used to specify exactly which devices are used
|
The ``devices`` argument can be used to specify exactly which devices are used
|
||||||
to run the parallel computation. For example, again assuming a single host
|
to run the parallel computation. For example, again assuming a single process
|
||||||
with 8 devices, the following code defines two parallel computations, one
|
with 8 devices, the following code defines two parallel computations, one
|
||||||
which runs on the first six devices and one on the remaining two:
|
which runs on the first six devices and one on the remaining two:
|
||||||
|
|
||||||
@ -1556,9 +1556,9 @@ def pmap(
|
|||||||
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
|
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
|
||||||
[ 13. 13.]
|
[ 13. 13.]
|
||||||
"""
|
"""
|
||||||
# axis_size is an optional integer representing the global axis size.
|
# axis_size is an optional integer representing the global axis size. The
|
||||||
# The aggregate size (across all hosts) size of the mapped axis must match
|
# aggregate size (across all processes) size of the mapped axis must match the
|
||||||
# the given value.
|
# given value.
|
||||||
|
|
||||||
_check_callable(fun)
|
_check_callable(fun)
|
||||||
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name
|
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name
|
||||||
|
@ -54,4 +54,3 @@ from jax._src.api import (
|
|||||||
_std_basis,
|
_std_basis,
|
||||||
_unravel_array_into_pytree,
|
_unravel_array_into_pytree,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -640,9 +640,9 @@ def parallel_callable(fun: lu.WrappedFun,
|
|||||||
|
|
||||||
# Determine global_axis_size for use in AxisEnv.
|
# Determine global_axis_size for use in AxisEnv.
|
||||||
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
||||||
# if xb.host_count() > 1 and global_axis_size is None and inner_pmap:
|
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
||||||
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
|
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
|
||||||
if (xb.host_count() == 1 and global_axis_size is not None and
|
if (xb.process_count() == 1 and global_axis_size is not None and
|
||||||
global_axis_size != axis_size):
|
global_axis_size != axis_size):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Specified axis_size {global_axis_size} doesn't match received "
|
f"Specified axis_size {global_axis_size} doesn't match received "
|
||||||
@ -651,7 +651,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
|||||||
must_run_on_all_devices = False
|
must_run_on_all_devices = False
|
||||||
no_nested_sharding = False
|
no_nested_sharding = False
|
||||||
if global_axis_size is None:
|
if global_axis_size is None:
|
||||||
if xb.host_count() == 1:
|
if xb.process_count() == 1:
|
||||||
global_axis_size = axis_size
|
global_axis_size = axis_size
|
||||||
elif devices:
|
elif devices:
|
||||||
# This allows each host in a multi-host pmap to run on a different number
|
# This allows each host in a multi-host pmap to run on a different number
|
||||||
@ -664,13 +664,13 @@ def parallel_callable(fun: lu.WrappedFun,
|
|||||||
# this assumption is true by requiring that the pmap is run on all devices
|
# this assumption is true by requiring that the pmap is run on all devices
|
||||||
# (and making the further assumption that each host has the same number of
|
# (and making the further assumption that each host has the same number of
|
||||||
# devices). Nested sharding is ok in this case.
|
# devices). Nested sharding is ok in this case.
|
||||||
global_axis_size = axis_size * xb.host_count()
|
global_axis_size = axis_size * xb.process_count()
|
||||||
assert all(len(xb.local_devices(host_id)) == xb.local_device_count()
|
assert all(len(xb.local_devices(process_index)) == xb.local_device_count()
|
||||||
for host_id in xb.host_ids())
|
for process_index in range(xb.process_count()))
|
||||||
must_run_on_all_devices = True
|
must_run_on_all_devices = True
|
||||||
|
|
||||||
if devices:
|
if devices:
|
||||||
local_devices = [d for d in devices if d.host_id == xb.host_id()]
|
local_devices = [d for d in devices if d.process_index == xb.process_index()]
|
||||||
assert len(local_devices) > 0
|
assert len(local_devices) > 0
|
||||||
else:
|
else:
|
||||||
local_devices = None # type: ignore
|
local_devices = None # type: ignore
|
||||||
@ -700,7 +700,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
|||||||
if devices is not None:
|
if devices is not None:
|
||||||
is_multi_host_pmap = len(local_devices) != len(devices)
|
is_multi_host_pmap = len(local_devices) != len(devices)
|
||||||
else:
|
else:
|
||||||
is_multi_host_pmap = xb.host_count() > 1
|
is_multi_host_pmap = xb.process_count() > 1
|
||||||
if is_multi_host_pmap:
|
if is_multi_host_pmap:
|
||||||
check_multihost_collective_allowlist(jaxpr)
|
check_multihost_collective_allowlist(jaxpr)
|
||||||
|
|
||||||
@ -734,7 +734,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
|||||||
num_local_shards = num_local_replicas * local_num_partitions
|
num_local_shards = num_local_replicas * local_num_partitions
|
||||||
num_global_shards = num_global_replicas * num_partitions
|
num_global_shards = num_global_replicas * num_partitions
|
||||||
|
|
||||||
if (xb.host_count() > 1 and must_run_on_all_devices and
|
if (xb.process_count() > 1 and must_run_on_all_devices and
|
||||||
num_local_shards != xb.local_device_count()):
|
num_local_shards != xb.local_device_count()):
|
||||||
if num_local_shards == axis_size:
|
if num_local_shards == axis_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -803,8 +803,8 @@ def parallel_callable(fun: lu.WrappedFun,
|
|||||||
if num_global_shards > num_local_shards:
|
if num_global_shards > num_local_shards:
|
||||||
# TODO(skye): use a locality-aware assignment that satisfies the above
|
# TODO(skye): use a locality-aware assignment that satisfies the above
|
||||||
# constraint.
|
# constraint.
|
||||||
devices = [d for host_id in xb.host_ids()
|
devices = [d for process_index in range(xb.process_count())
|
||||||
for d in xb.local_devices(host_id)]
|
for d in xb.local_devices(process_index)]
|
||||||
else:
|
else:
|
||||||
devices = xb.get_backend(backend).get_default_device_assignment(
|
devices = xb.get_backend(backend).get_default_device_assignment(
|
||||||
num_global_replicas, num_partitions)
|
num_global_replicas, num_partitions)
|
||||||
@ -1300,8 +1300,9 @@ class Mesh:
|
|||||||
def local_mesh(self):
|
def local_mesh(self):
|
||||||
if not self.devices.ndim:
|
if not self.devices.ndim:
|
||||||
return self
|
return self
|
||||||
host_id = xb.host_id()
|
process_index = xb.process_index()
|
||||||
is_local_device = np.vectorize(lambda d: d.host_id == host_id, otypes=[bool])(self.devices)
|
is_local_device = np.vectorize(
|
||||||
|
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
|
||||||
subcube_indices = []
|
subcube_indices = []
|
||||||
# We take the smallest slice of each dimension that doesn't skip any local device.
|
# We take the smallest slice of each dimension that doesn't skip any local device.
|
||||||
for axis in range(self.devices.ndim):
|
for axis in range(self.devices.ndim):
|
||||||
|
@ -683,7 +683,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
|||||||
f"compiling computation that requires {nreps} replicas, but only "
|
f"compiling computation that requires {nreps} replicas, but only "
|
||||||
f"{xb.device_count(backend)} XLA devices are available")
|
f"{xb.device_count(backend)} XLA devices are available")
|
||||||
|
|
||||||
if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
|
if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
|
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
|
||||||
"extra data movement anyway, so maybe you don't want it after all).")
|
"extra data movement anyway, so maybe you don't want it after all).")
|
||||||
|
@ -23,6 +23,7 @@ XLA. There are also a handful of related casting utilities.
|
|||||||
from functools import partial, lru_cache
|
from functools import partial, lru_cache
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
import warnings
|
||||||
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
# Disable "WARNING: Logging before flag parsing goes to stderr." message
|
# Disable "WARNING: Logging before flag parsing goes to stderr." message
|
||||||
@ -195,8 +196,9 @@ def device_count(backend: Optional[str] = None) -> int:
|
|||||||
"""Returns the total number of devices.
|
"""Returns the total number of devices.
|
||||||
|
|
||||||
On most platforms, this is the same as :py:func:`jax.local_device_count`.
|
On most platforms, this is the same as :py:func:`jax.local_device_count`.
|
||||||
However, on multi-host platforms, this will return the total number of devices
|
However, on multi-process platforms where different devices are associated
|
||||||
across all hosts.
|
with different processes, this will return the total number of devices across
|
||||||
|
all processes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
backend: This is an experimental feature and the API is likely to change.
|
backend: This is an experimental feature and the API is likely to change.
|
||||||
@ -205,12 +207,13 @@ def device_count(backend: Optional[str] = None) -> int:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of devices.
|
Number of devices.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return int(get_backend(backend).device_count())
|
return int(get_backend(backend).device_count())
|
||||||
|
|
||||||
|
|
||||||
def local_device_count(backend: Optional[str] = None) -> int:
|
def local_device_count(backend: Optional[str] = None) -> int:
|
||||||
"""Returns the number of devices on this host."""
|
"""Returns the number of devices addressable by this process."""
|
||||||
return int(get_backend(backend).local_device_count())
|
return int(get_backend(backend).local_device_count())
|
||||||
|
|
||||||
|
|
||||||
@ -219,8 +222,9 @@ def devices(backend: Optional[str] = None) -> List[xla_client.Device]:
|
|||||||
|
|
||||||
Each device is represented by a subclass of :class:`Device` (e.g.
|
Each device is represented by a subclass of :class:`Device` (e.g.
|
||||||
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
|
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
|
||||||
equal to ``device_count(backend)``. Local devices can be identified by comparing
|
equal to ``device_count(backend)``. Local devices can be identified by
|
||||||
:meth:`Device.host_id` to the value returned by :py:func:`jax.host_id`.
|
comparing :meth:`Device.process_index` to the value returned by
|
||||||
|
:py:func:`jax.process_index`.
|
||||||
|
|
||||||
If ``backend`` is ``None``, returns all the devices from the default backend.
|
If ``backend`` is ``None``, returns all the devices from the default backend.
|
||||||
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
|
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
|
||||||
@ -242,15 +246,16 @@ def default_backend() -> str:
|
|||||||
return get_backend(None).platform
|
return get_backend(None).platform
|
||||||
|
|
||||||
|
|
||||||
def local_devices(host_id: Optional[int] = None,
|
def local_devices(process_index: Optional[int] = None,
|
||||||
backend: Optional[str] = None) -> List[xla_client.Device]:
|
backend: Optional[str] = None,
|
||||||
"""Like :py:func:`jax.devices`, but only returns devices local to a given host.
|
host_id: Optional[int] = None) -> List[xla_client.Device]:
|
||||||
|
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
|
||||||
|
|
||||||
If ``host_id`` is ``None``, returns devices local to this host.
|
If ``process_index`` is ``None``, returns devices local to this process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
host_id: the integer ID of the host. Host IDs can be retrieved via
|
process_index: the integer index of the process. Process indices can be
|
||||||
:py:func:`jax.host_ids`.
|
retrieved via ``len(jax.process_count())``.
|
||||||
backend: This is an experimental feature and the API is likely to change.
|
backend: This is an experimental feature and the API is likely to change.
|
||||||
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
||||||
``'tpu'``.
|
``'tpu'``.
|
||||||
@ -258,17 +263,23 @@ def local_devices(host_id: Optional[int] = None,
|
|||||||
Returns:
|
Returns:
|
||||||
List of Device subclasses.
|
List of Device subclasses.
|
||||||
"""
|
"""
|
||||||
if host_id is None:
|
if host_id is not None:
|
||||||
host_id = get_backend(backend).host_id()
|
warnings.warn(
|
||||||
if host_id not in host_ids():
|
"The argument to jax.local_devices has been renamed from `host_id` to "
|
||||||
raise ValueError(f"Unknown host_id {host_id}")
|
"`process_index`. This alias will eventually be removed; please update "
|
||||||
return [d for d in devices(backend) if d.host_id == host_id]
|
"your code.")
|
||||||
|
process_index = host_id
|
||||||
|
if process_index is None:
|
||||||
|
process_index = get_backend(backend).process_index()
|
||||||
|
if not (0 <= process_index < process_count()):
|
||||||
|
raise ValueError(f"Unknown process_index {process_index}")
|
||||||
|
return [d for d in devices(backend) if d.process_index == process_index]
|
||||||
|
|
||||||
|
|
||||||
def host_id(backend: Optional[str] = None) -> int:
|
def process_index(backend: Optional[str] = None) -> int:
|
||||||
"""Returns the integer host ID of this host.
|
"""Returns the integer process index of this process.
|
||||||
|
|
||||||
On most platforms, this will always be 0. This will vary on multi-host
|
On most platforms, this will always be 0. This will vary on multi-process
|
||||||
platforms though.
|
platforms though.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -277,19 +288,39 @@ def host_id(backend: Optional[str] = None) -> int:
|
|||||||
``'tpu'``.
|
``'tpu'``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Integer host ID.
|
Integer process index.
|
||||||
"""
|
"""
|
||||||
return get_backend(backend).host_id()
|
return get_backend(backend).process_index()
|
||||||
|
|
||||||
|
|
||||||
def host_ids(backend: Optional[str] = None) -> List[int]:
|
# TODO: remove this sometime after jax 0.2.13 is released
|
||||||
"""Returns a sorted list of all host IDs."""
|
def host_id(backend=None):
|
||||||
return sorted({d.host_id for d in devices(backend)})
|
warnings.warn(
|
||||||
|
"jax.host_id has been renamed to jax.process_index. This alias "
|
||||||
|
"will eventually be removed; please update your code.")
|
||||||
|
return process_index(backend)
|
||||||
|
|
||||||
|
|
||||||
def host_count(backend: Optional[str] = None) -> int:
|
def process_count(backend: Optional[str] = None) -> int:
|
||||||
"""Returns the number of hosts."""
|
"""Returns the number of JAX processes associated with the backend."""
|
||||||
return len(host_ids(backend))
|
return max(d.process_index for d in devices(backend)) + 1
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: remove this sometime after jax 0.2.13 is released
|
||||||
|
def host_count(backend=None):
|
||||||
|
warnings.warn(
|
||||||
|
"jax.host_count has been renamed to jax.process_count. This alias "
|
||||||
|
"will eventually be removed; please update your code.")
|
||||||
|
return process_count(backend)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: remove this sometime after jax 0.2.13 is released
|
||||||
|
def host_ids(backend=None):
|
||||||
|
warnings.warn(
|
||||||
|
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
|
||||||
|
"instead. jax.host_ids will eventually be removed; please update your "
|
||||||
|
"code.")
|
||||||
|
return list(range(process_count(backend)))
|
||||||
|
|
||||||
|
|
||||||
### utility functions
|
### utility functions
|
||||||
|
@ -51,7 +51,7 @@ class XlaBridgeTest(absltest.TestCase):
|
|||||||
|
|
||||||
def test_local_devices(self):
|
def test_local_devices(self):
|
||||||
self.assertNotEmpty(xb.local_devices())
|
self.assertNotEmpty(xb.local_devices())
|
||||||
with self.assertRaisesRegex(ValueError, "Unknown host_id 100"):
|
with self.assertRaisesRegex(ValueError, "Unknown process_index 100"):
|
||||||
xb.local_devices(100)
|
xb.local_devices(100)
|
||||||
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
|
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
|
||||||
xb.local_devices(backend="foo")
|
xb.local_devices(backend="foo")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user