rocm_jax/tests/xla_bridge_test.py
Yang Chen 08d81e45d4 Use backend._get_all_devices() to validate devices.
PiperOrigin-RevId: 719367913
2025-01-24 11:09:16 -08:00

366 lines
13 KiB
Python

# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
from absl import logging
from absl.testing import absltest
from jax import version
from jax._src import compiler
from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
config.parse_flags_with_absl()
mock = absltest.mock
class XlaBridgeTest(jtu.JaxTestCase):
def test_set_device_assignment_no_partition(self):
compile_options = compiler.get_compile_options(
num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3])
expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: "
"0 1 2 3 \n")
self.assertEqual(compile_options.device_assignment.__repr__(),
expected_device_assignment)
def test_set_device_assignment_with_partition(self):
compile_options = compiler.get_compile_options(
num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]])
expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: "
"0 2 \nComputation 1: 1 3 \n")
self.assertEqual(compile_options.device_assignment.__repr__(),
expected_device_assignment)
def test_set_fdo_profile(self):
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1, fdo_profile=b"test_profile"
)
self.assertEqual(
compile_options.executable_build_options.fdo_profile, b"test_profile")
def test_autofdo_profile(self):
class _DummyBackend:
platform: str = "tpu"
# --jax_xla_profile_version takes precedence.
jax_flag_profile = 1
another_profile = 2
with config.jax_xla_profile_version(jax_flag_profile):
with mock.patch.object(compiler, "get_latest_profile_version",
side_effect=lambda _: another_profile):
self.assertEqual(
compiler.get_compile_options(
num_replicas=3, num_partitions=4, backend=_DummyBackend(),
).profile_version,
jax_flag_profile,
)
# Use whatever non-zero value the function get_latest_profile_version
# returns if --jax_xla_profile_version is not set.
profile_version = 1
with mock.patch.object(compiler, "get_latest_profile_version",
side_effect=lambda _: profile_version):
self.assertEqual(
compiler.get_compile_options(
num_replicas=3, num_partitions=4, backend=_DummyBackend(),
).profile_version,
profile_version,
)
# If the function returns 0, something is wrong, so expect that we set
# profile_version to -1 instead to ensure that no attempt is made to
# retrieve the latest profile later.
error_return = 0
no_profile_dont_retrieve = -1
with mock.patch.object(compiler, "get_latest_profile_version",
side_effect=lambda _: error_return):
self.assertEqual(
compiler.get_compile_options(
num_replicas=3, num_partitions=4, backend=_DummyBackend(),
).profile_version,
no_profile_dont_retrieve,
)
def test_deterministic_serialization(self):
c1 = compiler.get_compile_options(
num_replicas=2,
num_partitions=3,
env_options_overrides={"1": "1", "2": "2"},
)
c2 = compiler.get_compile_options(
num_replicas=2,
num_partitions=3,
env_options_overrides={"2": "2", "1": "1"}, # order changed
)
c1str = c1.SerializeAsString()
# Idempotence.
self.assertEqual(c1str, c1.SerializeAsString())
# Map order does not matter.
self.assertEqual(c1str, c2.SerializeAsString())
def test_local_devices(self):
self.assertNotEmpty(xb.local_devices())
with self.assertRaisesRegex(ValueError, "Unknown process_index 100"):
xb.local_devices(100)
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
xb.local_devices(backend="foo")
def test_register_plugin(self):
with self.assertLogs(level="WARNING") as log_output:
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
if platform.system() == "Windows":
os.environ["PJRT_NAMES_AND_LIBRARY_PATHS"] = (
"name1;path1,name2;path2,name3"
)
else:
os.environ["PJRT_NAMES_AND_LIBRARY_PATHS"] = (
"name1:path1,name2:path2,name3"
)
with mock.patch.object(
xc.profiler, "register_plugin_profiler", autospec=True
):
xb.register_pjrt_plugin_factories_from_env()
registration = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
with mock.patch.object(
xc, "pjrt_plugin_initialized", autospec=True, return_vale=True
):
with mock.patch.object(xc, "initialize_pjrt_plugin", autospec=True):
registration.factory()
self.assertRegex(
log_output[1][0],
r"invalid value name3 in env var PJRT_NAMES_AND_LIBRARY_PATHS"
r" name1.path1,name2.path2,name3",
)
self.assertIn("name1", xb._backend_factories)
self.assertIn("name2", xb._backend_factories)
self.assertEqual(registration.priority, 400)
self.assertTrue(registration.experimental)
options = {}
if xb.get_backend().platform == 'tpu':
options["ml_framework_name"] = "JAX"
options["ml_framework_version"] = version.__version__
mock_make.assert_called_once_with("name1", options, None)
def test_register_plugin_with_config(self):
test_json_file_path = os.path.join(
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
)
os.environ["PJRT_NAMES_AND_LIBRARY_PATHS"] = (
f"name1;{test_json_file_path}"
if platform.system() == "Windows"
else f"name1:{test_json_file_path}"
)
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
with mock.patch.object(
xc.profiler, "register_plugin_profiler", autospec=True
):
xb.register_pjrt_plugin_factories_from_env()
registration = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
with mock.patch.object(
xc, "pjrt_plugin_initialized", autospec=True, return_vale=True
):
with mock.patch.object(xc, "initialize_pjrt_plugin", autospec=True):
registration.factory()
self.assertIn("name1", xb._backend_factories)
self.assertEqual(registration.priority, 400)
self.assertTrue(registration.experimental)
# The expectation is specified in example_pjrt_plugin_config.json.
options = {
"int_option": 64,
"int_list_option": [32, 64],
"string_option": "string",
"float_option": 1.0,
}
if xb.get_backend().platform == 'tpu':
options["ml_framework_name"] = "JAX"
options["ml_framework_version"] = version.__version__
mock_make.assert_called_once_with("name1", options, None)
class GetBackendTest(jtu.JaxTestCase):
class _DummyBackend:
def __init__(self, platform, device_count):
self.platform = platform
self._device_count = device_count
def device_count(self):
return self._device_count
def process_index(self):
return 0
def devices(self):
return []
def local_devices(self):
return []
def _get_all_devices(self):
return self.devices()
def _register_factory(self, platform: str, priority, device_count=1,
assert_used_at_most_once=False, experimental=False):
if assert_used_at_most_once:
used = []
def factory():
if assert_used_at_most_once:
if used:
# We need to fail aggressively here since exceptions are caught by
# the caller and suppressed.
logging.fatal("Backend factory for %s was called more than once",
platform)
else:
used.append(True)
return self._DummyBackend(platform, device_count)
xb.register_backend_factory(platform, factory, priority=priority,
fail_quietly=False, experimental=experimental)
def setUp(self):
super().setUp()
self._orig_factories = xb._backend_factories
xb._backend_factories = {}
self.enter_context(config.jax_platforms(""))
self._save_backend_state()
self._reset_backend_state()
# get_backend logic assumes CPU platform is always present.
self._register_factory("cpu", 0)
def tearDown(self):
super().tearDown()
xb._backend_factories = self._orig_factories
self._restore_backend_state()
def _save_backend_state(self):
self._orig_backends = xb._backends
self._orig_backend_errors = xb._backend_errors
self._orig_default_backend = xb._default_backend
def _reset_backend_state(self):
xb._backends = {}
xb._backend_errors = {}
xb._default_backend = None
xb.get_backend.cache_clear()
def _restore_backend_state(self):
xb._backends = self._orig_backends
xb._backend_errors = self._orig_backend_errors
xb._default_backend = self._orig_default_backend
xb.get_backend.cache_clear()
def test_default(self):
self._register_factory("platform_A", 20)
self._register_factory("platform_B", 10)
backend = xb.get_backend()
self.assertEqual(backend.platform, "platform_A")
# All backends initialized.
self.assertEqual(len(xb._backends), len(xb._backend_factories))
def test_specific_platform(self):
self._register_factory("platform_A", 20)
self._register_factory("platform_B", 10)
backend = xb.get_backend("platform_B")
self.assertEqual(backend.platform, "platform_B")
# All backends initialized.
self.assertEqual(len(xb._backends), len(xb._backend_factories))
def test_unknown_backend_error(self):
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
xb.get_backend("foo")
def test_backend_init_error(self):
def factory():
raise RuntimeError("I'm not a real backend")
xb.register_backend_factory("error", factory, priority=10,
fail_quietly=False)
with self.assertRaisesRegex(
RuntimeError,
"Unable to initialize backend 'error': I'm not a real backend"
):
xb.get_backend("error")
def test_no_devices(self):
self._register_factory("no_devices", -10, device_count=0)
with self.assertRaisesRegex(
RuntimeError,
"Unable to initialize backend 'no_devices': "
"Backend 'no_devices' provides no devices."):
xb.get_backend("no_devices")
def test_factory_returns_none(self):
xb.register_backend_factory("none", lambda: None, priority=10,
fail_quietly=False)
with self.assertRaisesRegex(
RuntimeError,
"Unable to initialize backend 'none': "
"Could not initialize backend 'none'"):
xb.get_backend("none")
def cpu_fallback_warning(self):
with self.assertWarnsRegex(UserWarning, "No GPU/TPU found, falling back to CPU"):
xb.get_backend()
def test_jax_platforms_flag(self):
self._register_factory("platform_A", 20, assert_used_at_most_once=True)
self._register_factory("platform_B", 10, assert_used_at_most_once=True)
with config.jax_platforms("cpu,platform_A"):
backend = xb.get_backend()
self.assertEqual(backend.platform, "cpu")
# Only specified backends initialized.
self.assertEqual(len(xb._backends), 2)
backend = xb.get_backend("platform_A")
self.assertEqual(backend.platform, "platform_A")
with self.assertRaisesRegex(RuntimeError, "Unknown backend platform_B"):
backend = xb.get_backend("platform_B")
def test_experimental_warning(self):
self._register_factory("platform_A", 20, experimental=True)
with self.assertLogs("jax._src.xla_bridge", level="WARNING") as logs:
_ = xb.get_backend()
self.assertIn(
"WARNING:jax._src.xla_bridge:Platform 'platform_A' is experimental and "
"not all JAX functionality may be correctly supported!",
logs.output
)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())