mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add a jax.process_indices
function.
The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface. PiperOrigin-RevId: 662142636
This commit is contained in:
parent
ad74e55dbc
commit
60bf5b7727
@ -20,6 +20,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
more cases. Previously non-parallel computations were always dispatched
|
||||
synchronously. You can recover the old behavior by setting
|
||||
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
|
||||
* Added new {func}`jax.process_indices` function to replace the
|
||||
`jax.host_ids()` function that was deprecated in JAX v0.2.13.
|
||||
|
||||
* Breaking changes
|
||||
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
|
||||
|
@ -138,6 +138,7 @@ Parallelization (:code:`pmap`)
|
||||
device_count
|
||||
local_device_count
|
||||
process_count
|
||||
process_indices
|
||||
|
||||
Callbacks
|
||||
---------
|
||||
|
@ -119,6 +119,7 @@ from jax._src.api import named_scope as named_scope
|
||||
from jax._src.api import pmap as pmap
|
||||
from jax._src.xla_bridge import process_count as process_count
|
||||
from jax._src.xla_bridge import process_index as process_index
|
||||
from jax._src.xla_bridge import process_indices as process_indices
|
||||
from jax._src.callback import pure_callback as pure_callback
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
|
||||
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
|
||||
|
@ -1187,15 +1187,30 @@ def host_count(backend: str | xla_client.Client | None = None) -> int:
|
||||
return process_count(backend)
|
||||
|
||||
|
||||
def process_indices(
|
||||
backend: str | xla_client.Client | None = None
|
||||
) -> list[int]:
|
||||
"""Returns the list of all JAX process indices associated with the backend.
|
||||
|
||||
Args:
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
|
||||
Returns:
|
||||
List of integer process indices.
|
||||
"""
|
||||
return list(range(process_count(backend)))
|
||||
|
||||
|
||||
# TODO: remove this sometime after jax 0.2.13 is released
|
||||
def host_ids(
|
||||
backend: str | xla_client.Client | None = None
|
||||
) -> list[int]:
|
||||
warnings.warn(
|
||||
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
|
||||
"instead. jax.host_ids will eventually be removed; please update your "
|
||||
"code.")
|
||||
return list(range(process_count(backend)))
|
||||
"jax.host_ids has been renamed to jax.process_indices. This alias "
|
||||
"will eventually be removed; please update your code.")
|
||||
return process_indices(backend)
|
||||
|
||||
|
||||
def using_pjrt_c_api(backend=None):
|
||||
|
Loading…
x
Reference in New Issue
Block a user