From 080804c78dcf9695396c298cd3760ea8bda778ee Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 18 Mar 2025 14:50:49 -0700 Subject: [PATCH] Fix logging_test fails on Linux with NVIDIA Driver only. Some GPU tests in //tests/logging_test fail on Linux with NVIDIA driver only when we use hermetic CUDA (CUDA isn't installed on Linux). Reason: method tsl::Env::Default()->GetExecutablePath()` doesn't work properly with command flag (-c). As result subprocessor couldn't get path to logging_test.py file and convert it to path of runtime where CUDA hermetic libraries are placed. Solution: Save python program to file in runtime directory then run script from the file. PiperOrigin-RevId: 738152663 --- tests/logging_test.py | 148 +++++++++++++++++------------------------- 1 file changed, 61 insertions(+), 87 deletions(-) diff --git a/tests/logging_test.py b/tests/logging_test.py index a83058095..cfe10c5a9 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -15,9 +15,9 @@ import contextlib import io import logging +import os import platform import re -import shlex import subprocess import sys import tempfile @@ -78,6 +78,31 @@ def capture_jax_logs(): logger.removeHandler(handler) +# Saves and runs script from the file in order to fix the problem with +# `tsl::Env::Default()->GetExecutablePath()` not working properly with +# command flag. +def _run(program, env_var = {}): + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + with tempfile.NamedTemporaryFile( + mode="w+", encoding="utf-8", suffix=".py", dir=os.getcwd() + ) as f: + f.write(textwrap.dedent(program)) + f.flush() + python = sys.executable + assert "python" in python + if env_var: + env_var.update(os.environ) + else: + env_var = os.environ + + # Make sure C++ logging is at default level for the test process. + p = subprocess.run([python, f.name], env=env_var, capture_output=True, text=True) + + return type("", (object,), { "stdout": p.stdout, "stderr": p.stderr }) + + class LoggingTest(jtu.JaxTestCase): @unittest.skipIf(platform.system() == "Windows", @@ -90,36 +115,25 @@ class LoggingTest(jtu.JaxTestCase): if sys.executable is None: raise self.skipTest("test requires access to python binary") - # Save script in file to fix the problem with - # `tsl::Env::Default()->GetExecutablePath()` not working properly with - # command flag. - with tempfile.NamedTemporaryFile( - mode="w+", encoding="utf-8", suffix=".py" - ) as f: - f.write(textwrap.dedent(""" + o = _run(""" import jax jax.device_count() f = jax.jit(lambda x: x + 1) f(1) f(2) jax.numpy.add(1, 1) - """)) - python = sys.executable - assert "python" in python - # Make sure C++ logging is at default level for the test process. - proc = subprocess.run([python, f.name], capture_output=True) + """) - lines = proc.stdout.split(b"\n") - lines.extend(proc.stderr.split(b"\n")) - allowlist = [ - b"", - ( - b"An NVIDIA GPU may be present on this machine, but a" - b" CUDA-enabled jaxlib is not installed. Falling back to cpu." - ), - ] - lines = [l for l in lines if l not in allowlist] - self.assertEmpty(lines) + lines = o.stdout.split("\n") + lines.extend(o.stderr.split("\n")) + allowlist = [ + ( + "An NVIDIA GPU may be present on this machine, but a" + " CUDA-enabled jaxlib is not installed. Falling back to cpu." + ), + ] + lines = [l for l in lines if l in allowlist] + self.assertEmpty(lines) def test_debug_logging(self): # Warmup so we don't get "No GPU/TPU" warning later. @@ -164,19 +178,12 @@ class LoggingTest(jtu.JaxTestCase): if sys.executable is None: raise self.skipTest("test requires access to python binary") - program = """ - import jax # this prints INFO logging from backend imports - jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) - """ + o = _run(""" + import jax # this prints INFO logging from backend imports + jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) + """, { "JAX_LOGGING_LEVEL": "INFO" }) - # strip the leading whitespace from the program script - program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) - - # test INFO - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - log_output = p.stderr + log_output = o.stderr info_lines = log_output.split("\n") self.assertGreater(len(info_lines), 0) self.assertIn("INFO", log_output) @@ -194,22 +201,14 @@ class LoggingTest(jtu.JaxTestCase): jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) """ - # strip the leading whitespace from the program script - program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" }) - # test DEBUG - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - log_output = p.stderr + log_output = o.stderr self.assertIn("INFO", log_output) self.assertIn("DEBUG", log_output) - # test JAX_DEBUG_MODULES - cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - log_output = p.stderr + o = _run(program, { "JAX_DEBUG_LOG_MODULES": "jax" }) + log_output = o.stderr self.assertIn("DEBUG", log_output) @jtu.skip_on_devices("tpu") @@ -220,22 +219,15 @@ class LoggingTest(jtu.JaxTestCase): raise self.skipTest("test requires access to python binary") _separator = "---------------------------" - program = f""" + o = _run(f""" import sys import jax # this prints INFO logging from backend imports jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) jax.config.update("jax_logging_level", None) sys.stderr.write("{_separator}") jax.jit(lambda x: x)(1) # should not log anything now - """ - - # strip the leading whitespace from the program script - program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) - - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - log_output = p.stderr + """, {"JAX_LOGGING_LEVEL": "DEBUG"}) + log_output = o.stderr m = re.search(_separator, log_output) self.assertTrue(m is not None) log_output_verbose = log_output[:m.start()] @@ -252,19 +244,13 @@ class LoggingTest(jtu.JaxTestCase): if sys.executable is None: raise self.skipTest("test requires access to python binary") - program = """ + o = _run(""" import jax # this prints INFO logging from backend imports jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch") jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) - """ + """, { "JAX_LOGGING_LEVEL": "DEBUG" }) - # strip the leading whitespace from the program script - program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) - - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - log_output = p.stderr + log_output = o.stderr self.assertNotEmpty(log_output) log_lines = log_output.strip().split("\n") # only one tracing line should be printed, if there's more than one @@ -285,31 +271,19 @@ class LoggingTest(jtu.JaxTestCase): jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0) """ - # strip the leading whitespace from the program script - program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" }) + self.assertIn("Initializing CoordinationService", o.stderr) - # verbose logging: DEBUG, VERBOSE - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - self.assertIn("Initializing CoordinationService", p.stderr) - - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - self.assertIn("Initializing CoordinationService", p.stderr) + o = _run(program, { "JAX_LOGGING_LEVEL": "INFO" }) + self.assertIn("Initializing CoordinationService", o.stderr) # verbose logging: WARNING, None - cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) - self.assertNotIn("Initializing CoordinationService", p.stderr) + o = _run(program, { "JAX_LOGGING_LEVEL": "WARNING" }) + self.assertNotIn("Initializing CoordinationService", o.stderr) - cmd = shlex.split(f"{sys.executable} -c" - f" '{program}'") - p = subprocess.run(cmd, capture_output=True, text=True) + o = _run(program) if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1: - self.assertNotIn("Initializing CoordinationService", p.stderr) + self.assertNotIn("Initializing CoordinationService", o.stderr) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())