mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add an option to specify mock GPU topology
This commit is contained in:
parent
a99ccd9341
commit
eedd01118b
@ -90,6 +90,13 @@ _MOCK_NUM_GPU_PROCESSES = config.int_flag(
|
||||
help="Mock number of JAX processes in GPU client. Value zero turns "
|
||||
"off mocking.",
|
||||
)
|
||||
_MOCK_GPU_TOPOLOGY = config.string_flag(
|
||||
name="jax_mock_gpu_topology",
|
||||
default="",
|
||||
help='Mock multi-host GPU topology in GPU client. The value should '
|
||||
'be of the form "<number-of-slices> x <number-of-hosts-per-slice> x '
|
||||
'<number-of-devices-per-host>". Empty string turns off mocking.',
|
||||
)
|
||||
|
||||
_CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag(
|
||||
name="jax_cpu_enable_gloo_collectives",
|
||||
@ -425,6 +432,14 @@ def _check_cuda_versions(raise_on_first_error: bool = False,
|
||||
f'following issues with CUDA components:\n'
|
||||
f'{join_str.join(errors)}')
|
||||
|
||||
def _get_num_nodes_from_gpu_topology(topology: str) -> int:
|
||||
try:
|
||||
slices_str, hosts_per_slice_str, _ = topology.split("x", 2)
|
||||
return int(slices_str) * int(hosts_per_slice_str)
|
||||
except (IndexError, ValueError):
|
||||
raise ValueError('Mock topology must be of the form '
|
||||
'"<number-of-slices> x <number-of-hosts-per-slice> x '
|
||||
'<number-of-devices-per-host>".')
|
||||
|
||||
def make_gpu_client(
|
||||
*, platform_name: str, visible_devices_flag: config.Flag[str]
|
||||
@ -434,12 +449,14 @@ def make_gpu_client(
|
||||
if visible_devices != "all":
|
||||
allowed_devices = {int(x) for x in visible_devices.split(",")}
|
||||
|
||||
use_mock_gpu_client = _MOCK_NUM_GPU_PROCESSES.value > 0
|
||||
num_nodes = (
|
||||
_MOCK_NUM_GPU_PROCESSES.value
|
||||
if use_mock_gpu_client
|
||||
else distributed.global_state.num_processes
|
||||
)
|
||||
mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None
|
||||
mock_num_gpu_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if
|
||||
mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value)
|
||||
|
||||
use_mock_gpu_client = mock_num_gpu_processes > 0
|
||||
num_nodes = (mock_num_gpu_processes if use_mock_gpu_client
|
||||
else distributed.global_state.num_processes)
|
||||
|
||||
if platform_name == "cuda":
|
||||
if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"):
|
||||
_check_cuda_versions()
|
||||
@ -634,10 +651,14 @@ def _options_from_jax_configs(plugin_name):
|
||||
visible_devices = CUDA_VISIBLE_DEVICES.value
|
||||
if visible_devices != 'all':
|
||||
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
|
||||
mock_gpu_processes = _MOCK_NUM_GPU_PROCESSES.value
|
||||
options['enable_mock_nccl'] = mock_gpu_processes > 0
|
||||
if options['enable_mock_nccl']:
|
||||
options['num_nodes'] = mock_gpu_processes
|
||||
mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None
|
||||
mock_num_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if
|
||||
mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value)
|
||||
options['enable_mock_nccl'] = mock_num_processes > 0
|
||||
if mock_num_processes > 0:
|
||||
options['num_nodes'] = mock_num_processes
|
||||
if mock_gpu_topology:
|
||||
options['mock_gpu_topology'] = mock_gpu_topology
|
||||
|
||||
return options
|
||||
|
||||
|
15
tests/BUILD
15
tests/BUILD
@ -321,6 +321,21 @@ jax_multiplatform_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "mock_gpu_topology_test",
|
||||
srcs = ["mock_gpu_topology_test.py"],
|
||||
enable_backends = ["gpu"],
|
||||
enable_configs = [
|
||||
"gpu_h100",
|
||||
],
|
||||
tags = [
|
||||
"config-cuda-only",
|
||||
],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "array_test",
|
||||
srcs = ["array_test.py"],
|
||||
|
60
tests/mock_gpu_topology_test.py
Normal file
60
tests/mock_gpu_topology_test.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
NUM_SLICES = 2
|
||||
NUM_HOSTS_PER_SLICE = 4
|
||||
|
||||
|
||||
@jtu.with_config(
|
||||
jax_mock_gpu_topology=f"{NUM_SLICES}x{NUM_HOSTS_PER_SLICE}x1",
|
||||
jax_cuda_visible_devices="0")
|
||||
class MockGPUTopologyTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Mocking devices only works on the GPU backend.")
|
||||
super().setUp()
|
||||
|
||||
@jtu.skip_under_pytest("Test must run in an isolated process")
|
||||
def testMockDeviceCount(self):
|
||||
self.assertEqual(jax.device_count(), NUM_SLICES * NUM_HOSTS_PER_SLICE)
|
||||
|
||||
@jtu.skip_under_pytest("Test must run in an isolated process")
|
||||
def testMockWithSharding(self):
|
||||
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
|
||||
f = jax.jit(jnp.sum,
|
||||
in_shardings=NamedSharding(mesh, P('x')),
|
||||
out_shardings=NamedSharding(mesh, P()))
|
||||
|
||||
f_lowered = f.lower(jnp.arange(16))
|
||||
hlo = f_lowered.compiler_ir()
|
||||
|
||||
mocked_count = NUM_SLICES * NUM_HOSTS_PER_SLICE
|
||||
self.assertIn(f'num_partitions = {mocked_count}', str(hlo))
|
||||
self.assertIn(
|
||||
f'sharding = "{{devices=[{mocked_count}]<=[{mocked_count}]}}"',
|
||||
str(hlo)
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user