Disable profiler test for older plugins.

PiperOrigin-RevId: 581391435
This commit is contained in:
Jieying Luo 2023-11-10 15:50:07 -08:00 committed by jax authors
parent ed6fbd0166
commit 11236dbe34
2 changed files with 26 additions and 0 deletions

View File

@ -18,6 +18,7 @@ import logging
import math
import os
import tempfile
import unittest
from absl.testing import absltest
import jax
@ -36,6 +37,11 @@ config.parse_flags_with_absl()
class PgleTest(jtu.JaxTestCase):
def testPassingFDOProfile(self):
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
'Profiler is not supported on PJRT C API version < 0.34.'
)
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(
jax.jit,

View File

@ -84,6 +84,11 @@ class ProfilerTest(unittest.TestCase):
jax.profiler.stop_server()
def testProgrammaticProfiling(self):
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
"Profiler is not supported on PJRT C API version < 0.34."
)
with tempfile.TemporaryDirectory() as tmpdir:
try:
jax.profiler.start_trace(tmpdir)
@ -107,6 +112,11 @@ class ProfilerTest(unittest.TestCase):
def testProfilerGetFDOProfile(self):
if xla_extension_version < 206:
raise unittest.SkipTest("API version < 206")
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
"Profiler is not supported on PJRT C API version < 0.34."
)
# Tests stop_and_get_fod_profile could run.
try:
jax.profiler.start_trace("test")
@ -119,6 +129,11 @@ class ProfilerTest(unittest.TestCase):
self.assertIn(b"copy", fdo_profile)
def testProgrammaticProfilingErrors(self):
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
"Profiler is not supported on PJRT C API version < 0.34."
)
with self.assertRaisesRegex(RuntimeError, "No profile started"):
jax.profiler.stop_trace()
@ -134,6 +149,11 @@ class ProfilerTest(unittest.TestCase):
jax.profiler.stop_trace()
def testProgrammaticProfilingContextManager(self):
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
"Profiler is not supported on PJRT C API version < 0.34."
)
with tempfile.TemporaryDirectory() as tmpdir:
with jax.profiler.trace(tmpdir):
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(