Replace host_id with process_index terminology, take 2.

We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
This commit is contained in:
Skye Wanderman-Milne 2021-04-20 17:56:41 -07:00
parent 8310867b41
commit 9128ba0c74
9 changed files with 115 additions and 76 deletions

View File

@ -17,6 +17,14 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
be used within `jit` ({jax-issue}`6501`)
* Breaking changes:
* The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `host_id` --> {func}`~jax.process_index`
* `host_count` --> {func}`~jax.process_count`
* `host_ids` --> `range(jax.process_count())`
* Similarly, the argument to {func}`~jax.local_devices` has been renamed from
`host_id` to `process_index`.
* Arguments to {func}`jax.jit` other than the function are now marked as
keyword-only. This change is to prevent accidental breakage when arguments
are added to `jit`.

View File

@ -77,11 +77,10 @@ Parallelization (:code:`pmap`)
pmap
devices
local_devices
host_id
host_ids
process_index
device_count
local_device_count
host_count
process_count
.. autofunction:: jit
@ -124,8 +123,7 @@ Parallelization (:code:`pmap`)
.. autofunction:: pmap
.. autofunction:: devices
.. autofunction:: local_devices
.. autofunction:: host_id
.. autofunction:: host_ids
.. autofunction:: process_index
.. autofunction:: device_count
.. autofunction:: local_device_count
.. autofunction:: host_count
.. autofunction:: process_count

View File

