mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
skip tests with extra requirements
This commit is contained in:
parent
02f4531310
commit
1ae02bc069
@ -15,6 +15,7 @@
|
||||
import threading
|
||||
import time
|
||||
from typing import Sequence
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -30,6 +31,10 @@ import numpy as np
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
||||
try:
|
||||
import cloudpickle # noqa
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
raise unittest.SkipTest("tests depend on cloudpickle library")
|
||||
|
||||
def _colocated_cpu_devices(
|
||||
devices: Sequence[jax.Device],
|
||||
|
@ -292,6 +292,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
exp_f.call((a, b), a=a, b=b))
|
||||
|
||||
def test_pytree_namedtuple(self):
|
||||
if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers")
|
||||
T = collections.namedtuple("SomeType", ("a", "b", "c"))
|
||||
export.register_namedtuple_serialization(
|
||||
T,
|
||||
@ -317,6 +318,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
tree_util.tree_structure(res))
|
||||
|
||||
def test_pytree_namedtuple_error(self):
|
||||
if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers")
|
||||
T = collections.namedtuple("SomeType", ("a", "b"))
|
||||
x = T(a=1, b=2)
|
||||
with self.assertRaisesRegex(
|
||||
@ -363,6 +365,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def test_pytree_custom_types(self):
|
||||
if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers")
|
||||
x1 = collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)])
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
|
Loading…
x
Reference in New Issue
Block a user