Rename mock_num_processes to mock_num_gpu_processes since this flag is only for GPUs. The naming change was a regression introduced in https://github.com/google/jax/pull/22619

PiperOrigin-RevId: 658061107
This commit is contained in:
Yash Katariya 2024-07-31 10:40:22 -07:00 committed by jax authors
parent d696813b1f
commit e3fc05ad5b
2 changed files with 8 additions and 8 deletions

View File

@ -84,8 +84,8 @@ _ROCM_VISIBLE_DEVICES = config.string_flag(
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.')
_MOCK_NUM_PROCESSES = config.int_flag(
name="mock_num_processes",
_MOCK_NUM_GPU_PROCESSES = config.int_flag(
name="mock_num_gpu_processes",
default=0,
help="Mock number of JAX processes in GPU client. Value zero turns "
"off mocking.",
@ -433,9 +433,9 @@ def make_gpu_client(
if visible_devices != "all":
allowed_devices = {int(x) for x in visible_devices.split(",")}
use_mock_gpu_client = _MOCK_NUM_PROCESSES.value > 0
use_mock_gpu_client = _MOCK_NUM_GPU_PROCESSES.value > 0
num_nodes = (
_MOCK_NUM_PROCESSES.value
_MOCK_NUM_GPU_PROCESSES.value
if use_mock_gpu_client
else distributed.global_state.num_processes
)
@ -633,10 +633,10 @@ def _options_from_jax_configs(plugin_name):
visible_devices = CUDA_VISIBLE_DEVICES.value
if visible_devices != 'all':
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
mock_processes = _MOCK_NUM_PROCESSES.value
options['enable_mock_nccl'] = mock_processes > 0
mock_gpu_processes = _MOCK_NUM_GPU_PROCESSES.value
options['enable_mock_nccl'] = mock_gpu_processes > 0
if options['enable_mock_nccl']:
options['num_nodes'] = mock_processes
options['num_nodes'] = mock_gpu_processes
return options

View File

@ -27,7 +27,7 @@ jax.config.parse_flags_with_absl()
NUM_SHARDS = 4
@jtu.with_config(mock_num_processes=NUM_SHARDS)
@jtu.with_config(mock_num_gpu_processes=NUM_SHARDS)
class MockGPUTest(jtu.JaxTestCase):
def setUp(self):