From b403c2a083998dd475b2c12795e2c216f4d75359 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Tue, 21 Mar 2023 16:52:49 -0700 Subject: [PATCH] [PJRT C API] Add parsing PJRT client create options from json file. PiperOrigin-RevId: 518418760 --- jax/_src/xla_bridge.py | 54 ++++++++++++++++--- tests/BUILD | 1 + .../testdata/example_pjrt_plugin_config.json | 9 ++++ tests/xla_bridge_test.py | 38 +++++++++++-- 4 files changed, 92 insertions(+), 10 deletions(-) create mode 100644 tests/testdata/example_pjrt_plugin_config.json diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 2e0f9a1a2..ad814b9c7 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -20,11 +20,13 @@ XLA. There are also a handful of related casting utilities. """ from functools import partial, lru_cache +import io +import json import logging import os import platform as py_platform import threading -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import warnings import numpy as np @@ -286,19 +288,59 @@ def _get_pjrt_plugin_names_and_library_paths( return pjrt_plugins +def _get_pjrt_plugin_config( + json_path: str, +) -> Tuple[str, Optional[Mapping[str, Union[str, int, List[int], float]]]]: + """Gets PJRT plugin configuration from a json file. + + The json file needs to have a "library_path" field for the plugin library + path. It can have an optional "create_option" field for the options used when + creating a PJRT plugin client. The value of "create_option" is key-value + pairs. Please see xla_client._NameValueMapping for the supported types of + values. + """ + with io.open(json_path, 'r') as f: + config = json.load(f) + if 'library_path' not in config.keys(): + raise ValueError( + 'PJRT plugin config file should contain "library_path" field.' + ) + return (config['library_path'], config.get('create_options')) + + def register_pjrt_plugin_factories(plugins_from_env: str) -> None: """Registers backend factories for PJRT plugins. A backend factory will be registered for every PJRT plugin in the input string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' - for windows). TPU PJRT plugin will be loaded and registered separately in - make_tpu_client. + for windows). The path can be a path to the plugin library or a path to the + plugin configuration json file. The json file needs to have a "library_path" + field for the plugin library path. It can have an optional "create_option" + field for the options used when creating a PJRT plugin client. The value of + "create_option" is key-value pairs. Please see xla_client._NameValueMapping + for the supported types of values. + + TPU PJRT plugin will be loaded and registered separately in make_tpu_client. """ - def make_factory(name, path): + def make_factory(name: str, path: str): def factory(): - xla_client.load_pjrt_plugin_dynamically(name, path) - return xla_client.make_c_api_client(name) + if path.endswith('.json'): + library_path, options = _get_pjrt_plugin_config(path) + else: + library_path = path + options = None + + xla_client.load_pjrt_plugin_dynamically(name, library_path) + if lib.xla_extension_version >= 134: + return xla_client.make_c_api_client(name, options) + else: + if options: + raise ValueError( + 'Setting PJRT plugin options through json file requires' + ' jaxlib.xla_extension_version >= 134.' + ) + return xla_client.make_c_api_client(name) return factory diff --git a/tests/BUILD b/tests/BUILD index fde4db9e6..dd692e7e1 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -848,6 +848,7 @@ py_test( py_test( name = "xla_bridge_test", srcs = ["xla_bridge_test.py"], + data = ["testdata/example_pjrt_plugin_config.json"], deps = [ "//jax", "//jax:test_util", diff --git a/tests/testdata/example_pjrt_plugin_config.json b/tests/testdata/example_pjrt_plugin_config.json new file mode 100644 index 000000000..2a195727c --- /dev/null +++ b/tests/testdata/example_pjrt_plugin_config.json @@ -0,0 +1,9 @@ +{ + "library_path": "/path/pjrt_plugin_name1.so", + "create_options": { + "int_option": 64, + "int_list_option": [32, 64], + "string_option": "string", + "float_option": 1.0 + } +} diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index db35aeac1..468cd1e4b 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import time import warnings @@ -90,9 +91,6 @@ class XlaBridgeTest(jtu.JaxTestCase): xb.tpu_client_timer_callback(0.01) def test_register_plugin(self): - if xc._version < 126: - return - with self.assertLogs(level="WARNING") as log_output: xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3") client_factory, priotiy = xb._backend_factories["name1"] @@ -111,7 +109,39 @@ class XlaBridgeTest(jtu.JaxTestCase): self.assertIn("name2", xb._backend_factories) self.assertEqual(priotiy, 400) mock_load_plugin.assert_called_once_with("name1", "path1") - mock_make.assert_called_once_with("name1") + if xc._version >= 134: + mock_make.assert_called_once_with("name1", None) + else: + mock_make.assert_called_once_with("name1") + + def test_register_plugin_with_config(self): + if xc._version < 134: + return + test_json_file_path = os.path.join( + os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json" + ) + xb.register_pjrt_plugin_factories(f"name1:{test_json_file_path}") + client_factory, priority = xb._backend_factories["name1"] + with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: + with mock.patch.object( + xc, "load_pjrt_plugin_dynamically", autospec=True + ) as mock_load_plugin: + client_factory() + + self.assertIn("name1", xb._backend_factories) + self.assertEqual(priority, 400) + mock_load_plugin.assert_called_once_with( + "name1", "/path/pjrt_plugin_name1.so" + ) + mock_make.assert_called_once_with( + "name1", + { + "int_option": 64, + "int_list_option": [32, 64], + "string_option": "string", + "float_option": 1.0, + }, + ) class GetBackendTest(jtu.JaxTestCase):