Use xla_extension_version and remove some dead version check in xla_bridge_test.py.

Min jaxlib requires xla_extension_version >= 144.

PiperOrigin-RevId: 536810415
This commit is contained in:
Jieying Luo 2023-05-31 13:41:53 -07:00 committed by jax authors
parent 727c121169
commit b35c20ce5d

View File

@ -21,6 +21,7 @@ from absl.testing import absltest
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.interpreters import xla
from jax._src.config import config
@ -99,7 +100,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
with mock.patch.object(
xc, "load_pjrt_plugin_dynamically", autospec=True
) as mock_load_plugin:
if xc._version >= 152:
if xla_extension_version >= 152:
with mock.patch.object(
xc, "pjrt_plugin_loaded", autospec=True
) as mock_plugin_loaded:
@ -115,18 +116,13 @@ class XlaBridgeTest(jtu.JaxTestCase):
self.assertIn("name1", xb._backend_factories)
self.assertIn("name2", xb._backend_factories)
self.assertEqual(priotiy, 400)
if xc._version >= 152:
if xla_extension_version >= 152:
mock_plugin_loaded.assert_called_once_with("name1")
else:
mock_load_plugin.assert_called_once_with("name1", "path1")
if xc._version >= 134:
mock_make.assert_called_once_with("name1", None)
else:
mock_make.assert_called_once_with("name1")
mock_make.assert_called_once_with("name1", None)
def test_register_plugin_with_config(self):
if xc._version < 134:
return
test_json_file_path = os.path.join(
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
)
@ -137,7 +133,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
with mock.patch.object(
xc, "load_pjrt_plugin_dynamically", autospec=True
) as mock_load_plugin:
if xc._version >= 152:
if xla_extension_version >= 152:
with mock.patch.object(
xc, "pjrt_plugin_loaded", autospec=True
) as mock_plugin_loaded:
@ -147,7 +143,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
self.assertIn("name1", xb._backend_factories)
self.assertEqual(priority, 400)
if xc._version >= 152:
if xla_extension_version >= 152:
mock_plugin_loaded.assert_called_once_with("name1")
else:
mock_load_plugin.assert_called_once_with(