Replace host_id with process_index terminology, take 2.

We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
This commit is contained in:
Skye Wanderman-Milne 2021-04-20 17:56:41 -07:00
parent 8310867b41
commit 9128ba0c74
9 changed files with 115 additions and 76 deletions

View File

@ -17,6 +17,14 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* {func}`jax.nonzero` has a new optional `size` argument that allows it to * {func}`jax.nonzero` has a new optional `size` argument that allows it to
be used within `jit` ({jax-issue}`6501`) be used within `jit` ({jax-issue}`6501`)
* Breaking changes: * Breaking changes:
* The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `host_id` --> {func}`~jax.process_index`
* `host_count` --> {func}`~jax.process_count`
* `host_ids` --> `range(jax.process_count())`
* Similarly, the argument to {func}`~jax.local_devices` has been renamed from
`host_id` to `process_index`.
* Arguments to {func}`jax.jit` other than the function are now marked as * Arguments to {func}`jax.jit` other than the function are now marked as
keyword-only. This change is to prevent accidental breakage when arguments keyword-only. This change is to prevent accidental breakage when arguments
are added to `jit`. are added to `jit`.

View File

@ -77,11 +77,10 @@ Parallelization (:code:`pmap`)
pmap pmap
devices devices
local_devices local_devices
host_id process_index
host_ids
device_count device_count
local_device_count local_device_count
host_count process_count
.. autofunction:: jit .. autofunction:: jit
@ -124,8 +123,7 @@ Parallelization (:code:`pmap`)
.. autofunction:: pmap .. autofunction:: pmap
.. autofunction:: devices .. autofunction:: devices
.. autofunction:: local_devices .. autofunction:: local_devices
.. autofunction:: host_id .. autofunction:: process_index
.. autofunction:: host_ids
.. autofunction:: device_count .. autofunction:: device_count
.. autofunction:: local_device_count .. autofunction:: local_device_count
.. autofunction:: host_count .. autofunction:: process_count

View File