@ -82,6 +82,8 @@ from ._src.api import (
named_call,
partial, # TODO(phawkins): update callers to use functools.partial.
pmap,
process_count,
process_index,
pxla, # TODO(phawkins): update users to avoid this.
remat,
shapecheck,

View File

@ -62,8 +62,8 @@ from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
# Unused imports to be exported
from ..lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, host_id, host_ids, host_count,
default_backend)
local_devices, process_index, process_count,
host_id, host_ids, host_count, default_backend)
from ..core import ConcreteArray, ShapedArray, raise_to_shaped
from ..interpreters import partial_eval as pe
from ..interpreters import xla
@ -1372,20 +1372,20 @@ def pmap(
:py:func:`pmap` compiles ``fun``, so while it can be combined with
:py:func:`jit`, it's usually unnecessary.
**Multi-host platforms:** On multi-host platforms such as TPU pods,
**Multi-process platforms:** On multi-process platforms such as TPU pods,
:py:func:`pmap` is designed to be used in SPMD Python programs, where every
host is running the same Python code such that all hosts run the same pmapped
function in the same order. Each host should still call the pmapped function
with mapped axis size equal to the number of *local* devices (unless
process is running the same Python code such that all processes run the same
pmapped function in the same order. Each process should still call the pmapped
function with mapped axis size equal to the number of *local* devices (unless
``devices`` is specified, see below), and an array of the same leading axis
size will be returned as usual. However, any collective operations in ``fun``
will be computed over *all* participating devices, including those on other
hosts, via device-to-device communication. Conceptually, this can be thought
of as running a pmap over a single array sharded across hosts, where each host
"sees" only its local shard of the input and output. The SPMD model requires
that the same multi-host pmaps must be run in the same order on all devices,
but they can be interspersed with arbitrary operations running on a single
host.
processes, via device-to-device communication. Conceptually, this can be
thought of as running a pmap over a single array sharded across processes,
where each process "sees" only its local shard of the input and output. The
SPMD model requires that the same multi-process pmaps must be run in the same
order on all devices, but they can be interspersed with arbitrary operations
running in a single process.
Args:
fun: Function to be mapped over argument axes. Its arguments and return
@ -1519,26 +1519,26 @@ def pmap(
>>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP
1.0
On multi-host platforms, collective operations operate over all devices,
including those on other hosts. For example, assuming the following code runs
on two hosts with 4 XLA devices each:
On multi-process platforms, collective operations operate over all devices,
including those on other processes. For example, assuming the following code
runs on two processes with 4 XLA devices each:
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
>>> data = jnp.arange(4) if jax.host_id() == 0 else jnp.arange(4, 8)
>>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8)
>>> out = pmap(f, axis_name='i')(data) # doctest: +SKIP
>>> print(out) # doctest: +SKIP
[28 29 30 31] # on host 0
[32 33 34 35] # on host 1
[28 29 30 31] # on process 0
[32 33 34 35] # on process 1
Each host passes in a different length-4 array, corresponding to its 4 local
devices, and the psum operates over all 8 values. Conceptually, the two
Each process passes in a different length-4 array, corresponding to its 4
local devices, and the psum operates over all 8 values. Conceptually, the two
length-4 arrays can be thought of as a sharded length-8 array (in this example
equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped axis
given name 'i'. The pmap call on each host then returns the corresponding
length-4 output shard.
equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped
axis given name 'i'. The pmap call on each process then returns the
corresponding length-4 output shard.
The ``devices`` argument can be used to specify exactly which devices are used
to run the parallel computation. For example, again assuming a single host
to run the parallel computation. For example, again assuming a single process
with 8 devices, the following code defines two parallel computations, one
which runs on the first six devices and one on the remaining two:
@ -1556,9 +1556,9 @@ def pmap(
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
[ 13. 13.]
"""
# axis_size is an optional integer representing the global axis size.
# The aggregate size (across all hosts) size of the mapped axis must match
# the given value.
# axis_size is an optional integer representing the global axis size. The
# aggregate size (across all processes) size of the mapped axis must match the
# given value.
_check_callable(fun)
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name

View File

@ -54,4 +54,3 @@ from jax._src.api import (
_std_basis,
_unravel_array_into_pytree,
)

View File

@ -640,9 +640,9 @@ def parallel_callable(fun: lu.WrappedFun,
# Determine global_axis_size for use in AxisEnv.
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
# if xb.host_count() > 1 and global_axis_size is None and inner_pmap:
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
if (xb.host_count() == 1 and global_axis_size is not None and
if (xb.process_count() == 1 and global_axis_size is not None and
global_axis_size != axis_size):
raise ValueError(
f"Specified axis_size {global_axis_size} doesn't match received "
@ -651,7 +651,7 @@ def parallel_callable(fun: lu.WrappedFun,
must_run_on_all_devices = False
no_nested_sharding = False
if global_axis_size is None:
if xb.host_count() == 1:
if xb.process_count() == 1:
global_axis_size = axis_size
elif devices:
# This allows each host in a multi-host pmap to run on a different number
@ -664,13 +664,13 @@ def parallel_callable(fun: lu.WrappedFun,
# this assumption is true by requiring that the pmap is run on all devices
# (and making the further assumption that each host has the same number of
# devices). Nested sharding is ok in this case.
global_axis_size = axis_size * xb.host_count()
assert all(len(xb.local_devices(host_id)) == xb.local_device_count()
for host_id in xb.host_ids())
global_axis_size = axis_size * xb.process_count()
assert all(len(xb.local_devices(process_index)) == xb.local_device_count()
for process_index in range(xb.process_count()))
must_run_on_all_devices = True
if devices:
local_devices = [d for d in devices if d.host_id == xb.host_id()]
local_devices = [d for d in devices if d.process_index == xb.process_index()]
assert len(local_devices) > 0
else:
local_devices = None # type: ignore
@ -700,7 +700,7 @@ def parallel_callable(fun: lu.WrappedFun,
if devices is not None:
is_multi_host_pmap = len(local_devices) != len(devices)
else:
is_multi_host_pmap = xb.host_count() > 1
is_multi_host_pmap = xb.process_count() > 1
if is_multi_host_pmap:
check_multihost_collective_allowlist(jaxpr)
@ -734,7 +734,7 @@ def parallel_callable(fun: lu.WrappedFun,
num_local_shards = num_local_replicas * local_num_partitions
num_global_shards = num_global_replicas * num_partitions
if (xb.host_count() > 1 and must_run_on_all_devices and
if (xb.process_count() > 1 and must_run_on_all_devices and
num_local_shards != xb.local_device_count()):
if num_local_shards == axis_size:
raise ValueError(
@ -803,8 +803,8 @@ def parallel_callable(fun: lu.WrappedFun,
if num_global_shards > num_local_shards:
# TODO(skye): use a locality-aware assignment that satisfies the above
# constraint.
devices = [d for host_id in xb.host_ids()
for d in xb.local_devices(host_id)]
devices = [d for process_index in range(xb.process_count())
for d in xb.local_devices(process_index)]
else:
devices = xb.get_backend(backend).get_default_device_assignment(
num_global_replicas, num_partitions)
@ -1300,8 +1300,9 @@ class Mesh:
def local_mesh(self):
if not self.devices.ndim:
return self
host_id = xb.host_id()
is_local_device = np.vectorize(lambda d: d.host_id == host_id, otypes=[bool])(self.devices)
process_index = xb.process_index()
is_local_device = np.vectorize(
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
subcube_indices = []
# We take the smallest slice of each dimension that doesn't skip any local device.
for axis in range(self.devices.ndim):

View File

@ -683,7 +683,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
f"compiling computation that requires {nreps} replicas, but only "
f"{xb.device_count(backend)} XLA devices are available")
if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")

View File

@ -23,6 +23,7 @@ XLA. There are also a handful of related casting utilities.
from functools import partial, lru_cache
import os
from typing import Callable, Dict, List, Optional, Tuple, Union
import warnings
from absl import logging
# Disable "WARNING: Logging before flag parsing goes to stderr." message
@ -195,8 +196,9 @@ def device_count(backend: Optional[str] = None) -> int:
"""Returns the total number of devices.
On most platforms, this is the same as :py:func:`jax.local_device_count`.
However, on multi-host platforms, this will return the total number of devices
across all hosts.
However, on multi-process platforms where different devices are associated
with different processes, this will return the total number of devices across
all processes.
Args:
backend: This is an experimental feature and the API is likely to change.
@ -205,12 +207,13 @@ def device_count(backend: Optional[str] = None) -> int:
Returns:
Number of devices.
"""
return int(get_backend(backend).device_count())
def local_device_count(backend: Optional[str] = None) -> int:
"""Returns the number of devices on this host."""
"""Returns the number of devices addressable by this process."""
return int(get_backend(backend).local_device_count())
@ -219,8 +222,9 @@ def devices(backend: Optional[str] = None) -> List[xla_client.Device]:
Each device is represented by a subclass of :class:`Device` (e.g.
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by comparing
:meth:`Device.host_id` to the value returned by :py:func:`jax.host_id`.
equal to ``device_count(backend)``. Local devices can be identified by
comparing :meth:`Device.process_index` to the value returned by
:py:func:`jax.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
@ -242,15 +246,16 @@ def default_backend() -> str:
return get_backend(None).platform
def local_devices(host_id: Optional[int] = None,
backend: Optional[str] = None) -> List[xla_client.Device]:
"""Like :py:func:`jax.devices`, but only returns devices local to a given host.
def local_devices(process_index: Optional[int] = None,
backend: Optional[str] = None,
host_id: Optional[int] = None) -> List[xla_client.Device]:
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
If ``host_id`` is ``None``, returns devices local to this host.
If ``process_index`` is ``None``, returns devices local to this process.
Args:
host_id: the integer ID of the host. Host IDs can be retrieved via
:py:func:`jax.host_ids`.
process_index: the integer index of the process. Process indices can be
retrieved via ``len(jax.process_count())``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
@ -258,17 +263,23 @@ def local_devices(host_id: Optional[int] = None,
Returns:
List of Device subclasses.
"""
if host_id is None:
host_id = get_backend(backend).host_id()
if host_id not in host_ids():
raise ValueError(f"Unknown host_id {host_id}")
return [d for d in devices(backend) if d.host_id == host_id]
if host_id is not None:
warnings.warn(
"The argument to jax.local_devices has been renamed from `host_id` to "
"`process_index`. This alias will eventually be removed; please update "
"your code.")
process_index = host_id
if process_index is None:
process_index = get_backend(backend).process_index()
if not (0 <= process_index < process_count()):
raise ValueError(f"Unknown process_index {process_index}")
return [d for d in devices(backend) if d.process_index == process_index]
def host_id(backend: Optional[str] = None) -> int:
"""Returns the integer host ID of this host.
def process_index(backend: Optional[str] = None) -> int:
"""Returns the integer process index of this process.
On most platforms, this will always be 0. This will vary on multi-host
On most platforms, this will always be 0. This will vary on multi-process
platforms though.
Args:
@ -277,19 +288,39 @@ def host_id(backend: Optional[str] = None) -> int:
``'tpu'``.
Returns:
Integer host ID.
Integer process index.
"""
return get_backend(backend).host_id()
return get_backend(backend).process_index()
def host_ids(backend: Optional[str] = None) -> List[int]:
"""Returns a sorted list of all host IDs."""
return sorted({d.host_id for d in devices(backend)})
# TODO: remove this sometime after jax 0.2.13 is released
def host_id(backend=None):
warnings.warn(
"jax.host_id has been renamed to jax.process_index. This alias "
"will eventually be removed; please update your code.")
return process_index(backend)
def host_count(backend: Optional[str] = None) -> int:
"""Returns the number of hosts."""
return len(host_ids(backend))
def process_count(backend: Optional[str] = None) -> int:
"""Returns the number of JAX processes associated with the backend."""
return max(d.process_index for d in devices(backend)) + 1
# TODO: remove this sometime after jax 0.2.13 is released
def host_count(backend=None):
warnings.warn(
"jax.host_count has been renamed to jax.process_count. This alias "
"will eventually be removed; please update your code.")
return process_count(backend)
# TODO: remove this sometime after jax 0.2.13 is released
def host_ids(backend=None):
warnings.warn(
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
"instead. jax.host_ids will eventually be removed; please update your "
"code.")
return list(range(process_count(backend)))
### utility functions

View File

@ -51,7 +51,7 @@ class XlaBridgeTest(absltest.TestCase):
def test_local_devices(self):
self.assertNotEmpty(xb.local_devices())
with self.assertRaisesRegex(ValueError, "Unknown host_id 100"):
with self.assertRaisesRegex(ValueError, "Unknown process_index 100"):
xb.local_devices(100)
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
xb.local_devices(backend="foo")