[PJRT C API] Add parsing PJRT client create options from json file.

PiperOrigin-RevId: 518418760
This commit is contained in:
Jieying Luo 2023-03-21 16:52:49 -07:00 committed by jax authors
parent a041c553f9
commit b403c2a083
4 changed files with 92 additions and 10 deletions

View File

@ -20,11 +20,13 @@ XLA. There are also a handful of related casting utilities.
""" """
from functools import partial, lru_cache from functools import partial, lru_cache
import io
import json
import logging import logging
import os import os
import platform as py_platform import platform as py_platform
import threading import threading
from typing import Any, Callable, Dict, List, Optional, Union, Tuple from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import warnings import warnings
import numpy as np import numpy as np
@ -286,19 +288,59 @@ def _get_pjrt_plugin_names_and_library_paths(
return pjrt_plugins return pjrt_plugins
def _get_pjrt_plugin_config(
json_path: str,
) -> Tuple[str, Optional[Mapping[str, Union[str, int, List[int], float]]]]:
"""Gets PJRT plugin configuration from a json file.
The json file needs to have a "library_path" field for the plugin library
path. It can have an optional "create_option" field for the options used when
creating a PJRT plugin client. The value of "create_option" is key-value
pairs. Please see xla_client._NameValueMapping for the supported types of
values.
"""
with io.open(json_path, 'r') as f:
config = json.load(f)
if 'library_path' not in config.keys():
raise ValueError(
'PJRT plugin config file should contain "library_path" field.'
)
return (config['library_path'], config.get('create_options'))
def register_pjrt_plugin_factories(plugins_from_env: str) -> None: def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
"""Registers backend factories for PJRT plugins. """Registers backend factories for PJRT plugins.
A backend factory will be registered for every PJRT plugin in the input A backend factory will be registered for every PJRT plugin in the input
string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2'
for windows). TPU PJRT plugin will be loaded and registered separately in for windows). The path can be a path to the plugin library or a path to the
make_tpu_client. plugin configuration json file. The json file needs to have a "library_path"
field for the plugin library path. It can have an optional "create_option"
field for the options used when creating a PJRT plugin client. The value of
"create_option" is key-value pairs. Please see xla_client._NameValueMapping
for the supported types of values.
TPU PJRT plugin will be loaded and registered separately in make_tpu_client.
""" """
def make_factory(name, path): def make_factory(name: str, path: str):
def factory(): def factory():
xla_client.load_pjrt_plugin_dynamically(name, path) if path.endswith('.json'):
return xla_client.make_c_api_client(name) library_path, options = _get_pjrt_plugin_config(path)
else:
library_path = path
options = None
xla_client.load_pjrt_plugin_dynamically(name, library_path)
if lib.xla_extension_version >= 134:
return xla_client.make_c_api_client(name, options)
else:
if options:
raise ValueError(
'Setting PJRT plugin options through json file requires'
' jaxlib.xla_extension_version >= 134.'
)
return xla_client.make_c_api_client(name)
return factory return factory

View File

@ -848,6 +848,7 @@ py_test(
py_test( py_test(
name = "xla_bridge_test", name = "xla_bridge_test",
srcs = ["xla_bridge_test.py"], srcs = ["xla_bridge_test.py"],
data = ["testdata/example_pjrt_plugin_config.json"],
deps = [ deps = [
"//jax", "//jax",
"//jax:test_util", "//jax:test_util",

View File

@ -0,0 +1,9 @@
{
"library_path": "/path/pjrt_plugin_name1.so",
"create_options": {
"int_option": 64,
"int_list_option": [32, 64],
"string_option": "string",
"float_option": 1.0
}
}

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import time import time
import warnings import warnings
@ -90,9 +91,6 @@ class XlaBridgeTest(jtu.JaxTestCase):
xb.tpu_client_timer_callback(0.01) xb.tpu_client_timer_callback(0.01)
def test_register_plugin(self): def test_register_plugin(self):
if xc._version < 126:
return
with self.assertLogs(level="WARNING") as log_output: with self.assertLogs(level="WARNING") as log_output:
xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3") xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3")
client_factory, priotiy = xb._backend_factories["name1"] client_factory, priotiy = xb._backend_factories["name1"]
@ -111,7 +109,39 @@ class XlaBridgeTest(jtu.JaxTestCase):
self.assertIn("name2", xb._backend_factories) self.assertIn("name2", xb._backend_factories)
self.assertEqual(priotiy, 400) self.assertEqual(priotiy, 400)
mock_load_plugin.assert_called_once_with("name1", "path1") mock_load_plugin.assert_called_once_with("name1", "path1")
mock_make.assert_called_once_with("name1") if xc._version >= 134:
mock_make.assert_called_once_with("name1", None)
else:
mock_make.assert_called_once_with("name1")
def test_register_plugin_with_config(self):
if xc._version < 134:
return
test_json_file_path = os.path.join(
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
)
xb.register_pjrt_plugin_factories(f"name1:{test_json_file_path}")
client_factory, priority = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
with mock.patch.object(
xc, "load_pjrt_plugin_dynamically", autospec=True
) as mock_load_plugin:
client_factory()
self.assertIn("name1", xb._backend_factories)
self.assertEqual(priority, 400)
mock_load_plugin.assert_called_once_with(
"name1", "/path/pjrt_plugin_name1.so"
)
mock_make.assert_called_once_with(
"name1",
{
"int_option": 64,
"int_list_option": [32, 64],
"string_option": "string",
"float_option": 1.0,
},
)
class GetBackendTest(jtu.JaxTestCase): class GetBackendTest(jtu.JaxTestCase):