2019-12-03 10:08:55 -05:00
|
|
|
# Copyright 2019 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2021-07-28 15:22:42 -07:00
|
|
|
import time
|
2021-08-09 18:49:18 -07:00
|
|
|
import warnings
|
2019-12-02 16:07:23 -08:00
|
|
|
|
|
|
|
from absl.testing import absltest
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import test_util as jtu
|
2021-09-23 06:33:25 -07:00
|
|
|
from jax._src.lib import xla_bridge as xb
|
|
|
|
from jax._src.lib import xla_client as xc
|
2021-11-30 14:24:02 -08:00
|
|
|
from jax.interpreters import xla
|
2019-12-02 16:07:23 -08:00
|
|
|
|
2021-09-28 11:46:52 -07:00
|
|
|
from jax._src.config import config
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
FLAGS = config.FLAGS
|
|
|
|
|
2021-07-28 15:22:42 -07:00
|
|
|
mock = absltest.mock
|
|
|
|
|
|
|
|
|
2021-08-13 17:09:12 -07:00
|
|
|
class XlaBridgeTest(jtu.JaxTestCase):
|
2019-12-02 16:07:23 -08:00
|
|
|
|
|
|
|
def test_set_device_assignment_no_partition(self):
|
|
|
|
compile_options = xb.get_compile_options(
|
2020-01-29 19:35:48 +00:00
|
|
|
num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3])
|
2019-12-02 16:07:23 -08:00
|
|
|
expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: "
|
|
|
|
"0 1 2 3 \n")
|
|
|
|
self.assertEqual(compile_options.device_assignment.__repr__(),
|
|
|
|
expected_device_assignment)
|
|
|
|
|
|
|
|
def test_set_device_assignment_with_partition(self):
|
|
|
|
compile_options = xb.get_compile_options(
|
2020-01-29 19:35:48 +00:00
|
|
|
num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]])
|
2019-12-02 16:07:23 -08:00
|
|
|
expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: "
|
|
|
|
"0 2 \nComputation 1: 1 3 \n")
|
|
|
|
self.assertEqual(compile_options.device_assignment.__repr__(),
|
|
|
|
expected_device_assignment)
|
2020-05-20 14:40:28 -07:00
|
|
|
|
2020-04-29 11:31:36 -07:00
|
|
|
def test_parameter_replication_default(self):
|
2021-10-18 13:19:45 -04:00
|
|
|
c = xc.XlaBuilder("test")
|
2021-11-30 14:24:02 -08:00
|
|
|
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
|
2020-04-29 11:31:36 -07:00
|
|
|
built_c = c.Build()
|
2020-05-15 15:51:07 -04:00
|
|
|
assert "replication" not in built_c.as_hlo_text()
|
2020-04-29 11:31:36 -07:00
|
|
|
|
|
|
|
def test_parameter_replication(self):
|
2021-10-18 13:19:45 -04:00
|
|
|
c = xc.XlaBuilder("test")
|
2021-11-30 14:24:02 -08:00
|
|
|
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "",
|
2021-08-09 18:49:18 -07:00
|
|
|
False)
|
2020-04-29 11:31:36 -07:00
|
|
|
built_c = c.Build()
|
2020-05-11 17:43:55 -04:00
|
|
|
assert "parameter_replication={false}" in built_c.as_hlo_text()
|
2019-12-02 16:07:23 -08:00
|
|
|
|
2020-05-20 14:40:28 -07:00
|
|
|
def test_local_devices(self):
|
|
|
|
self.assertNotEmpty(xb.local_devices())
|
2021-04-20 17:56:41 -07:00
|
|
|
with self.assertRaisesRegex(ValueError, "Unknown process_index 100"):
|
2020-05-20 14:40:28 -07:00
|
|
|
xb.local_devices(100)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
|
|
|
|
xb.local_devices(backend="foo")
|
|
|
|
|
2021-08-09 18:49:18 -07:00
|
|
|
def test_timer_tpu_warning(self):
|
2021-07-28 18:56:22 -07:00
|
|
|
with warnings.catch_warnings(record=True) as w:
|
2021-08-09 18:49:18 -07:00
|
|
|
warnings.simplefilter("always")
|
|
|
|
|
|
|
|
def _mock_tpu_client():
|
|
|
|
time_to_wait = 5
|
|
|
|
start = time.time()
|
|
|
|
while not w:
|
|
|
|
if time.time() - start > time_to_wait:
|
|
|
|
raise ValueError(
|
|
|
|
"This test should not hang for more than "
|
|
|
|
f"{time_to_wait} seconds.")
|
|
|
|
time.sleep(0.1)
|
|
|
|
|
|
|
|
self.assertLen(w, 1)
|
|
|
|
msg = str(w[-1].message)
|
|
|
|
self.assertIn("Did you run your code on all TPU hosts?", msg)
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
with mock.patch("jax._src.lib.xla_client.make_tpu_client",
|
|
|
|
side_effect=_mock_tpu_client):
|
2021-08-09 18:49:18 -07:00
|
|
|
xb.tpu_client_timer_callback(0.01)
|
2021-07-28 15:22:42 -07:00
|
|
|
|
2019-12-02 16:07:23 -08:00
|
|
|
|
2021-08-13 17:09:12 -07:00
|
|
|
class GetBackendTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
class _DummyBackend:
|
|
|
|
def __init__(self, platform, device_count):
|
|
|
|
self.platform = platform
|
|
|
|
self._device_count = device_count
|
|
|
|
|
|
|
|
def device_count(self):
|
|
|
|
return self._device_count
|
|
|
|
|
|
|
|
def process_index(self):
|
|
|
|
return 0
|
|
|
|
|
|
|
|
def local_devices(self):
|
|
|
|
return []
|
|
|
|
|
|
|
|
def _register_factory(self, platform: str, priority, device_count=1):
|
|
|
|
xb.register_backend_factory(
|
|
|
|
platform, lambda: self._DummyBackend(platform, device_count),
|
|
|
|
priority=priority)
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
self._orig_factories = xb._backend_factories
|
|
|
|
xb._backend_factories = {}
|
|
|
|
self._save_backend_state()
|
|
|
|
self._reset_backend_state()
|
|
|
|
|
|
|
|
# get_backend logic assumes CPU platform is always present.
|
|
|
|
self._register_factory("cpu", 0)
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
xb._backend_factories = self._orig_factories
|
|
|
|
self._restore_backend_state()
|
|
|
|
|
|
|
|
def _save_backend_state(self):
|
|
|
|
self._orig_backends = xb._backends
|
|
|
|
self._orig_backends_errors = xb._backends_errors
|
|
|
|
self._orig_default_backend = xb._default_backend
|
|
|
|
|
|
|
|
def _reset_backend_state(self):
|
|
|
|
xb._backends = {}
|
|
|
|
xb._backends_errors = {}
|
|
|
|
xb._default_backend = None
|
|
|
|
xb.get_backend.cache_clear()
|
|
|
|
|
|
|
|
def _restore_backend_state(self):
|
|
|
|
xb._backends = self._orig_backends
|
|
|
|
xb._backends_errors = self._orig_backends_errors
|
|
|
|
xb._default_backend = self._orig_default_backend
|
|
|
|
xb.get_backend.cache_clear()
|
|
|
|
|
|
|
|
def test_default(self):
|
|
|
|
self._register_factory("platform_A", 20)
|
|
|
|
self._register_factory("platform_B", 10)
|
|
|
|
|
|
|
|
backend = xb.get_backend()
|
|
|
|
self.assertEqual(backend.platform, "platform_A")
|
|
|
|
# All backends initialized.
|
|
|
|
self.assertEqual(len(xb._backends), len(xb._backend_factories))
|
|
|
|
|
|
|
|
def test_specific_platform(self):
|
|
|
|
self._register_factory("platform_A", 20)
|
|
|
|
self._register_factory("platform_B", 10)
|
|
|
|
|
|
|
|
backend = xb.get_backend("platform_B")
|
|
|
|
self.assertEqual(backend.platform, "platform_B")
|
|
|
|
# All backends initialized.
|
|
|
|
self.assertEqual(len(xb._backends), len(xb._backend_factories))
|
|
|
|
|
|
|
|
def test_unknown_backend_error(self):
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
|
|
|
|
xb.get_backend("foo")
|
|
|
|
|
|
|
|
def test_backend_init_error(self):
|
|
|
|
def factory():
|
|
|
|
raise RuntimeError("I'm not a real backend")
|
|
|
|
|
|
|
|
xb.register_backend_factory("error", factory, priority=10)
|
|
|
|
# No error raised if there's a fallback backend.
|
|
|
|
default_backend = xb.get_backend()
|
|
|
|
self.assertEqual(default_backend.platform, "cpu")
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "I'm not a real backend"):
|
|
|
|
xb.get_backend("error")
|
|
|
|
|
|
|
|
def test_no_devices(self):
|
|
|
|
self._register_factory("no_devices", -10, device_count=0)
|
|
|
|
default_backend = xb.get_backend()
|
|
|
|
self.assertEqual(default_backend.platform, "cpu")
|
2021-09-28 07:00:46 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
RuntimeError,
|
|
|
|
"Backend 'no_devices' failed to initialize: "
|
|
|
|
"Backend 'no_devices' provides no devices."):
|
2021-08-13 17:09:12 -07:00
|
|
|
xb.get_backend("no_devices")
|
|
|
|
|
|
|
|
self._reset_backend_state()
|
|
|
|
|
|
|
|
self._register_factory("no_devices2", 10, device_count=0)
|
|
|
|
default_backend = xb.get_backend()
|
2021-09-28 07:00:46 -07:00
|
|
|
self.assertEqual(default_backend.platform, "cpu")
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
RuntimeError,
|
|
|
|
"Backend 'no_devices2' failed to initialize: "
|
|
|
|
"Backend 'no_devices2' provides no devices."):
|
2021-08-13 17:09:12 -07:00
|
|
|
xb.get_backend("no_devices2")
|
|
|
|
|
|
|
|
def test_factory_returns_none(self):
|
|
|
|
xb.register_backend_factory("none", lambda: None, priority=10)
|
|
|
|
default_backend = xb.get_backend()
|
|
|
|
self.assertEqual(default_backend.platform, "cpu")
|
2021-09-28 07:00:46 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
RuntimeError,
|
|
|
|
"Backend 'none' failed to initialize: "
|
|
|
|
"Could not initialize backend 'none'"):
|
2021-08-13 17:09:12 -07:00
|
|
|
xb.get_backend("none")
|
|
|
|
|
|
|
|
def cpu_fallback_warning(self):
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
|
|
warnings.simplefilter("always")
|
|
|
|
xb.get_backend()
|
|
|
|
self.assertLen(w, 1)
|
|
|
|
msg = str(w[-1].message)
|
|
|
|
self.assertIn("No GPU/TPU found, falling back to CPU", msg)
|
|
|
|
|
2021-09-28 11:46:52 -07:00
|
|
|
def test_jax_platforms_flag(self):
|
|
|
|
self._register_factory("platform_A", 20)
|
|
|
|
self._register_factory("platform_B", 10)
|
|
|
|
|
|
|
|
orig_jax_platforms = config._read("jax_platforms")
|
|
|
|
try:
|
|
|
|
config.FLAGS.jax_platforms = "cpu,platform_A"
|
|
|
|
|
|
|
|
backend = xb.get_backend()
|
|
|
|
self.assertEqual(backend.platform, "cpu")
|
|
|
|
# Only specified backends initialized.
|
|
|
|
self.assertEqual(len(xb._backends), 2)
|
|
|
|
|
|
|
|
backend = xb.get_backend("platform_A")
|
|
|
|
self.assertEqual(backend.platform, "platform_A")
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown backend platform_B"):
|
|
|
|
backend = xb.get_backend("platform_B")
|
|
|
|
|
|
|
|
finally:
|
|
|
|
config.FLAGS.jax_platforms = orig_jax_platforms
|
|
|
|
|
2021-08-13 17:09:12 -07:00
|
|
|
|
2019-12-02 16:07:23 -08:00
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|