mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Rename mock_num_processes
to mock_num_gpu_processes
since this flag is only for GPUs. The naming change was a regression introduced in https://github.com/google/jax/pull/22619
PiperOrigin-RevId: 658061107
This commit is contained in:
parent
d696813b1f
commit
e3fc05ad5b
@ -84,8 +84,8 @@ _ROCM_VISIBLE_DEVICES = config.string_flag(
|
||||
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
|
||||
'comma-separate list of integer device IDs.')
|
||||
|
||||
_MOCK_NUM_PROCESSES = config.int_flag(
|
||||
name="mock_num_processes",
|
||||
_MOCK_NUM_GPU_PROCESSES = config.int_flag(
|
||||
name="mock_num_gpu_processes",
|
||||
default=0,
|
||||
help="Mock number of JAX processes in GPU client. Value zero turns "
|
||||
"off mocking.",
|
||||
@ -433,9 +433,9 @@ def make_gpu_client(
|
||||
if visible_devices != "all":
|
||||
allowed_devices = {int(x) for x in visible_devices.split(",")}
|
||||
|
||||
use_mock_gpu_client = _MOCK_NUM_PROCESSES.value > 0
|
||||
use_mock_gpu_client = _MOCK_NUM_GPU_PROCESSES.value > 0
|
||||
num_nodes = (
|
||||
_MOCK_NUM_PROCESSES.value
|
||||
_MOCK_NUM_GPU_PROCESSES.value
|
||||
if use_mock_gpu_client
|
||||
else distributed.global_state.num_processes
|
||||
)
|
||||
@ -633,10 +633,10 @@ def _options_from_jax_configs(plugin_name):
|
||||
visible_devices = CUDA_VISIBLE_DEVICES.value
|
||||
if visible_devices != 'all':
|
||||
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
|
||||
mock_processes = _MOCK_NUM_PROCESSES.value
|
||||
options['enable_mock_nccl'] = mock_processes > 0
|
||||
mock_gpu_processes = _MOCK_NUM_GPU_PROCESSES.value
|
||||
options['enable_mock_nccl'] = mock_gpu_processes > 0
|
||||
if options['enable_mock_nccl']:
|
||||
options['num_nodes'] = mock_processes
|
||||
options['num_nodes'] = mock_gpu_processes
|
||||
|
||||
return options
|
||||
|
||||
|
@ -27,7 +27,7 @@ jax.config.parse_flags_with_absl()
|
||||
NUM_SHARDS = 4
|
||||
|
||||
|
||||
@jtu.with_config(mock_num_processes=NUM_SHARDS)
|
||||
@jtu.with_config(mock_num_gpu_processes=NUM_SHARDS)
|
||||
class MockGPUTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user