mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add more API set up for Mock GPU client. Also clean up previous mock GPU client
API. PiperOrigin-RevId: 570153877
This commit is contained in:
parent
d45fa22424
commit
c12929b012
@ -84,6 +84,18 @@ _ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string(
|
||||
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
|
||||
'comma-separate list of integer device IDs.')
|
||||
|
||||
_USE_MOCK_GPU_CLIENT = jax_config.DEFINE_bool(
|
||||
name="use_mock_gpu_client",
|
||||
default=False,
|
||||
help="If True, use a mock GPU client instead of a real one.",
|
||||
)
|
||||
|
||||
_MOCK_NUM_GPUS = jax_config.DEFINE_integer(
|
||||
name="mock_num_gpus",
|
||||
default=1,
|
||||
help="Mock GPU client number of gpus.",
|
||||
)
|
||||
|
||||
|
||||
# Backends
|
||||
|
||||
@ -221,12 +233,28 @@ def make_gpu_client(
|
||||
if platform_name == "cuda":
|
||||
_check_cuda_versions()
|
||||
|
||||
if xla_extension_version <= 199:
|
||||
return xla_client.make_gpu_client(
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
platform_name=platform_name,
|
||||
allowed_devices=allowed_devices,
|
||||
)
|
||||
use_mock_gpu_client = _USE_MOCK_GPU_CLIENT.value
|
||||
num_nodes = (
|
||||
_MOCK_NUM_GPUS.value
|
||||
if use_mock_gpu_client
|
||||
else distributed.global_state.num_processes
|
||||
)
|
||||
|
||||
return xla_client.make_gpu_client(
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
num_nodes=num_nodes,
|
||||
platform_name=platform_name,
|
||||
allowed_devices=allowed_devices,
|
||||
mock=use_mock_gpu_client, # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
|
||||
|
15
tests/BUILD
15
tests/BUILD
@ -231,6 +231,21 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "mock_gpu_test",
|
||||
srcs = ["mock_gpu_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
tags = [
|
||||
"config-cuda-only",
|
||||
],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "array_test",
|
||||
srcs = ["array_test.py"],
|
||||
|
69
tests/mock_gpu_test.py
Normal file
69
tests/mock_gpu_test.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class MockGPUTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
jax.config.update('use_mock_gpu_client', True)
|
||||
|
||||
def tearDown(self):
|
||||
jax.config.update('use_mock_gpu_client', False)
|
||||
jax.config.update('mock_num_gpus', 1)
|
||||
super().tearDown()
|
||||
|
||||
def testMockWithSharding(self):
|
||||
if xla_extension_version < 200:
|
||||
return
|
||||
num_shards = 16
|
||||
jax.config.update('mock_num_gpus', num_shards)
|
||||
mesh_shape = (num_shards,)
|
||||
axis_names = ('x',)
|
||||
mesh_devices = np.array(jax.devices()).reshape(mesh_shape)
|
||||
mesh = jax.sharding.Mesh(mesh_devices, axis_names)
|
||||
@partial(
|
||||
jax.jit,
|
||||
in_shardings=NamedSharding(mesh, P('x',)),
|
||||
out_shardings=NamedSharding(mesh, P('x',)),
|
||||
)
|
||||
def f(x, y):
|
||||
z = x @ y
|
||||
return z @ y
|
||||
|
||||
shape = (64, 64)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
||||
y = x + 1
|
||||
f_lowered = f.lower(x, y)
|
||||
hlo = f_lowered.compiler_ir()
|
||||
self.assertIn('sharding = "{devices=[16,1]<=[16]}"', str(hlo))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user