mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add pjrt_c_api_unimplemented
pytest marker to skip unsupported tests.
Also adds `test_util.pytest_mark_if_available` helper function.
This commit is contained in:
parent
68c43e6c99
commit
f90b5eed52
@ -332,6 +332,17 @@ def skip_on_flag(flag_name, skip_value):
|
||||
return skip
|
||||
|
||||
|
||||
def pytest_mark_if_available(marker: str):
|
||||
"""A decorator for test classes or methods to pytest.mark if installed."""
|
||||
def wrap(func_or_class):
|
||||
try:
|
||||
import pytest
|
||||
except ImportError:
|
||||
return func_or_class
|
||||
return getattr(pytest.mark, marker)(func_or_class)
|
||||
return wrap
|
||||
|
||||
|
||||
def format_test_name_suffix(opname, shapes, dtypes):
|
||||
arg_descriptions = (format_shape_dtype_string(shape, dtype)
|
||||
for shape, dtype in zip(shapes, dtypes))
|
||||
|
@ -1,5 +1,6 @@
|
||||
[pytest]
|
||||
markers =
|
||||
pjrt_c_api_unimplemented: indicates that a test will fail using the PJRT C API due to unimplemented functionality
|
||||
multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators
|
||||
SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI
|
||||
filterwarnings =
|
||||
|
@ -23,6 +23,7 @@ from jax._src.lib import xla_bridge as xb
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class ClearBackendsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_clear_backends(self):
|
||||
|
@ -55,6 +55,7 @@ def tearDownModule():
|
||||
|
||||
foo = 2
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class CliDebuggerTest(jtu.JaxTestCase):
|
||||
|
||||
def test_debugger_eof(self):
|
||||
|
@ -60,6 +60,7 @@ class DummyDevice:
|
||||
self.platform = platform
|
||||
self.id = id
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class DebugPrintTest(jtu.JaxTestCase):
|
||||
|
||||
def tearDown(self):
|
||||
@ -223,6 +224,7 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class DebugPrintTransformationTest(jtu.JaxTestCase):
|
||||
|
||||
def test_debug_print_batching(self):
|
||||
@ -500,6 +502,7 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(output(), "hello bwd: 2.0 3.0\n")
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
def _assertLinesEqual(self, text1, text2):
|
||||
@ -736,6 +739,7 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
b3: 2
|
||||
"""))
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
|
||||
def _assertLinesEqual(self, text1, text2):
|
||||
@ -978,6 +982,7 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
f(jnp.arange(2))
|
||||
jax.effects_barrier()
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class VisualizeShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def _create_devices(self, shape):
|
||||
@ -1160,6 +1165,7 @@ class VisualizeShardingTest(jtu.JaxTestCase):
|
||||
""")
|
||||
self.assertEqual(output(), expected)
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class InspectShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_inspect_sharding_is_called_in_pjit(self):
|
||||
|
@ -234,6 +234,7 @@ def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase,
|
||||
return assertMultiLineStrippedEqual(tst, expected, what)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -2027,6 +2028,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
"""Tests for hcb.call"""
|
||||
|
||||
@ -2455,6 +2457,7 @@ def call_jax_other_device(jax_outside_fun, arg, *, device):
|
||||
return make_call(arg)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class CallJaxTest(jtu.JaxTestCase):
|
||||
"""Tests using `call_jax_other_device`."""
|
||||
|
||||
@ -2529,6 +2532,7 @@ class CallJaxTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(res_jax, res_outside)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -157,6 +157,7 @@ CALL_TF_IMPLEMENTATIONS = {
|
||||
}
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class CallToTFTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -27,6 +27,7 @@ import numpy as np
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # infeed
|
||||
class InfeedTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion.
|
||||
|
@ -573,6 +573,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0])
|
||||
self.assertIs(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0])
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class EffectOrderingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_can_execute_python_callback(self):
|
||||
@ -686,6 +687,7 @@ class ParallelEffectsTest(jtu.JaxTestCase):
|
||||
return x
|
||||
jax.pmap(f)(jnp.arange(jax.local_device_count()))
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
def test_can_pmap_unordered_callback(self):
|
||||
# TODO(sharadmv): enable this test on GPU and TPU when backends are
|
||||
# supported
|
||||
|
@ -775,6 +775,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
self.assertAllClose(res0, res, check_dtypes=True)
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # outfeed
|
||||
def testOutfeed(self):
|
||||
devices = np.array(jax.local_devices())
|
||||
nr_devices = len(devices)
|
||||
|
@ -122,6 +122,7 @@ mlir.register_lowering(callback_p, callback_lowering, platform="gpu")
|
||||
mlir.register_lowering(callback_p, callback_lowering, platform="tpu")
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class PythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def tearDown(self):
|
||||
@ -451,6 +452,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
out,
|
||||
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def tearDown(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user