diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index f9dd3ce52..a1486a7d8 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -15,6 +15,7 @@ import threading import time from typing import Sequence +import unittest from absl.testing import absltest from absl.testing import parameterized @@ -30,6 +31,10 @@ import numpy as np config.parse_flags_with_absl() jtu.request_cpu_devices(8) +try: + import cloudpickle # noqa +except (ModuleNotFoundError, ImportError): + raise unittest.SkipTest("tests depend on cloudpickle library") def _colocated_cpu_devices( devices: Sequence[jax.Device], diff --git a/tests/export_test.py b/tests/export_test.py index 0eb777a61..e4a866639 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -292,6 +292,7 @@ class JaxExportTest(jtu.JaxTestCase): exp_f.call((a, b), a=a, b=b)) def test_pytree_namedtuple(self): + if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers") T = collections.namedtuple("SomeType", ("a", "b", "c")) export.register_namedtuple_serialization( T, @@ -317,6 +318,7 @@ class JaxExportTest(jtu.JaxTestCase): tree_util.tree_structure(res)) def test_pytree_namedtuple_error(self): + if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers") T = collections.namedtuple("SomeType", ("a", "b")) x = T(a=1, b=2) with self.assertRaisesRegex( @@ -363,6 +365,7 @@ class JaxExportTest(jtu.JaxTestCase): ) def test_pytree_custom_types(self): + if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers") x1 = collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]) @tree_util.register_pytree_node_class