Migrate JAX internals to builtin Python logging

This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
This commit is contained in:
Nicholas Junge 2022-10-13 17:06:22 +02:00
parent 0313b2241d
commit efd61b73f6
16 changed files with 133 additions and 126 deletions

View File

@ -1,3 +1,4 @@
absl-py
cloudpickle
colorama>=0.4.4
matplotlib

View File

@ -18,7 +18,6 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
FrozenSet)
import types
from absl import logging
import numpy as np
import jax

View File

@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Optional, Type, Sequence, Tuple
from absl import logging
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
logger = logging.getLogger(__name__)
class ClusterEnv:
"""Interface for defining a cluster environment.
@ -48,7 +49,7 @@ class ClusterEnv:
local_device_ids)
env = next((env for env in cls._cluster_types if env.is_env_present()), None)
if env:
logging.vlog(1, 'Initializing distributed JAX environment via %s', env.__name__)
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
if coordinator_address is None:
coordinator_address = env.get_coordinator_address()
if num_processes is None:
@ -64,7 +65,7 @@ class ClusterEnv:
env.get_local_process_id() is not None):
local_device_ids = [env.get_local_process_id()] # type: ignore[list-item]
else:
logging.vlog(1, 'Could not find a known environment for initializing distributed JAX. '
logger.debug('Could not find a known environment for initializing distributed JAX. '
'Known environments: %s', ', '.join(e.__name__ for e in cls._cluster_types))
return (coordinator_address, num_processes, process_id, local_device_ids)
# pytype: enable=bad-return-type

View File

@ -18,19 +18,20 @@
import contextlib
import functools
import itertools
import logging
import os
import sys
import threading
from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
import warnings
from absl import logging
from jax._src import lib
from jax._src.lib import jax_jit
from jax._src.lib import transfer_guard_lib
from jax._src.lib import xla_client
logger = logging.getLogger(__name__)
def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
@ -643,7 +644,7 @@ log_compiles = config.define_bool_state(
name='jax_log_compiles',
default=False,
help=('Log a message each time every time `jit` or `pmap` compiles an XLA '
'computation. Logging is performed with `absl.logging`. When this '
'computation. Logging is performed with `logging`. When this '
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))
@ -674,7 +675,7 @@ distributed_debug = config.define_bool_state(
name='jax_distributed_debug',
default=False,
help=('Enable logging useful for debugging multi-process distributed '
'computations. Logging is performed with `absl.logging` at WARNING '
'computations. Logging is performed with `logging` at WARNING '
'level.'))
@ -772,7 +773,7 @@ def _validate_default_device(val):
# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
# all JAX backends use a single C++ device interface.
if 'Device' in str(type(val)):
logging.info(
logger.info(
'Allowing non-`xla_client.Device` default device: %s, type: %s',
repr(val), type(val))
return

View File

@ -22,14 +22,14 @@ import itertools
import time
from typing import (
Any, Callable, Dict, Optional, Sequence, Set, Tuple, List, Type, Union,
TYPE_CHECKING, Iterator)
Iterator)
from typing_extensions import Protocol
import logging
import os
import re
import threading
import warnings
from absl import logging
import numpy as np
import jax
@ -82,6 +82,8 @@ CompileOptions = xc.CompileOptions
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
logger = logging.getLogger(__name__)
# This flag is set on exit; no logging should be attempted
_on_exit = False
@ -360,7 +362,7 @@ def log_elapsed_time(fmt: str):
start_time = time.time()
yield
elapsed_time = time.time() - start_time
logging.log(log_priority, fmt.format(elapsed_time=elapsed_time))
logger.log(log_priority, fmt.format(elapsed_time=elapsed_time))
def should_tuple_args(num_args: int, platform: str):
@ -476,7 +478,7 @@ def lower_xla_callable(
msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args."
else:
msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}."
logging.log(log_priority, msg)
logger.log(log_priority, msg)
raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr)
@ -1041,7 +1043,7 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
cached_executable = _cache_read(serialized_computation, module_name,
compile_options, backend)
if cached_executable is not None:
logging.info("Persistent compilation cache hit for '%s'", module_name)
logger.info("Persistent compilation cache hit for '%s'", module_name)
return cached_executable
else:
compiled = backend_compile(backend, serialized_computation,

View File

@ -13,16 +13,17 @@
# limitations under the License.
import atexit
import logging
import os
import functools
from typing import Any, Optional, Union, Sequence
from absl import logging
from jax._src.clusters import ClusterEnv
from jax._src.config import config
from jax._src.lib import xla_extension
logger = logging.getLogger(__name__)
class State:
process_id: int = 0
@ -55,7 +56,7 @@ class State:
if local_device_ids:
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
logging.info('JAX distributed initialized with visible devices: %s', visible_devices)
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
config.update("jax_cuda_visible_devices", visible_devices)
config.update("jax_rocm_visible_devices", visible_devices)
@ -64,7 +65,7 @@ class State:
if process_id == 0:
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
logging.info('Starting JAX distributed service on %s', coordinator_address)
logger.info('Starting JAX distributed service on %s', coordinator_address)
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)
@ -75,7 +76,7 @@ class State:
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service,
init_timeout=300)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
logger.info('Connecting to JAX distributed service on %s', coordinator_address)
self.client.connect()
if config.jax_coordination_service:

View File

@ -20,16 +20,13 @@ XLA. There are also a handful of related casting utilities.
"""
from functools import partial, lru_cache
import logging
import os
import platform as py_platform
import threading
from typing import Any, Dict, List, Optional, Union
import warnings
from absl import logging
# Disable "WARNING: Logging before flag parsing goes to stderr." message
logging._warn_preinit_stderr = 0
import jax._src.lib as lib
from jax._src.config import flags, bool_env, int_env
from jax._src import distributed
@ -58,6 +55,8 @@ ShardedBuffer = Any
FLAGS = flags.FLAGS
logger = logging.getLogger(__name__)
# TODO(phawkins): Remove jax_xla_backend.
flags.DEFINE_string(
'jax_xla_backend', '',
@ -126,8 +125,7 @@ def get_compile_options(
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids
if device_assignment is not None:
logging.vlog(
2,
logger.debug(
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
num_replicas, num_partitions, device_assignment)
device_assignment = np.array(device_assignment)
@ -170,10 +168,10 @@ def get_compile_options(
def _make_tpu_driver_client():
if tpu_driver_client is None:
logging.info("Remote TPU is not linked into jax; skipping remote TPU.")
logger.info("Remote TPU is not linked into jax; skipping remote TPU.")
return None
if FLAGS.jax_backend_target is None:
logging.info("No --jax_backend_target was provided; skipping remote TPU.")
logger.info("No --jax_backend_target was provided; skipping remote TPU.")
return None
return tpu_driver_client.TpuBackend.create(worker=FLAGS.jax_backend_target)
@ -354,14 +352,14 @@ def backends():
raise RuntimeError(err_msg)
else:
_backends_errors[platform] = str(err)
logging.info(err_msg)
logger.info(err_msg)
continue
# We don't warn about falling back to CPU on Mac OS, because we don't
# support anything else there at the moment and warning would be pointless.
if (py_platform.system() != "Darwin" and
_default_backend.platform == "cpu" and
FLAGS.jax_platform_name != 'cpu'):
logging.warning('No GPU/TPU found, falling back to CPU. '
logger.warning('No GPU/TPU found, falling back to CPU. '
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
return _backends
@ -371,7 +369,7 @@ def _clear_backends():
global _backends_errors
global _default_backend
logging.info("Clearing JAX backend caches.")
logger.info("Clearing JAX backend caches.")
with _backend_lock:
_backends = {}
_backends_errors = {}
@ -385,7 +383,7 @@ def _init_backend(platform):
if factory is None:
raise RuntimeError(f"Unknown backend '{platform}'")
logging.vlog(1, "Initializing backend '%s'" % platform)
logger.debug("Initializing backend '%s'", platform)
backend = factory()
# TODO(skye): consider raising more descriptive errors directly from backend
# factories instead of returning None.
@ -397,7 +395,7 @@ def _init_backend(platform):
("process_index", backend.process_index()),
("device_count", backend.device_count()),
("local_devices", backend.local_devices()))
logging.vlog(1, "Backend '%s' initialized" % platform)
logger.debug("Backend '%s' initialized", platform)
return backend

View File

@ -18,14 +18,13 @@ import glob
import gzip
import http.server
import json
import logging
import os
import socketserver
import threading
import warnings
from typing import Callable, Optional
from absl import logging
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
@ -34,6 +33,8 @@ from jax._src.lib import xla_client
_profiler_server: Optional[xla_client.profiler.ProfilerServer] = None
logger = logging.getLogger(__name__)
def start_server(port: int):
"""Starts the profiler server on port `port`.
@ -134,7 +135,7 @@ def _write_perfetto_trace_file(log_dir):
raise ValueError(f"Invalid trace folder: {latest_folder}")
trace_json, = trace_jsons
logging.info("Loading trace.json.gz and removing its metadata...")
logger.info("Loading trace.json.gz and removing its metadata...")
# Perfetto doesn't like the `metadata` field in `trace.json` so we remove
# it.
# TODO(sharadmv): speed this up by updating the generated `trace.json`
@ -144,7 +145,7 @@ def _write_perfetto_trace_file(log_dir):
del trace["metadata"]
filename = "perfetto_trace.json.gz"
perfetto_trace = os.path.join(latest_folder, filename)
logging.info("Writing perfetto_trace.json.gz...")
logger.info("Writing perfetto_trace.json.gz...")
with gzip.open(perfetto_trace, "w") as fp:
fp.write(json.dumps(trace).encode("utf-8"))
return perfetto_trace

View File

@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import functools
from functools import partial
import itertools as it
from collections import namedtuple
import logging
import operator
import types
import threading
@ -24,13 +24,14 @@ from typing import (Any, Callable, Dict, Iterable, List, Tuple, Generic,
TypeVar, Set, Iterator, Sequence, Optional)
import weakref
from absl import logging
import numpy as np
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.config import config
logger = logging.getLogger(__name__)
Seq = Sequence
T = TypeVar("T")
@ -510,7 +511,7 @@ def distributed_debug_log(*pairs):
lines.append("DISTRIBUTED_DEBUG logging failed!")
lines.append(f"{e}")
lines.append("DISTRIBUTED_DEBUG_END")
logging.warning("\n".join(lines))
logger.warning("\n".join(lines))
class OrderedSet(Generic[T]):

View File

@ -13,27 +13,29 @@
# limitations under the License.
import hashlib
import logging
import os
import re
import sys
from typing import List, Optional
import jax
from jax.experimental.compilation_cache.gfile_cache import GFileCache
from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib import xla_client
from jax.interpreters import xla
from absl import logging
_cache = None
logger = logging.getLogger(__name__)
def initialize_cache(path):
"""Creates a global cache object. Should only be called once per process.
"""
global _cache
assert _cache == None, f"The cache path has already been initialized to {_cache._path}"
_cache = GFileCache(path)
logging.warning("Initialized persistent compilation cache at %s", path)
logger.warning("Initialized persistent compilation cache at %s", path)
def get_executable(xla_computation, compile_options,
@ -54,20 +56,20 @@ def put_executable(module_name, xla_computation, compile_options,
"""Adds 'executable' to the cache, possibly evicting older entries."""
assert _cache is not None, "initialize_cache must be called before you can call put_executable()"
cache_key = get_cache_key(xla_computation, compile_options, backend)
logging.info('Writing %s to persistent compilation cache with key %s.',
logger.info('Writing %s to persistent compilation cache with key %s.',
module_name, cache_key)
serialized_executable = backend.serialize_executable(executable)
_cache.put(cache_key, serialized_executable)
def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
if logging.vlog_is_on(1):
if logger.isEnabledFor(logging.DEBUG):
# Log the hash of just this entry
fresh_hash_obj = hashlib.sha256()
hashfn(fresh_hash_obj)
logging.vlog(1, "get_cache_key hash of serialized %s: %s", last_serialized,
logger.debug("get_cache_key hash of serialized %s: %s", last_serialized,
fresh_hash_obj.digest().hex())
# Log the cumulative hash
logging.vlog(1, "get_cache_key hash after serializing %s: %s",
logger.debug("get_cache_key hash after serializing %s: %s",
last_serialized, hash_obj.digest().hex())
def get_cache_key(xla_computation, compile_options, backend) -> str:
@ -215,9 +217,9 @@ def _hash_xla_flags(hash_obj):
# (e.g. --xla_force_host_platform_device_count=8) (I think).
for flag in xla_flags:
if flag.split('=')[0] in _xla_flags_to_exclude_from_cache_key:
logging.vlog(1, "Not including XLA flag in cache key: %s", flag)
logger.debug("Not including XLA flag in cache key: %s", flag)
continue
logging.vlog(1, "Including XLA flag in cache key: %s", flag)
logger.debug("Including XLA flag in cache key: %s", flag)
_hash_string(hash_obj, flag)
def _hash_int(hash_obj, int_var):

View File

@ -16,11 +16,11 @@
import abc
import asyncio
import itertools
import logging
from functools import partial
import re
import threading
from typing import Callable, Sequence, Optional, Dict, Any
from absl import logging
import jax
from jax._src import distributed
@ -40,6 +40,8 @@ _REMOVED_VALUE = 'Value removed'
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
_module_unique_count = itertools.count()
logger = logging.getLogger(__name__)
async def create_async_array_from_callback(
global_shape: array.Shape,
@ -385,7 +387,7 @@ class AsyncManager:
def __del__(self):
if self._thread is not None and self._thread.is_alive():
logging.warning('Please add `.wait_until_finished()` in the main thread '
logger.warning('Please add `.wait_until_finished()` in the main thread '
'before your program finishes because there is a '
'possibility of losing errors raised if the '
'this class is deleted before writing is completed.')
@ -393,20 +395,20 @@ class AsyncManager:
def _thread_func(self):
try:
current_process = jax.process_index()
logging.info('Starting commit to storage layer by process: %s',
logger.info('Starting commit to storage layer by process: %s',
current_process)
for future in self._commit_futures:
future.result()
logging.info('Finished committing to storage layer by process: %s',
logger.info('Finished committing to storage layer by process: %s',
current_process)
# All processes will wait at the barrier. When all processes are at the
# barrier, the barrier will be satisfied. If not, then it will timeout.
key_for_barrier = _get_key(self._count)
logging.info('Key used for barrier is %s for process %s',
logger.info('Key used for barrier is %s for process %s',
key_for_barrier, current_process)
self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms)
logging.info('Finished waiting at barrier for process %s',
logger.info('Finished waiting at barrier for process %s',
current_process)
if current_process == 0:
@ -471,7 +473,7 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
final_checkpoint_dir: Final checkpoint directory where the checkpoints
will be moved from `temp_checkpoint_dir`.
"""
logging.info('Waiting for previous serialization to finish.')
logger.info('Waiting for previous serialization to finish.')
self.wait_until_finished()
commit_futures = [[] for _ in range(len(tensorstore_specs))]

View File

@ -497,14 +497,13 @@ Still to do:
import atexit
import functools
import itertools
import logging
import threading
import traceback
from typing import (Any, Callable, Dict, List, Optional, Sequence,
Tuple, cast)
import warnings
from absl import logging
from jax._src import api
from jax import core
from jax.config import config
@ -532,6 +531,8 @@ import numpy as np
FLAGS = config.FLAGS
logger = logging.getLogger(__name__)
def _inline_host_callback() -> bool:
return FLAGS.jax_host_callback_inline
@ -1231,9 +1232,10 @@ def _outside_call_run_callback(
try:
arg = api.tree_unflatten(arg_treedef, arrays)
unpacked_transforms = _unpack_transforms(transforms)
if logging.vlog_is_on(2):
logging.vlog(2,
f"Outside call invoking call_func {callback}, device={device}, transforms={unpacked_transforms}")
logger.debug(
"Outside call invoking call_func %s, device=%s, transforms=%s",
callback, device, unpacked_transforms
)
res = callback(arg, device, unpacked_transforms)
if identity:
return tuple(arrays)
@ -1250,10 +1252,9 @@ def _outside_call_run_callback(
canonical_flat_results = tuple(util.safe_map(xla.canonicalize_dtype, actual_flat_results))
actual_flat_results_aval = _values_to_avals(canonical_flat_results)
if logging.vlog_is_on(2):
logging.vlog(
2,
f"Outside call {callback} result {flat_results_aval}. Sending to infeed for device {device}."
logger.debug(
"Outside call %s result %s. Sending to infeed for device %s.",
callback, flat_results_aval, device,
)
if not all(ea.strip_weak_type() == ra.strip_weak_type()
@ -1273,7 +1274,7 @@ def _outside_call_run_callback(
return canonical_flat_results
except Exception as e:
logging.error("Outside call %s threw exception %s.", callback, e)
logger.error("Outside call %s threw exception %s.", callback, e)
if send_infeed:
# Prepare some results to send in case of error. We are sending something
# with a distinctive shape (int8[12345]), one that is unlikely to be what the device
@ -1285,12 +1286,12 @@ def _outside_call_run_callback(
# TODO: implement a proper error handling for TPU
if device.platform != "tpu":
canonical_flat_results = [xla.canonicalize_dtype(np.arange(12345, dtype=np.int8))]
if logging.vlog_is_on(2):
logging.vlog(2, f"Outside call consumer {callback} exception {e}. Sending to infeed the error result.")
logger.debug("Outside call consumer %s exception %s. Sending to infeed the error result.",
callback, e)
device.transfer_to_infeed(tuple(canonical_flat_results))
else:
if logging.vlog_is_on(2):
logging.vlog(2, f"Outside call consumer {callback} exception {e}. On TPU we do not send infeed.")
logger.debug("Outside call consumer %s exception %s. On TPU we do not send infeed.",
callback, e)
raise e # Let the exception propagate
@ -1860,17 +1861,16 @@ _callback_handler_data = _CallbackHandlerData()
# This function is called from C++; it must not allow exceptions through.
def _callback_input_received(device, consumer_id, arrays: Tuple):
logging.vlog(
2,
f"Callback input received on device {device} for consumer {consumer_id} "
+ "arrays: " + (", ".join([f"({a.dtype}{a.shape})" for a in arrays])))
array_repr = ", ".join([f"({a.dtype}{a.shape})" for a in arrays])
logger.debug("Callback input received on device %s for consumer %s arrays: %s",
device, consumer_id, array_repr)
callback = _callback_handler_data.callback_registry_by_id.get(consumer_id)
assert callback is not None, "We should have crashed in the runtime"
try:
return callback(arrays, device)
except Exception as e:
formatted_e = traceback.format_exc()
logging.error("Postponing exception raised in callback function: %s", formatted_e)
logger.error("Postponing exception raised in callback function: %s", formatted_e)
_callback_handler_data.last_callback_exception = (e, formatted_e)
@ -1920,11 +1920,10 @@ def _initialize_outfeed_receiver(
if clients_with_outfeed:
devices_with_outfeed = list(
itertools.chain(*[backend.local_devices() for backend in clients_with_outfeed]))
if logging.vlog_is_on(2):
logging.vlog(
2,
f"Starting outfeed_receiver for {[str(d) for d in devices_with_outfeed]}. "
f"max_callback_queue_size_bytes={max_callback_queue_size_bytes}")
if logger.isEnabledFor(logging.DEBUG):
device_repr = ", ".join([str(d) for d in devices_with_outfeed])
logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s",
device_repr, max_callback_queue_size_bytes)
_callback_handler_data.receiver = outfeed_receiver_module.start(
_callback_input_received, tuple(clients_with_outfeed),
max_callback_queue_size_bytes)
@ -1959,43 +1958,40 @@ def barrier_wait(logging_name: Optional[str] = None):
for this invocation. See `Debugging` in the module documentation.
"""
logging_name = logging_name or ""
if logging.vlog_is_on(2):
logging.vlog(2, f"barrier_wait[{logging_name}]: start")
logger.debug("barrier_wait[%s]: start", logging_name)
lock = threading.Lock()
cv = threading.Condition(lock=lock)
devices_at_barrier = [] # Protected by lock
def barrier_tap_received(dev_idx, _):
device = _callback_handler_data.devices[dev_idx]
if logging.vlog_is_on(2):
logging.vlog(
2,
f"barrier_wait[{logging_name}]: at barrier_tap for device {device} "
f". Thread {threading.current_thread()}")
logger.debug(
"barrier_wait[%s]: at barrier_tap for device %s. Thread %s",
logging_name, device, threading.current_thread()
)
with lock:
devices_at_barrier.append(device)
if logging.vlog_is_on(2):
if logger.isEnabledFor(logging.DEBUG):
waiting_for_devices = [d for d in _callback_handler_data.devices
if d not in devices_at_barrier]
logging.vlog(2,
f"barrier_wait[{logging_name}]: still waiting "
f"for {len(waiting_for_devices)} devices at "
f"barrier ({waiting_for_devices})")
logger.debug(
"barrier_wait[%s]: still waiting for %s devices at barrier (%s)",
logging_name, len(waiting_for_devices), waiting_for_devices
)
cv.notify()
for d_idx, d in enumerate(_callback_handler_data.devices):
if logging.vlog_is_on(2):
logging.vlog(2,
f"barrier_wait[{logging_name}]: enqueueing barrier on device {d}")
logger.debug("barrier_wait[%s]: enqueueing barrier on device %s", logging_name, d)
x_on_dev = api.device_put(d_idx, device=d)
api.jit(lambda x: id_tap(barrier_tap_received, x), device=d)(x_on_dev)
if logging.vlog_is_on(2):
logging.vlog(2,
f"barrier_wait[{logging_name}]: waiting for callbacks")
logger.debug("barrier_wait[%s]: waiting for callbacks", logging_name)
with lock:
cv.wait_for(lambda: len(devices_at_barrier) == len(_callback_handler_data.devices))
if logging.vlog_is_on(2):
logging.vlog(2, f"barrier_wait[{logging_name}]: done")
logger.debug("barrier_wait[%s]: done", logging_name)
if _callback_handler_data.last_callback_exception is not None:
last_exception, formatted_last_exception = _callback_handler_data.last_callback_exception
_callback_handler_data.last_callback_exception = None

View File

@ -14,13 +14,12 @@
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
from functools import partial
import contextlib
import logging
import os
import re
import threading
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, cast
from absl import logging
import jax
from jax import lax
from jax import config
@ -33,7 +32,6 @@ from jax.experimental import pjit
from jax._src import sharding
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import partial_eval
from jax.interpreters import pxla
from jax.interpreters import xla
@ -52,7 +50,6 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import slicing as lax_slicing
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src import lib as jaxlib
from jax._src.lib import xla_client
from jax.experimental.global_device_array import GlobalDeviceArray
@ -94,6 +91,8 @@ _INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\/-]")
map = util.safe_map
zip = util.safe_zip
logger = logging.getLogger(__name__)
def _sanitize_scope_name(name):
scope_name = _INVALID_SCOPE_CHAR.sub("_", name)
@ -636,7 +635,7 @@ def _lower_native_and_run(fun_jax: Callable,
lowered = fun_jax_lower(*arg_specs_jax)._lowering
mhlo_module = lowered.mhlo()
mhlo_module_text = mlir.module_to_string(mhlo_module)
logging.vlog(2, f"XlaCallModule {mhlo_module_text}")
logger.debug("XlaCallModule %s", mhlo_module_text)
# We do not support custom_call, try to give an error for now
if "mhlo.custom_call" in mhlo_module_text:
# Try to give a nice error message. We could just dump the module...

View File

@ -15,12 +15,14 @@
"""Utils for building a device mesh."""
import itertools
import logging
from typing import Any, Dict, List, Optional, Sequence, Tuple
from absl import logging
import jax
import numpy as np
logger = logging.getLogger(__name__)
_TPU_V2 = 'TPU v2'
_TPU_V3 = 'TPU v3'
_TPU_V4 = 'TPU v4'
@ -248,14 +250,14 @@ def create_device_mesh(
device_kind = devices[-1].device_kind
if device_kind in (_TPU_V2, _TPU_V3):
if len(devices) == 8:
logging.info('Reordering mesh to physical ring order on single-tray TPU v2/v3.')
logger.info('Reordering mesh to physical ring order on single-tray TPU v2/v3.')
device_mesh = np.asarray(devices)
device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)]
device_mesh = device_mesh.reshape(mesh_shape)
return device_mesh
elif mesh_shape[-1] == 8:
device_mesh = np.asarray(devices).reshape(mesh_shape)
logging.info('Reordering mesh to physical ring order on each TPU v2/v3 tray.')
logger.info('Reordering mesh to physical ring order on each TPU v2/v3 tray.')
perm = np.array(_TRAY_RING_ORDER)
device_mesh = device_mesh[..., perm]
return device_mesh
@ -270,7 +272,7 @@ def create_device_mesh(
physical_mesh = _transpose_trick(physical_mesh, mesh_shape)
device_mesh, assignment = _create_device_mesh_for_nd_torus(
physical_mesh, mesh_shape)
logging.info('_create_device_mesh_for_nd_torus assignment: %s', assignment)
logger.info('_create_device_mesh_for_nd_torus assignment: %s', assignment)
return device_mesh
else:
device_mesh = np.asarray(devices).reshape(mesh_shape)

View File

@ -36,6 +36,7 @@ from collections import defaultdict, OrderedDict
import dataclasses
from functools import partial, lru_cache
import itertools as it
import logging
import operator as op
import sys
import warnings
@ -45,7 +46,6 @@ from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
TYPE_CHECKING)
from absl import logging
import numpy as np
@ -106,6 +106,8 @@ xe = xc._xla
unsafe_map, map = map, safe_map # type: ignore
logger = logging.getLogger(__name__)
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
NoSharding = pmap_lib.NoSharding
@ -1385,19 +1387,19 @@ def lower_parallel_callable(
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
pci, fun, global_arg_shapes)
if logging.vlog_is_on(2):
logging.vlog(2, "sharded_avals: %s", shards.sharded_avals)
logging.vlog(2, "global_sharded_avals: %s", shards.global_sharded_avals)
logging.vlog(2, "num_replicas: %d num_local_replicas: %d",
if logger.isEnabledFor(logging.DEBUG):
logger.debug("sharded_avals: %s", shards.sharded_avals)
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
logger.debug("num_replicas: %d num_local_replicas: %d",
replicas.num_global_replicas, replicas.num_local_replicas)
logging.vlog(2, "num_partitions: %d local_num_partitions: %d",
logger.debug("num_partitions: %d local_num_partitions: %d",
parts.num_partitions, parts.local_num_partitions)
logging.vlog(2, "arg_parts: %s", parts.arg_parts)
logging.vlog(2, "local_arg_parts: %s", parts.local_arg_parts)
logging.vlog(2, "out_parts: %s", parts.out_parts)
logging.vlog(2, "local_out_parts: %s", parts.local_out_parts)
logging.vlog(2, "devices: %s", devices)
logging.vlog(2, "local_devices: %s", pci.local_devices)
logger.debug("arg_parts: %s", parts.arg_parts)
logger.debug("local_arg_parts: %s", parts.local_arg_parts)
logger.debug("out_parts: %s", parts.out_parts)
logger.debug("local_out_parts: %s", parts.local_out_parts)
logger.debug("devices: %s", devices)
logger.debug("local_devices: %s", pci.local_devices)
if (xb.process_count(backend) > 1 and must_run_on_all_devices and
shards.num_local_shards != xb.local_device_count(backend)):
@ -1425,7 +1427,7 @@ def lower_parallel_callable(
f"{replicas.jaxpr_replicas} and nested_partitions={parts.num_partitions}")
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
logger.log(log_priority,
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d"
" num_partitions=%d)", fun.__name__, id(fun),
shards.num_global_shards, avals, replicas.num_global_replicas,
@ -2785,7 +2787,7 @@ def lower_sharding_computation(
if _is_unspecified(i) else i for i in in_shardings)
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
logger.log(log_priority,
"Compiling %s (%d) for with global shapes and types %s. "
"Argument mapping: %s.",
getattr(fun, '__name__', '<unnamed function>'), id(fun),
@ -2964,7 +2966,7 @@ def lower_mesh_computation(
global_axis_sizes = mesh.shape
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
logger.log(log_priority,
"Compiling %s (%d) for %s mesh with global shapes and types %s. "
"Argument mapping: %s.",
getattr(fun, '__name__', '<unnamed function>'), id(fun),

View File

@ -64,7 +64,6 @@ setup(
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
python_requires='>=3.7',
install_requires=[
'absl-py',
'numpy>=1.20',
'opt_einsum',
'scipy>=1.5',