mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

A recent change broke jax.distributed initialization, which was unsurprising because those APIs were not tested. In particular, we need to only initialize the service from the first process. Fix it and add some tests that use the distributed service from multiple threads within a unit test. Move the state of jax.distributed into an object so it can be instantiated multiple times from a test case in parallel rather than being process-global. [XLA:Python] Add gil release guards around distributed system init/shutdown. This allows testing using multiple threads. PiperOrigin-RevId: 453480351
70 lines
2.2 KiB
Python
70 lines
2.2 KiB
Python
# Copyright 2022 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import threading
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import jax
|
|
from jax.config import config
|
|
import jax._src.distributed as distributed
|
|
import jax._src.lib
|
|
from jax._src import test_util as jtu
|
|
|
|
try:
|
|
import portpicker
|
|
except ImportError:
|
|
portpicker = None
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
@unittest.skipIf(jax._src.lib.xla_extension_version < 73,
|
|
"Test requires jaxlib 0.3.12 or newer.")
|
|
@unittest.skipIf(not portpicker, "Test requires portpicker")
|
|
class DistributedTest(jtu.JaxTestCase):
|
|
|
|
def testInitializeAndShutdown(self):
|
|
# Tests the public APIs. Since they use global state, we cannot use
|
|
# concurrency to simulate multiple tasks.
|
|
port = portpicker.pick_unused_port()
|
|
jax.distributed.initialize(coordinator_address=f"localhost:{port}",
|
|
num_processes=1,
|
|
process_id=0)
|
|
jax.distributed.shutdown()
|
|
|
|
|
|
@parameterized.parameters([1, 2, 4])
|
|
def testConcurrentInitializeAndShutdown(self, n):
|
|
port = portpicker.pick_unused_port()
|
|
def task(i):
|
|
# We can't call the public APIs directly because they use global state.
|
|
state = distributed.State()
|
|
state.initialize(coordinator_address=f"localhost:{port}",
|
|
num_processes=n,
|
|
process_id=i)
|
|
state.shutdown()
|
|
|
|
threads = [threading.Thread(target=task, args=(i,)) for i in range(n)]
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|