mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Replace host_id
with process_index
terminology, take 2.
We're switching to the new terminology to avoid confusion in cases where multiple jax processes are running on a single host, and each process has a unique process_index/host_id. This keeps aliases for the old `host_id` APIs for now, but these will eventually be removed. This was originally commited in b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in 14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test failures from renaming the local_devices argument name. This change is identical except it also adds staging for the argument name change.
This commit is contained in:
parent
8310867b41
commit
9128ba0c74
@ -17,6 +17,14 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
|
||||
be used within `jit` ({jax-issue}`6501`)
|
||||
* Breaking changes:
|
||||
* The following function names have changed. There are still aliases, so this
|
||||
should not break existing code, but the aliases will eventually be removed
|
||||
so please change your code.
|
||||
* `host_id` --> {func}`~jax.process_index`
|
||||
* `host_count` --> {func}`~jax.process_count`
|
||||
* `host_ids` --> `range(jax.process_count())`
|
||||
* Similarly, the argument to {func}`~jax.local_devices` has been renamed from
|
||||
`host_id` to `process_index`.
|
||||
* Arguments to {func}`jax.jit` other than the function are now marked as
|
||||
keyword-only. This change is to prevent accidental breakage when arguments
|
||||
are added to `jit`.
|
||||
|
10
docs/jax.rst
10
docs/jax.rst
@ -77,11 +77,10 @@ Parallelization (:code:`pmap`)
|
||||
pmap
|
||||
devices
|
||||
local_devices
|
||||
host_id
|
||||
host_ids
|
||||
process_index
|
||||
device_count
|
||||
local_device_count
|
||||
host_count
|
||||
process_count
|
||||
|
||||
|
||||
.. autofunction:: jit
|
||||
@ -124,8 +123,7 @@ Parallelization (:code:`pmap`)
|
||||
.. autofunction:: pmap
|
||||
.. autofunction:: devices
|
||||
.. autofunction:: local_devices
|
||||
.. autofunction:: host_id
|
||||
.. autofunction:: host_ids
|
||||
.. autofunction:: process_index
|
||||
.. autofunction:: device_count
|
||||
.. autofunction:: local_device_count
|
||||
.. autofunction:: host_count
|
||||
.. autofunction:: process_count
|
||||
|
@ -82,6 +82,8 @@ from ._src.api import (
|
||||
named_call,
|
||||
partial, # TODO(phawkins): update callers to use functools.partial.
|
||||
pmap,
|
||||
process_count,
|
||||
process_index,
|
||||
pxla, # TODO(phawkins): update users to avoid this.
|
||||
remat,
|
||||
shapecheck,
|
||||
|
@ -62,8 +62,8 @@ from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
# Unused imports to be exported
|
||||
from ..lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
local_devices, host_id, host_ids, host_count,
|
||||
default_backend)
|
||||
local_devices, process_index, process_count,
|
||||
host_id, host_ids, host_count, default_backend)
|
||||
from ..core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters import xla
|
||||
@ -1372,20 +1372,20 @@ def pmap(
|
||||
:py:func:`pmap` compiles ``fun``, so while it can be combined with
|
||||
:py:func:`jit`, it's usually unnecessary.
|
||||
|
||||
**Multi-host platforms:** On multi-host platforms such as TPU pods,
|
||||
**Multi-process platforms:** On multi-process platforms such as TPU pods,
|
||||
:py:func:`pmap` is designed to be used in SPMD Python programs, where every
|
||||
host is running the same Python code such that all hosts run the same pmapped
|
||||
function in the same order. Each host should still call the pmapped function
|
||||
with mapped axis size equal to the number of *local* devices (unless
|
||||
process is running the same Python code such that all processes run the same
|
||||
pmapped function in the same order. Each process should still call the pmapped
|
||||
function with mapped axis size equal to the number of *local* devices (unless
|
||||
``devices`` is specified, see below), and an array of the same leading axis
|
||||
size will be returned as usual. However, any collective operations in ``fun``
|
||||
will be computed over *all* participating devices, including those on other
|
||||
hosts, via device-to-device communication. Conceptually, this can be thought
|
||||
of as running a pmap over a single array sharded across hosts, where each host
|
||||
"sees" only its local shard of the input and output. The SPMD model requires
|
||||
that the same multi-host pmaps must be run in the same order on all devices,
|
||||
but they can be interspersed with arbitrary operations running on a single
|
||||
host.
|
||||
processes, via device-to-device communication. Conceptually, this can be
|
||||
thought of as running a pmap over a single array sharded across processes,
|
||||
where each process "sees" only its local shard of the input and output. The
|
||||
SPMD model requires that the same multi-process pmaps must be run in the same
|
||||
order on all devices, but they can be interspersed with arbitrary operations
|
||||
running in a single process.
|
||||
|
||||
Args:
|
||||
fun: Function to be mapped over argument axes. Its arguments and return
|
||||
@ -1519,26 +1519,26 @@ def pmap(
|
||||
>>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP
|
||||
1.0
|
||||
|
||||
On multi-host platforms, collective operations operate over all devices,
|
||||
including those on other hosts. For example, assuming the following code runs
|
||||
on two hosts with 4 XLA devices each:
|
||||
On multi-process platforms, collective operations operate over all devices,
|
||||
including those on other processes. For example, assuming the following code
|
||||
runs on two processes with 4 XLA devices each:
|
||||
|
||||
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
|
||||
>>> data = jnp.arange(4) if jax.host_id() == 0 else jnp.arange(4, 8)
|
||||
>>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8)
|
||||
>>> out = pmap(f, axis_name='i')(data) # doctest: +SKIP
|
||||
>>> print(out) # doctest: +SKIP
|
||||
[28 29 30 31] # on host 0
|
||||
[32 33 34 35] # on host 1
|
||||
[28 29 30 31] # on process 0
|
||||
[32 33 34 35] # on process 1
|
||||
|
||||
Each host passes in a different length-4 array, corresponding to its 4 local
|
||||
devices, and the psum operates over all 8 values. Conceptually, the two
|
||||
Each process passes in a different length-4 array, corresponding to its 4
|
||||
local devices, and the psum operates over all 8 values. Conceptually, the two
|
||||
length-4 arrays can be thought of as a sharded length-8 array (in this example
|
||||
equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped axis
|
||||
given name 'i'. The pmap call on each host then returns the corresponding
|
||||
length-4 output shard.
|
||||
equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped
|
||||
axis given name 'i'. The pmap call on each process then returns the
|
||||
corresponding length-4 output shard.
|
||||
|
||||
The ``devices`` argument can be used to specify exactly which devices are used
|
||||
to run the parallel computation. For example, again assuming a single host
|
||||
to run the parallel computation. For example, again assuming a single process
|
||||
with 8 devices, the following code defines two parallel computations, one
|
||||
which runs on the first six devices and one on the remaining two:
|
||||
|
||||
@ -1556,9 +1556,9 @@ def pmap(
|
||||
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
|
||||
[ 13. 13.]
|
||||
"""
|
||||
# axis_size is an optional integer representing the global axis size.
|
||||
# The aggregate size (across all hosts) size of the mapped axis must match
|
||||
# the given value.
|
||||
# axis_size is an optional integer representing the global axis size. The
|
||||
# aggregate size (across all processes) size of the mapped axis must match the
|
||||
# given value.
|
||||
|
||||
_check_callable(fun)
|
||||
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name
|
||||
|
@ -54,4 +54,3 @@ from jax._src.api import (
|
||||
_std_basis,
|
||||
_unravel_array_into_pytree,
|
||||
)
|
||||
|
||||
|
@ -640,9 +640,9 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
|
||||
# Determine global_axis_size for use in AxisEnv.
|
||||
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
||||
# if xb.host_count() > 1 and global_axis_size is None and inner_pmap:
|
||||
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
||||
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
|
||||
if (xb.host_count() == 1 and global_axis_size is not None and
|
||||
if (xb.process_count() == 1 and global_axis_size is not None and
|
||||
global_axis_size != axis_size):
|
||||
raise ValueError(
|
||||
f"Specified axis_size {global_axis_size} doesn't match received "
|
||||
@ -651,7 +651,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
must_run_on_all_devices = False
|
||||
no_nested_sharding = False
|
||||
if global_axis_size is None:
|
||||
if xb.host_count() == 1:
|
||||
if xb.process_count() == 1:
|
||||
global_axis_size = axis_size
|
||||
elif devices:
|
||||
# This allows each host in a multi-host pmap to run on a different number
|
||||
@ -664,13 +664,13 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
# this assumption is true by requiring that the pmap is run on all devices
|
||||
# (and making the further assumption that each host has the same number of
|
||||
# devices). Nested sharding is ok in this case.
|
||||
global_axis_size = axis_size * xb.host_count()
|
||||
assert all(len(xb.local_devices(host_id)) == xb.local_device_count()
|
||||
for host_id in xb.host_ids())
|
||||
global_axis_size = axis_size * xb.process_count()
|
||||
assert all(len(xb.local_devices(process_index)) == xb.local_device_count()
|
||||
for process_index in range(xb.process_count()))
|
||||
must_run_on_all_devices = True
|
||||
|
||||
if devices:
|
||||
local_devices = [d for d in devices if d.host_id == xb.host_id()]
|
||||
local_devices = [d for d in devices if d.process_index == xb.process_index()]
|
||||
assert len(local_devices) > 0
|
||||
else:
|
||||
local_devices = None # type: ignore
|
||||
@ -700,7 +700,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
if devices is not None:
|
||||
is_multi_host_pmap = len(local_devices) != len(devices)
|
||||
else:
|
||||
is_multi_host_pmap = xb.host_count() > 1
|
||||
is_multi_host_pmap = xb.process_count() > 1
|
||||
if is_multi_host_pmap:
|
||||
check_multihost_collective_allowlist(jaxpr)
|
||||
|
||||
@ -734,7 +734,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
num_local_shards = num_local_replicas * local_num_partitions
|
||||
num_global_shards = num_global_replicas * num_partitions
|
||||
|
||||
if (xb.host_count() > 1 and must_run_on_all_devices and
|
||||
if (xb.process_count() > 1 and must_run_on_all_devices and
|
||||
num_local_shards != xb.local_device_count()):
|
||||
if num_local_shards == axis_size:
|
||||
raise ValueError(
|
||||
@ -803,8 +803,8 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
if num_global_shards > num_local_shards:
|
||||
# TODO(skye): use a locality-aware assignment that satisfies the above
|
||||
# constraint.
|
||||
devices = [d for host_id in xb.host_ids()
|
||||
for d in xb.local_devices(host_id)]
|
||||
devices = [d for process_index in range(xb.process_count())
|
||||
for d in xb.local_devices(process_index)]
|
||||
else:
|
||||
devices = xb.get_backend(backend).get_default_device_assignment(
|
||||
num_global_replicas, num_partitions)
|
||||
@ -1300,8 +1300,9 @@ class Mesh:
|
||||
def local_mesh(self):
|
||||
if not self.devices.ndim:
|
||||
return self
|
||||
host_id = xb.host_id()
|
||||
is_local_device = np.vectorize(lambda d: d.host_id == host_id, otypes=[bool])(self.devices)
|
||||
process_index = xb.process_index()
|
||||
is_local_device = np.vectorize(
|
||||
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
|
||||
subcube_indices = []
|
||||
# We take the smallest slice of each dimension that doesn't skip any local device.
|
||||
for axis in range(self.devices.ndim):
|
||||
|
@ -683,7 +683,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
f"compiling computation that requires {nreps} replicas, but only "
|
||||
f"{xb.device_count(backend)} XLA devices are available")
|
||||
|
||||
if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
|
||||
if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
|
||||
raise NotImplementedError(
|
||||
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
|
||||
"extra data movement anyway, so maybe you don't want it after all).")
|
||||
|
@ -23,6 +23,7 @@ XLA. There are also a handful of related casting utilities.
|
||||
from functools import partial, lru_cache
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
# Disable "WARNING: Logging before flag parsing goes to stderr." message
|
||||
@ -195,8 +196,9 @@ def device_count(backend: Optional[str] = None) -> int:
|
||||
"""Returns the total number of devices.
|
||||
|
||||
On most platforms, this is the same as :py:func:`jax.local_device_count`.
|
||||
However, on multi-host platforms, this will return the total number of devices
|
||||
across all hosts.
|
||||
However, on multi-process platforms where different devices are associated
|
||||
with different processes, this will return the total number of devices across
|
||||
all processes.
|
||||
|
||||
Args:
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
@ -205,12 +207,13 @@ def device_count(backend: Optional[str] = None) -> int:
|
||||
|
||||
Returns:
|
||||
Number of devices.
|
||||
|
||||
"""
|
||||
return int(get_backend(backend).device_count())
|
||||
|
||||
|
||||
def local_device_count(backend: Optional[str] = None) -> int:
|
||||
"""Returns the number of devices on this host."""
|
||||
"""Returns the number of devices addressable by this process."""
|
||||
return int(get_backend(backend).local_device_count())
|
||||
|
||||
|
||||
@ -219,8 +222,9 @@ def devices(backend: Optional[str] = None) -> List[xla_client.Device]:
|
||||
|
||||
Each device is represented by a subclass of :class:`Device` (e.g.
|
||||
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
|
||||
equal to ``device_count(backend)``. Local devices can be identified by comparing
|
||||
:meth:`Device.host_id` to the value returned by :py:func:`jax.host_id`.
|
||||
equal to ``device_count(backend)``. Local devices can be identified by
|
||||
comparing :meth:`Device.process_index` to the value returned by
|
||||
:py:func:`jax.process_index`.
|
||||
|
||||
If ``backend`` is ``None``, returns all the devices from the default backend.
|
||||
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
|
||||
@ -242,15 +246,16 @@ def default_backend() -> str:
|
||||
return get_backend(None).platform
|
||||
|
||||
|
||||
def local_devices(host_id: Optional[int] = None,
|
||||
backend: Optional[str] = None) -> List[xla_client.Device]:
|
||||
"""Like :py:func:`jax.devices`, but only returns devices local to a given host.
|
||||
def local_devices(process_index: Optional[int] = None,
|
||||
backend: Optional[str] = None,
|
||||
host_id: Optional[int] = None) -> List[xla_client.Device]:
|
||||
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
|
||||
|
||||
If ``host_id`` is ``None``, returns devices local to this host.
|
||||
If ``process_index`` is ``None``, returns devices local to this process.
|
||||
|
||||
Args:
|
||||
host_id: the integer ID of the host. Host IDs can be retrieved via
|
||||
:py:func:`jax.host_ids`.
|
||||
process_index: the integer index of the process. Process indices can be
|
||||
retrieved via ``len(jax.process_count())``.
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
@ -258,17 +263,23 @@ def local_devices(host_id: Optional[int] = None,
|
||||
Returns:
|
||||
List of Device subclasses.
|
||||
"""
|
||||
if host_id is None:
|
||||
host_id = get_backend(backend).host_id()
|
||||
if host_id not in host_ids():
|
||||
raise ValueError(f"Unknown host_id {host_id}")
|
||||
return [d for d in devices(backend) if d.host_id == host_id]
|
||||
if host_id is not None:
|
||||
warnings.warn(
|
||||
"The argument to jax.local_devices has been renamed from `host_id` to "
|
||||
"`process_index`. This alias will eventually be removed; please update "
|
||||
"your code.")
|
||||
process_index = host_id
|
||||
if process_index is None:
|
||||
process_index = get_backend(backend).process_index()
|
||||
if not (0 <= process_index < process_count()):
|
||||
raise ValueError(f"Unknown process_index {process_index}")
|
||||
return [d for d in devices(backend) if d.process_index == process_index]
|
||||
|
||||
|
||||
def host_id(backend: Optional[str] = None) -> int:
|
||||
"""Returns the integer host ID of this host.
|
||||
def process_index(backend: Optional[str] = None) -> int:
|
||||
"""Returns the integer process index of this process.
|
||||
|
||||
On most platforms, this will always be 0. This will vary on multi-host
|
||||
On most platforms, this will always be 0. This will vary on multi-process
|
||||
platforms though.
|
||||
|
||||
Args:
|
||||
@ -277,19 +288,39 @@ def host_id(backend: Optional[str] = None) -> int:
|
||||
``'tpu'``.
|
||||
|
||||
Returns:
|
||||
Integer host ID.
|
||||
Integer process index.
|
||||
"""
|
||||
return get_backend(backend).host_id()
|
||||
return get_backend(backend).process_index()
|
||||
|
||||
|
||||
def host_ids(backend: Optional[str] = None) -> List[int]:
|
||||
"""Returns a sorted list of all host IDs."""
|
||||
return sorted({d.host_id for d in devices(backend)})
|
||||
# TODO: remove this sometime after jax 0.2.13 is released
|
||||
def host_id(backend=None):
|
||||
warnings.warn(
|
||||
"jax.host_id has been renamed to jax.process_index. This alias "
|
||||
"will eventually be removed; please update your code.")
|
||||
return process_index(backend)
|
||||
|
||||
|
||||
def host_count(backend: Optional[str] = None) -> int:
|
||||
"""Returns the number of hosts."""
|
||||
return len(host_ids(backend))
|
||||
def process_count(backend: Optional[str] = None) -> int:
|
||||
"""Returns the number of JAX processes associated with the backend."""
|
||||
return max(d.process_index for d in devices(backend)) + 1
|
||||
|
||||
|
||||
# TODO: remove this sometime after jax 0.2.13 is released
|
||||
def host_count(backend=None):
|
||||
warnings.warn(
|
||||
"jax.host_count has been renamed to jax.process_count. This alias "
|
||||
"will eventually be removed; please update your code.")
|
||||
return process_count(backend)
|
||||
|
||||
|
||||
# TODO: remove this sometime after jax 0.2.13 is released
|
||||
def host_ids(backend=None):
|
||||
warnings.warn(
|
||||
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
|
||||
"instead. jax.host_ids will eventually be removed; please update your "
|
||||
"code.")
|
||||
return list(range(process_count(backend)))
|
||||
|
||||
|
||||
### utility functions
|
||||
|
@ -51,7 +51,7 @@ class XlaBridgeTest(absltest.TestCase):
|
||||
|
||||
def test_local_devices(self):
|
||||
self.assertNotEmpty(xb.local_devices())
|
||||
with self.assertRaisesRegex(ValueError, "Unknown host_id 100"):
|
||||
with self.assertRaisesRegex(ValueError, "Unknown process_index 100"):
|
||||
xb.local_devices(100)
|
||||
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
|
||||
xb.local_devices(backend="foo")
|
||||
|
Loading…
x
Reference in New Issue
Block a user