[PJRT C API] Add parsing PJRT client create options from json file.

PiperOrigin-RevId: 518418760
This commit is contained in:
Jieying Luo 2023-03-21 16:52:49 -07:00 committed by jax authors
parent a041c553f9
commit b403c2a083
4 changed files with 92 additions and 10 deletions

View File

@ -20,11 +20,13 @@ XLA. There are also a handful of related casting utilities.
"""
from functools import partial, lru_cache
import io
import json
import logging
import os
import platform as py_platform
import threading
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import warnings
import numpy as np
@ -286,19 +288,59 @@ def _get_pjrt_plugin_names_and_library_paths(
return pjrt_plugins
def _get_pjrt_plugin_config(
json_path: str,
) -> Tuple[str, Optional[Mapping[str, Union[str, int, List[int], float]]]]:
"""Gets PJRT plugin configuration from a json file.
The json file needs to have a "library_path" field for the plugin library
path. It can have an optional "create_option" field for the options used when
creating a PJRT plugin client. The value of "create_option" is key-value
pairs. Please see xla_client._NameValueMapping for the supported types of
values.
"""
with io.open(json_path, 'r') as f:
config = json.load(f)
if 'library_path' not in config.keys():
raise ValueError(
'PJRT plugin config file should contain "library_path" field.'
)
return (config['library_path'], config.get('create_options'))
def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
"""Registers backend factories for PJRT plugins.
A backend factory will be registered for every PJRT plugin in the input
string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2'
for windows). TPU PJRT plugin will be loaded and registered separately in
make_tpu_client.
for windows). The path can be a path to the plugin library or a path to the
plugin configuration json file. The json file needs to have a "library_path"
field for the plugin library path. It can have an optional "create_option"
field for the options used when creating a PJRT plugin client. The value of
"create_option" is key-value pairs. Please see xla_client._NameValueMapping
for the supported types of values.
TPU PJRT plugin will be loaded and registered separately in make_tpu_client.
"""
def make_factory(name, path):
def make_factory(name: str, path: str):
def factory():
xla_client.load_pjrt_plugin_dynamically(name, path)
return xla_client.make_c_api_client(name)
if path.endswith('.json'):
library_path, options = _get_pjrt_plugin_config(path)
else:
library_path = path
options = None
xla_client.load_pjrt_plugin_dynamically(name, library_path)
if lib.xla_extension_version >= 134:
return xla_client.make_c_api_client(name, options)
else:
if options:
raise ValueError(
'Setting PJRT plugin options through json file requires'
' jaxlib.xla_extension_version >= 134.'
)
return xla_client.make_c_api_client(name)
return factory

View File

@ -848,6 +848,7 @@ py_test(
py_test(
name = "xla_bridge_test",
srcs = ["xla_bridge_test.py"],
data = ["testdata/example_pjrt_plugin_config.json"],
deps = [
"//jax",
"//jax:test_util",

View File

@ -0,0 +1,9 @@
{
"library_path": "/path/pjrt_plugin_name1.so",
"create_options": {
"int_option": 64,
"int_list_option": [32, 64],
"string_option": "string",
"float_option": 1.0
}
}

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import warnings
@ -90,9 +91,6 @@ class XlaBridgeTest(jtu.JaxTestCase):
xb.tpu_client_timer_callback(0.01)
def test_register_plugin(self):
if xc._version < 126:
return
with self.assertLogs(level="WARNING") as log_output:
xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3")
client_factory, priotiy = xb._backend_factories["name1"]
@ -111,7 +109,39 @@ class XlaBridgeTest(jtu.JaxTestCase):
self.assertIn("name2", xb._backend_factories)
self.assertEqual(priotiy, 400)
mock_load_plugin.assert_called_once_with("name1", "path1")
mock_make.assert_called_once_with("name1")
if xc._version >= 134:
mock_make.assert_called_once_with("name1", None)
else:
mock_make.assert_called_once_with("name1")
def test_register_plugin_with_config(self):
if xc._version < 134:
return
test_json_file_path = os.path.join(
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
)
xb.register_pjrt_plugin_factories(f"name1:{test_json_file_path}")
client_factory, priority = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
with mock.patch.object(
xc, "load_pjrt_plugin_dynamically", autospec=True
) as mock_load_plugin:
client_factory()
self.assertIn("name1", xb._backend_factories)
self.assertEqual(priority, 400)
mock_load_plugin.assert_called_once_with(
"name1", "/path/pjrt_plugin_name1.so"
)
mock_make.assert_called_once_with(
"name1",
{
"int_option": 64,
"int_list_option": [32, 64],
"string_option": "string",
"float_option": 1.0,
},
)
class GetBackendTest(jtu.JaxTestCase):