skip tests with extra requirements

This commit is contained in:
Matthew Johnson 2025-02-05 01:37:56 +00:00
parent 02f4531310
commit 1ae02bc069
2 changed files with 8 additions and 0 deletions

View File

@ -15,6 +15,7 @@
import threading
import time
from typing import Sequence
import unittest
from absl.testing import absltest
from absl.testing import parameterized
@ -30,6 +31,10 @@ import numpy as np
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
try:
import cloudpickle # noqa
except (ModuleNotFoundError, ImportError):
raise unittest.SkipTest("tests depend on cloudpickle library")
def _colocated_cpu_devices(
devices: Sequence[jax.Device],

View File

@ -292,6 +292,7 @@ class JaxExportTest(jtu.JaxTestCase):
exp_f.call((a, b), a=a, b=b))
def test_pytree_namedtuple(self):
if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers")
T = collections.namedtuple("SomeType", ("a", "b", "c"))
export.register_namedtuple_serialization(
T,
@ -317,6 +318,7 @@ class JaxExportTest(jtu.JaxTestCase):
tree_util.tree_structure(res))
def test_pytree_namedtuple_error(self):
if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers")
T = collections.namedtuple("SomeType", ("a", "b"))
x = T(a=1, b=2)
with self.assertRaisesRegex(
@ -363,6 +365,7 @@ class JaxExportTest(jtu.JaxTestCase):
)
def test_pytree_custom_types(self):
if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers")
x1 = collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)])
@tree_util.register_pytree_node_class