[PJRT PLUGIN] Pass in a blocking key value get and a key value put function pointer instead of a DistributedRuntimeClient pointer when creating a GPU client.

This is to support multi-process/multi-node GPU PJRT plugin. A DistributedRuntimeClient pointer should not be passed through the C API boundary. Therefore a key value get and key value put function pointer is provided by the framework. This change focuses on changes related to the C++ GPU client. C API related changes will be a follow up change.

This change includes:
- Use kv_get and kv_put in NCCL id.
- The lead node (node_id 0) uses kv_get and kv_put to generate and put the global topology in se_gpu_pjrt_client. Other nodes use kv_get to get the global topology.
- Use kv_get, kv_put and number of nodes when creating a StreamExecutorGpuClient. kv_get and kv_put can be generated from DistributedRuntimeClient. However, DistributedRuntimeClient and DistributedRuntimeService does not expose number of nodes. Currently it will be obtained from distributed.global_state.
- Modify xla.cc to create kv_get and kv_put from DistributedRuntimeClient.
- Modify xla_bridge to pass in num_nodes.
- Change call sites of GetStreamExecutorGpuClient. Most call sites use a nullptr DistributedRuntimeClient and it is a no-op for them.

PiperOrigin-RevId: 538845061
This commit is contained in:
Jieying Luo 2023-06-08 11:33:42 -07:00 committed by jax authors
parent 5c5baccb2e
commit e598bc8e7e

View File

@ -37,6 +37,7 @@ from jax._src import lib
from jax._src import distributed
from jax._src.config import flags, bool_env, config, int_env
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src import traceback_util
from jax._src import util
@ -234,11 +235,24 @@ def make_gpu_client(
allowed_devices = None
if visible_devices != "all":
allowed_devices = {int(x) for x in visible_devices.split(",")}
return xla_client.make_gpu_client(
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
platform_name=platform_name,
allowed_devices=allowed_devices)
if xla_extension_version < 160:
return xla_client.make_gpu_client(
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
platform_name=platform_name,
allowed_devices=allowed_devices,
)
else:
# Remove `type: ignore` when the min jaxlib version (xla_extension_version)
# >= 160.
return xla_client.make_gpu_client(
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
num_nodes=distributed.global_state.num_processes,
platform_name=platform_name,
allowed_devices=allowed_devices,
) # type: ignore
if hasattr(xla_client, "make_gpu_client"):