diff --git a/jax/BUILD b/jax/BUILD index d09d66c8c..e1e042418 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -293,10 +293,16 @@ pytype_strict_library( name = "config", srcs = ["_src/config.py"], deps = [ + ":logging_config", "//jax/_src/lib", ], ) +pytype_strict_library( + name = "logging_config", + srcs = ["_src/logging_config.py"], +) + pytype_strict_library( name = "core", srcs = [ diff --git a/jax/__init__.py b/jax/__init__.py index 1ada9e4a9..5e289593d 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Set default logging level before any logging happens. +# Set default C++ logging level before any logging happens. import os as _os _os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') del _os diff --git a/jax/_src/config.py b/jax/_src/config.py index fa42727e5..a578084ce 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -28,6 +28,7 @@ from jax._src import lib from jax._src.lib import jax_jit from jax._src.lib import transfer_guard_lib from jax._src.lib import xla_client +from jax._src import logging_config logger = logging.getLogger(__name__) @@ -1205,3 +1206,21 @@ def transfer_guard(new_val: str) -> Iterator[None]: stack.enter_context(transfer_guard_device_to_host(new_val)) stack.enter_context(_transfer_guard(new_val)) yield + + +def _update_debug_log_modules(module_names_str: Optional[str]): + logging_config.disable_all_debug_logging() + if not module_names_str: + return + module_names = module_names_str.split(',') + for module_name in module_names: + logging_config.enable_debug_logging(module_name) + +# Don't define a context manager since this isn't threadsafe. +config.define_string_state( + name='jax_debug_log_modules', + default='', + help=('Comma-separated list of module names (e.g. "jax" or ' + '"jax._src.xla_bridge,jax._src.dispatch") to enable debug logging ' + 'for.'), + update_global_hook=_update_debug_log_modules) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index c6564c157..00692a869 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -268,7 +268,7 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None): yield elapsed_time = time.time() - start_time if logger.isEnabledFor(log_priority): - logger.log(logging.WARNING, fmt.format( + logger.log(log_priority, fmt.format( fun_name=fun_name, elapsed_time=elapsed_time)) if event is not None: record_event_duration_secs(event, elapsed_time) diff --git a/jax/_src/logging_config.py b/jax/_src/logging_config.py new file mode 100644 index 000000000..d2f9d9c8f --- /dev/null +++ b/jax/_src/logging_config.py @@ -0,0 +1,54 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys + +_debug_handler = logging.StreamHandler(sys.stderr) +_debug_handler.setLevel(logging.DEBUG) +# Example log message: +# DEBUG:2023-06-07 00:14:40,280:jax._src.xla_bridge:590: Initializing backend 'cpu' +_debug_handler.setFormatter(logging.Formatter( + "{levelname}:{asctime}:{name}:{lineno}: {message}", style='{')) + +_debug_enabled_loggers = [] + + +def enable_debug_logging(logger_name): + """Makes the specified logger log everything to stderr. + + Also adds more useful debug information to the log messages, e.g. the time. + + Args: + logger_name: the name of the logger, e.g. "jax._src.xla_bridge". + """ + logger = logging.getLogger(logger_name) + logger.addHandler(_debug_handler) + logger.setLevel(logging.DEBUG) + _debug_enabled_loggers.append(logger) + + +def disable_all_debug_logging(): + """Disables all debug logging enabled via `enable_debug_logging`. + + The default logging behavior will still be in effect, i.e. WARNING and above + will be logged to stderr without extra message formatting. + """ + for logger in _debug_enabled_loggers: + logger.removeHandler(_debug_handler) + # Assume that the default non-debug log level is always WARNING. In theory + # we could keep track of what it was set to before. This shouldn't make a + # difference if not other handlers are attached, but set it back in case + # something else gets attached (e.g. absl logger) and for consistency. + logger.setLevel(logging.WARNING) diff --git a/tests/BUILD b/tests/BUILD index 717e45479..727b8e9e1 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1144,6 +1144,15 @@ py_test( ], ) +py_test( + name = "logging_test", + srcs = ["logging_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + exports_files( [ "api_test.py", diff --git a/tests/logging_test.py b/tests/logging_test.py new file mode 100644 index 000000000..3991f6ed8 --- /dev/null +++ b/tests/logging_test.py @@ -0,0 +1,87 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import io +import logging + + +import jax +from jax import config +import jax._src.test_util as jtu + +# Note: importing absltest causes an extra absl root log handler to be +# registered, which causes extra debug log messages. We don't expect users to +# import absl logging, so it should only affect this test. We need to use +# absltest.main and config.parse_flags_with_absl() in order for jax_test flag +# parsing to work correctly with bazel (otherwise we could avoid importing +# absltest/absl logging altogether). +from absl.testing import absltest +jax.config.parse_flags_with_absl() + + +@contextlib.contextmanager +def capture_jax_logs(): + log_output = io.StringIO() + handler = logging.StreamHandler(log_output) + logger = logging.getLogger("jax") + + logger.addHandler(handler) + try: + yield log_output + finally: + logger.removeHandler(handler) + + +class LoggingTest(jtu.JaxTestCase): + + def test_debug_logging(self): + # Warmup so we don't get "No GPU/TPU" warning later. + jax.jit(lambda x: x + 1)(1) + + # Nothing logged by default (except warning messages, which we don't expect + # here). + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertEmpty(log_output.getvalue()) + + # Turn on all debug logging. + config.update("jax_debug_log_modules", "jax") + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertIn("Finished tracing + transforming", log_output.getvalue()) + self.assertIn("Compiling ", log_output.getvalue()) + + # Turn off all debug logging. + config.update("jax_debug_log_modules", None) + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertEmpty(log_output.getvalue()) + + # Turn on one module. + config.update("jax_debug_log_modules", "jax._src.dispatch") + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertIn("Finished tracing + transforming", log_output.getvalue()) + self.assertNotIn("Compiling ", log_output.getvalue()) + + # Turn everything off again. + config.update("jax_debug_log_modules", None) + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertEmpty(log_output.getvalue()) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())