mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[PJRT C API] Add parsing PJRT client create options from json file.
PiperOrigin-RevId: 518418760
This commit is contained in:
parent
a041c553f9
commit
b403c2a083
@ -20,11 +20,13 @@ XLA. There are also a handful of related casting utilities.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial, lru_cache
|
from functools import partial, lru_cache
|
||||||
|
import io
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform as py_platform
|
import platform as py_platform
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -286,19 +288,59 @@ def _get_pjrt_plugin_names_and_library_paths(
|
|||||||
return pjrt_plugins
|
return pjrt_plugins
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pjrt_plugin_config(
|
||||||
|
json_path: str,
|
||||||
|
) -> Tuple[str, Optional[Mapping[str, Union[str, int, List[int], float]]]]:
|
||||||
|
"""Gets PJRT plugin configuration from a json file.
|
||||||
|
|
||||||
|
The json file needs to have a "library_path" field for the plugin library
|
||||||
|
path. It can have an optional "create_option" field for the options used when
|
||||||
|
creating a PJRT plugin client. The value of "create_option" is key-value
|
||||||
|
pairs. Please see xla_client._NameValueMapping for the supported types of
|
||||||
|
values.
|
||||||
|
"""
|
||||||
|
with io.open(json_path, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
if 'library_path' not in config.keys():
|
||||||
|
raise ValueError(
|
||||||
|
'PJRT plugin config file should contain "library_path" field.'
|
||||||
|
)
|
||||||
|
return (config['library_path'], config.get('create_options'))
|
||||||
|
|
||||||
|
|
||||||
def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
|
def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
|
||||||
"""Registers backend factories for PJRT plugins.
|
"""Registers backend factories for PJRT plugins.
|
||||||
|
|
||||||
A backend factory will be registered for every PJRT plugin in the input
|
A backend factory will be registered for every PJRT plugin in the input
|
||||||
string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2'
|
string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2'
|
||||||
for windows). TPU PJRT plugin will be loaded and registered separately in
|
for windows). The path can be a path to the plugin library or a path to the
|
||||||
make_tpu_client.
|
plugin configuration json file. The json file needs to have a "library_path"
|
||||||
|
field for the plugin library path. It can have an optional "create_option"
|
||||||
|
field for the options used when creating a PJRT plugin client. The value of
|
||||||
|
"create_option" is key-value pairs. Please see xla_client._NameValueMapping
|
||||||
|
for the supported types of values.
|
||||||
|
|
||||||
|
TPU PJRT plugin will be loaded and registered separately in make_tpu_client.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def make_factory(name, path):
|
def make_factory(name: str, path: str):
|
||||||
def factory():
|
def factory():
|
||||||
xla_client.load_pjrt_plugin_dynamically(name, path)
|
if path.endswith('.json'):
|
||||||
return xla_client.make_c_api_client(name)
|
library_path, options = _get_pjrt_plugin_config(path)
|
||||||
|
else:
|
||||||
|
library_path = path
|
||||||
|
options = None
|
||||||
|
|
||||||
|
xla_client.load_pjrt_plugin_dynamically(name, library_path)
|
||||||
|
if lib.xla_extension_version >= 134:
|
||||||
|
return xla_client.make_c_api_client(name, options)
|
||||||
|
else:
|
||||||
|
if options:
|
||||||
|
raise ValueError(
|
||||||
|
'Setting PJRT plugin options through json file requires'
|
||||||
|
' jaxlib.xla_extension_version >= 134.'
|
||||||
|
)
|
||||||
|
return xla_client.make_c_api_client(name)
|
||||||
|
|
||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
@ -848,6 +848,7 @@ py_test(
|
|||||||
py_test(
|
py_test(
|
||||||
name = "xla_bridge_test",
|
name = "xla_bridge_test",
|
||||||
srcs = ["xla_bridge_test.py"],
|
srcs = ["xla_bridge_test.py"],
|
||||||
|
data = ["testdata/example_pjrt_plugin_config.json"],
|
||||||
deps = [
|
deps = [
|
||||||
"//jax",
|
"//jax",
|
||||||
"//jax:test_util",
|
"//jax:test_util",
|
||||||
|
9
tests/testdata/example_pjrt_plugin_config.json
vendored
Normal file
9
tests/testdata/example_pjrt_plugin_config.json
vendored
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"library_path": "/path/pjrt_plugin_name1.so",
|
||||||
|
"create_options": {
|
||||||
|
"int_option": 64,
|
||||||
|
"int_list_option": [32, 64],
|
||||||
|
"string_option": "string",
|
||||||
|
"float_option": 1.0
|
||||||
|
}
|
||||||
|
}
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -90,9 +91,6 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
|||||||
xb.tpu_client_timer_callback(0.01)
|
xb.tpu_client_timer_callback(0.01)
|
||||||
|
|
||||||
def test_register_plugin(self):
|
def test_register_plugin(self):
|
||||||
if xc._version < 126:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self.assertLogs(level="WARNING") as log_output:
|
with self.assertLogs(level="WARNING") as log_output:
|
||||||
xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3")
|
xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3")
|
||||||
client_factory, priotiy = xb._backend_factories["name1"]
|
client_factory, priotiy = xb._backend_factories["name1"]
|
||||||
@ -111,7 +109,39 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
|||||||
self.assertIn("name2", xb._backend_factories)
|
self.assertIn("name2", xb._backend_factories)
|
||||||
self.assertEqual(priotiy, 400)
|
self.assertEqual(priotiy, 400)
|
||||||
mock_load_plugin.assert_called_once_with("name1", "path1")
|
mock_load_plugin.assert_called_once_with("name1", "path1")
|
||||||
mock_make.assert_called_once_with("name1")
|
if xc._version >= 134:
|
||||||
|
mock_make.assert_called_once_with("name1", None)
|
||||||
|
else:
|
||||||
|
mock_make.assert_called_once_with("name1")
|
||||||
|
|
||||||
|
def test_register_plugin_with_config(self):
|
||||||
|
if xc._version < 134:
|
||||||
|
return
|
||||||
|
test_json_file_path = os.path.join(
|
||||||
|
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
|
||||||
|
)
|
||||||
|
xb.register_pjrt_plugin_factories(f"name1:{test_json_file_path}")
|
||||||
|
client_factory, priority = xb._backend_factories["name1"]
|
||||||
|
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
||||||
|
with mock.patch.object(
|
||||||
|
xc, "load_pjrt_plugin_dynamically", autospec=True
|
||||||
|
) as mock_load_plugin:
|
||||||
|
client_factory()
|
||||||
|
|
||||||
|
self.assertIn("name1", xb._backend_factories)
|
||||||
|
self.assertEqual(priority, 400)
|
||||||
|
mock_load_plugin.assert_called_once_with(
|
||||||
|
"name1", "/path/pjrt_plugin_name1.so"
|
||||||
|
)
|
||||||
|
mock_make.assert_called_once_with(
|
||||||
|
"name1",
|
||||||
|
{
|
||||||
|
"int_option": 64,
|
||||||
|
"int_list_option": [32, 64],
|
||||||
|
"string_option": "string",
|
||||||
|
"float_option": 1.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GetBackendTest(jtu.JaxTestCase):
|
class GetBackendTest(jtu.JaxTestCase):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user