From 5db78e7ae03350f0833b0f315ad7e9d69a14c039 Mon Sep 17 00:00:00 2001
From: Zac Cranko <zcrank@google.com>
Date: Tue, 18 Feb 2025 16:47:19 -0800
Subject: [PATCH] add distributed.is_initialized

---
 jax/_src/distributed.py        |  5 +-
 jax/distributed.py             |  1 +
 tests/BUILD                    |  5 ++
 tests/distributed_test.py      | 88 ++++++++++++++++++++++++++++++++++
 tests/multiprocess_gpu_test.py | 40 ----------------
 5 files changed, 98 insertions(+), 41 deletions(-)
 create mode 100644 tests/distributed_test.py

diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py
index 4867c1189..af50e2e9e 100644
--- a/jax/_src/distributed.py
+++ b/jax/_src/distributed.py
@@ -169,7 +169,6 @@ class State:
 
 global_state = State()
 
-
 def initialize(coordinator_address: str | None = None,
                num_processes: int | None = None,
                process_id: int | None = None,
@@ -265,6 +264,10 @@ def initialize(coordinator_address: str | None = None,
                           initialization_timeout, coordinator_bind_address)
 
 
+def is_initialized() -> bool:
+  """Check if the JAX distributed system is initialized."""
+  return global_state.client is not None
+
 def shutdown():
   """Shuts down the distributed system.
 
diff --git a/jax/distributed.py b/jax/distributed.py
index cf39b81f4..e5c5af195 100644
--- a/jax/distributed.py
+++ b/jax/distributed.py
@@ -14,5 +14,6 @@
 
 from jax._src.distributed import (
    initialize as initialize,
+   is_initialized as is_initialized,
    shutdown as shutdown,
 )
diff --git a/tests/BUILD b/tests/BUILD
index e56c5493d..84cd16448 100644
--- a/tests/BUILD
+++ b/tests/BUILD
@@ -125,6 +125,11 @@ jax_multiplatform_test(
     srcs = ["debug_nans_test.py"],
 )
 
+jax_multiplatform_test(
+    name = "distributed_test",
+    srcs = ["distributed_test.py"],
+)
+
 jax_py_test(
     name = "multiprocess_gpu_test",
     srcs = ["multiprocess_gpu_test.py"],
diff --git a/tests/distributed_test.py b/tests/distributed_test.py
new file mode 100644
index 000000000..3961932df
--- /dev/null
+++ b/tests/distributed_test.py
@@ -0,0 +1,88 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import subprocess
+import sys
+import threading
+import unittest
+
+from absl.testing import absltest, parameterized
+
+import jax
+from jax._src import distributed
+from jax._src import test_util as jtu
+
+try:
+  import portpicker
+except ImportError:
+  portpicker = None
+
+jax.config.parse_flags_with_absl()
+
+
+@unittest.skipIf(not portpicker, "Test requires portpicker")
+class DistributedTest(jtu.JaxTestCase):
+  # TODO(phawkins): Enable after https://github.com/jax-ml/jax/issues/11222
+  # is fixed.
+  @unittest.SkipTest
+  def testInitializeAndShutdown(self):
+    if not jtu.test_device_matches(["gpu"]):
+      self.skipTest("Test only works with GPUs.")
+    # Tests the public APIs. Since they use global state, we cannot use
+    # concurrency to simulate multiple tasks.
+    port = portpicker.pick_unused_port()
+    jax.distributed.initialize(
+      coordinator_address=f"localhost:{port}", num_processes=1, process_id=0
+    )
+    jax.distributed.shutdown()
+
+  @parameterized.parameters([1, 2, 4])
+  def testConcurrentInitializeAndShutdown(self, n):
+    if not jtu.test_device_matches(["gpu"]):
+      self.skipTest("Test only works with GPUs.")
+    port = portpicker.pick_unused_port()
+
+    def task(i):
+      # We can't call the public APIs directly because they use global state.
+      state = distributed.State()
+      state.initialize(
+        coordinator_address=f"localhost:{port}", num_processes=n, process_id=i
+      )
+      state.shutdown()
+
+    threads = [threading.Thread(target=task, args=(i,)) for i in range(n)]
+    for thread in threads:
+      thread.start()
+    for thread in threads:
+      thread.join()
+
+  def test_is_distributed_initialized(self):
+    # Run in subprocess to isolate side effects from jax.distributed.initialize which conflict with other
+    # tests. Unfortunately this can't be avoided by calling jax.distributed.shutdown, as the XLA backend
+    # will be warmed up, which yields a RuntimeError on subsequent calls to initialize.
+    port = portpicker.pick_unused_port() # type: ignore
+    cmd = f"""import jax;
+    assert not jax.distributed.is_initialized();
+    jax.distributed.initialize('localhost:{port}', 1, 0);
+    assert jax.distributed.is_initialized();
+    """.replace("\n", ' ')
+
+    result = subprocess.run([sys.executable, "-c", cmd], capture_output=True)
+    self.assertEqual(
+      result.returncode, 0, msg=f"Test failed with:\n{result.stdout}\n{result.stderr}"
+    )
+
+
+if __name__ == "__main__":
+  absltest.main(testLoader=jtu.JaxTestLoader())
diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py
index 5c84f8c69..fe9922148 100644
--- a/tests/multiprocess_gpu_test.py
+++ b/tests/multiprocess_gpu_test.py
@@ -17,17 +17,14 @@ import os
 import shutil
 import subprocess
 import sys
-import threading
 import unittest
 import functools
 
 from absl.testing import absltest
-from absl.testing import parameterized
 import numpy as np
 
 import jax
 from jax._src import core
-from jax._src import distributed
 from jax._src import test_util as jtu
 from jax._src import util
 from jax.experimental import pjit
@@ -43,43 +40,6 @@ except ImportError:
 
 jax.config.parse_flags_with_absl()
 
-@unittest.skipIf(not portpicker, "Test requires portpicker")
-class DistributedTest(jtu.JaxTestCase):
-
-  # TODO(phawkins): Enable after https://github.com/jax-ml/jax/issues/11222
-  # is fixed.
-  @unittest.SkipTest
-  def testInitializeAndShutdown(self):
-    if not jtu.test_device_matches(['gpu']):
-      self.skipTest('Test only works with GPUs.')
-    # Tests the public APIs. Since they use global state, we cannot use
-    # concurrency to simulate multiple tasks.
-    port = portpicker.pick_unused_port()
-    jax.distributed.initialize(coordinator_address=f"localhost:{port}",
-                               num_processes=1,
-                               process_id=0)
-    jax.distributed.shutdown()
-
-
-  @parameterized.parameters([1, 2, 4])
-  def testConcurrentInitializeAndShutdown(self, n):
-    if not jtu.test_device_matches(['gpu']):
-      self.skipTest('Test only works with GPUs.')
-    port = portpicker.pick_unused_port()
-    def task(i):
-      # We can't call the public APIs directly because they use global state.
-      state = distributed.State()
-      state.initialize(coordinator_address=f"localhost:{port}",
-                       num_processes=n,
-                       process_id=i)
-      state.shutdown()
-
-    threads = [threading.Thread(target=task, args=(i,)) for i in range(n)]
-    for thread in threads:
-      thread.start()
-    for thread in threads:
-      thread.join()
-
 
 @unittest.skipIf(not portpicker, "Test requires portpicker")
 class MultiProcessGpuTest(jtu.JaxTestCase):