diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 23b255ef1..28148761c 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -90,6 +90,13 @@ _MOCK_NUM_GPU_PROCESSES = config.int_flag( help="Mock number of JAX processes in GPU client. Value zero turns " "off mocking.", ) +_MOCK_GPU_TOPOLOGY = config.string_flag( + name="jax_mock_gpu_topology", + default="", + help='Mock multi-host GPU topology in GPU client. The value should ' + 'be of the form " x x ' + '". Empty string turns off mocking.', +) _CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag( name="jax_cpu_enable_gloo_collectives", @@ -425,6 +432,14 @@ def _check_cuda_versions(raise_on_first_error: bool = False, f'following issues with CUDA components:\n' f'{join_str.join(errors)}') +def _get_num_nodes_from_gpu_topology(topology: str) -> int: + try: + slices_str, hosts_per_slice_str, _ = topology.split("x", 2) + return int(slices_str) * int(hosts_per_slice_str) + except (IndexError, ValueError): + raise ValueError('Mock topology must be of the form ' + '" x x ' + '".') def make_gpu_client( *, platform_name: str, visible_devices_flag: config.Flag[str] @@ -434,12 +449,14 @@ def make_gpu_client( if visible_devices != "all": allowed_devices = {int(x) for x in visible_devices.split(",")} - use_mock_gpu_client = _MOCK_NUM_GPU_PROCESSES.value > 0 - num_nodes = ( - _MOCK_NUM_GPU_PROCESSES.value - if use_mock_gpu_client - else distributed.global_state.num_processes - ) + mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None + mock_num_gpu_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + + use_mock_gpu_client = mock_num_gpu_processes > 0 + num_nodes = (mock_num_gpu_processes if use_mock_gpu_client + else distributed.global_state.num_processes) + if platform_name == "cuda": if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): _check_cuda_versions() @@ -634,10 +651,14 @@ def _options_from_jax_configs(plugin_name): visible_devices = CUDA_VISIBLE_DEVICES.value if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - mock_gpu_processes = _MOCK_NUM_GPU_PROCESSES.value - options['enable_mock_nccl'] = mock_gpu_processes > 0 - if options['enable_mock_nccl']: - options['num_nodes'] = mock_gpu_processes + mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None + mock_num_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + options['enable_mock_nccl'] = mock_num_processes > 0 + if mock_num_processes > 0: + options['num_nodes'] = mock_num_processes + if mock_gpu_topology: + options['mock_gpu_topology'] = mock_gpu_topology return options diff --git a/tests/BUILD b/tests/BUILD index dc81c408c..1f6ea90b7 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -321,6 +321,21 @@ jax_multiplatform_test( ], ) +jax_multiplatform_test( + name = "mock_gpu_topology_test", + srcs = ["mock_gpu_topology_test.py"], + enable_backends = ["gpu"], + enable_configs = [ + "gpu_h100", + ], + tags = [ + "config-cuda-only", + ], + deps = [ + "//jax:experimental", + ], +) + jax_multiplatform_test( name = "array_test", srcs = ["array_test.py"], diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py new file mode 100644 index 000000000..44ec4e2f9 --- /dev/null +++ b/tests/mock_gpu_topology_test.py @@ -0,0 +1,60 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P + +jax.config.parse_flags_with_absl() + +NUM_SLICES = 2 +NUM_HOSTS_PER_SLICE = 4 + + +@jtu.with_config( + jax_mock_gpu_topology=f"{NUM_SLICES}x{NUM_HOSTS_PER_SLICE}x1", + jax_cuda_visible_devices="0") +class MockGPUTopologyTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Mocking devices only works on the GPU backend.") + super().setUp() + + @jtu.skip_under_pytest("Test must run in an isolated process") + def testMockDeviceCount(self): + self.assertEqual(jax.device_count(), NUM_SLICES * NUM_HOSTS_PER_SLICE) + + @jtu.skip_under_pytest("Test must run in an isolated process") + def testMockWithSharding(self): + mesh = jax.sharding.Mesh(jax.devices(), ('x',)) + f = jax.jit(jnp.sum, + in_shardings=NamedSharding(mesh, P('x')), + out_shardings=NamedSharding(mesh, P())) + + f_lowered = f.lower(jnp.arange(16)) + hlo = f_lowered.compiler_ir() + + mocked_count = NUM_SLICES * NUM_HOSTS_PER_SLICE + self.assertIn(f'num_partitions = {mocked_count}', str(hlo)) + self.assertIn( + f'sharding = "{{devices=[{mocked_count}]<=[{mocked_count}]}}"', + str(hlo) + ) + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader())