Only initialize GPU backends if they are not already initialized

PiperOrigin-RevId: 456664792
This commit is contained in:
Yash Katariya 2022-06-22 19:39:18 -07:00 committed by jax authors
parent b623ed58b0
commit 1908da33af
2 changed files with 4 additions and 0 deletions

View File

@ -6,3 +6,4 @@ pillow>=8.3.1,<9.1.0
pytest-benchmark
pytest-xdist
wheel
numpy<1.23.0

View File

@ -40,6 +40,9 @@ config.parse_flags_with_absl()
@unittest.skipIf(not portpicker, "Test requires portpicker")
class DistributedTest(jtu.JaxTestCase):
# TODO(phawkins): Enable after https://github.com/google/jax/issues/11222
# is fixed.
@unittest.SkipTest
def testInitializeAndShutdown(self):
# Tests the public APIs. Since they use global state, we cannot use
# concurrency to simulate multiple tasks.