1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 11:06:06 +00:00

Fix logging_test fails on Linux with NVIDIA Driver only.

Some GPU tests in //tests/logging_test fail on Linux with NVIDIA driver only when we use hermetic CUDA (CUDA isn't installed on Linux).

Reason: method tsl::Env::Default()->GetExecutablePath()` doesn't work properly with command flag (-c). As result subprocessor couldn't get path to logging_test.py file and convert it to path of runtime where CUDA hermetic libraries are placed.

Solution: Save python program to file in runtime directory then run script from the file.
PiperOrigin-RevId: 738152663
This commit is contained in:
jax authors 2025-03-18 14:50:49 -07:00
parent 54691b125a
commit 080804c78d

@ -15,9 +15,9 @@
import contextlib import contextlib
import io import io
import logging import logging
import os
import platform import platform
import re import re
import shlex
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
@ -78,6 +78,31 @@ def capture_jax_logs():
logger.removeHandler(handler) logger.removeHandler(handler)
# Saves and runs script from the file in order to fix the problem with
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
# command flag.
def _run(program, env_var = {}):
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
with tempfile.NamedTemporaryFile(
mode="w+", encoding="utf-8", suffix=".py", dir=os.getcwd()
) as f:
f.write(textwrap.dedent(program))
f.flush()
python = sys.executable
assert "python" in python
if env_var:
env_var.update(os.environ)
else:
env_var = os.environ
# Make sure C++ logging is at default level for the test process.
p = subprocess.run([python, f.name], env=env_var, capture_output=True, text=True)
return type("", (object,), { "stdout": p.stdout, "stderr": p.stderr })
class LoggingTest(jtu.JaxTestCase): class LoggingTest(jtu.JaxTestCase):
@unittest.skipIf(platform.system() == "Windows", @unittest.skipIf(platform.system() == "Windows",
@ -90,35 +115,24 @@ class LoggingTest(jtu.JaxTestCase):
if sys.executable is None: if sys.executable is None:
raise self.skipTest("test requires access to python binary") raise self.skipTest("test requires access to python binary")
# Save script in file to fix the problem with o = _run("""
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
# command flag.
with tempfile.NamedTemporaryFile(
mode="w+", encoding="utf-8", suffix=".py"
) as f:
f.write(textwrap.dedent("""
import jax import jax
jax.device_count() jax.device_count()
f = jax.jit(lambda x: x + 1) f = jax.jit(lambda x: x + 1)
f(1) f(1)
f(2) f(2)
jax.numpy.add(1, 1) jax.numpy.add(1, 1)
""")) """)
python = sys.executable
assert "python" in python
# Make sure C++ logging is at default level for the test process.
proc = subprocess.run([python, f.name], capture_output=True)
lines = proc.stdout.split(b"\n") lines = o.stdout.split("\n")
lines.extend(proc.stderr.split(b"\n")) lines.extend(o.stderr.split("\n"))
allowlist = [ allowlist = [
b"",
( (
b"An NVIDIA GPU may be present on this machine, but a" "An NVIDIA GPU may be present on this machine, but a"
b" CUDA-enabled jaxlib is not installed. Falling back to cpu." " CUDA-enabled jaxlib is not installed. Falling back to cpu."
), ),
] ]
lines = [l for l in lines if l not in allowlist] lines = [l for l in lines if l in allowlist]
self.assertEmpty(lines) self.assertEmpty(lines)
def test_debug_logging(self): def test_debug_logging(self):
@ -164,19 +178,12 @@ class LoggingTest(jtu.JaxTestCase):
if sys.executable is None: if sys.executable is None:
raise self.skipTest("test requires access to python binary") raise self.skipTest("test requires access to python binary")
program = """ o = _run("""
import jax # this prints INFO logging from backend imports import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
""" """, { "JAX_LOGGING_LEVEL": "INFO" })
# strip the leading whitespace from the program script log_output = o.stderr
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
# test INFO
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
info_lines = log_output.split("\n") info_lines = log_output.split("\n")
self.assertGreater(len(info_lines), 0) self.assertGreater(len(info_lines), 0)
self.assertIn("INFO", log_output) self.assertIn("INFO", log_output)
@ -194,22 +201,14 @@ class LoggingTest(jtu.JaxTestCase):
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
""" """
# strip the leading whitespace from the program script o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
# test DEBUG log_output = o.stderr
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
self.assertIn("INFO", log_output) self.assertIn("INFO", log_output)
self.assertIn("DEBUG", log_output) self.assertIn("DEBUG", log_output)
# test JAX_DEBUG_MODULES o = _run(program, { "JAX_DEBUG_LOG_MODULES": "jax" })
cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c" log_output = o.stderr
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
self.assertIn("DEBUG", log_output) self.assertIn("DEBUG", log_output)
@jtu.skip_on_devices("tpu") @jtu.skip_on_devices("tpu")
@ -220,22 +219,15 @@ class LoggingTest(jtu.JaxTestCase):
raise self.skipTest("test requires access to python binary") raise self.skipTest("test requires access to python binary")
_separator = "---------------------------" _separator = "---------------------------"
program = f""" o = _run(f"""
import sys import sys
import jax # this prints INFO logging from backend imports import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
jax.config.update("jax_logging_level", None) jax.config.update("jax_logging_level", None)
sys.stderr.write("{_separator}") sys.stderr.write("{_separator}")
jax.jit(lambda x: x)(1) # should not log anything now jax.jit(lambda x: x)(1) # should not log anything now
""" """, {"JAX_LOGGING_LEVEL": "DEBUG"})
log_output = o.stderr
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
m = re.search(_separator, log_output) m = re.search(_separator, log_output)
self.assertTrue(m is not None) self.assertTrue(m is not None)
log_output_verbose = log_output[:m.start()] log_output_verbose = log_output[:m.start()]
@ -252,19 +244,13 @@ class LoggingTest(jtu.JaxTestCase):
if sys.executable is None: if sys.executable is None:
raise self.skipTest("test requires access to python binary") raise self.skipTest("test requires access to python binary")
program = """ o = _run("""
import jax # this prints INFO logging from backend imports import jax # this prints INFO logging from backend imports
jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch") jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch")
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
""" """, { "JAX_LOGGING_LEVEL": "DEBUG" })
# strip the leading whitespace from the program script log_output = o.stderr
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
self.assertNotEmpty(log_output) self.assertNotEmpty(log_output)
log_lines = log_output.strip().split("\n") log_lines = log_output.strip().split("\n")
# only one tracing line should be printed, if there's more than one # only one tracing line should be printed, if there's more than one
@ -285,31 +271,19 @@ class LoggingTest(jtu.JaxTestCase):
jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0) jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0)
""" """
# strip the leading whitespace from the program script o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) self.assertIn("Initializing CoordinationService", o.stderr)
# verbose logging: DEBUG, VERBOSE o = _run(program, { "JAX_LOGGING_LEVEL": "INFO" })
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" self.assertIn("Initializing CoordinationService", o.stderr)
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertIn("Initializing CoordinationService", p.stderr)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertIn("Initializing CoordinationService", p.stderr)
# verbose logging: WARNING, None # verbose logging: WARNING, None
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c" o = _run(program, { "JAX_LOGGING_LEVEL": "WARNING" })
f" '{program}'") self.assertNotIn("Initializing CoordinationService", o.stderr)
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertNotIn("Initializing CoordinationService", p.stderr)
cmd = shlex.split(f"{sys.executable} -c" o = _run(program)
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1: if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1:
self.assertNotIn("Initializing CoordinationService", p.stderr) self.assertNotIn("Initializing CoordinationService", o.stderr)
if __name__ == "__main__": if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader()) absltest.main(testLoader=jtu.JaxTestLoader())