mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Fix logging_test fails on Linux with NVIDIA Driver only.
Some GPU tests in //tests/logging_test fail on Linux with NVIDIA driver only when we use hermetic CUDA (CUDA isn't installed on Linux). Reason: method tsl::Env::Default()->GetExecutablePath()` doesn't work properly with command flag (-c). As result subprocessor couldn't get path to logging_test.py file and convert it to path of runtime where CUDA hermetic libraries are placed. Solution: Save python program to file in runtime directory then run script from the file. PiperOrigin-RevId: 738152663
This commit is contained in:
parent
54691b125a
commit
080804c78d
@ -15,9 +15,9 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@ -78,6 +78,31 @@ def capture_jax_logs():
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
# Saves and runs script from the file in order to fix the problem with
|
||||
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
|
||||
# command flag.
|
||||
def _run(program, env_var = {}):
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w+", encoding="utf-8", suffix=".py", dir=os.getcwd()
|
||||
) as f:
|
||||
f.write(textwrap.dedent(program))
|
||||
f.flush()
|
||||
python = sys.executable
|
||||
assert "python" in python
|
||||
if env_var:
|
||||
env_var.update(os.environ)
|
||||
else:
|
||||
env_var = os.environ
|
||||
|
||||
# Make sure C++ logging is at default level for the test process.
|
||||
p = subprocess.run([python, f.name], env=env_var, capture_output=True, text=True)
|
||||
|
||||
return type("", (object,), { "stdout": p.stdout, "stderr": p.stderr })
|
||||
|
||||
|
||||
class LoggingTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(platform.system() == "Windows",
|
||||
@ -90,36 +115,25 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
if sys.executable is None:
|
||||
raise self.skipTest("test requires access to python binary")
|
||||
|
||||
# Save script in file to fix the problem with
|
||||
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
|
||||
# command flag.
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w+", encoding="utf-8", suffix=".py"
|
||||
) as f:
|
||||
f.write(textwrap.dedent("""
|
||||
o = _run("""
|
||||
import jax
|
||||
jax.device_count()
|
||||
f = jax.jit(lambda x: x + 1)
|
||||
f(1)
|
||||
f(2)
|
||||
jax.numpy.add(1, 1)
|
||||
"""))
|
||||
python = sys.executable
|
||||
assert "python" in python
|
||||
# Make sure C++ logging is at default level for the test process.
|
||||
proc = subprocess.run([python, f.name], capture_output=True)
|
||||
""")
|
||||
|
||||
lines = proc.stdout.split(b"\n")
|
||||
lines.extend(proc.stderr.split(b"\n"))
|
||||
allowlist = [
|
||||
b"",
|
||||
(
|
||||
b"An NVIDIA GPU may be present on this machine, but a"
|
||||
b" CUDA-enabled jaxlib is not installed. Falling back to cpu."
|
||||
),
|
||||
]
|
||||
lines = [l for l in lines if l not in allowlist]
|
||||
self.assertEmpty(lines)
|
||||
lines = o.stdout.split("\n")
|
||||
lines.extend(o.stderr.split("\n"))
|
||||
allowlist = [
|
||||
(
|
||||
"An NVIDIA GPU may be present on this machine, but a"
|
||||
" CUDA-enabled jaxlib is not installed. Falling back to cpu."
|
||||
),
|
||||
]
|
||||
lines = [l for l in lines if l in allowlist]
|
||||
self.assertEmpty(lines)
|
||||
|
||||
def test_debug_logging(self):
|
||||
# Warmup so we don't get "No GPU/TPU" warning later.
|
||||
@ -164,19 +178,12 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
if sys.executable is None:
|
||||
raise self.skipTest("test requires access to python binary")
|
||||
|
||||
program = """
|
||||
import jax # this prints INFO logging from backend imports
|
||||
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
||||
"""
|
||||
o = _run("""
|
||||
import jax # this prints INFO logging from backend imports
|
||||
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
||||
""", { "JAX_LOGGING_LEVEL": "INFO" })
|
||||
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
|
||||
# test INFO
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log_output = p.stderr
|
||||
log_output = o.stderr
|
||||
info_lines = log_output.split("\n")
|
||||
self.assertGreater(len(info_lines), 0)
|
||||
self.assertIn("INFO", log_output)
|
||||
@ -194,22 +201,14 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
||||
"""
|
||||
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
|
||||
|
||||
# test DEBUG
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log_output = p.stderr
|
||||
log_output = o.stderr
|
||||
self.assertIn("INFO", log_output)
|
||||
self.assertIn("DEBUG", log_output)
|
||||
|
||||
# test JAX_DEBUG_MODULES
|
||||
cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log_output = p.stderr
|
||||
o = _run(program, { "JAX_DEBUG_LOG_MODULES": "jax" })
|
||||
log_output = o.stderr
|
||||
self.assertIn("DEBUG", log_output)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@ -220,22 +219,15 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
raise self.skipTest("test requires access to python binary")
|
||||
|
||||
_separator = "---------------------------"
|
||||
program = f"""
|
||||
o = _run(f"""
|
||||
import sys
|
||||
import jax # this prints INFO logging from backend imports
|
||||
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
||||
jax.config.update("jax_logging_level", None)
|
||||
sys.stderr.write("{_separator}")
|
||||
jax.jit(lambda x: x)(1) # should not log anything now
|
||||
"""
|
||||
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log_output = p.stderr
|
||||
""", {"JAX_LOGGING_LEVEL": "DEBUG"})
|
||||
log_output = o.stderr
|
||||
m = re.search(_separator, log_output)
|
||||
self.assertTrue(m is not None)
|
||||
log_output_verbose = log_output[:m.start()]
|
||||
@ -252,19 +244,13 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
if sys.executable is None:
|
||||
raise self.skipTest("test requires access to python binary")
|
||||
|
||||
program = """
|
||||
o = _run("""
|
||||
import jax # this prints INFO logging from backend imports
|
||||
jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch")
|
||||
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
||||
"""
|
||||
""", { "JAX_LOGGING_LEVEL": "DEBUG" })
|
||||
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log_output = p.stderr
|
||||
log_output = o.stderr
|
||||
self.assertNotEmpty(log_output)
|
||||
log_lines = log_output.strip().split("\n")
|
||||
# only one tracing line should be printed, if there's more than one
|
||||
@ -285,31 +271,19 @@ class LoggingTest(jtu.JaxTestCase):
|
||||
jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0)
|
||||
"""
|
||||
|
||||
# strip the leading whitespace from the program script
|
||||
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
||||
o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
|
||||
self.assertIn("Initializing CoordinationService", o.stderr)
|
||||
|
||||
# verbose logging: DEBUG, VERBOSE
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
self.assertIn("Initializing CoordinationService", p.stderr)
|
||||
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
self.assertIn("Initializing CoordinationService", p.stderr)
|
||||
o = _run(program, { "JAX_LOGGING_LEVEL": "INFO" })
|
||||
self.assertIn("Initializing CoordinationService", o.stderr)
|
||||
|
||||
# verbose logging: WARNING, None
|
||||
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
self.assertNotIn("Initializing CoordinationService", p.stderr)
|
||||
o = _run(program, { "JAX_LOGGING_LEVEL": "WARNING" })
|
||||
self.assertNotIn("Initializing CoordinationService", o.stderr)
|
||||
|
||||
cmd = shlex.split(f"{sys.executable} -c"
|
||||
f" '{program}'")
|
||||
p = subprocess.run(cmd, capture_output=True, text=True)
|
||||
o = _run(program)
|
||||
if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1:
|
||||
self.assertNotIn("Initializing CoordinationService", p.stderr)
|
||||
self.assertNotIn("Initializing CoordinationService", o.stderr)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user