From e598bc8e7e560fb66ccb2e0cb78645d8f1d6d156 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Thu, 8 Jun 2023 11:33:42 -0700 Subject: [PATCH] [PJRT PLUGIN] Pass in a blocking key value get and a key value put function pointer instead of a DistributedRuntimeClient pointer when creating a GPU client. This is to support multi-process/multi-node GPU PJRT plugin. A DistributedRuntimeClient pointer should not be passed through the C API boundary. Therefore a key value get and key value put function pointer is provided by the framework. This change focuses on changes related to the C++ GPU client. C API related changes will be a follow up change. This change includes: - Use kv_get and kv_put in NCCL id. - The lead node (node_id 0) uses kv_get and kv_put to generate and put the global topology in se_gpu_pjrt_client. Other nodes use kv_get to get the global topology. - Use kv_get, kv_put and number of nodes when creating a StreamExecutorGpuClient. kv_get and kv_put can be generated from DistributedRuntimeClient. However, DistributedRuntimeClient and DistributedRuntimeService does not expose number of nodes. Currently it will be obtained from distributed.global_state. - Modify xla.cc to create kv_get and kv_put from DistributedRuntimeClient. - Modify xla_bridge to pass in num_nodes. - Change call sites of GetStreamExecutorGpuClient. Most call sites use a nullptr DistributedRuntimeClient and it is a no-op for them. PiperOrigin-RevId: 538845061 --- jax/_src/xla_bridge.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index cee5381d8..42b2c6a74 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -37,6 +37,7 @@ from jax._src import lib from jax._src import distributed from jax._src.config import flags, bool_env, config, int_env from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version from jax._src import traceback_util from jax._src import util @@ -234,11 +235,24 @@ def make_gpu_client( allowed_devices = None if visible_devices != "all": allowed_devices = {int(x) for x in visible_devices.split(",")} - return xla_client.make_gpu_client( - distributed_client=distributed.global_state.client, - node_id=distributed.global_state.process_id, - platform_name=platform_name, - allowed_devices=allowed_devices) + + if xla_extension_version < 160: + return xla_client.make_gpu_client( + distributed_client=distributed.global_state.client, + node_id=distributed.global_state.process_id, + platform_name=platform_name, + allowed_devices=allowed_devices, + ) + else: + # Remove `type: ignore` when the min jaxlib version (xla_extension_version) + # >= 160. + return xla_client.make_gpu_client( + distributed_client=distributed.global_state.client, + node_id=distributed.global_state.process_id, + num_nodes=distributed.global_state.num_processes, + platform_name=platform_name, + allowed_devices=allowed_devices, + ) # type: ignore if hasattr(xla_client, "make_gpu_client"):