Remove 'pjrt_c_api_unimplemented' pytest mark.

Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
This commit is contained in:
Skye Wanderman-Milne 2023-03-24 20:55:04 +00:00
parent 6ed66ada0f
commit ef5e4a4035
11 changed files with 11 additions and 26 deletions

View File

@ -58,14 +58,12 @@ jobs:
env:
JAX_PLATFORMS: tpu,cpu
JAX_USE_PJRT_C_API_ON_TPU: ${{ matrix.pjrt }}
EXTRA_TAGS: "${{ matrix.pjrt == 'true' && 'and not pjrt_c_api_unimplemented' || '' }}"
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
--maxfail=20 -m "not multiaccelerator ${EXTRA_TAGS}" tests examples
--maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python -m pytest --tb=short --maxfail=20 \
-m "multiaccelerator ${EXTRA_TAGS}" tests
python -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' }}

View File

@ -1,6 +1,5 @@
[pytest]
markers =
pjrt_c_api_unimplemented: indicates that a test will fail using the PJRT C API due to unimplemented functionality
multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators
SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI
filterwarnings =

View File

@ -13,6 +13,8 @@
# limitations under the License.
"""Tests for release_backend_clients."""
import unittest
from absl.testing import absltest
import jax
@ -23,10 +25,12 @@ from jax._src import xla_bridge as xb
config.parse_flags_with_absl()
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class ClearBackendsTest(jtu.JaxTestCase):
def test_clear_backends(self):
if xb.using_pjrt_c_api():
raise unittest.SkipTest('test crashes runtime with PJRT C API')
g = jax.jit(lambda x, y: x * y)
self.assertEqual(g(1, 2), 2)
self.assertNotEmpty(xb.get_backend().live_executables())

View File

@ -54,7 +54,6 @@ def tearDownModule():
foo = 2
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class CliDebuggerTest(jtu.JaxTestCase):
def test_debugger_eof(self):

View File

@ -59,7 +59,6 @@ class DummyDevice:
self.platform = platform
self.id = id
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class DebugPrintTest(jtu.JaxTestCase):
def tearDown(self):
@ -221,7 +220,6 @@ class DebugPrintTest(jtu.JaxTestCase):
"""))
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class DebugPrintTransformationTest(jtu.JaxTestCase):
def test_debug_print_batching(self):
@ -499,7 +497,6 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
jax.effects_barrier()
self.assertEqual(output(), "hello bwd: 2.0 3.0\n")
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class DebugPrintControlFlowTest(jtu.JaxTestCase):
def _assertLinesEqual(self, text1, text2):
@ -736,7 +733,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
b3: 2
"""))
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class DebugPrintParallelTest(jtu.JaxTestCase):
def _assertLinesEqual(self, text1, text2):
@ -965,7 +961,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
f(jnp.arange(2))
jax.effects_barrier()
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class VisualizeShardingTest(jtu.JaxTestCase):
def _create_devices(self, shape):
@ -1148,13 +1143,14 @@ class VisualizeShardingTest(jtu.JaxTestCase):
""")
self.assertEqual(output(), expected)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class InspectShardingTest(jtu.JaxTestCase):
def test_inspect_sharding_is_called_in_pjit(self):
if jtu.is_cloud_tpu():
raise unittest.SkipTest("Inspect sharding is not supported on libtpu.")
if xla_bridge.using_pjrt_c_api():
raise unittest.SkipTest("Inspect sharding is not supported on Cloud TPU")
is_called = False
def _cb(sd):

View File

@ -232,7 +232,6 @@ def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase,
return assertMultiLineStrippedEqual(tst, expected, what)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class HostCallbackTapTest(jtu.JaxTestCase):
def setUp(self):
@ -2022,7 +2021,6 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class HostCallbackCallTest(jtu.JaxTestCase):
"""Tests for hcb.call"""
@ -2453,7 +2451,6 @@ def call_jax_other_device(jax_outside_fun, arg, *, device):
return make_call(arg)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class CallJaxTest(jtu.JaxTestCase):
"""Tests using `call_jax_other_device`."""
@ -2530,7 +2527,6 @@ class CallJaxTest(jtu.JaxTestCase):
self.assertAllClose(res_jax, res_outside)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class OutfeedRewriterTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -158,7 +158,6 @@ CALL_TF_IMPLEMENTATIONS = {
}
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
class CallToTFTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -30,7 +30,6 @@ import numpy as np
config.parse_flags_with_absl()
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # infeed
class InfeedTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -590,7 +590,6 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens[foo_effect][0])
self.assertIs(foo2_token, dispatch.runtime_tokens.tokens[foo2_effect][0])
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class EffectOrderingTest(jtu.JaxTestCase):
def test_can_execute_python_callback(self):
@ -704,7 +703,6 @@ class ParallelEffectsTest(jtu.JaxTestCase):
return x
jax.pmap(f)(jnp.arange(jax.local_device_count()))
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
def test_can_pmap_unordered_callback(self):
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported

View File

@ -800,7 +800,6 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(res0, res, check_dtypes=True)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # outfeed
def testOutfeed(self):
if xla_bridge.using_pjrt_c_api():
raise unittest.SkipTest('outfeed not implemented in PJRT C API')
@ -1109,12 +1108,13 @@ class PJitTest(jtu.BufferDonationTestCase):
"valid for values of rank at least 4, but was applied to a value of rank 1"):
pjit_f(jnp.array([1, 2, 3]))
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # custom partitoner
@jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU.
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_custom_partitioner(self):
if jtu.is_cloud_tpu():
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
if xla_bridge.using_pjrt_c_api():
raise unittest.SkipTest('custom partitioning not implemented in PJRT C API')
def partition(
precision, arg_shapes, arg_shardings, result_shape, result_sharding

View File

@ -125,7 +125,6 @@ mlir.register_lowering(callback_p, callback_lowering, platform="gpu")
mlir.register_lowering(callback_p, callback_lowering, platform="tpu")
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class PythonCallbackTest(jtu.JaxTestCase):
def tearDown(self):
@ -455,7 +454,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
out,
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class PurePythonCallbackTest(jtu.JaxTestCase):
def tearDown(self):
@ -877,7 +875,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
x = np.arange(6, dtype=np.int32).reshape((3, 2))
np.testing.assert_allclose(g(x), x)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class IOPythonCallbackTest(jtu.JaxTestCase):
def setUp(self):