mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[PJRT C API] Add parsing PJRT client create options from json file.
PiperOrigin-RevId: 518418760
This commit is contained in:
parent
a041c553f9
commit
b403c2a083
@ -20,11 +20,13 @@ XLA. There are also a handful of related casting utilities.
|
||||
"""
|
||||
|
||||
from functools import partial, lru_cache
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform as py_platform
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -286,19 +288,59 @@ def _get_pjrt_plugin_names_and_library_paths(
|
||||
return pjrt_plugins
|
||||
|
||||
|
||||
def _get_pjrt_plugin_config(
|
||||
json_path: str,
|
||||
) -> Tuple[str, Optional[Mapping[str, Union[str, int, List[int], float]]]]:
|
||||
"""Gets PJRT plugin configuration from a json file.
|
||||
|
||||
The json file needs to have a "library_path" field for the plugin library
|
||||
path. It can have an optional "create_option" field for the options used when
|
||||
creating a PJRT plugin client. The value of "create_option" is key-value
|
||||
pairs. Please see xla_client._NameValueMapping for the supported types of
|
||||
values.
|
||||
"""
|
||||
with io.open(json_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
if 'library_path' not in config.keys():
|
||||
raise ValueError(
|
||||
'PJRT plugin config file should contain "library_path" field.'
|
||||
)
|
||||
return (config['library_path'], config.get('create_options'))
|
||||
|
||||
|
||||
def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
|
||||
"""Registers backend factories for PJRT plugins.
|
||||
|
||||
A backend factory will be registered for every PJRT plugin in the input
|
||||
string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2'
|
||||
for windows). TPU PJRT plugin will be loaded and registered separately in
|
||||
make_tpu_client.
|
||||
for windows). The path can be a path to the plugin library or a path to the
|
||||
plugin configuration json file. The json file needs to have a "library_path"
|
||||
field for the plugin library path. It can have an optional "create_option"
|
||||
field for the options used when creating a PJRT plugin client. The value of
|
||||
"create_option" is key-value pairs. Please see xla_client._NameValueMapping
|
||||
for the supported types of values.
|
||||
|
||||
TPU PJRT plugin will be loaded and registered separately in make_tpu_client.
|
||||
"""
|
||||
|
||||
def make_factory(name, path):
|
||||
def make_factory(name: str, path: str):
|
||||
def factory():
|
||||
xla_client.load_pjrt_plugin_dynamically(name, path)
|
||||
return xla_client.make_c_api_client(name)
|
||||
if path.endswith('.json'):
|
||||
library_path, options = _get_pjrt_plugin_config(path)
|
||||
else:
|
||||
library_path = path
|
||||
options = None
|
||||
|
||||
xla_client.load_pjrt_plugin_dynamically(name, library_path)
|
||||
if lib.xla_extension_version >= 134:
|
||||
return xla_client.make_c_api_client(name, options)
|
||||
else:
|
||||
if options:
|
||||
raise ValueError(
|
||||
'Setting PJRT plugin options through json file requires'
|
||||
' jaxlib.xla_extension_version >= 134.'
|
||||
)
|
||||
return xla_client.make_c_api_client(name)
|
||||
|
||||
return factory
|
||||
|
||||
|
@ -848,6 +848,7 @@ py_test(
|
||||
py_test(
|
||||
name = "xla_bridge_test",
|
||||
srcs = ["xla_bridge_test.py"],
|
||||
data = ["testdata/example_pjrt_plugin_config.json"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
|
9
tests/testdata/example_pjrt_plugin_config.json
vendored
Normal file
9
tests/testdata/example_pjrt_plugin_config.json
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"library_path": "/path/pjrt_plugin_name1.so",
|
||||
"create_options": {
|
||||
"int_option": 64,
|
||||
"int_list_option": [32, 64],
|
||||
"string_option": "string",
|
||||
"float_option": 1.0
|
||||
}
|
||||
}
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
|
||||
@ -90,9 +91,6 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
xb.tpu_client_timer_callback(0.01)
|
||||
|
||||
def test_register_plugin(self):
|
||||
if xc._version < 126:
|
||||
return
|
||||
|
||||
with self.assertLogs(level="WARNING") as log_output:
|
||||
xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3")
|
||||
client_factory, priotiy = xb._backend_factories["name1"]
|
||||
@ -111,7 +109,39 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
self.assertIn("name2", xb._backend_factories)
|
||||
self.assertEqual(priotiy, 400)
|
||||
mock_load_plugin.assert_called_once_with("name1", "path1")
|
||||
mock_make.assert_called_once_with("name1")
|
||||
if xc._version >= 134:
|
||||
mock_make.assert_called_once_with("name1", None)
|
||||
else:
|
||||
mock_make.assert_called_once_with("name1")
|
||||
|
||||
def test_register_plugin_with_config(self):
|
||||
if xc._version < 134:
|
||||
return
|
||||
test_json_file_path = os.path.join(
|
||||
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
|
||||
)
|
||||
xb.register_pjrt_plugin_factories(f"name1:{test_json_file_path}")
|
||||
client_factory, priority = xb._backend_factories["name1"]
|
||||
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
||||
with mock.patch.object(
|
||||
xc, "load_pjrt_plugin_dynamically", autospec=True
|
||||
) as mock_load_plugin:
|
||||
client_factory()
|
||||
|
||||
self.assertIn("name1", xb._backend_factories)
|
||||
self.assertEqual(priority, 400)
|
||||
mock_load_plugin.assert_called_once_with(
|
||||
"name1", "/path/pjrt_plugin_name1.so"
|
||||
)
|
||||
mock_make.assert_called_once_with(
|
||||
"name1",
|
||||
{
|
||||
"int_option": 64,
|
||||
"int_list_option": [32, 64],
|
||||
"string_option": "string",
|
||||
"float_option": 1.0,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class GetBackendTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user