mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Raise an exception if jax.distributed.initialize() is called after backends have been initialized.
Fixes https://github.com/google/jax/issues/18237 PiperOrigin-RevId: 579936065
This commit is contained in:
parent
f54be7aa7c
commit
eeafff5891
@ -20,6 +20,7 @@ from typing import Any, Optional, Union
|
||||
|
||||
from jax._src import clusters
|
||||
from jax._src import config
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
@ -40,6 +41,9 @@ class State:
|
||||
process_id: Optional[int] = None,
|
||||
local_device_ids: Optional[Union[int, Sequence[int]]] = None,
|
||||
initialization_timeout: int = 300):
|
||||
if xla_bridge.backends_are_initialized():
|
||||
raise RuntimeError("jax.distributed.initialize() must be called before "
|
||||
"any JAX computations are executed.")
|
||||
coordinator_address = (coordinator_address or
|
||||
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
|
||||
if isinstance(local_device_ids, int):
|
||||
|
@ -578,6 +578,13 @@ def expand_platform_alias(platform: str) -> list[str]:
|
||||
def is_gpu(platform):
|
||||
return platform in ("cuda", "rocm")
|
||||
|
||||
|
||||
def backends_are_initialized() -> bool:
|
||||
"Returns true if backends have already been initialized."
|
||||
with _backend_lock:
|
||||
return _backends is not None
|
||||
|
||||
|
||||
def backends() -> dict[str, xla_client.Client]:
|
||||
global _backends
|
||||
global _backend_errors
|
||||
|
Loading…
x
Reference in New Issue
Block a user