2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2019 The JAX Authors.
|
2019-12-03 10:08:55 -05:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-03-21 16:52:49 -07:00
|
|
|
import os
|
2023-06-16 10:15:14 -04:00
|
|
|
import platform
|
2021-07-28 15:22:42 -07:00
|
|
|
import time
|
2023-08-23 11:33:43 -07:00
|
|
|
import unittest
|
2021-08-09 18:49:18 -07:00
|
|
|
import warnings
|
2019-12-02 16:07:23 -08:00
|
|
|
|
2022-05-20 09:46:04 -07:00
|
|
|
from absl import logging
|
2019-12-02 16:07:23 -08:00
|
|
|
from absl.testing import absltest
|
2023-08-15 06:38:56 -07:00
|
|
|
from jax._src import compiler
|
2023-08-23 11:33:43 -07:00
|
|
|
from jax._src import config as jax_config
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import test_util as jtu
|
2023-02-28 07:01:14 -08:00
|
|
|
from jax._src import xla_bridge as xb
|
2023-08-23 11:33:43 -07:00
|
|
|
from jax._src.config import config
|
|
|
|
from jax._src.interpreters import xla
|
2021-09-23 06:33:25 -07:00
|
|
|
from jax._src.lib import xla_client as xc
|
2023-08-15 15:23:17 -07:00
|
|
|
from jax._src.lib import xla_extension_version
|
2023-08-31 17:21:17 -07:00
|
|
|
|
2021-09-28 11:46:52 -07:00
|
|
|
config.parse_flags_with_absl()
|
|
|
|
FLAGS = config.FLAGS
|
|
|
|
|
2021-07-28 15:22:42 -07:00
|
|
|
mock = absltest.mock
|
|
|
|
|
|
|
|
|
2021-08-13 17:09:12 -07:00
|
|
|
class XlaBridgeTest(jtu.JaxTestCase):
|
2019-12-02 16:07:23 -08:00
|
|
|
|
|
|
|
def test_set_device_assignment_no_partition(self):
|
2023-08-15 06:38:56 -07:00
|
|
|
compile_options = compiler.get_compile_options(
|
2020-01-29 19:35:48 +00:00
|
|
|
num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3])
|
2019-12-02 16:07:23 -08:00
|
|
|
expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: "
|
|
|
|
"0 1 2 3 \n")
|
|
|
|
self.assertEqual(compile_options.device_assignment.__repr__(),
|
|
|
|
expected_device_assignment)
|
|
|
|
|
|
|
|
def test_set_device_assignment_with_partition(self):
|
2023-08-15 06:38:56 -07:00
|
|
|
compile_options = compiler.get_compile_options(
|
2020-01-29 19:35:48 +00:00
|
|
|
num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]])
|
2019-12-02 16:07:23 -08:00
|
|
|
expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: "
|
|
|
|
"0 2 \nComputation 1: 1 3 \n")
|
|
|
|
self.assertEqual(compile_options.device_assignment.__repr__(),
|
|
|
|
expected_device_assignment)
|
2020-05-20 14:40:28 -07:00
|
|
|
|
2023-07-11 14:47:04 -07:00
|
|
|
def test_set_fdo_profile(self):
|
2023-08-15 06:38:56 -07:00
|
|
|
compile_options = compiler.get_compile_options(
|
2023-08-01 08:52:54 -07:00
|
|
|
num_replicas=1, num_partitions=1, fdo_profile=b"test_profile"
|
|
|
|
)
|
|
|
|
self.assertEqual(
|
|
|
|
compile_options.executable_build_options.fdo_profile, "test_profile"
|
|
|
|
)
|
2023-07-11 14:47:04 -07:00
|
|
|
|
2023-08-09 18:16:12 -07:00
|
|
|
def test_autofdo_profile(self):
|
|
|
|
# --jax_xla_profile_version takes precedence.
|
|
|
|
jax_flag_profile = 1
|
|
|
|
another_profile = 2
|
|
|
|
with jax_config.jax_xla_profile_version(jax_flag_profile):
|
2023-08-15 06:38:56 -07:00
|
|
|
with mock.patch.object(compiler, "get_latest_profile_version",
|
2023-08-09 18:16:12 -07:00
|
|
|
side_effect=lambda: another_profile):
|
|
|
|
self.assertEqual(
|
2023-08-15 06:38:56 -07:00
|
|
|
compiler.get_compile_options(
|
2023-08-09 18:16:12 -07:00
|
|
|
num_replicas=3, num_partitions=4
|
|
|
|
).profile_version,
|
|
|
|
jax_flag_profile,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Use whatever non-zero value the function get_latest_profile_version
|
|
|
|
# returns if --jax_xla_profile_version is not set.
|
|
|
|
profile_version = 1
|
2023-08-15 06:38:56 -07:00
|
|
|
with mock.patch.object(compiler, "get_latest_profile_version",
|
2023-08-09 18:16:12 -07:00
|
|
|
side_effect=lambda: profile_version):
|
|
|
|
self.assertEqual(
|
2023-08-15 06:38:56 -07:00
|
|
|
compiler.get_compile_options(
|
2023-08-09 18:16:12 -07:00
|
|
|
num_replicas=3, num_partitions=4
|
|
|
|
).profile_version,
|
|
|
|
profile_version,
|
|
|
|
)
|
|
|
|
|
|
|
|
# If the function returns 0, something is wrong, so expect that we set
|
|
|
|
# profile_version to -1 instead to ensure that no attempt is made to
|
|
|
|
# retrieve the latest profile later.
|
|
|
|
error_return = 0
|
|
|
|
no_profile_dont_retrieve = -1
|
2023-08-15 06:38:56 -07:00
|
|
|
with mock.patch.object(compiler, "get_latest_profile_version",
|
2023-08-09 18:16:12 -07:00
|
|
|
side_effect=lambda: error_return):
|
|
|
|
self.assertEqual(
|
2023-08-15 06:38:56 -07:00
|
|
|
compiler.get_compile_options(
|
2023-08-09 18:16:12 -07:00
|
|
|
num_replicas=3, num_partitions=4
|
|
|
|
).profile_version,
|
|
|
|
no_profile_dont_retrieve,
|
|
|
|
)
|
|
|
|
|
2023-08-23 11:33:43 -07:00
|
|
|
@unittest.skipIf(
|
|
|
|
xla_extension_version < 189, "Test requires jaxlib 0.4.15 or newer"
|
|
|
|
)
|
|
|
|
def test_deterministic_serialization(self):
|
|
|
|
c1 = compiler.get_compile_options(
|
|
|
|
num_replicas=2,
|
|
|
|
num_partitions=3,
|
|
|
|
env_options_overrides={"1": "1", "2": "2"},
|
|
|
|
)
|
|
|
|
c2 = compiler.get_compile_options(
|
|
|
|
num_replicas=2,
|
|
|
|
num_partitions=3,
|
|
|
|
env_options_overrides={"2": "2", "1": "1"}, # order changed
|
|
|
|
)
|
|
|
|
c1str = c1.SerializeAsString()
|
|
|
|
|
|
|
|
# Idempotence.
|
|
|
|
self.assertEqual(c1str, c1.SerializeAsString())
|
|
|
|
# Map order does not matter.
|
|
|
|
self.assertEqual(c1str, c2.SerializeAsString())
|
|
|
|
|
2020-04-29 11:31:36 -07:00
|
|
|
def test_parameter_replication_default(self):
|
2021-10-18 13:19:45 -04:00
|
|
|
c = xc.XlaBuilder("test")
|
2021-11-30 14:24:02 -08:00
|
|
|
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
|
2020-04-29 11:31:36 -07:00
|
|
|
built_c = c.Build()
|
2020-05-15 15:51:07 -04:00
|
|
|
assert "replication" not in built_c.as_hlo_text()
|
2020-04-29 11:31:36 -07:00
|
|
|
|
|
|
|
def test_parameter_replication(self):
|
2021-10-18 13:19:45 -04:00
|
|
|
c = xc.XlaBuilder("test")
|
2021-11-30 14:24:02 -08:00
|
|
|
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "",
|
2021-08-09 18:49:18 -07:00
|
|
|
False)
|
2020-04-29 11:31:36 -07:00
|
|
|
built_c = c.Build()
|
2020-05-11 17:43:55 -04:00
|
|
|
assert "parameter_replication={false}" in built_c.as_hlo_text()
|
2019-12-02 16:07:23 -08:00
|
|
|
|
2020-05-20 14:40:28 -07:00
|
|
|
def test_local_devices(self):
|
|
|
|
self.assertNotEmpty(xb.local_devices())
|
2021-04-20 17:56:41 -07:00
|
|
|
with self.assertRaisesRegex(ValueError, "Unknown process_index 100"):
|
2020-05-20 14:40:28 -07:00
|
|
|
xb.local_devices(100)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
|
|
|
|
xb.local_devices(backend="foo")
|
|
|
|
|
2021-08-09 18:49:18 -07:00
|
|
|
def test_timer_tpu_warning(self):
|
2021-07-28 18:56:22 -07:00
|
|
|
with warnings.catch_warnings(record=True) as w:
|
2021-08-09 18:49:18 -07:00
|
|
|
warnings.simplefilter("always")
|
|
|
|
|
|
|
|
def _mock_tpu_client():
|
|
|
|
time_to_wait = 5
|
|
|
|
start = time.time()
|
|
|
|
while not w:
|
|
|
|
if time.time() - start > time_to_wait:
|
|
|
|
raise ValueError(
|
|
|
|
"This test should not hang for more than "
|
|
|
|
f"{time_to_wait} seconds.")
|
|
|
|
time.sleep(0.1)
|
|
|
|
|
|
|
|
self.assertLen(w, 1)
|
|
|
|
msg = str(w[-1].message)
|
|
|
|
self.assertIn("Did you run your code on all TPU hosts?", msg)
|
|
|
|
|
2022-08-24 07:50:56 -07:00
|
|
|
with mock.patch.object(xc, "make_tpu_client",
|
|
|
|
side_effect=_mock_tpu_client):
|
2021-08-09 18:49:18 -07:00
|
|
|
xb.tpu_client_timer_callback(0.01)
|
2021-07-28 15:22:42 -07:00
|
|
|
|
2023-02-09 14:33:05 -08:00
|
|
|
def test_register_plugin(self):
|
|
|
|
with self.assertLogs(level="WARNING") as log_output:
|
2023-09-25 10:51:12 -07:00
|
|
|
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
|
|
|
|
if platform.system() == "Windows":
|
|
|
|
os.environ["PJRT_NAMES_AND_LIBRARY_PATHS"] = (
|
|
|
|
"name1;path1,name2;path2,name3"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
os.environ["PJRT_NAMES_AND_LIBRARY_PATHS"] = (
|
|
|
|
"name1:path1,name2:path2,name3"
|
|
|
|
)
|
|
|
|
xb.register_pjrt_plugin_factories_from_env()
|
2023-06-14 09:27:21 -04:00
|
|
|
registration = xb._backend_factories["name1"]
|
2023-02-09 14:33:05 -08:00
|
|
|
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
2023-09-25 10:51:12 -07:00
|
|
|
if xla_extension_version < 183:
|
|
|
|
registration.factory()
|
|
|
|
else:
|
2023-06-01 09:36:32 -07:00
|
|
|
with mock.patch.object(
|
2023-09-25 10:51:12 -07:00
|
|
|
xc, "pjrt_plugin_initialized", autospec=True, return_vale=True
|
|
|
|
):
|
|
|
|
with mock.patch.object(xc, "initialize_pjrt_plugin", autospec=True):
|
2023-08-15 15:23:17 -07:00
|
|
|
registration.factory()
|
2023-02-09 14:33:05 -08:00
|
|
|
|
|
|
|
self.assertRegex(
|
|
|
|
log_output[1][0],
|
|
|
|
r"invalid value name3 in env var PJRT_NAMES_AND_LIBRARY_PATHS"
|
2023-06-16 10:15:14 -04:00
|
|
|
r" name1.path1,name2.path2,name3",
|
2023-02-09 14:33:05 -08:00
|
|
|
)
|
|
|
|
self.assertIn("name1", xb._backend_factories)
|
|
|
|
self.assertIn("name2", xb._backend_factories)
|
2023-06-14 09:27:21 -04:00
|
|
|
self.assertEqual(registration.priority, 400)
|
2023-06-20 10:00:10 -04:00
|
|
|
self.assertTrue(registration.experimental)
|
2023-08-01 08:52:54 -07:00
|
|
|
mock_make.assert_called_once_with("name1", None, None)
|
2023-03-21 16:52:49 -07:00
|
|
|
|
|
|
|
def test_register_plugin_with_config(self):
|
|
|
|
test_json_file_path = os.path.join(
|
|
|
|
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
|
|
|
|
)
|
2023-08-15 15:23:17 -07:00
|
|
|
os.environ["PJRT_NAMES_AND_LIBRARY_PATHS"] = (
|
|
|
|
f"name1;{test_json_file_path}"
|
|
|
|
if platform.system() == "Windows"
|
|
|
|
else f"name1:{test_json_file_path}"
|
|
|
|
)
|
2023-09-25 10:51:12 -07:00
|
|
|
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
|
|
|
|
xb.register_pjrt_plugin_factories_from_env()
|
2023-06-14 09:27:21 -04:00
|
|
|
registration = xb._backend_factories["name1"]
|
2023-03-21 16:52:49 -07:00
|
|
|
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
2023-09-25 10:51:12 -07:00
|
|
|
if xla_extension_version < 183:
|
|
|
|
registration.factory()
|
|
|
|
else:
|
2023-06-01 09:36:32 -07:00
|
|
|
with mock.patch.object(
|
2023-09-25 10:51:12 -07:00
|
|
|
xc, "pjrt_plugin_initialized", autospec=True, return_vale=True
|
|
|
|
):
|
|
|
|
with mock.patch.object(xc, "initialize_pjrt_plugin", autospec=True):
|
2023-08-15 15:23:17 -07:00
|
|
|
registration.factory()
|
2023-03-21 16:52:49 -07:00
|
|
|
|
|
|
|
self.assertIn("name1", xb._backend_factories)
|
2023-06-14 09:27:21 -04:00
|
|
|
self.assertEqual(registration.priority, 400)
|
2023-06-20 10:00:10 -04:00
|
|
|
self.assertTrue(registration.experimental)
|
2023-08-01 08:52:54 -07:00
|
|
|
mock_make.assert_called_once_with(
|
|
|
|
"name1",
|
|
|
|
{
|
|
|
|
"int_option": 64,
|
|
|
|
"int_list_option": [32, 64],
|
|
|
|
"string_option": "string",
|
|
|
|
"float_option": 1.0,
|
|
|
|
},
|
|
|
|
None,
|
|
|
|
)
|
2023-02-09 14:33:05 -08:00
|
|
|
|
2019-12-02 16:07:23 -08:00
|
|
|
|
2021-08-13 17:09:12 -07:00
|
|
|
class GetBackendTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
class _DummyBackend:
|
2023-05-11 13:42:42 -07:00
|
|
|
|
2021-08-13 17:09:12 -07:00
|
|
|
def __init__(self, platform, device_count):
|
|
|
|
self.platform = platform
|
|
|
|
self._device_count = device_count
|
|
|
|
|
|
|
|
def device_count(self):
|
|
|
|
return self._device_count
|
|
|
|
|
|
|
|
def process_index(self):
|
|
|
|
return 0
|
|
|
|
|
|
|
|
def local_devices(self):
|
|
|
|
return []
|
|
|
|
|
2022-05-20 09:46:04 -07:00
|
|
|
def _register_factory(self, platform: str, priority, device_count=1,
|
2023-06-20 10:00:10 -04:00
|
|
|
assert_used_at_most_once=False, experimental=False):
|
2022-05-20 09:46:04 -07:00
|
|
|
if assert_used_at_most_once:
|
|
|
|
used = []
|
|
|
|
def factory():
|
|
|
|
if assert_used_at_most_once:
|
|
|
|
if used:
|
|
|
|
# We need to fail aggressively here since exceptions are caught by
|
|
|
|
# the caller and suppressed.
|
|
|
|
logging.fatal("Backend factory for %s was called more than once",
|
|
|
|
platform)
|
|
|
|
else:
|
|
|
|
used.append(True)
|
|
|
|
return self._DummyBackend(platform, device_count)
|
|
|
|
|
2023-06-14 09:27:21 -04:00
|
|
|
xb.register_backend_factory(platform, factory, priority=priority,
|
2023-06-20 10:00:10 -04:00
|
|
|
fail_quietly=False, experimental=experimental)
|
2021-08-13 17:09:12 -07:00
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
self._orig_factories = xb._backend_factories
|
|
|
|
xb._backend_factories = {}
|
2022-10-17 15:43:25 -07:00
|
|
|
self._orig_jax_platforms = config._read("jax_platforms")
|
|
|
|
config.FLAGS.jax_platforms = ""
|
2021-08-13 17:09:12 -07:00
|
|
|
self._save_backend_state()
|
|
|
|
self._reset_backend_state()
|
|
|
|
|
|
|
|
# get_backend logic assumes CPU platform is always present.
|
|
|
|
self._register_factory("cpu", 0)
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
xb._backend_factories = self._orig_factories
|
2022-10-17 15:43:25 -07:00
|
|
|
config.FLAGS.jax_platforms = self._orig_jax_platforms
|
2021-08-13 17:09:12 -07:00
|
|
|
self._restore_backend_state()
|
|
|
|
|
|
|
|
def _save_backend_state(self):
|
|
|
|
self._orig_backends = xb._backends
|
2023-09-23 20:06:19 +00:00
|
|
|
self._orig_backend_errors = xb._backend_errors
|
2021-08-13 17:09:12 -07:00
|
|
|
self._orig_default_backend = xb._default_backend
|
|
|
|
|
|
|
|
def _reset_backend_state(self):
|
|
|
|
xb._backends = {}
|
2023-09-23 20:06:19 +00:00
|
|
|
xb._backend_errors = {}
|
2021-08-13 17:09:12 -07:00
|
|
|
xb._default_backend = None
|
|
|
|
xb.get_backend.cache_clear()
|
|
|
|
|
|
|
|
def _restore_backend_state(self):
|
|
|
|
xb._backends = self._orig_backends
|
2023-09-23 20:06:19 +00:00
|
|
|
xb._backend_errors = self._orig_backend_errors
|
2021-08-13 17:09:12 -07:00
|
|
|
xb._default_backend = self._orig_default_backend
|
|
|
|
xb.get_backend.cache_clear()
|
|
|
|
|
|
|
|
def test_default(self):
|
|
|
|
self._register_factory("platform_A", 20)
|
|
|
|
self._register_factory("platform_B", 10)
|
|
|
|
|
|
|
|
backend = xb.get_backend()
|
|
|
|
self.assertEqual(backend.platform, "platform_A")
|
|
|
|
# All backends initialized.
|
|
|
|
self.assertEqual(len(xb._backends), len(xb._backend_factories))
|
|
|
|
|
|
|
|
def test_specific_platform(self):
|
|
|
|
self._register_factory("platform_A", 20)
|
|
|
|
self._register_factory("platform_B", 10)
|
|
|
|
|
|
|
|
backend = xb.get_backend("platform_B")
|
|
|
|
self.assertEqual(backend.platform, "platform_B")
|
|
|
|
# All backends initialized.
|
|
|
|
self.assertEqual(len(xb._backends), len(xb._backend_factories))
|
|
|
|
|
|
|
|
def test_unknown_backend_error(self):
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
|
|
|
|
xb.get_backend("foo")
|
|
|
|
|
|
|
|
def test_backend_init_error(self):
|
|
|
|
def factory():
|
|
|
|
raise RuntimeError("I'm not a real backend")
|
|
|
|
|
2023-06-14 09:27:21 -04:00
|
|
|
xb.register_backend_factory("error", factory, priority=10,
|
|
|
|
fail_quietly=False)
|
2021-08-13 17:09:12 -07:00
|
|
|
|
2023-06-14 09:27:21 -04:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
RuntimeError,
|
|
|
|
"Unable to initialize backend 'error': I'm not a real backend"
|
|
|
|
):
|
2021-08-13 17:09:12 -07:00
|
|
|
xb.get_backend("error")
|
|
|
|
|
2023-06-14 09:27:21 -04:00
|
|
|
|
2021-08-13 17:09:12 -07:00
|
|
|
def test_no_devices(self):
|
|
|
|
self._register_factory("no_devices", -10, device_count=0)
|
2021-09-28 07:00:46 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
RuntimeError,
|
2023-06-14 09:27:21 -04:00
|
|
|
"Unable to initialize backend 'no_devices': "
|
2021-09-28 07:00:46 -07:00
|
|
|
"Backend 'no_devices' provides no devices."):
|
2021-08-13 17:09:12 -07:00
|
|
|
xb.get_backend("no_devices")
|
|
|
|
|
|
|
|
def test_factory_returns_none(self):
|
2023-06-14 09:27:21 -04:00
|
|
|
xb.register_backend_factory("none", lambda: None, priority=10,
|
|
|
|
fail_quietly=False)
|
2021-09-28 07:00:46 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
RuntimeError,
|
2023-06-14 09:27:21 -04:00
|
|
|
"Unable to initialize backend 'none': "
|
2021-09-28 07:00:46 -07:00
|
|
|
"Could not initialize backend 'none'"):
|
2021-08-13 17:09:12 -07:00
|
|
|
xb.get_backend("none")
|
|
|
|
|
|
|
|
def cpu_fallback_warning(self):
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
|
|
warnings.simplefilter("always")
|
|
|
|
xb.get_backend()
|
|
|
|
self.assertLen(w, 1)
|
|
|
|
msg = str(w[-1].message)
|
|
|
|
self.assertIn("No GPU/TPU found, falling back to CPU", msg)
|
|
|
|
|
2021-09-28 11:46:52 -07:00
|
|
|
def test_jax_platforms_flag(self):
|
2022-05-20 09:46:04 -07:00
|
|
|
self._register_factory("platform_A", 20, assert_used_at_most_once=True)
|
|
|
|
self._register_factory("platform_B", 10, assert_used_at_most_once=True)
|
2021-09-28 11:46:52 -07:00
|
|
|
|
|
|
|
orig_jax_platforms = config._read("jax_platforms")
|
|
|
|
try:
|
|
|
|
config.FLAGS.jax_platforms = "cpu,platform_A"
|
|
|
|
|
|
|
|
backend = xb.get_backend()
|
|
|
|
self.assertEqual(backend.platform, "cpu")
|
|
|
|
# Only specified backends initialized.
|
|
|
|
self.assertEqual(len(xb._backends), 2)
|
|
|
|
|
|
|
|
backend = xb.get_backend("platform_A")
|
|
|
|
self.assertEqual(backend.platform, "platform_A")
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown backend platform_B"):
|
|
|
|
backend = xb.get_backend("platform_B")
|
|
|
|
|
|
|
|
finally:
|
|
|
|
config.FLAGS.jax_platforms = orig_jax_platforms
|
|
|
|
|
2021-08-13 17:09:12 -07:00
|
|
|
|
2023-06-20 10:00:10 -04:00
|
|
|
def test_experimental_warning(self):
|
|
|
|
self._register_factory("platform_A", 20, experimental=True)
|
|
|
|
|
|
|
|
with self.assertLogs("jax._src.xla_bridge", level="WARNING") as logs:
|
|
|
|
_ = xb.get_backend()
|
2023-09-26 17:52:37 -04:00
|
|
|
self.assertIn(
|
2023-06-20 10:00:10 -04:00
|
|
|
"WARNING:jax._src.xla_bridge:Platform 'platform_A' is experimental and "
|
2023-09-26 17:52:37 -04:00
|
|
|
"not all JAX functionality may be correctly supported!",
|
|
|
|
logs.output
|
|
|
|
)
|
|
|
|
|
2023-06-20 10:00:10 -04:00
|
|
|
|
|
|
|
|
2019-12-02 16:07:23 -08:00
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|