Remove test for CUPTI multi-subscriber error message that needed cupti-python and a subprocess.

This commit is contained in:
Olli Lupton 2025-02-06 08:30:34 +00:00
parent 1bba1ea2e2
commit d8f811e790
4 changed files with 1 additions and 49 deletions

View File

@ -20,4 +20,3 @@ matplotlib~=3.8.4; python_version=="3.10"
matplotlib; python_version>="3.11"
opt-einsum
auditwheel
cupti-python

View File

@ -74,7 +74,6 @@ _py_deps = {
"absl/flags": ["@pypi_absl_py//:pkg"],
"cloudpickle": ["@pypi_cloudpickle//:pkg"],
"colorama": ["@pypi_colorama//:pkg"],
"cupti-python": ["@pypi_cupti_python//:pkg"],
"epath": ["@pypi_etils//:pkg"], # etils.epath
"filelock": ["@pypi_filelock//:pkg"],
"flatbuffers": ["@pypi_flatbuffers//:pkg"],

View File

@ -325,7 +325,7 @@ jax_multiplatform_test(
],
deps = [
"//jax:experimental",
] + py_deps("cupti-python"),
],
)
jax_multiplatform_test(

View File

@ -18,17 +18,10 @@ import logging
import math
import os
import shutil
import subprocess
import sys
import tempfile
import textwrap
import warnings
from absl.testing import absltest
try:
from cupti import cupti
except ImportError:
cupti = None
import jax
from jax._src import api
from jax._src import compilation_cache as cc
@ -474,44 +467,5 @@ class PgleTest(jtu.JaxTestCase):
self.assertLen(w, 1)
self.assertIn("PERSISTENT CACHE WRITE with key jit_h-", str(w[0].message))
def testAutoPgleClashWithOtherCuptiTools(self):
if cupti is None:
self.skipTest("Multiple CUPTI subscriber test requires cupti-python")
# XLA does not recover from CUPTI errors, which this test intentionally
# triggers, and we cannot robustly require that this test case runs last
program = r"""
from cupti import cupti
import jax
import jax.numpy as jnp
from jax._src import config
def callback(userdata, domain, callback_id, callback_data):
pass
userdata = dict()
# Would throw if we tried to run the test with a CUPTI subscriber already active
subscriber = cupti.subscribe(callback, userdata)
x = jnp.arange(16)
@jax.jit
def f(x):
return x + 42
with config.enable_pgle(True), config.pgle_profiling_runs(1):
f(x)
"""
p = subprocess.run(
[sys.executable, "-c", textwrap.dedent(program)],
capture_output=True,
text=True
)
self.assertIn(
"check for CUPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED from XLA",
p.stderr
)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())