From 5db78e7ae03350f0833b0f315ad7e9d69a14c039 Mon Sep 17 00:00:00 2001 From: Zac Cranko <zcrank@google.com> Date: Tue, 18 Feb 2025 16:47:19 -0800 Subject: [PATCH] add distributed.is_initialized --- jax/_src/distributed.py | 5 +- jax/distributed.py | 1 + tests/BUILD | 5 ++ tests/distributed_test.py | 88 ++++++++++++++++++++++++++++++++++ tests/multiprocess_gpu_test.py | 40 ---------------- 5 files changed, 98 insertions(+), 41 deletions(-) create mode 100644 tests/distributed_test.py diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 4867c1189..af50e2e9e 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -169,7 +169,6 @@ class State: global_state = State() - def initialize(coordinator_address: str | None = None, num_processes: int | None = None, process_id: int | None = None, @@ -265,6 +264,10 @@ def initialize(coordinator_address: str | None = None, initialization_timeout, coordinator_bind_address) +def is_initialized() -> bool: + """Check if the JAX distributed system is initialized.""" + return global_state.client is not None + def shutdown(): """Shuts down the distributed system. diff --git a/jax/distributed.py b/jax/distributed.py index cf39b81f4..e5c5af195 100644 --- a/jax/distributed.py +++ b/jax/distributed.py @@ -14,5 +14,6 @@ from jax._src.distributed import ( initialize as initialize, + is_initialized as is_initialized, shutdown as shutdown, ) diff --git a/tests/BUILD b/tests/BUILD index e56c5493d..84cd16448 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -125,6 +125,11 @@ jax_multiplatform_test( srcs = ["debug_nans_test.py"], ) +jax_multiplatform_test( + name = "distributed_test", + srcs = ["distributed_test.py"], +) + jax_py_test( name = "multiprocess_gpu_test", srcs = ["multiprocess_gpu_test.py"], diff --git a/tests/distributed_test.py b/tests/distributed_test.py new file mode 100644 index 000000000..3961932df --- /dev/null +++ b/tests/distributed_test.py @@ -0,0 +1,88 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import sys +import threading +import unittest + +from absl.testing import absltest, parameterized + +import jax +from jax._src import distributed +from jax._src import test_util as jtu + +try: + import portpicker +except ImportError: + portpicker = None + +jax.config.parse_flags_with_absl() + + +@unittest.skipIf(not portpicker, "Test requires portpicker") +class DistributedTest(jtu.JaxTestCase): + # TODO(phawkins): Enable after https://github.com/jax-ml/jax/issues/11222 + # is fixed. + @unittest.SkipTest + def testInitializeAndShutdown(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Test only works with GPUs.") + # Tests the public APIs. Since they use global state, we cannot use + # concurrency to simulate multiple tasks. + port = portpicker.pick_unused_port() + jax.distributed.initialize( + coordinator_address=f"localhost:{port}", num_processes=1, process_id=0 + ) + jax.distributed.shutdown() + + @parameterized.parameters([1, 2, 4]) + def testConcurrentInitializeAndShutdown(self, n): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Test only works with GPUs.") + port = portpicker.pick_unused_port() + + def task(i): + # We can't call the public APIs directly because they use global state. + state = distributed.State() + state.initialize( + coordinator_address=f"localhost:{port}", num_processes=n, process_id=i + ) + state.shutdown() + + threads = [threading.Thread(target=task, args=(i,)) for i in range(n)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + def test_is_distributed_initialized(self): + # Run in subprocess to isolate side effects from jax.distributed.initialize which conflict with other + # tests. Unfortunately this can't be avoided by calling jax.distributed.shutdown, as the XLA backend + # will be warmed up, which yields a RuntimeError on subsequent calls to initialize. + port = portpicker.pick_unused_port() # type: ignore + cmd = f"""import jax; + assert not jax.distributed.is_initialized(); + jax.distributed.initialize('localhost:{port}', 1, 0); + assert jax.distributed.is_initialized(); + """.replace("\n", ' ') + + result = subprocess.run([sys.executable, "-c", cmd], capture_output=True) + self.assertEqual( + result.returncode, 0, msg=f"Test failed with:\n{result.stdout}\n{result.stderr}" + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index 5c84f8c69..fe9922148 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -17,17 +17,14 @@ import os import shutil import subprocess import sys -import threading import unittest import functools from absl.testing import absltest -from absl.testing import parameterized import numpy as np import jax from jax._src import core -from jax._src import distributed from jax._src import test_util as jtu from jax._src import util from jax.experimental import pjit @@ -43,43 +40,6 @@ except ImportError: jax.config.parse_flags_with_absl() -@unittest.skipIf(not portpicker, "Test requires portpicker") -class DistributedTest(jtu.JaxTestCase): - - # TODO(phawkins): Enable after https://github.com/jax-ml/jax/issues/11222 - # is fixed. - @unittest.SkipTest - def testInitializeAndShutdown(self): - if not jtu.test_device_matches(['gpu']): - self.skipTest('Test only works with GPUs.') - # Tests the public APIs. Since they use global state, we cannot use - # concurrency to simulate multiple tasks. - port = portpicker.pick_unused_port() - jax.distributed.initialize(coordinator_address=f"localhost:{port}", - num_processes=1, - process_id=0) - jax.distributed.shutdown() - - - @parameterized.parameters([1, 2, 4]) - def testConcurrentInitializeAndShutdown(self, n): - if not jtu.test_device_matches(['gpu']): - self.skipTest('Test only works with GPUs.') - port = portpicker.pick_unused_port() - def task(i): - # We can't call the public APIs directly because they use global state. - state = distributed.State() - state.initialize(coordinator_address=f"localhost:{port}", - num_processes=n, - process_id=i) - state.shutdown() - - threads = [threading.Thread(target=task, args=(i,)) for i in range(n)] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - @unittest.skipIf(not portpicker, "Test requires portpicker") class MultiProcessGpuTest(jtu.JaxTestCase):