diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 5ca72ab2c..facfd3e3f 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -6,3 +6,4 @@ pillow>=8.3.1,<9.1.0 pytest-benchmark pytest-xdist wheel +numpy<1.23.0 diff --git a/tests/distributed_test.py b/tests/distributed_test.py index 7c575a1ca..53955b84f 100644 --- a/tests/distributed_test.py +++ b/tests/distributed_test.py @@ -40,6 +40,9 @@ config.parse_flags_with_absl() @unittest.skipIf(not portpicker, "Test requires portpicker") class DistributedTest(jtu.JaxTestCase): + # TODO(phawkins): Enable after https://github.com/google/jax/issues/11222 + # is fixed. + @unittest.SkipTest def testInitializeAndShutdown(self): # Tests the public APIs. Since they use global state, we cannot use # concurrency to simulate multiple tasks.