mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00

Also move reading jax config value to be right before the client is created. Previously they were read before calling register_plugin, which happens during import and before any call of jax.config.update. The decorator in mock_gpu_test was used wrongly. jtu.run_on_devices will create the client before jax.config.update is called, which is not desired. Remove the decorator will not fail CPU/TPU tests because the mesh will check the num_shard and the number of devices in the client and skip it if it does not match. generate_pjrt_gpu_plugin_options is only used in places that do not require compatibility so do not need to update xla_client version. PiperOrigin-RevId: 611610915
64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
import math
|
|
|
|
from absl.testing import absltest
|
|
import jax
|
|
from jax import config
|
|
from jax._src import test_util as jtu
|
|
import jax.numpy as jnp
|
|
from jax.sharding import NamedSharding
|
|
from jax.sharding import PartitionSpec as P
|
|
import numpy as np
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
class MockGPUTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
jax.config.update('use_mock_gpu_client', True)
|
|
|
|
def tearDown(self):
|
|
jax.config.update('use_mock_gpu_client', False)
|
|
jax.config.update('mock_num_gpus', 1)
|
|
super().tearDown()
|
|
|
|
def testMockWithSharding(self):
|
|
num_shards = 16
|
|
jax.config.update('mock_num_gpus', num_shards)
|
|
mesh = jtu.create_global_mesh((num_shards,), ('x',))
|
|
@partial(
|
|
jax.jit,
|
|
in_shardings=NamedSharding(mesh, P('x',)),
|
|
out_shardings=NamedSharding(mesh, P('x',)),
|
|
)
|
|
def f(x, y):
|
|
z = x @ y
|
|
return z @ y
|
|
|
|
shape = (64, 64)
|
|
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
|
y = x + 1
|
|
f_lowered = f.lower(x, y)
|
|
hlo = f_lowered.compiler_ir()
|
|
self.assertIn('sharding = "{devices=[16,1]<=[16]}"', str(hlo))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|