mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add jax_debug_log_modules
config option.
This can be used to enable debug logging for specific files (e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`) or all jax (`JAX_DEBUG_LOG_MODULES="jax"`). Example output: ``` $ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)" DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O. DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu' DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu' DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})). DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]] DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec ```
This commit is contained in:
parent
640ee1e815
commit
8b58e38ec5
@ -293,10 +293,16 @@ pytype_strict_library(
|
||||
name = "config",
|
||||
srcs = ["_src/config.py"],
|
||||
deps = [
|
||||
":logging_config",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "logging_config",
|
||||
srcs = ["_src/logging_config.py"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "core",
|
||||
srcs = [
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Set default logging level before any logging happens.
|
||||
# Set default C++ logging level before any logging happens.
|
||||
import os as _os
|
||||
_os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
|
||||
del _os
|
||||
|
@ -28,6 +28,7 @@ from jax._src import lib
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import transfer_guard_lib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import logging_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1205,3 +1206,21 @@ def transfer_guard(new_val: str) -> Iterator[None]:
|
||||
stack.enter_context(transfer_guard_device_to_host(new_val))
|
||||
stack.enter_context(_transfer_guard(new_val))
|
||||
yield
|
||||
|
||||
|
||||
def _update_debug_log_modules(module_names_str: Optional[str]):
|
||||
logging_config.disable_all_debug_logging()
|
||||
if not module_names_str:
|
||||
return
|
||||
module_names = module_names_str.split(',')
|
||||
for module_name in module_names:
|
||||
logging_config.enable_debug_logging(module_name)
|
||||
|
||||
# Don't define a context manager since this isn't threadsafe.
|
||||
config.define_string_state(
|
||||
name='jax_debug_log_modules',
|
||||
default='',
|
||||
help=('Comma-separated list of module names (e.g. "jax" or '
|
||||
'"jax._src.xla_bridge,jax._src.dispatch") to enable debug logging '
|
||||
'for.'),
|
||||
update_global_hook=_update_debug_log_modules)
|
||||
|
@ -268,7 +268,7 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
|
||||
yield
|
||||
elapsed_time = time.time() - start_time
|
||||
if logger.isEnabledFor(log_priority):
|
||||
logger.log(logging.WARNING, fmt.format(
|
||||
logger.log(log_priority, fmt.format(
|
||||
fun_name=fun_name, elapsed_time=elapsed_time))
|
||||
if event is not None:
|
||||
record_event_duration_secs(event, elapsed_time)
|
||||
|
54
jax/_src/logging_config.py
Normal file
54
jax/_src/logging_config.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
_debug_handler = logging.StreamHandler(sys.stderr)
|
||||
_debug_handler.setLevel(logging.DEBUG)
|
||||
# Example log message:
|
||||
# DEBUG:2023-06-07 00:14:40,280:jax._src.xla_bridge:590: Initializing backend 'cpu'
|
||||
_debug_handler.setFormatter(logging.Formatter(
|
||||
"{levelname}:{asctime}:{name}:{lineno}: {message}", style='{'))
|
||||
|
||||
_debug_enabled_loggers = []
|
||||
|
||||
|
||||
def enable_debug_logging(logger_name):
|
||||
"""Makes the specified logger log everything to stderr.
|
||||
|
||||
Also adds more useful debug information to the log messages, e.g. the time.
|
||||
|
||||
Args:
|
||||
logger_name: the name of the logger, e.g. "jax._src.xla_bridge".
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.addHandler(_debug_handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
_debug_enabled_loggers.append(logger)
|
||||
|
||||
|
||||
def disable_all_debug_logging():
|
||||
"""Disables all debug logging enabled via `enable_debug_logging`.
|
||||
|
||||
The default logging behavior will still be in effect, i.e. WARNING and above
|
||||
will be logged to stderr without extra message formatting.
|
||||
"""
|
||||
for logger in _debug_enabled_loggers:
|
||||
logger.removeHandler(_debug_handler)
|
||||
# Assume that the default non-debug log level is always WARNING. In theory
|
||||
# we could keep track of what it was set to before. This shouldn't make a
|
||||
# difference if not other handlers are attached, but set it back in case
|
||||
# something else gets attached (e.g. absl logger) and for consistency.
|
||||
logger.setLevel(logging.WARNING)
|
@ -1144,6 +1144,15 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "logging_test",
|
||||
srcs = ["logging_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"api_test.py",
|
||||
|
87
tests/logging_test.py
Normal file
87
tests/logging_test.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
# Note: importing absltest causes an extra absl root log handler to be
|
||||
# registered, which causes extra debug log messages. We don't expect users to
|
||||
# import absl logging, so it should only affect this test. We need to use
|
||||
# absltest.main and config.parse_flags_with_absl() in order for jax_test flag
|
||||
# parsing to work correctly with bazel (otherwise we could avoid importing
|
||||
# absltest/absl logging altogether).
|
||||
from absl.testing import absltest
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_jax_logs():
|
||||
log_output = io.StringIO()
|
||||
handler = logging.StreamHandler(log_output)
|
||||
logger = logging.getLogger("jax")
|
||||
|
||||
logger.addHandler(handler)
|
||||
try:
|
||||
yield log_output
|
||||
finally:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
class LoggingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_debug_logging(self):
|
||||
# Warmup so we don't get "No GPU/TPU" warning later.
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
|
||||
# Nothing logged by default (except warning messages, which we don't expect
|
||||
# here).
|
||||
with capture_jax_logs() as log_output:
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
self.assertEmpty(log_output.getvalue())
|
||||
|
||||
# Turn on all debug logging.
|
||||
config.update("jax_debug_log_modules", "jax")
|
||||
with capture_jax_logs() as log_output:
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
self.assertIn("Finished tracing + transforming", log_output.getvalue())
|
||||
self.assertIn("Compiling <lambda>", log_output.getvalue())
|
||||
|
||||
# Turn off all debug logging.
|
||||
config.update("jax_debug_log_modules", None)
|
||||
with capture_jax_logs() as log_output:
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
self.assertEmpty(log_output.getvalue())
|
||||
|
||||
# Turn on one module.
|
||||
config.update("jax_debug_log_modules", "jax._src.dispatch")
|
||||
with capture_jax_logs() as log_output:
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
self.assertIn("Finished tracing + transforming", log_output.getvalue())
|
||||
self.assertNotIn("Compiling <lambda>", log_output.getvalue())
|
||||
|
||||
# Turn everything off again.
|
||||
config.update("jax_debug_log_modules", None)
|
||||
with capture_jax_logs() as log_output:
|
||||
jax.jit(lambda x: x + 1)(1)
|
||||
self.assertEmpty(log_output.getvalue())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user