mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add unit test for catching log spam.
This commit is contained in:
parent
64799a431a
commit
e5f7598166
@ -15,11 +15,15 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
import jax._src.test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
|
||||
# Note: importing absltest causes an extra absl root log handler to be
|
||||
# registered, which causes extra debug log messages. We don't expect users to
|
||||
@ -46,6 +50,36 @@ def capture_jax_logs():
|
||||
|
||||
class LoggingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_no_log_spam(self):
|
||||
if jtu.is_cloud_tpu() and xla_bridge._backends:
|
||||
raise self.skipTest(
|
||||
"test requires fresh process on Cloud TPU because only one process "
|
||||
"can use the TPU at a time")
|
||||
if sys.executable is None:
|
||||
raise self.skipTest("test requires access to python binary")
|
||||
|
||||
program = textwrap.dedent("""
|
||||
import jax
|
||||
jax.device_count()
|
||||
f = jax.jit(lambda x: x + 1)
|
||||
f(1)
|
||||
f(2)
|
||||
jax.numpy.add(1, 1)
|
||||
""")
|
||||
python = sys.executable
|
||||
assert "python" in python
|
||||
proc = subprocess.run([python, "-c", program], capture_output=True)
|
||||
|
||||
lines = proc.stdout.split(b"\n")
|
||||
lines.extend(proc.stderr.split(b"\n"))
|
||||
allowlist = [
|
||||
b"",
|
||||
b"An NVIDIA GPU may be present on this machine, but a CUDA-enabled "
|
||||
b"jaxlib is not installed. Falling back to cpu.",
|
||||
]
|
||||
lines = [l for l in lines if l not in allowlist]
|
||||
self.assertEmpty(lines)
|
||||
|
||||
def test_debug_logging(self):
|
||||
# Warmup so we don't get "No GPU/TPU" warning later.
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user