Raise an exception if jax.distributed.initialize() is called after backends have been initialized.

Fixes https://github.com/google/jax/issues/18237

PiperOrigin-RevId: 579936065
This commit is contained in:
Peter Hawkins 2023-11-06 13:10:49 -08:00 committed by jax authors
parent f54be7aa7c
commit eeafff5891
2 changed files with 11 additions and 0 deletions

View File

@ -20,6 +20,7 @@ from typing import Any, Optional, Union
from jax._src import clusters
from jax._src import config
from jax._src import xla_bridge
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
@ -40,6 +41,9 @@ class State:
process_id: Optional[int] = None,
local_device_ids: Optional[Union[int, Sequence[int]]] = None,
initialization_timeout: int = 300):
if xla_bridge.backends_are_initialized():
raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.")
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if isinstance(local_device_ids, int):

View File

@ -578,6 +578,13 @@ def expand_platform_alias(platform: str) -> list[str]:
def is_gpu(platform):
return platform in ("cuda", "rocm")
def backends_are_initialized() -> bool:
"Returns true if backends have already been initialized."
with _backend_lock:
return _backends is not None
def backends() -> dict[str, xla_client.Client]:
global _backends
global _backend_errors