Add a jax.process_indices function.

The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface.

PiperOrigin-RevId: 662142636
This commit is contained in:
Dan Foreman-Mackey 2024-08-12 10:29:15 -07:00 committed by jax authors
parent ad74e55dbc
commit 60bf5b7727
4 changed files with 23 additions and 4 deletions

View File

@ -20,6 +20,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* Added new {func}`jax.process_indices` function to replace the
`jax.host_ids()` function that was deprecated in JAX v0.2.13.
* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the

View File

@ -138,6 +138,7 @@ Parallelization (:code:`pmap`)
device_count
local_device_count
process_count
process_indices
Callbacks
---------

View File

@ -119,6 +119,7 @@ from jax._src.api import named_scope as named_scope
from jax._src.api import pmap as pmap
from jax._src.xla_bridge import process_count as process_count
from jax._src.xla_bridge import process_index as process_index
from jax._src.xla_bridge import process_indices as process_indices
from jax._src.callback import pure_callback as pure_callback
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct

View File

@ -1187,15 +1187,30 @@ def host_count(backend: str | xla_client.Client | None = None) -> int:
return process_count(backend)
def process_indices(
backend: str | xla_client.Client | None = None
) -> list[int]:
"""Returns the list of all JAX process indices associated with the backend.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
List of integer process indices.
"""
return list(range(process_count(backend)))
# TODO: remove this sometime after jax 0.2.13 is released
def host_ids(
backend: str | xla_client.Client | None = None
) -> list[int]:
warnings.warn(
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
"instead. jax.host_ids will eventually be removed; please update your "
"code.")
return list(range(process_count(backend)))
"jax.host_ids has been renamed to jax.process_indices. This alias "
"will eventually be removed; please update your code.")
return process_indices(backend)
def using_pjrt_c_api(backend=None):