mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 01:56:06 +00:00

This can be used to enable debug logging for specific files (e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`) or all jax (`JAX_DEBUG_LOG_MODULES="jax"`). Example output: ``` $ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)" DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O. DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu' DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu' DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})). DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]] DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec ```
55 lines
2.0 KiB
Python
55 lines
2.0 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
import sys
|
|
|
|
_debug_handler = logging.StreamHandler(sys.stderr)
|
|
_debug_handler.setLevel(logging.DEBUG)
|
|
# Example log message:
|
|
# DEBUG:2023-06-07 00:14:40,280:jax._src.xla_bridge:590: Initializing backend 'cpu'
|
|
_debug_handler.setFormatter(logging.Formatter(
|
|
"{levelname}:{asctime}:{name}:{lineno}: {message}", style='{'))
|
|
|
|
_debug_enabled_loggers = []
|
|
|
|
|
|
def enable_debug_logging(logger_name):
|
|
"""Makes the specified logger log everything to stderr.
|
|
|
|
Also adds more useful debug information to the log messages, e.g. the time.
|
|
|
|
Args:
|
|
logger_name: the name of the logger, e.g. "jax._src.xla_bridge".
|
|
"""
|
|
logger = logging.getLogger(logger_name)
|
|
logger.addHandler(_debug_handler)
|
|
logger.setLevel(logging.DEBUG)
|
|
_debug_enabled_loggers.append(logger)
|
|
|
|
|
|
def disable_all_debug_logging():
|
|
"""Disables all debug logging enabled via `enable_debug_logging`.
|
|
|
|
The default logging behavior will still be in effect, i.e. WARNING and above
|
|
will be logged to stderr without extra message formatting.
|
|
"""
|
|
for logger in _debug_enabled_loggers:
|
|
logger.removeHandler(_debug_handler)
|
|
# Assume that the default non-debug log level is always WARNING. In theory
|
|
# we could keep track of what it was set to before. This shouldn't make a
|
|
# difference if not other handlers are attached, but set it back in case
|
|
# something else gets attached (e.g. absl logger) and for consistency.
|
|
logger.setLevel(logging.WARNING)
|