mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove 'pjrt_c_api_unimplemented' pytest mark.
Instead, we skip tests that the PJRT C API doesn't support. We had this tag for feature development so it was easy to broadly disable, but now we don't expect to need to do that.
This commit is contained in:
parent
6ed66ada0f
commit
ef5e4a4035
6
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
6
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
@ -58,14 +58,12 @@ jobs:
|
||||
env:
|
||||
JAX_PLATFORMS: tpu,cpu
|
||||
JAX_USE_PJRT_C_API_ON_TPU: ${{ matrix.pjrt }}
|
||||
EXTRA_TAGS: "${{ matrix.pjrt == 'true' && 'and not pjrt_c_api_unimplemented' || '' }}"
|
||||
run: |
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
|
||||
--maxfail=20 -m "not multiaccelerator ${EXTRA_TAGS}" tests examples
|
||||
--maxfail=20 -m "not multiaccelerator" tests examples
|
||||
# Run multi-accelerator across all chips
|
||||
python -m pytest --tb=short --maxfail=20 \
|
||||
-m "multiaccelerator ${EXTRA_TAGS}" tests
|
||||
python -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
|
||||
- name: Send chat on failure
|
||||
# Don't notify when testing the workflow from a branch.
|
||||
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' }}
|
||||
|
@ -1,6 +1,5 @@
|
||||
[pytest]
|
||||
markers =
|
||||
pjrt_c_api_unimplemented: indicates that a test will fail using the PJRT C API due to unimplemented functionality
|
||||
multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators
|
||||
SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI
|
||||
filterwarnings =
|
||||
|
@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
"""Tests for release_backend_clients."""
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
@ -23,10 +25,12 @@ from jax._src import xla_bridge as xb
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class ClearBackendsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_clear_backends(self):
|
||||
if xb.using_pjrt_c_api():
|
||||
raise unittest.SkipTest('test crashes runtime with PJRT C API')
|
||||
|
||||
g = jax.jit(lambda x, y: x * y)
|
||||
self.assertEqual(g(1, 2), 2)
|
||||
self.assertNotEmpty(xb.get_backend().live_executables())
|
||||
|
@ -54,7 +54,6 @@ def tearDownModule():
|
||||
|
||||
foo = 2
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class CliDebuggerTest(jtu.JaxTestCase):
|
||||
|
||||
def test_debugger_eof(self):
|
||||
|
@ -59,7 +59,6 @@ class DummyDevice:
|
||||
self.platform = platform
|
||||
self.id = id
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class DebugPrintTest(jtu.JaxTestCase):
|
||||
|
||||
def tearDown(self):
|
||||
@ -221,7 +220,6 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
"""))
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class DebugPrintTransformationTest(jtu.JaxTestCase):
|
||||
|
||||
def test_debug_print_batching(self):
|
||||
@ -499,7 +497,6 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(output(), "hello bwd: 2.0 3.0\n")
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
def _assertLinesEqual(self, text1, text2):
|
||||
@ -736,7 +733,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
b3: 2
|
||||
"""))
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
|
||||
def _assertLinesEqual(self, text1, text2):
|
||||
@ -965,7 +961,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
f(jnp.arange(2))
|
||||
jax.effects_barrier()
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class VisualizeShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def _create_devices(self, shape):
|
||||
@ -1148,13 +1143,14 @@ class VisualizeShardingTest(jtu.JaxTestCase):
|
||||
""")
|
||||
self.assertEqual(output(), expected)
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class InspectShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_inspect_sharding_is_called_in_pjit(self):
|
||||
|
||||
if jtu.is_cloud_tpu():
|
||||
raise unittest.SkipTest("Inspect sharding is not supported on libtpu.")
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise unittest.SkipTest("Inspect sharding is not supported on Cloud TPU")
|
||||
|
||||
is_called = False
|
||||
def _cb(sd):
|
||||
|
@ -232,7 +232,6 @@ def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase,
|
||||
return assertMultiLineStrippedEqual(tst, expected, what)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -2022,7 +2021,6 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
"""Tests for hcb.call"""
|
||||
|
||||
@ -2453,7 +2451,6 @@ def call_jax_other_device(jax_outside_fun, arg, *, device):
|
||||
return make_call(arg)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class CallJaxTest(jtu.JaxTestCase):
|
||||
"""Tests using `call_jax_other_device`."""
|
||||
|
||||
@ -2530,7 +2527,6 @@ class CallJaxTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(res_jax, res_outside)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -158,7 +158,6 @@ CALL_TF_IMPLEMENTATIONS = {
|
||||
}
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # crashes runtime
|
||||
class CallToTFTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -30,7 +30,6 @@ import numpy as np
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # infeed
|
||||
class InfeedTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -590,7 +590,6 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens[foo_effect][0])
|
||||
self.assertIs(foo2_token, dispatch.runtime_tokens.tokens[foo2_effect][0])
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class EffectOrderingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_can_execute_python_callback(self):
|
||||
@ -704,7 +703,6 @@ class ParallelEffectsTest(jtu.JaxTestCase):
|
||||
return x
|
||||
jax.pmap(f)(jnp.arange(jax.local_device_count()))
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
def test_can_pmap_unordered_callback(self):
|
||||
# TODO(sharadmv): enable this test on GPU and TPU when backends are
|
||||
# supported
|
||||
|
@ -800,7 +800,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
self.assertAllClose(res0, res, check_dtypes=True)
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # outfeed
|
||||
def testOutfeed(self):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise unittest.SkipTest('outfeed not implemented in PJRT C API')
|
||||
@ -1109,12 +1108,13 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
"valid for values of rank at least 4, but was applied to a value of rank 1"):
|
||||
pjit_f(jnp.array([1, 2, 3]))
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # custom partitoner
|
||||
@jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU.
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_custom_partitioner(self):
|
||||
if jtu.is_cloud_tpu():
|
||||
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise unittest.SkipTest('custom partitioning not implemented in PJRT C API')
|
||||
|
||||
def partition(
|
||||
precision, arg_shapes, arg_shardings, result_shape, result_sharding
|
||||
|
@ -125,7 +125,6 @@ mlir.register_lowering(callback_p, callback_lowering, platform="gpu")
|
||||
mlir.register_lowering(callback_p, callback_lowering, platform="tpu")
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class PythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def tearDown(self):
|
||||
@ -455,7 +454,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
out,
|
||||
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def tearDown(self):
|
||||
@ -877,7 +875,6 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
x = np.arange(6, dtype=np.int32).reshape((3, 2))
|
||||
np.testing.assert_allclose(g(x), x)
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class IOPythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user