mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use xla_extension_version and remove some dead version check in xla_bridge_test.py.
Min jaxlib requires xla_extension_version >= 144. PiperOrigin-RevId: 536810415
This commit is contained in:
parent
727c121169
commit
b35c20ce5d
@ -21,6 +21,7 @@ from absl.testing import absltest
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src.config import config
|
||||
@ -99,7 +100,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
with mock.patch.object(
|
||||
xc, "load_pjrt_plugin_dynamically", autospec=True
|
||||
) as mock_load_plugin:
|
||||
if xc._version >= 152:
|
||||
if xla_extension_version >= 152:
|
||||
with mock.patch.object(
|
||||
xc, "pjrt_plugin_loaded", autospec=True
|
||||
) as mock_plugin_loaded:
|
||||
@ -115,18 +116,13 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
self.assertIn("name1", xb._backend_factories)
|
||||
self.assertIn("name2", xb._backend_factories)
|
||||
self.assertEqual(priotiy, 400)
|
||||
if xc._version >= 152:
|
||||
if xla_extension_version >= 152:
|
||||
mock_plugin_loaded.assert_called_once_with("name1")
|
||||
else:
|
||||
mock_load_plugin.assert_called_once_with("name1", "path1")
|
||||
if xc._version >= 134:
|
||||
mock_make.assert_called_once_with("name1", None)
|
||||
else:
|
||||
mock_make.assert_called_once_with("name1")
|
||||
mock_make.assert_called_once_with("name1", None)
|
||||
|
||||
def test_register_plugin_with_config(self):
|
||||
if xc._version < 134:
|
||||
return
|
||||
test_json_file_path = os.path.join(
|
||||
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
|
||||
)
|
||||
@ -137,7 +133,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
with mock.patch.object(
|
||||
xc, "load_pjrt_plugin_dynamically", autospec=True
|
||||
) as mock_load_plugin:
|
||||
if xc._version >= 152:
|
||||
if xla_extension_version >= 152:
|
||||
with mock.patch.object(
|
||||
xc, "pjrt_plugin_loaded", autospec=True
|
||||
) as mock_plugin_loaded:
|
||||
@ -147,7 +143,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertIn("name1", xb._backend_factories)
|
||||
self.assertEqual(priority, 400)
|
||||
if xc._version >= 152:
|
||||
if xla_extension_version >= 152:
|
||||
mock_plugin_loaded.assert_called_once_with("name1")
|
||||
else:
|
||||
mock_load_plugin.assert_called_once_with(
|
||||
|
Loading…
x
Reference in New Issue
Block a user