@ -82,6 +82,8 @@ from ._src.api import (
named_call, named_call,
partial, # TODO(phawkins): update callers to use functools.partial. partial, # TODO(phawkins): update callers to use functools.partial.
pmap, pmap,
process_count,
process_index,
pxla, # TODO(phawkins): update users to avoid this. pxla, # TODO(phawkins): update users to avoid this.
remat, remat,
shapecheck, shapecheck,

View File

@ -62,8 +62,8 @@ from ..lib import xla_bridge as xb
from ..lib import xla_client as xc from ..lib import xla_client as xc
# Unused imports to be exported # Unused imports to be exported
from ..lib.xla_bridge import (device_count, local_device_count, devices, from ..lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, host_id, host_ids, host_count, local_devices, process_index, process_count,
default_backend) host_id, host_ids, host_count, default_backend)
from ..core import ConcreteArray, ShapedArray, raise_to_shaped from ..core import ConcreteArray, ShapedArray, raise_to_shaped
from ..interpreters import partial_eval as pe from ..interpreters import partial_eval as pe
from ..interpreters import xla from ..interpreters import xla
@ -1372,20 +1372,20 @@ def pmap(
:py:func:`pmap` compiles ``fun``, so while it can be combined with :py:func:`pmap` compiles ``fun``, so while it can be combined with
:py:func:`jit`, it's usually unnecessary. :py:func:`jit`, it's usually unnecessary.
**Multi-host platforms:** On multi-host platforms such as TPU pods, **Multi-process platforms:** On multi-process platforms such as TPU pods,
:py:func:`pmap` is designed to be used in SPMD Python programs, where every :py:func:`pmap` is designed to be used in SPMD Python programs, where every
host is running the same Python code such that all hosts run the same pmapped process is running the same Python code such that all processes run the same
function in the same order. Each host should still call the pmapped function pmapped function in the same order. Each process should still call the pmapped
with mapped axis size equal to the number of *local* devices (unless function with mapped axis size equal to the number of *local* devices (unless
``devices`` is specified, see below), and an array of the same leading axis ``devices`` is specified, see below), and an array of the same leading axis
size will be returned as usual. However, any collective operations in ``fun`` size will be returned as usual. However, any collective operations in ``fun``
will be computed over *all* participating devices, including those on other will be computed over *all* participating devices, including those on other
hosts, via device-to-device communication. Conceptually, this can be thought processes, via device-to-device communication. Conceptually, this can be
of as running a pmap over a single array sharded across hosts, where each host thought of as running a pmap over a single array sharded across processes,
"sees" only its local shard of the input and output. The SPMD model requires where each process "sees" only its local shard of the input and output. The
that the same multi-host pmaps must be run in the same order on all devices, SPMD model requires that the same multi-process pmaps must be run in the same
but they can be interspersed with arbitrary operations running on a single order on all devices, but they can be interspersed with arbitrary operations
host. running in a single process.
Args: Args:
fun: Function to be mapped over argument axes. Its arguments and return fun: Function to be mapped over argument axes. Its arguments and return
@ -1519,26 +1519,26 @@ def pmap(
>>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP >>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP
1.0 1.0
On multi-host platforms, collective operations operate over all devices, On multi-process platforms, collective operations operate over all devices,
including those on other hosts. For example, assuming the following code runs including those on other processes. For example, assuming the following code
on two hosts with 4 XLA devices each: runs on two processes with 4 XLA devices each:
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i') >>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
>>> data = jnp.arange(4) if jax.host_id() == 0 else jnp.arange(4, 8) >>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8)
>>> out = pmap(f, axis_name='i')(data) # doctest: +SKIP >>> out = pmap(f, axis_name='i')(data) # doctest: +SKIP
>>> print(out) # doctest: +SKIP >>> print(out) # doctest: +SKIP
[28 29 30 31] # on host 0 [28 29 30 31] # on process 0
[32 33 34 35] # on host 1 [32 33 34 35] # on process 1
Each host passes in a different length-4 array, corresponding to its 4 local Each process passes in a different length-4 array, corresponding to its 4
devices, and the psum operates over all 8 values. Conceptually, the two local devices, and the psum operates over all 8 values. Conceptually, the two
length-4 arrays can be thought of as a sharded length-8 array (in this example length-4 arrays can be thought of as a sharded length-8 array (in this example
equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped axis equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped
given name 'i'. The pmap call on each host then returns the corresponding axis given name 'i'. The pmap call on each process then returns the
length-4 output shard. corresponding length-4 output shard.
The ``devices`` argument can be used to specify exactly which devices are used The ``devices`` argument can be used to specify exactly which devices are used
to run the parallel computation. For example, again assuming a single host to run the parallel computation. For example, again assuming a single process
with 8 devices, the following code defines two parallel computations, one with 8 devices, the following code defines two parallel computations, one
which runs on the first six devices and one on the remaining two: which runs on the first six devices and one on the remaining two:
@ -1556,9 +1556,9 @@ def pmap(
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP >>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
[ 13. 13.] [ 13. 13.]
""" """
# axis_size is an optional integer representing the global axis size. # axis_size is an optional integer representing the global axis size. The
# The aggregate size (across all hosts) size of the mapped axis must match # aggregate size (across all processes) size of the mapped axis must match the
# the given value. # given value.
_check_callable(fun) _check_callable(fun)
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name axis_name = core._TempAxisName(fun) if axis_name is None else axis_name

View File

@ -54,4 +54,3 @@ from jax._src.api import (
_std_basis, _std_basis,
_unravel_array_into_pytree, _unravel_array_into_pytree,
) )

View File

@ -640,9 +640,9 @@ def parallel_callable(fun: lu.WrappedFun,
# Determine global_axis_size for use in AxisEnv. # Determine global_axis_size for use in AxisEnv.
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now) # TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
# if xb.host_count() > 1 and global_axis_size is None and inner_pmap: # if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps") # raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
if (xb.host_count() == 1 and global_axis_size is not None and if (xb.process_count() == 1 and global_axis_size is not None and
global_axis_size != axis_size): global_axis_size != axis_size):
raise ValueError( raise ValueError(
f"Specified axis_size {global_axis_size} doesn't match received " f"Specified axis_size {global_axis_size} doesn't match received "
@ -651,7 +651,7 @@ def parallel_callable(fun: lu.WrappedFun,
must_run_on_all_devices = False must_run_on_all_devices = False
no_nested_sharding = False no_nested_sharding = False
if global_axis_size is None: if global_axis_size is None:
if xb.host_count() == 1: if xb.process_count() == 1:
global_axis_size = axis_size global_axis_size = axis_size
elif devices: elif devices:
# This allows each host in a multi-host pmap to run on a different number # This allows each host in a multi-host pmap to run on a different number
@ -664,13 +664,13 @@ def parallel_callable(fun: lu.WrappedFun,
# this assumption is true by requiring that the pmap is run on all devices # this assumption is true by requiring that the pmap is run on all devices
# (and making the further assumption that each host has the same number of # (and making the further assumption that each host has the same number of
# devices). Nested sharding is ok in this case. # devices). Nested sharding is ok in this case.
global_axis_size = axis_size * xb.host_count() global_axis_size = axis_size * xb.process_count()
assert all(len(xb.local_devices(host_id)) == xb.local_device_count() assert all(len(xb.local_devices(process_index)) == xb.local_device_count()
for host_id in xb.host_ids()) for process_index in range(xb.process_count()))
must_run_on_all_devices = True must_run_on_all_devices = True
if devices: if devices:
local_devices = [d for d in devices if d.host_id == xb.host_id()] local_devices = [d for d in devices if d.process_index == xb.process_index()]
assert len(local_devices) > 0 assert len(local_devices) > 0
else: else:
local_devices = None # type: ignore local_devices = None # type: ignore
@ -700,7 +700,7 @@ def parallel_callable(fun: lu.WrappedFun,
if devices is not None: if devices is not None:
is_multi_host_pmap = len(local_devices) != len(devices) is_multi_host_pmap = len(local_devices) != len(devices)
else: else:
is_multi_host_pmap = xb.host_count() > 1 is_multi_host_pmap = xb.process_count() > 1
if is_multi_host_pmap: if is_multi_host_pmap:
check_multihost_collective_allowlist(jaxpr) check_multihost_collective_allowlist(jaxpr)
@ -734,7 +734,7 @@ def parallel_callable(fun: lu.WrappedFun,
num_local_shards = num_local_replicas * local_num_partitions num_local_shards = num_local_replicas * local_num_partitions
num_global_shards = num_global_replicas * num_partitions num_global_shards = num_global_replicas * num_partitions
if (xb.host_count() > 1 and must_run_on_all_devices and if (xb.process_count() > 1 and must_run_on_all_devices and
num_local_shards != xb.local_device_count()): num_local_shards != xb.local_device_count()):
if num_local_shards == axis_size: if num_local_shards == axis_size:
raise ValueError( raise ValueError(
@ -803,8 +803,8 @@ def parallel_callable(fun: lu.WrappedFun,
if num_global_shards > num_local_shards: if num_global_shards > num_local_shards:
# TODO(skye): use a locality-aware assignment that satisfies the above # TODO(skye): use a locality-aware assignment that satisfies the above
# constraint. # constraint.
devices = [d for host_id in xb.host_ids() devices = [d for process_index in range(xb.process_count())
for d in xb.local_devices(host_id)] for d in xb.local_devices(process_index)]
else: else:
devices = xb.get_backend(backend).get_default_device_assignment( devices = xb.get_backend(backend).get_default_device_assignment(
num_global_replicas, num_partitions) num_global_replicas, num_partitions)
@ -1300,8 +1300,9 @@ class Mesh:
def local_mesh(self): def local_mesh(self):
if not self.devices.ndim: if not self.devices.ndim:
return self return self
host_id = xb.host_id() process_index = xb.process_index()
is_local_device = np.vectorize(lambda d: d.host_id == host_id, otypes=[bool])(self.devices) is_local_device = np.vectorize(
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
subcube_indices = [] subcube_indices = []
# We take the smallest slice of each dimension that doesn't skip any local device. # We take the smallest slice of each dimension that doesn't skip any local device.
for axis in range(self.devices.ndim): for axis in range(self.devices.ndim):

View File

@ -683,7 +683,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
f"compiling computation that requires {nreps} replicas, but only " f"compiling computation that requires {nreps} replicas, but only "
f"{xb.device_count(backend)} XLA devices are available") f"{xb.device_count(backend)} XLA devices are available")
if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)): if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
raise NotImplementedError( raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause " "jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).") "extra data movement anyway, so maybe you don't want it after all).")

View File

@ -23,6 +23,7 @@ XLA. There are also a handful of related casting utilities.
from functools import partial, lru_cache from functools import partial, lru_cache
import os import os
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import warnings
from absl import logging from absl import logging
# Disable "WARNING: Logging before flag parsing goes to stderr." message # Disable "WARNING: Logging before flag parsing goes to stderr." message
@ -195,8 +196,9 @@ def device_count(backend: Optional[str] = None) -> int:
"""Returns the total number of devices. """Returns the total number of devices.
On most platforms, this is the same as :py:func:`jax.local_device_count`. On most platforms, this is the same as :py:func:`jax.local_device_count`.
However, on multi-host platforms, this will return the total number of devices However, on multi-process platforms where different devices are associated
across all hosts. with different processes, this will return the total number of devices across
all processes.
Args: Args:
backend: This is an experimental feature and the API is likely to change. backend: This is an experimental feature and the API is likely to change.
@ -205,12 +207,13 @@ def device_count(backend: Optional[str] = None) -> int:
Returns: Returns:
Number of devices. Number of devices.
""" """
return int(get_backend(backend).device_count()) return int(get_backend(backend).device_count())
def local_device_count(backend: Optional[str] = None) -> int: def local_device_count(backend: Optional[str] = None) -> int:
"""Returns the number of devices on this host.""" """Returns the number of devices addressable by this process."""
return int(get_backend(backend).local_device_count()) return int(get_backend(backend).local_device_count())
@ -219,8 +222,9 @@ def devices(backend: Optional[str] = None) -> List[xla_client.Device]:
Each device is represented by a subclass of :class:`Device` (e.g. Each device is represented by a subclass of :class:`Device` (e.g.
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by comparing equal to ``device_count(backend)``. Local devices can be identified by
:meth:`Device.host_id` to the value returned by :py:func:`jax.host_id`. comparing :meth:`Device.process_index` to the value returned by
:py:func:`jax.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend. If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available, The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
@ -242,15 +246,16 @@ def default_backend() -> str:
return get_backend(None).platform return get_backend(None).platform
def local_devices(host_id: Optional[int] = None, def local_devices(process_index: Optional[int] = None,
backend: Optional[str] = None) -> List[xla_client.Device]: backend: Optional[str] = None,
"""Like :py:func:`jax.devices`, but only returns devices local to a given host. host_id: Optional[int] = None) -> List[xla_client.Device]:
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
If ``host_id`` is ``None``, returns devices local to this host. If ``process_index`` is ``None``, returns devices local to this process.
Args: Args:
host_id: the integer ID of the host. Host IDs can be retrieved via process_index: the integer index of the process. Process indices can be
:py:func:`jax.host_ids`. retrieved via ``len(jax.process_count())``.
backend: This is an experimental feature and the API is likely to change. backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``. ``'tpu'``.
@ -258,17 +263,23 @@ def local_devices(host_id: Optional[int] = None,
Returns: Returns:
List of Device subclasses. List of Device subclasses.
""" """
if host_id is None: if host_id is not None:
host_id = get_backend(backend).host_id() warnings.warn(
if host_id not in host_ids(): "The argument to jax.local_devices has been renamed from `host_id` to "
raise ValueError(f"Unknown host_id {host_id}") "`process_index`. This alias will eventually be removed; please update "
return [d for d in devices(backend) if d.host_id == host_id] "your code.")
process_index = host_id
if process_index is None:
process_index = get_backend(backend).process_index()
if not (0 <= process_index < process_count()):
raise ValueError(f"Unknown process_index {process_index}")
return [d for d in devices(backend) if d.process_index == process_index]
def host_id(backend: Optional[str] = None) -> int: def process_index(backend: Optional[str] = None) -> int:
"""Returns the integer host ID of this host. """Returns the integer process index of this process.
On most platforms, this will always be 0. This will vary on multi-host On most platforms, this will always be 0. This will vary on multi-process
platforms though. platforms though.
Args: Args:
@ -277,19 +288,39 @@ def host_id(backend: Optional[str] = None) -> int:
``'tpu'``. ``'tpu'``.
Returns: Returns:
Integer host ID. Integer process index.
""" """
return get_backend(backend).host_id() return get_backend(backend).process_index()
def host_ids(backend: Optional[str] = None) -> List[int]: # TODO: remove this sometime after jax 0.2.13 is released
"""Returns a sorted list of all host IDs.""" def host_id(backend=None):
return sorted({d.host_id for d in devices(backend)}) warnings.warn(
"jax.host_id has been renamed to jax.process_index. This alias "
"will eventually be removed; please update your code.")
return process_index(backend)
def host_count(backend: Optional[str] = None) -> int: def process_count(backend: Optional[str] = None) -> int:
"""Returns the number of hosts.""" """Returns the number of JAX processes associated with the backend."""
return len(host_ids(backend)) return max(d.process_index for d in devices(backend)) + 1
# TODO: remove this sometime after jax 0.2.13 is released
def host_count(backend=None):
warnings.warn(
"jax.host_count has been renamed to jax.process_count. This alias "
"will eventually be removed; please update your code.")
return process_count(backend)
# TODO: remove this sometime after jax 0.2.13 is released
def host_ids(backend=None):
warnings.warn(
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
"instead. jax.host_ids will eventually be removed; please update your "
"code.")
return list(range(process_count(backend)))
### utility functions ### utility functions

View File

@ -51,7 +51,7 @@ class XlaBridgeTest(absltest.TestCase):
def test_local_devices(self): def test_local_devices(self):
self.assertNotEmpty(xb.local_devices()) self.assertNotEmpty(xb.local_devices())
with self.assertRaisesRegex(ValueError, "Unknown host_id 100"): with self.assertRaisesRegex(ValueError, "Unknown process_index 100"):
xb.local_devices(100) xb.local_devices(100)
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"): with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
xb.local_devices(backend="foo") xb.local_devices(backend="foo")