diff --git a/CHANGELOG.md b/CHANGELOG.md index 038c0131a..2b30b08ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. more cases. Previously non-parallel computations were always dispatched synchronously. You can recover the old behavior by setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`. + * Added new {func}`jax.process_indices` function to replace the + `jax.host_ids()` function that was deprecated in JAX v0.2.13. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the diff --git a/docs/jax.rst b/docs/jax.rst index b112490a0..7be3e6301 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -138,6 +138,7 @@ Parallelization (:code:`pmap`) device_count local_device_count process_count + process_indices Callbacks --------- diff --git a/jax/__init__.py b/jax/__init__.py index d9c4de6bb..dc3d9af3a 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -119,6 +119,7 @@ from jax._src.api import named_scope as named_scope from jax._src.api import pmap as pmap from jax._src.xla_bridge import process_count as process_count from jax._src.xla_bridge import process_index as process_index +from jax._src.xla_bridge import process_indices as process_indices from jax._src.callback import pure_callback as pure_callback from jax._src.ad_checkpoint import checkpoint_wrapper as remat from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 91d761fec..1d3c50403 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -1187,15 +1187,30 @@ def host_count(backend: str | xla_client.Client | None = None) -> int: return process_count(backend) +def process_indices( + backend: str | xla_client.Client | None = None +) -> list[int]: + """Returns the list of all JAX process indices associated with the backend. + + Args: + backend: This is an experimental feature and the API is likely to change. + Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or + ``'tpu'``. + + Returns: + List of integer process indices. + """ + return list(range(process_count(backend))) + + # TODO: remove this sometime after jax 0.2.13 is released def host_ids( backend: str | xla_client.Client | None = None ) -> list[int]: warnings.warn( - "jax.host_ids has been deprecated; please use range(jax.process_count()) " - "instead. jax.host_ids will eventually be removed; please update your " - "code.") - return list(range(process_count(backend))) + "jax.host_ids has been renamed to jax.process_indices. This alias " + "will eventually be removed; please update your code.") + return process_indices(backend) def using_pjrt_c_api(backend=None):