From 60bf5b7727c9cdcc5928ca6b8b9ae4f7695892cd Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 12 Aug 2024 10:29:15 -0700 Subject: [PATCH] Add a `jax.process_indices` function. The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface. PiperOrigin-RevId: 662142636 --- CHANGELOG.md | 2 ++ docs/jax.rst | 1 + jax/__init__.py | 1 + jax/_src/xla_bridge.py | 23 +++++++++++++++++++---- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 038c0131a..2b30b08ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. more cases. Previously non-parallel computations were always dispatched synchronously. You can recover the old behavior by setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`. + * Added new {func}`jax.process_indices` function to replace the + `jax.host_ids()` function that was deprecated in JAX v0.2.13. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the diff --git a/docs/jax.rst b/docs/jax.rst index b112490a0..7be3e6301 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -138,6 +138,7 @@ Parallelization (:code:`pmap`) device_count local_device_count process_count + process_indices Callbacks --------- diff --git a/jax/__init__.py b/jax/__init__.py index d9c4de6bb..dc3d9af3a 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -119,6 +119,7 @@ from jax._src.api import named_scope as named_scope from jax._src.api import pmap as pmap from jax._src.xla_bridge import process_count as process_count from jax._src.xla_bridge import process_index as process_index +from jax._src.xla_bridge import process_indices as process_indices from jax._src.callback import pure_callback as pure_callback from jax._src.ad_checkpoint import checkpoint_wrapper as remat from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 91d761fec..1d3c50403 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -1187,15 +1187,30 @@ def host_count(backend: str | xla_client.Client | None = None) -> int: return process_count(backend) +def process_indices( + backend: str | xla_client.Client | None = None +) -> list[int]: + """Returns the list of all JAX process indices associated with the backend. + + Args: + backend: This is an experimental feature and the API is likely to change. + Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or + ``'tpu'``. + + Returns: + List of integer process indices. + """ + return list(range(process_count(backend))) + + # TODO: remove this sometime after jax 0.2.13 is released def host_ids( backend: str | xla_client.Client | None = None ) -> list[int]: warnings.warn( - "jax.host_ids has been deprecated; please use range(jax.process_count()) " - "instead. jax.host_ids will eventually be removed; please update your " - "code.") - return list(range(process_count(backend))) + "jax.host_ids has been renamed to jax.process_indices. This alias " + "will eventually be removed; please update your code.") + return process_indices(backend) def using_pjrt_c_api(backend=None):