Add pjrt_c_api_unimplemented pytest marker to skip unsupported tests.

Also adds `test_util.pytest_mark_if_available` helper function.
This commit is contained in:
Skye Wanderman-Milne 2022-12-27 21:32:00 +00:00
parent 68c43e6c99
commit f90b5eed52
11 changed files with 31 additions and 0 deletions

View File

@ -332,6 +332,17 @@ def skip_on_flag(flag_name, skip_value):
return skip
def pytest_mark_if_available(marker: str):
"""A decorator for test classes or methods to pytest.mark if installed."""
def wrap(func_or_class):
try:
import pytest
except ImportError:
return func_or_class
return getattr(pytest.mark, marker)(func_or_class)
return wrap
def format_test_name_suffix(opname, shapes, dtypes):
arg_descriptions = (format_shape_dtype_string(shape, dtype)
for shape, dtype in zip(shapes, dtypes))

View File

@ -1,5 +1,6 @@
[pytest]
markers =
pjrt_c_api_unimplemented: indicates that a test will fail using the PJRT C API due to unimplemented functionality
multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators
SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI
filterwarnings =

View File

@ -23,6 +23,7 @@ from jax._src.lib import xla_bridge as xb
config.parse_flags_with_absl()
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class ClearBackendsTest(jtu.JaxTestCase):
def test_clear_backends(self):

View File

@ -55,6 +55,7 @@ def tearDownModule():
foo = 2
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class CliDebuggerTest(jtu.JaxTestCase):
def test_debugger_eof(self):

View File

@ -60,6 +60,7 @@ class DummyDevice:
self.platform = platform
self.id = id
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class DebugPrintTest(jtu.JaxTestCase):
def tearDown(self):
@ -223,6 +224,7 @@ class DebugPrintTest(jtu.JaxTestCase):
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class DebugPrintTransformationTest(jtu.JaxTestCase):
def test_debug_print_batching(self):
@ -500,6 +502,7 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
jax.effects_barrier()
self.assertEqual(output(), "hello bwd: 2.0 3.0\n")
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class DebugPrintControlFlowTest(jtu.JaxTestCase):
def _assertLinesEqual(self, text1, text2):
@ -736,6 +739,7 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
b3: 2
"""))
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class DebugPrintParallelTest(jtu.JaxTestCase):
def _assertLinesEqual(self, text1, text2):
@ -978,6 +982,7 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
f(jnp.arange(2))
jax.effects_barrier()
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class VisualizeShardingTest(jtu.JaxTestCase):
def _create_devices(self, shape):
@ -1160,6 +1165,7 @@ class VisualizeShardingTest(jtu.JaxTestCase):
""")
self.assertEqual(output(), expected)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class InspectShardingTest(jtu.JaxTestCase):
def test_inspect_sharding_is_called_in_pjit(self):

View File

@ -234,6 +234,7 @@ def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase,
return assertMultiLineStrippedEqual(tst, expected, what)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class HostCallbackTapTest(jtu.JaxTestCase):
def setUp(self):
@ -2027,6 +2028,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class HostCallbackCallTest(jtu.JaxTestCase):
"""Tests for hcb.call"""
@ -2455,6 +2457,7 @@ def call_jax_other_device(jax_outside_fun, arg, *, device):
return make_call(arg)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class CallJaxTest(jtu.JaxTestCase):
"""Tests using `call_jax_other_device`."""
@ -2529,6 +2532,7 @@ class CallJaxTest(jtu.JaxTestCase):
self.assertAllClose(res_jax, res_outside)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class OutfeedRewriterTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -157,6 +157,7 @@ CALL_TF_IMPLEMENTATIONS = {
}
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class CallToTFTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -27,6 +27,7 @@ import numpy as np
config.parse_flags_with_absl()
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # infeed
class InfeedTest(jtu.JaxTestCase):
@jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion.

View File

@ -573,6 +573,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0])
self.assertIs(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0])
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class EffectOrderingTest(jtu.JaxTestCase):
def test_can_execute_python_callback(self):
@ -686,6 +687,7 @@ class ParallelEffectsTest(jtu.JaxTestCase):
return x
jax.pmap(f)(jnp.arange(jax.local_device_count()))
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
def test_can_pmap_unordered_callback(self):
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported

View File

@ -775,6 +775,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(res0, res, check_dtypes=True)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # outfeed
def testOutfeed(self):
devices = np.array(jax.local_devices())
nr_devices = len(devices)

View File

@ -122,6 +122,7 @@ mlir.register_lowering(callback_p, callback_lowering, platform="gpu")
mlir.register_lowering(callback_p, callback_lowering, platform="tpu")
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class PythonCallbackTest(jtu.JaxTestCase):
def tearDown(self):
@ -451,6 +452,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
out,
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class PurePythonCallbackTest(jtu.JaxTestCase):
def tearDown(self):