mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[PJRT PLUGIN] Pass in a blocking key value get and a key value put function pointer instead of a DistributedRuntimeClient pointer when creating a GPU client.
This is to support multi-process/multi-node GPU PJRT plugin. A DistributedRuntimeClient pointer should not be passed through the C API boundary. Therefore a key value get and key value put function pointer is provided by the framework. This change focuses on changes related to the C++ GPU client. C API related changes will be a follow up change. This change includes: - Use kv_get and kv_put in NCCL id. - The lead node (node_id 0) uses kv_get and kv_put to generate and put the global topology in se_gpu_pjrt_client. Other nodes use kv_get to get the global topology. - Use kv_get, kv_put and number of nodes when creating a StreamExecutorGpuClient. kv_get and kv_put can be generated from DistributedRuntimeClient. However, DistributedRuntimeClient and DistributedRuntimeService does not expose number of nodes. Currently it will be obtained from distributed.global_state. - Modify xla.cc to create kv_get and kv_put from DistributedRuntimeClient. - Modify xla_bridge to pass in num_nodes. - Change call sites of GetStreamExecutorGpuClient. Most call sites use a nullptr DistributedRuntimeClient and it is a no-op for them. PiperOrigin-RevId: 538845061
This commit is contained in:
parent
5c5baccb2e
commit
e598bc8e7e
@ -37,6 +37,7 @@ from jax._src import lib
|
||||
from jax._src import distributed
|
||||
from jax._src.config import flags, bool_env, config, int_env
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
|
||||
@ -234,11 +235,24 @@ def make_gpu_client(
|
||||
allowed_devices = None
|
||||
if visible_devices != "all":
|
||||
allowed_devices = {int(x) for x in visible_devices.split(",")}
|
||||
return xla_client.make_gpu_client(
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
platform_name=platform_name,
|
||||
allowed_devices=allowed_devices)
|
||||
|
||||
if xla_extension_version < 160:
|
||||
return xla_client.make_gpu_client(
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
platform_name=platform_name,
|
||||
allowed_devices=allowed_devices,
|
||||
)
|
||||
else:
|
||||
# Remove `type: ignore` when the min jaxlib version (xla_extension_version)
|
||||
# >= 160.
|
||||
return xla_client.make_gpu_client(
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
platform_name=platform_name,
|
||||
allowed_devices=allowed_devices,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
if hasattr(xla_client, "make_gpu_client"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user