From 9128ba0c742a8d5791f9af9445bdc6bfa90fd11c Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 20 Apr 2021 17:56:41 -0700 Subject: [PATCH] Replace `host_id` with `process_index` terminology, take 2. We're switching to the new terminology to avoid confusion in cases where multiple jax processes are running on a single host, and each process has a unique process_index/host_id. This keeps aliases for the old `host_id` APIs for now, but these will eventually be removed. This was originally commited in b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in 14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test failures from renaming the local_devices argument name. This change is identical except it also adds staging for the argument name change. --- CHANGELOG.md | 8 ++++ docs/jax.rst | 10 ++--- jax/__init__.py | 2 + jax/_src/api.py | 54 ++++++++++++------------- jax/api.py | 1 - jax/interpreters/pxla.py | 27 +++++++------ jax/interpreters/xla.py | 2 +- jax/lib/xla_bridge.py | 85 +++++++++++++++++++++++++++------------- tests/xla_bridge_test.py | 2 +- 9 files changed, 115 insertions(+), 76 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6ca821b5..3f2ad693a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,14 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. * {func}`jax.nonzero` has a new optional `size` argument that allows it to be used within `jit` ({jax-issue}`6501`) * Breaking changes: + * The following function names have changed. There are still aliases, so this + should not break existing code, but the aliases will eventually be removed + so please change your code. + * `host_id` --> {func}`~jax.process_index` + * `host_count` --> {func}`~jax.process_count` + * `host_ids` --> `range(jax.process_count())` + * Similarly, the argument to {func}`~jax.local_devices` has been renamed from + `host_id` to `process_index`. * Arguments to {func}`jax.jit` other than the function are now marked as keyword-only. This change is to prevent accidental breakage when arguments are added to `jit`. diff --git a/docs/jax.rst b/docs/jax.rst index f7cfb673a..d027f6360 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -77,11 +77,10 @@ Parallelization (:code:`pmap`) pmap devices local_devices - host_id - host_ids + process_index device_count local_device_count - host_count + process_count .. autofunction:: jit @@ -124,8 +123,7 @@ Parallelization (:code:`pmap`) .. autofunction:: pmap .. autofunction:: devices .. autofunction:: local_devices -.. autofunction:: host_id -.. autofunction:: host_ids +.. autofunction:: process_index .. autofunction:: device_count .. autofunction:: local_device_count -.. autofunction:: host_count +.. autofunction:: process_count diff --git a/jax/__init__.py b/jax/__init__.py index c3e4f71a3..a8e1824bf 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -82,6 +82,8 @@ from ._src.api import ( named_call, partial, # TODO(phawkins): update callers to use functools.partial. pmap, + process_count, + process_index, pxla, # TODO(phawkins): update users to avoid this. remat, shapecheck, diff --git a/jax/_src/api.py b/jax/_src/api.py index 42db8fddd..61e1ba519 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -62,8 +62,8 @@ from ..lib import xla_bridge as xb from ..lib import xla_client as xc # Unused imports to be exported from ..lib.xla_bridge import (device_count, local_device_count, devices, - local_devices, host_id, host_ids, host_count, - default_backend) + local_devices, process_index, process_count, + host_id, host_ids, host_count, default_backend) from ..core import ConcreteArray, ShapedArray, raise_to_shaped from ..interpreters import partial_eval as pe from ..interpreters import xla @@ -1372,20 +1372,20 @@ def pmap( :py:func:`pmap` compiles ``fun``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. - **Multi-host platforms:** On multi-host platforms such as TPU pods, + **Multi-process platforms:** On multi-process platforms such as TPU pods, :py:func:`pmap` is designed to be used in SPMD Python programs, where every - host is running the same Python code such that all hosts run the same pmapped - function in the same order. Each host should still call the pmapped function - with mapped axis size equal to the number of *local* devices (unless + process is running the same Python code such that all processes run the same + pmapped function in the same order. Each process should still call the pmapped + function with mapped axis size equal to the number of *local* devices (unless ``devices`` is specified, see below), and an array of the same leading axis size will be returned as usual. However, any collective operations in ``fun`` will be computed over *all* participating devices, including those on other - hosts, via device-to-device communication. Conceptually, this can be thought - of as running a pmap over a single array sharded across hosts, where each host - "sees" only its local shard of the input and output. The SPMD model requires - that the same multi-host pmaps must be run in the same order on all devices, - but they can be interspersed with arbitrary operations running on a single - host. + processes, via device-to-device communication. Conceptually, this can be + thought of as running a pmap over a single array sharded across processes, + where each process "sees" only its local shard of the input and output. The + SPMD model requires that the same multi-process pmaps must be run in the same + order on all devices, but they can be interspersed with arbitrary operations + running in a single process. Args: fun: Function to be mapped over argument axes. Its arguments and return @@ -1519,26 +1519,26 @@ def pmap( >>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP 1.0 - On multi-host platforms, collective operations operate over all devices, - including those on other hosts. For example, assuming the following code runs - on two hosts with 4 XLA devices each: + On multi-process platforms, collective operations operate over all devices, + including those on other processes. For example, assuming the following code + runs on two processes with 4 XLA devices each: >>> f = lambda x: x + jax.lax.psum(x, axis_name='i') - >>> data = jnp.arange(4) if jax.host_id() == 0 else jnp.arange(4, 8) + >>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8) >>> out = pmap(f, axis_name='i')(data) # doctest: +SKIP >>> print(out) # doctest: +SKIP - [28 29 30 31] # on host 0 - [32 33 34 35] # on host 1 + [28 29 30 31] # on process 0 + [32 33 34 35] # on process 1 - Each host passes in a different length-4 array, corresponding to its 4 local - devices, and the psum operates over all 8 values. Conceptually, the two + Each process passes in a different length-4 array, corresponding to its 4 + local devices, and the psum operates over all 8 values. Conceptually, the two length-4 arrays can be thought of as a sharded length-8 array (in this example - equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped axis - given name 'i'. The pmap call on each host then returns the corresponding - length-4 output shard. + equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped + axis given name 'i'. The pmap call on each process then returns the + corresponding length-4 output shard. The ``devices`` argument can be used to specify exactly which devices are used - to run the parallel computation. For example, again assuming a single host + to run the parallel computation. For example, again assuming a single process with 8 devices, the following code defines two parallel computations, one which runs on the first six devices and one on the remaining two: @@ -1556,9 +1556,9 @@ def pmap( >>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP [ 13. 13.] """ - # axis_size is an optional integer representing the global axis size. - # The aggregate size (across all hosts) size of the mapped axis must match - # the given value. + # axis_size is an optional integer representing the global axis size. The + # aggregate size (across all processes) size of the mapped axis must match the + # given value. _check_callable(fun) axis_name = core._TempAxisName(fun) if axis_name is None else axis_name diff --git a/jax/api.py b/jax/api.py index b7106841a..d156032b3 100644 --- a/jax/api.py +++ b/jax/api.py @@ -54,4 +54,3 @@ from jax._src.api import ( _std_basis, _unravel_array_into_pytree, ) - diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 32aed8373..90dc5ed4f 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -640,9 +640,9 @@ def parallel_callable(fun: lu.WrappedFun, # Determine global_axis_size for use in AxisEnv. # TODO(mattjj,skyewm): revive this check (inner_pmap always False now) - # if xb.host_count() > 1 and global_axis_size is None and inner_pmap: + # if xb.process_count() > 1 and global_axis_size is None and inner_pmap: # raise ValueError("'axis_size' must be specified for nested multi-host pmaps") - if (xb.host_count() == 1 and global_axis_size is not None and + if (xb.process_count() == 1 and global_axis_size is not None and global_axis_size != axis_size): raise ValueError( f"Specified axis_size {global_axis_size} doesn't match received " @@ -651,7 +651,7 @@ def parallel_callable(fun: lu.WrappedFun, must_run_on_all_devices = False no_nested_sharding = False if global_axis_size is None: - if xb.host_count() == 1: + if xb.process_count() == 1: global_axis_size = axis_size elif devices: # This allows each host in a multi-host pmap to run on a different number @@ -664,13 +664,13 @@ def parallel_callable(fun: lu.WrappedFun, # this assumption is true by requiring that the pmap is run on all devices # (and making the further assumption that each host has the same number of # devices). Nested sharding is ok in this case. - global_axis_size = axis_size * xb.host_count() - assert all(len(xb.local_devices(host_id)) == xb.local_device_count() - for host_id in xb.host_ids()) + global_axis_size = axis_size * xb.process_count() + assert all(len(xb.local_devices(process_index)) == xb.local_device_count() + for process_index in range(xb.process_count())) must_run_on_all_devices = True if devices: - local_devices = [d for d in devices if d.host_id == xb.host_id()] + local_devices = [d for d in devices if d.process_index == xb.process_index()] assert len(local_devices) > 0 else: local_devices = None # type: ignore @@ -700,7 +700,7 @@ def parallel_callable(fun: lu.WrappedFun, if devices is not None: is_multi_host_pmap = len(local_devices) != len(devices) else: - is_multi_host_pmap = xb.host_count() > 1 + is_multi_host_pmap = xb.process_count() > 1 if is_multi_host_pmap: check_multihost_collective_allowlist(jaxpr) @@ -734,7 +734,7 @@ def parallel_callable(fun: lu.WrappedFun, num_local_shards = num_local_replicas * local_num_partitions num_global_shards = num_global_replicas * num_partitions - if (xb.host_count() > 1 and must_run_on_all_devices and + if (xb.process_count() > 1 and must_run_on_all_devices and num_local_shards != xb.local_device_count()): if num_local_shards == axis_size: raise ValueError( @@ -803,8 +803,8 @@ def parallel_callable(fun: lu.WrappedFun, if num_global_shards > num_local_shards: # TODO(skye): use a locality-aware assignment that satisfies the above # constraint. - devices = [d for host_id in xb.host_ids() - for d in xb.local_devices(host_id)] + devices = [d for process_index in range(xb.process_count()) + for d in xb.local_devices(process_index)] else: devices = xb.get_backend(backend).get_default_device_assignment( num_global_replicas, num_partitions) @@ -1300,8 +1300,9 @@ class Mesh: def local_mesh(self): if not self.devices.ndim: return self - host_id = xb.host_id() - is_local_device = np.vectorize(lambda d: d.host_id == host_id, otypes=[bool])(self.devices) + process_index = xb.process_index() + is_local_device = np.vectorize( + lambda d: d.process_index == process_index, otypes=[bool])(self.devices) subcube_indices = [] # We take the smallest slice of each dimension that doesn't skip any local device. for axis in range(self.devices.ndim): diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index d3d6750f1..95dc70520 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -683,7 +683,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar f"compiling computation that requires {nreps} replicas, but only " f"{xb.device_count(backend)} XLA devices are available") - if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)): + if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)): raise NotImplementedError( "jit of multi-host pmap not implemented (and jit-of-pmap can cause " "extra data movement anyway, so maybe you don't want it after all).") diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index 53122665c..5bcc98d50 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -23,6 +23,7 @@ XLA. There are also a handful of related casting utilities. from functools import partial, lru_cache import os from typing import Callable, Dict, List, Optional, Tuple, Union +import warnings from absl import logging # Disable "WARNING: Logging before flag parsing goes to stderr." message @@ -195,8 +196,9 @@ def device_count(backend: Optional[str] = None) -> int: """Returns the total number of devices. On most platforms, this is the same as :py:func:`jax.local_device_count`. - However, on multi-host platforms, this will return the total number of devices - across all hosts. + However, on multi-process platforms where different devices are associated + with different processes, this will return the total number of devices across + all processes. Args: backend: This is an experimental feature and the API is likely to change. @@ -205,12 +207,13 @@ def device_count(backend: Optional[str] = None) -> int: Returns: Number of devices. + """ return int(get_backend(backend).device_count()) def local_device_count(backend: Optional[str] = None) -> int: - """Returns the number of devices on this host.""" + """Returns the number of devices addressable by this process.""" return int(get_backend(backend).local_device_count()) @@ -219,8 +222,9 @@ def devices(backend: Optional[str] = None) -> List[xla_client.Device]: Each device is represented by a subclass of :class:`Device` (e.g. :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is - equal to ``device_count(backend)``. Local devices can be identified by comparing - :meth:`Device.host_id` to the value returned by :py:func:`jax.host_id`. + equal to ``device_count(backend)``. Local devices can be identified by + comparing :meth:`Device.process_index` to the value returned by + :py:func:`jax.process_index`. If ``backend`` is ``None``, returns all the devices from the default backend. The default backend is generally ``'gpu'`` or ``'tpu'`` if available, @@ -242,15 +246,16 @@ def default_backend() -> str: return get_backend(None).platform -def local_devices(host_id: Optional[int] = None, - backend: Optional[str] = None) -> List[xla_client.Device]: - """Like :py:func:`jax.devices`, but only returns devices local to a given host. +def local_devices(process_index: Optional[int] = None, + backend: Optional[str] = None, + host_id: Optional[int] = None) -> List[xla_client.Device]: + """Like :py:func:`jax.devices`, but only returns devices local to a given process. - If ``host_id`` is ``None``, returns devices local to this host. + If ``process_index`` is ``None``, returns devices local to this process. Args: - host_id: the integer ID of the host. Host IDs can be retrieved via - :py:func:`jax.host_ids`. + process_index: the integer index of the process. Process indices can be + retrieved via ``len(jax.process_count())``. backend: This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or ``'tpu'``. @@ -258,17 +263,23 @@ def local_devices(host_id: Optional[int] = None, Returns: List of Device subclasses. """ - if host_id is None: - host_id = get_backend(backend).host_id() - if host_id not in host_ids(): - raise ValueError(f"Unknown host_id {host_id}") - return [d for d in devices(backend) if d.host_id == host_id] + if host_id is not None: + warnings.warn( + "The argument to jax.local_devices has been renamed from `host_id` to " + "`process_index`. This alias will eventually be removed; please update " + "your code.") + process_index = host_id + if process_index is None: + process_index = get_backend(backend).process_index() + if not (0 <= process_index < process_count()): + raise ValueError(f"Unknown process_index {process_index}") + return [d for d in devices(backend) if d.process_index == process_index] -def host_id(backend: Optional[str] = None) -> int: - """Returns the integer host ID of this host. +def process_index(backend: Optional[str] = None) -> int: + """Returns the integer process index of this process. - On most platforms, this will always be 0. This will vary on multi-host + On most platforms, this will always be 0. This will vary on multi-process platforms though. Args: @@ -277,19 +288,39 @@ def host_id(backend: Optional[str] = None) -> int: ``'tpu'``. Returns: - Integer host ID. + Integer process index. """ - return get_backend(backend).host_id() + return get_backend(backend).process_index() -def host_ids(backend: Optional[str] = None) -> List[int]: - """Returns a sorted list of all host IDs.""" - return sorted({d.host_id for d in devices(backend)}) +# TODO: remove this sometime after jax 0.2.13 is released +def host_id(backend=None): + warnings.warn( + "jax.host_id has been renamed to jax.process_index. This alias " + "will eventually be removed; please update your code.") + return process_index(backend) -def host_count(backend: Optional[str] = None) -> int: - """Returns the number of hosts.""" - return len(host_ids(backend)) +def process_count(backend: Optional[str] = None) -> int: + """Returns the number of JAX processes associated with the backend.""" + return max(d.process_index for d in devices(backend)) + 1 + + +# TODO: remove this sometime after jax 0.2.13 is released +def host_count(backend=None): + warnings.warn( + "jax.host_count has been renamed to jax.process_count. This alias " + "will eventually be removed; please update your code.") + return process_count(backend) + + +# TODO: remove this sometime after jax 0.2.13 is released +def host_ids(backend=None): + warnings.warn( + "jax.host_ids has been deprecated; please use range(jax.process_count()) " + "instead. jax.host_ids will eventually be removed; please update your " + "code.") + return list(range(process_count(backend))) ### utility functions diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 4cb3c4ca9..d37d15e9c 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -51,7 +51,7 @@ class XlaBridgeTest(absltest.TestCase): def test_local_devices(self): self.assertNotEmpty(xb.local_devices()) - with self.assertRaisesRegex(ValueError, "Unknown host_id 100"): + with self.assertRaisesRegex(ValueError, "Unknown process_index 100"): xb.local_devices(100) with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"): xb.local_devices(backend="foo")