From 11236dbe34d0c55f90f7e5353d0524301dc04dca Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Fri, 10 Nov 2023 15:50:07 -0800 Subject: [PATCH] Disable profiler test for older plugins. PiperOrigin-RevId: 581391435 --- tests/pgle_test.py | 6 ++++++ tests/profiler_test.py | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 466da7f27..e0adf8963 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -18,6 +18,7 @@ import logging import math import os import tempfile +import unittest from absl.testing import absltest import jax @@ -36,6 +37,11 @@ config.parse_flags_with_absl() class PgleTest(jtu.JaxTestCase): def testPassingFDOProfile(self): + # TODO(jieying): remove after 01/10/2023. + if not jtu.pjrt_c_api_version_at_least(0, 34): + raise unittest.SkipTest( + 'Profiler is not supported on PJRT C API version < 0.34.' + ) mesh = jtu.create_global_mesh((2,), ('x',)) @partial( jax.jit, diff --git a/tests/profiler_test.py b/tests/profiler_test.py index e5f6e8a50..49b1fce88 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -84,6 +84,11 @@ class ProfilerTest(unittest.TestCase): jax.profiler.stop_server() def testProgrammaticProfiling(self): + # TODO(jieying): remove after 01/10/2023. + if not jtu.pjrt_c_api_version_at_least(0, 34): + raise unittest.SkipTest( + "Profiler is not supported on PJRT C API version < 0.34." + ) with tempfile.TemporaryDirectory() as tmpdir: try: jax.profiler.start_trace(tmpdir) @@ -107,6 +112,11 @@ class ProfilerTest(unittest.TestCase): def testProfilerGetFDOProfile(self): if xla_extension_version < 206: raise unittest.SkipTest("API version < 206") + # TODO(jieying): remove after 01/10/2023. + if not jtu.pjrt_c_api_version_at_least(0, 34): + raise unittest.SkipTest( + "Profiler is not supported on PJRT C API version < 0.34." + ) # Tests stop_and_get_fod_profile could run. try: jax.profiler.start_trace("test") @@ -119,6 +129,11 @@ class ProfilerTest(unittest.TestCase): self.assertIn(b"copy", fdo_profile) def testProgrammaticProfilingErrors(self): + # TODO(jieying): remove after 01/10/2023. + if not jtu.pjrt_c_api_version_at_least(0, 34): + raise unittest.SkipTest( + "Profiler is not supported on PJRT C API version < 0.34." + ) with self.assertRaisesRegex(RuntimeError, "No profile started"): jax.profiler.stop_trace() @@ -134,6 +149,11 @@ class ProfilerTest(unittest.TestCase): jax.profiler.stop_trace() def testProgrammaticProfilingContextManager(self): + # TODO(jieying): remove after 01/10/2023. + if not jtu.pjrt_c_api_version_at_least(0, 34): + raise unittest.SkipTest( + "Profiler is not supported on PJRT C API version < 0.34." + ) with tempfile.TemporaryDirectory() as tmpdir: with jax.profiler.trace(tmpdir): jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(