Add unit test for catching log spam.

This commit is contained in:
Skye Wanderman-Milne 2023-12-14 15:21:07 -08:00
parent 64799a431a
commit e5f7598166

View File

@ -15,11 +15,15 @@
import contextlib
import io
import logging
import subprocess
import sys
import textwrap
import jax
from jax import config
import jax._src.test_util as jtu
from jax._src import xla_bridge
# Note: importing absltest causes an extra absl root log handler to be
# registered, which causes extra debug log messages. We don't expect users to
@ -46,6 +50,36 @@ def capture_jax_logs():
class LoggingTest(jtu.JaxTestCase):
def test_no_log_spam(self):
if jtu.is_cloud_tpu() and xla_bridge._backends:
raise self.skipTest(
"test requires fresh process on Cloud TPU because only one process "
"can use the TPU at a time")
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
program = textwrap.dedent("""
import jax
jax.device_count()
f = jax.jit(lambda x: x + 1)
f(1)
f(2)
jax.numpy.add(1, 1)
""")
python = sys.executable
assert "python" in python
proc = subprocess.run([python, "-c", program], capture_output=True)
lines = proc.stdout.split(b"\n")
lines.extend(proc.stderr.split(b"\n"))
allowlist = [
b"",
b"An NVIDIA GPU may be present on this machine, but a CUDA-enabled "
b"jaxlib is not installed. Falling back to cpu.",
]
lines = [l for l in lines if l not in allowlist]
self.assertEmpty(lines)
def test_debug_logging(self):
# Warmup so we don't get "No GPU/TPU" warning later.
jax.jit(lambda x: x + 1)(1)