rocm_jax/tests/distributed_test.py
Peter Hawkins 3e699ddec0 Unbreak jax.distributed initialization.
A recent change broke jax.distributed initialization, which was unsurprising because those APIs were not tested. In particular, we need to only initialize the service from the first process.

Fix it and add some tests that use the distributed service from multiple threads within a unit test. Move the state of jax.distributed into an object so it can be instantiated multiple times from a test case in parallel rather than being process-global.

[XLA:Python] Add gil release guards around distributed system init/shutdown. This allows testing using multiple threads.

PiperOrigin-RevId: 453480351
2022-06-07 11:10:17 -07:00

70 lines
2.2 KiB
Python

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax.config import config
import jax._src.distributed as distributed
import jax._src.lib
from jax._src import test_util as jtu
try:
import portpicker
except ImportError:
portpicker = None
config.parse_flags_with_absl()
@unittest.skipIf(jax._src.lib.xla_extension_version < 73,
"Test requires jaxlib 0.3.12 or newer.")
@unittest.skipIf(not portpicker, "Test requires portpicker")
class DistributedTest(jtu.JaxTestCase):
def testInitializeAndShutdown(self):
# Tests the public APIs. Since they use global state, we cannot use
# concurrency to simulate multiple tasks.
port = portpicker.pick_unused_port()
jax.distributed.initialize(coordinator_address=f"localhost:{port}",
num_processes=1,
process_id=0)
jax.distributed.shutdown()
@parameterized.parameters([1, 2, 4])
def testConcurrentInitializeAndShutdown(self, n):
port = portpicker.pick_unused_port()
def task(i):
# We can't call the public APIs directly because they use global state.
state = distributed.State()
state.initialize(coordinator_address=f"localhost:{port}",
num_processes=n,
process_id=i)
state.shutdown()
threads = [threading.Thread(target=task, args=(i,)) for i in range(n)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())