mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove test for CUPTI multi-subscriber error message that needed cupti-python and a subprocess.
This commit is contained in:
parent
1bba1ea2e2
commit
d8f811e790
@ -20,4 +20,3 @@ matplotlib~=3.8.4; python_version=="3.10"
|
||||
matplotlib; python_version>="3.11"
|
||||
opt-einsum
|
||||
auditwheel
|
||||
cupti-python
|
||||
|
@ -74,7 +74,6 @@ _py_deps = {
|
||||
"absl/flags": ["@pypi_absl_py//:pkg"],
|
||||
"cloudpickle": ["@pypi_cloudpickle//:pkg"],
|
||||
"colorama": ["@pypi_colorama//:pkg"],
|
||||
"cupti-python": ["@pypi_cupti_python//:pkg"],
|
||||
"epath": ["@pypi_etils//:pkg"], # etils.epath
|
||||
"filelock": ["@pypi_filelock//:pkg"],
|
||||
"flatbuffers": ["@pypi_flatbuffers//:pkg"],
|
||||
|
@ -325,7 +325,7 @@ jax_multiplatform_test(
|
||||
],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
] + py_deps("cupti-python"),
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
|
@ -18,17 +18,10 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
try:
|
||||
from cupti import cupti
|
||||
except ImportError:
|
||||
cupti = None
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax._src import compilation_cache as cc
|
||||
@ -474,44 +467,5 @@ class PgleTest(jtu.JaxTestCase):
|
||||
self.assertLen(w, 1)
|
||||
self.assertIn("PERSISTENT CACHE WRITE with key jit_h-", str(w[0].message))
|
||||
|
||||
def testAutoPgleClashWithOtherCuptiTools(self):
|
||||
if cupti is None:
|
||||
self.skipTest("Multiple CUPTI subscriber test requires cupti-python")
|
||||
|
||||
# XLA does not recover from CUPTI errors, which this test intentionally
|
||||
# triggers, and we cannot robustly require that this test case runs last
|
||||
program = r"""
|
||||
from cupti import cupti
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import config
|
||||
|
||||
def callback(userdata, domain, callback_id, callback_data):
|
||||
pass
|
||||
|
||||
userdata = dict()
|
||||
|
||||
# Would throw if we tried to run the test with a CUPTI subscriber already active
|
||||
subscriber = cupti.subscribe(callback, userdata)
|
||||
|
||||
x = jnp.arange(16)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x + 42
|
||||
|
||||
with config.enable_pgle(True), config.pgle_profiling_runs(1):
|
||||
f(x)
|
||||
"""
|
||||
p = subprocess.run(
|
||||
[sys.executable, "-c", textwrap.dedent(program)],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
self.assertIn(
|
||||
"check for CUPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED from XLA",
|
||||
p.stderr
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user