mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others): - absl-py can be removed as an external dependency of JAX. - Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams. Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging: ```py import logging logger = logging.getLogger(__name__) logger.debug(...) logger.info(...) ``` The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation. The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
This commit is contained in:
parent
0313b2241d
commit
efd61b73f6
@ -1,3 +1,4 @@
|
||||
absl-py
|
||||
cloudpickle
|
||||
colorama>=0.4.4
|
||||
matplotlib
|
||||
|
@ -18,7 +18,6 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
|
||||
FrozenSet)
|
||||
import types
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
|
@ -12,10 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Type, Sequence, Tuple
|
||||
from absl import logging
|
||||
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ClusterEnv:
|
||||
"""Interface for defining a cluster environment.
|
||||
@ -48,7 +49,7 @@ class ClusterEnv:
|
||||
local_device_ids)
|
||||
env = next((env for env in cls._cluster_types if env.is_env_present()), None)
|
||||
if env:
|
||||
logging.vlog(1, 'Initializing distributed JAX environment via %s', env.__name__)
|
||||
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
|
||||
if coordinator_address is None:
|
||||
coordinator_address = env.get_coordinator_address()
|
||||
if num_processes is None:
|
||||
@ -64,7 +65,7 @@ class ClusterEnv:
|
||||
env.get_local_process_id() is not None):
|
||||
local_device_ids = [env.get_local_process_id()] # type: ignore[list-item]
|
||||
else:
|
||||
logging.vlog(1, 'Could not find a known environment for initializing distributed JAX. '
|
||||
logger.debug('Could not find a known environment for initializing distributed JAX. '
|
||||
'Known environments: %s', ', '.join(e.__name__ for e in cls._cluster_types))
|
||||
return (coordinator_address, num_processes, process_id, local_device_ids)
|
||||
# pytype: enable=bad-return-type
|
||||
|
@ -18,19 +18,20 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import transfer_guard_lib
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def bool_env(varname: str, default: bool) -> bool:
|
||||
"""Read an environment variable and interpret it as a boolean.
|
||||
|
||||
@ -643,7 +644,7 @@ log_compiles = config.define_bool_state(
|
||||
name='jax_log_compiles',
|
||||
default=False,
|
||||
help=('Log a message each time every time `jit` or `pmap` compiles an XLA '
|
||||
'computation. Logging is performed with `absl.logging`. When this '
|
||||
'computation. Logging is performed with `logging`. When this '
|
||||
'option is set, the log level is WARNING; otherwise the level is '
|
||||
'DEBUG.'))
|
||||
|
||||
@ -674,7 +675,7 @@ distributed_debug = config.define_bool_state(
|
||||
name='jax_distributed_debug',
|
||||
default=False,
|
||||
help=('Enable logging useful for debugging multi-process distributed '
|
||||
'computations. Logging is performed with `absl.logging` at WARNING '
|
||||
'computations. Logging is performed with `logging` at WARNING '
|
||||
'level.'))
|
||||
|
||||
|
||||
@ -772,7 +773,7 @@ def _validate_default_device(val):
|
||||
# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
|
||||
# all JAX backends use a single C++ device interface.
|
||||
if 'Device' in str(type(val)):
|
||||
logging.info(
|
||||
logger.info(
|
||||
'Allowing non-`xla_client.Device` default device: %s, type: %s',
|
||||
repr(val), type(val))
|
||||
return
|
||||
|
@ -22,14 +22,14 @@ import itertools
|
||||
import time
|
||||
from typing import (
|
||||
Any, Callable, Dict, Optional, Sequence, Set, Tuple, List, Type, Union,
|
||||
TYPE_CHECKING, Iterator)
|
||||
Iterator)
|
||||
from typing_extensions import Protocol
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
@ -82,6 +82,8 @@ CompileOptions = xc.CompileOptions
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This flag is set on exit; no logging should be attempted
|
||||
_on_exit = False
|
||||
|
||||
@ -360,7 +362,7 @@ def log_elapsed_time(fmt: str):
|
||||
start_time = time.time()
|
||||
yield
|
||||
elapsed_time = time.time() - start_time
|
||||
logging.log(log_priority, fmt.format(elapsed_time=elapsed_time))
|
||||
logger.log(log_priority, fmt.format(elapsed_time=elapsed_time))
|
||||
|
||||
|
||||
def should_tuple_args(num_args: int, platform: str):
|
||||
@ -476,7 +478,7 @@ def lower_xla_callable(
|
||||
msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args."
|
||||
else:
|
||||
msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}."
|
||||
logging.log(log_priority, msg)
|
||||
logger.log(log_priority, msg)
|
||||
|
||||
raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr)
|
||||
|
||||
@ -1041,7 +1043,7 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
||||
cached_executable = _cache_read(serialized_computation, module_name,
|
||||
compile_options, backend)
|
||||
if cached_executable is not None:
|
||||
logging.info("Persistent compilation cache hit for '%s'", module_name)
|
||||
logger.info("Persistent compilation cache hit for '%s'", module_name)
|
||||
return cached_executable
|
||||
else:
|
||||
compiled = backend_compile(backend, serialized_computation,
|
||||
|
@ -13,16 +13,17 @@
|
||||
# limitations under the License.
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import os
|
||||
import functools
|
||||
|
||||
from typing import Any, Optional, Union, Sequence
|
||||
|
||||
from absl import logging
|
||||
from jax._src.clusters import ClusterEnv
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_extension
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class State:
|
||||
process_id: int = 0
|
||||
@ -55,7 +56,7 @@ class State:
|
||||
|
||||
if local_device_ids:
|
||||
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
|
||||
logging.info('JAX distributed initialized with visible devices: %s', visible_devices)
|
||||
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
|
||||
config.update("jax_cuda_visible_devices", visible_devices)
|
||||
config.update("jax_rocm_visible_devices", visible_devices)
|
||||
|
||||
@ -64,7 +65,7 @@ class State:
|
||||
if process_id == 0:
|
||||
if self.service is not None:
|
||||
raise RuntimeError('distributed.initialize should only be called once.')
|
||||
logging.info('Starting JAX distributed service on %s', coordinator_address)
|
||||
logger.info('Starting JAX distributed service on %s', coordinator_address)
|
||||
self.service = xla_extension.get_distributed_runtime_service(
|
||||
coordinator_address, num_processes, config.jax_coordination_service)
|
||||
|
||||
@ -75,7 +76,7 @@ class State:
|
||||
self.client = xla_extension.get_distributed_runtime_client(
|
||||
coordinator_address, process_id, config.jax_coordination_service,
|
||||
init_timeout=300)
|
||||
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
|
||||
logger.info('Connecting to JAX distributed service on %s', coordinator_address)
|
||||
self.client.connect()
|
||||
|
||||
if config.jax_coordination_service:
|
||||
|
@ -20,16 +20,13 @@ XLA. There are also a handful of related casting utilities.
|
||||
"""
|
||||
|
||||
from functools import partial, lru_cache
|
||||
import logging
|
||||
import os
|
||||
import platform as py_platform
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
# Disable "WARNING: Logging before flag parsing goes to stderr." message
|
||||
logging._warn_preinit_stderr = 0
|
||||
|
||||
import jax._src.lib as lib
|
||||
from jax._src.config import flags, bool_env, int_env
|
||||
from jax._src import distributed
|
||||
@ -58,6 +55,8 @@ ShardedBuffer = Any
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO(phawkins): Remove jax_xla_backend.
|
||||
flags.DEFINE_string(
|
||||
'jax_xla_backend', '',
|
||||
@ -126,8 +125,7 @@ def get_compile_options(
|
||||
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape
|
||||
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids
|
||||
if device_assignment is not None:
|
||||
logging.vlog(
|
||||
2,
|
||||
logger.debug(
|
||||
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
|
||||
num_replicas, num_partitions, device_assignment)
|
||||
device_assignment = np.array(device_assignment)
|
||||
@ -170,10 +168,10 @@ def get_compile_options(
|
||||
|
||||
def _make_tpu_driver_client():
|
||||
if tpu_driver_client is None:
|
||||
logging.info("Remote TPU is not linked into jax; skipping remote TPU.")
|
||||
logger.info("Remote TPU is not linked into jax; skipping remote TPU.")
|
||||
return None
|
||||
if FLAGS.jax_backend_target is None:
|
||||
logging.info("No --jax_backend_target was provided; skipping remote TPU.")
|
||||
logger.info("No --jax_backend_target was provided; skipping remote TPU.")
|
||||
return None
|
||||
return tpu_driver_client.TpuBackend.create(worker=FLAGS.jax_backend_target)
|
||||
|
||||
@ -354,14 +352,14 @@ def backends():
|
||||
raise RuntimeError(err_msg)
|
||||
else:
|
||||
_backends_errors[platform] = str(err)
|
||||
logging.info(err_msg)
|
||||
logger.info(err_msg)
|
||||
continue
|
||||
# We don't warn about falling back to CPU on Mac OS, because we don't
|
||||
# support anything else there at the moment and warning would be pointless.
|
||||
if (py_platform.system() != "Darwin" and
|
||||
_default_backend.platform == "cpu" and
|
||||
FLAGS.jax_platform_name != 'cpu'):
|
||||
logging.warning('No GPU/TPU found, falling back to CPU. '
|
||||
logger.warning('No GPU/TPU found, falling back to CPU. '
|
||||
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
|
||||
return _backends
|
||||
|
||||
@ -371,7 +369,7 @@ def _clear_backends():
|
||||
global _backends_errors
|
||||
global _default_backend
|
||||
|
||||
logging.info("Clearing JAX backend caches.")
|
||||
logger.info("Clearing JAX backend caches.")
|
||||
with _backend_lock:
|
||||
_backends = {}
|
||||
_backends_errors = {}
|
||||
@ -385,7 +383,7 @@ def _init_backend(platform):
|
||||
if factory is None:
|
||||
raise RuntimeError(f"Unknown backend '{platform}'")
|
||||
|
||||
logging.vlog(1, "Initializing backend '%s'" % platform)
|
||||
logger.debug("Initializing backend '%s'", platform)
|
||||
backend = factory()
|
||||
# TODO(skye): consider raising more descriptive errors directly from backend
|
||||
# factories instead of returning None.
|
||||
@ -397,7 +395,7 @@ def _init_backend(platform):
|
||||
("process_index", backend.process_index()),
|
||||
("device_count", backend.device_count()),
|
||||
("local_devices", backend.local_devices()))
|
||||
logging.vlog(1, "Backend '%s' initialized" % platform)
|
||||
logger.debug("Backend '%s' initialized", platform)
|
||||
return backend
|
||||
|
||||
|
||||
|
@ -18,14 +18,13 @@ import glob
|
||||
import gzip
|
||||
import http.server
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socketserver
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from absl import logging
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
@ -34,6 +33,8 @@ from jax._src.lib import xla_client
|
||||
|
||||
_profiler_server: Optional[xla_client.profiler.ProfilerServer] = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def start_server(port: int):
|
||||
"""Starts the profiler server on port `port`.
|
||||
@ -134,7 +135,7 @@ def _write_perfetto_trace_file(log_dir):
|
||||
raise ValueError(f"Invalid trace folder: {latest_folder}")
|
||||
trace_json, = trace_jsons
|
||||
|
||||
logging.info("Loading trace.json.gz and removing its metadata...")
|
||||
logger.info("Loading trace.json.gz and removing its metadata...")
|
||||
# Perfetto doesn't like the `metadata` field in `trace.json` so we remove
|
||||
# it.
|
||||
# TODO(sharadmv): speed this up by updating the generated `trace.json`
|
||||
@ -144,7 +145,7 @@ def _write_perfetto_trace_file(log_dir):
|
||||
del trace["metadata"]
|
||||
filename = "perfetto_trace.json.gz"
|
||||
perfetto_trace = os.path.join(latest_folder, filename)
|
||||
logging.info("Writing perfetto_trace.json.gz...")
|
||||
logger.info("Writing perfetto_trace.json.gz...")
|
||||
with gzip.open(perfetto_trace, "w") as fp:
|
||||
fp.write(json.dumps(trace).encode("utf-8"))
|
||||
return perfetto_trace
|
||||
|
@ -12,11 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from collections import namedtuple
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
import operator
|
||||
import types
|
||||
import threading
|
||||
@ -24,13 +24,14 @@ from typing import (Any, Callable, Dict, Iterable, List, Tuple, Generic,
|
||||
TypeVar, Set, Iterator, Sequence, Optional)
|
||||
import weakref
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Seq = Sequence
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -510,7 +511,7 @@ def distributed_debug_log(*pairs):
|
||||
lines.append("DISTRIBUTED_DEBUG logging failed!")
|
||||
lines.append(f"{e}")
|
||||
lines.append("DISTRIBUTED_DEBUG_END")
|
||||
logging.warning("\n".join(lines))
|
||||
logger.warning("\n".join(lines))
|
||||
|
||||
|
||||
class OrderedSet(Generic[T]):
|
||||
|
@ -13,27 +13,29 @@
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import jax
|
||||
from jax.experimental.compilation_cache.gfile_cache import GFileCache
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax._src.lib import xla_client
|
||||
from jax.interpreters import xla
|
||||
from absl import logging
|
||||
|
||||
_cache = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_cache(path):
|
||||
"""Creates a global cache object. Should only be called once per process.
|
||||
"""
|
||||
global _cache
|
||||
assert _cache == None, f"The cache path has already been initialized to {_cache._path}"
|
||||
_cache = GFileCache(path)
|
||||
logging.warning("Initialized persistent compilation cache at %s", path)
|
||||
logger.warning("Initialized persistent compilation cache at %s", path)
|
||||
|
||||
|
||||
def get_executable(xla_computation, compile_options,
|
||||
@ -54,20 +56,20 @@ def put_executable(module_name, xla_computation, compile_options,
|
||||
"""Adds 'executable' to the cache, possibly evicting older entries."""
|
||||
assert _cache is not None, "initialize_cache must be called before you can call put_executable()"
|
||||
cache_key = get_cache_key(xla_computation, compile_options, backend)
|
||||
logging.info('Writing %s to persistent compilation cache with key %s.',
|
||||
logger.info('Writing %s to persistent compilation cache with key %s.',
|
||||
module_name, cache_key)
|
||||
serialized_executable = backend.serialize_executable(executable)
|
||||
_cache.put(cache_key, serialized_executable)
|
||||
|
||||
def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
|
||||
if logging.vlog_is_on(1):
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
# Log the hash of just this entry
|
||||
fresh_hash_obj = hashlib.sha256()
|
||||
hashfn(fresh_hash_obj)
|
||||
logging.vlog(1, "get_cache_key hash of serialized %s: %s", last_serialized,
|
||||
logger.debug("get_cache_key hash of serialized %s: %s", last_serialized,
|
||||
fresh_hash_obj.digest().hex())
|
||||
# Log the cumulative hash
|
||||
logging.vlog(1, "get_cache_key hash after serializing %s: %s",
|
||||
logger.debug("get_cache_key hash after serializing %s: %s",
|
||||
last_serialized, hash_obj.digest().hex())
|
||||
|
||||
def get_cache_key(xla_computation, compile_options, backend) -> str:
|
||||
@ -215,9 +217,9 @@ def _hash_xla_flags(hash_obj):
|
||||
# (e.g. --xla_force_host_platform_device_count=8) (I think).
|
||||
for flag in xla_flags:
|
||||
if flag.split('=')[0] in _xla_flags_to_exclude_from_cache_key:
|
||||
logging.vlog(1, "Not including XLA flag in cache key: %s", flag)
|
||||
logger.debug("Not including XLA flag in cache key: %s", flag)
|
||||
continue
|
||||
logging.vlog(1, "Including XLA flag in cache key: %s", flag)
|
||||
logger.debug("Including XLA flag in cache key: %s", flag)
|
||||
_hash_string(hash_obj, flag)
|
||||
|
||||
def _hash_int(hash_obj, int_var):
|
||||
|
@ -16,11 +16,11 @@
|
||||
import abc
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
from functools import partial
|
||||
import re
|
||||
import threading
|
||||
from typing import Callable, Sequence, Optional, Dict, Any
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
from jax._src import distributed
|
||||
@ -40,6 +40,8 @@ _REMOVED_VALUE = 'Value removed'
|
||||
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
|
||||
_module_unique_count = itertools.count()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_async_array_from_callback(
|
||||
global_shape: array.Shape,
|
||||
@ -385,7 +387,7 @@ class AsyncManager:
|
||||
|
||||
def __del__(self):
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
logging.warning('Please add `.wait_until_finished()` in the main thread '
|
||||
logger.warning('Please add `.wait_until_finished()` in the main thread '
|
||||
'before your program finishes because there is a '
|
||||
'possibility of losing errors raised if the '
|
||||
'this class is deleted before writing is completed.')
|
||||
@ -393,20 +395,20 @@ class AsyncManager:
|
||||
def _thread_func(self):
|
||||
try:
|
||||
current_process = jax.process_index()
|
||||
logging.info('Starting commit to storage layer by process: %s',
|
||||
logger.info('Starting commit to storage layer by process: %s',
|
||||
current_process)
|
||||
for future in self._commit_futures:
|
||||
future.result()
|
||||
logging.info('Finished committing to storage layer by process: %s',
|
||||
logger.info('Finished committing to storage layer by process: %s',
|
||||
current_process)
|
||||
|
||||
# All processes will wait at the barrier. When all processes are at the
|
||||
# barrier, the barrier will be satisfied. If not, then it will timeout.
|
||||
key_for_barrier = _get_key(self._count)
|
||||
logging.info('Key used for barrier is %s for process %s',
|
||||
logger.info('Key used for barrier is %s for process %s',
|
||||
key_for_barrier, current_process)
|
||||
self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms)
|
||||
logging.info('Finished waiting at barrier for process %s',
|
||||
logger.info('Finished waiting at barrier for process %s',
|
||||
current_process)
|
||||
|
||||
if current_process == 0:
|
||||
@ -471,7 +473,7 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
|
||||
final_checkpoint_dir: Final checkpoint directory where the checkpoints
|
||||
will be moved from `temp_checkpoint_dir`.
|
||||
"""
|
||||
logging.info('Waiting for previous serialization to finish.')
|
||||
logger.info('Waiting for previous serialization to finish.')
|
||||
self.wait_until_finished()
|
||||
|
||||
commit_futures = [[] for _ in range(len(tensorstore_specs))]
|
||||
|
@ -497,14 +497,13 @@ Still to do:
|
||||
import atexit
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import threading
|
||||
import traceback
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence,
|
||||
Tuple, cast)
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
|
||||
from jax._src import api
|
||||
from jax import core
|
||||
from jax.config import config
|
||||
@ -532,6 +531,8 @@ import numpy as np
|
||||
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _inline_host_callback() -> bool:
|
||||
return FLAGS.jax_host_callback_inline
|
||||
@ -1231,9 +1232,10 @@ def _outside_call_run_callback(
|
||||
try:
|
||||
arg = api.tree_unflatten(arg_treedef, arrays)
|
||||
unpacked_transforms = _unpack_transforms(transforms)
|
||||
if logging.vlog_is_on(2):
|
||||
logging.vlog(2,
|
||||
f"Outside call invoking call_func {callback}, device={device}, transforms={unpacked_transforms}")
|
||||
logger.debug(
|
||||
"Outside call invoking call_func %s, device=%s, transforms=%s",
|
||||
callback, device, unpacked_transforms
|
||||
)
|
||||
res = callback(arg, device, unpacked_transforms)
|
||||
if identity:
|
||||
return tuple(arrays)
|
||||
@ -1250,10 +1252,9 @@ def _outside_call_run_callback(
|
||||
|
||||
canonical_flat_results = tuple(util.safe_map(xla.canonicalize_dtype, actual_flat_results))
|
||||
actual_flat_results_aval = _values_to_avals(canonical_flat_results)
|
||||
if logging.vlog_is_on(2):
|
||||
logging.vlog(
|
||||
2,
|
||||
f"Outside call {callback} result {flat_results_aval}. Sending to infeed for device {device}."
|
||||
logger.debug(
|
||||
"Outside call %s result %s. Sending to infeed for device %s.",
|
||||
callback, flat_results_aval, device,
|
||||
)
|
||||
|
||||
if not all(ea.strip_weak_type() == ra.strip_weak_type()
|
||||
@ -1273,7 +1274,7 @@ def _outside_call_run_callback(
|
||||
return canonical_flat_results
|
||||
|
||||
except Exception as e:
|
||||
logging.error("Outside call %s threw exception %s.", callback, e)
|
||||
logger.error("Outside call %s threw exception %s.", callback, e)
|
||||
if send_infeed:
|
||||
# Prepare some results to send in case of error. We are sending something
|
||||
# with a distinctive shape (int8[12345]), one that is unlikely to be what the device
|
||||
@ -1285,12 +1286,12 @@ def _outside_call_run_callback(
|
||||
# TODO: implement a proper error handling for TPU
|
||||
if device.platform != "tpu":
|
||||
canonical_flat_results = [xla.canonicalize_dtype(np.arange(12345, dtype=np.int8))]
|
||||
if logging.vlog_is_on(2):
|
||||
logging.vlog(2, f"Outside call consumer {callback} exception {e}. Sending to infeed the error result.")
|
||||
logger.debug("Outside call consumer %s exception %s. Sending to infeed the error result.",
|
||||
callback, e)
|
||||
device.transfer_to_infeed(tuple(canonical_flat_results))
|
||||
else:
|
||||
if logging.vlog_is_on(2):
|
||||
logging.vlog(2, f"Outside call consumer {callback} exception {e}. On TPU we do not send infeed.")
|
||||
logger.debug("Outside call consumer %s exception %s. On TPU we do not send infeed.",
|
||||
callback, e)
|
||||
raise e # Let the exception propagate
|
||||
|
||||
|
||||
@ -1860,17 +1861,16 @@ _callback_handler_data = _CallbackHandlerData()
|
||||
|
||||
# This function is called from C++; it must not allow exceptions through.
|
||||
def _callback_input_received(device, consumer_id, arrays: Tuple):
|
||||
logging.vlog(
|
||||
2,
|
||||
f"Callback input received on device {device} for consumer {consumer_id} "
|
||||
+ "arrays: " + (", ".join([f"({a.dtype}{a.shape})" for a in arrays])))
|
||||
array_repr = ", ".join([f"({a.dtype}{a.shape})" for a in arrays])
|
||||
logger.debug("Callback input received on device %s for consumer %s arrays: %s",
|
||||
device, consumer_id, array_repr)
|
||||
callback = _callback_handler_data.callback_registry_by_id.get(consumer_id)
|
||||
assert callback is not None, "We should have crashed in the runtime"
|
||||
try:
|
||||
return callback(arrays, device)
|
||||
except Exception as e:
|
||||
formatted_e = traceback.format_exc()
|
||||
logging.error("Postponing exception raised in callback function: %s", formatted_e)
|
||||
logger.error("Postponing exception raised in callback function: %s", formatted_e)
|
||||
_callback_handler_data.last_callback_exception = (e, formatted_e)
|
||||
|
||||
|
||||
@ -1920,11 +1920,10 @@ def _initialize_outfeed_receiver(
|
||||
if clients_with_outfeed:
|
||||
devices_with_outfeed = list(
|
||||
itertools.chain(*[backend.local_devices() for backend in clients_with_outfeed]))
|
||||
if logging.vlog_is_on(2):
|
||||
logging.vlog(
|
||||
2,
|
||||
f"Starting outfeed_receiver for {[str(d) for d in devices_with_outfeed]}. "
|
||||
f"max_callback_queue_size_bytes={max_callback_queue_size_bytes}")
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
device_repr = ", ".join([str(d) for d in devices_with_outfeed])
|
||||
logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s",
|
||||
device_repr, max_callback_queue_size_bytes)
|
||||
_callback_handler_data.receiver = outfeed_receiver_module.start(
|
||||
_callback_input_received, tuple(clients_with_outfeed),
|
||||
max_callback_queue_size_bytes)
|
||||
@ -1959,43 +1958,40 @@ def barrier_wait(logging_name: Optional[str] = None):
|
||||
for this invocation. See `Debugging` in the module documentation.
|
||||
"""
|
||||
logging_name = logging_name or ""
|
||||
if logging.vlog_is_on(2):
|
||||
logging.vlog(2, f"barrier_wait[{logging_name}]: start")
|
||||
logger.debug("barrier_wait[%s]: start", logging_name)
|
||||
|
||||
lock = threading.Lock()
|
||||
cv = threading.Condition(lock=lock)
|
||||
devices_at_barrier = [] # Protected by lock
|
||||
def barrier_tap_received(dev_idx, _):
|
||||
device = _callback_handler_data.devices[dev_idx]
|
||||
if logging.vlog_is_on(2):
|
||||
logging.vlog(
|
||||
2,
|
||||
f"barrier_wait[{logging_name}]: at barrier_tap for device {device} "
|
||||
f". Thread {threading.current_thread()}")
|
||||
logger.debug(
|
||||
"barrier_wait[%s]: at barrier_tap for device %s. Thread %s",
|
||||
logging_name, device, threading.current_thread()
|
||||
)
|
||||
with lock:
|
||||
devices_at_barrier.append(device)
|
||||
if logging.vlog_is_on(2):
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
waiting_for_devices = [d for d in _callback_handler_data.devices
|
||||
if d not in devices_at_barrier]
|
||||
logging.vlog(2,
|
||||
f"barrier_wait[{logging_name}]: still waiting "
|
||||
f"for {len(waiting_for_devices)} devices at "
|
||||
f"barrier ({waiting_for_devices})")
|
||||
logger.debug(
|
||||
"barrier_wait[%s]: still waiting for %s devices at barrier (%s)",
|
||||
logging_name, len(waiting_for_devices), waiting_for_devices
|
||||
)
|
||||
cv.notify()
|
||||
|
||||
for d_idx, d in enumerate(_callback_handler_data.devices):
|
||||
if logging.vlog_is_on(2):
|
||||
logging.vlog(2,
|
||||
f"barrier_wait[{logging_name}]: enqueueing barrier on device {d}")
|
||||
logger.debug("barrier_wait[%s]: enqueueing barrier on device %s", logging_name, d)
|
||||
x_on_dev = api.device_put(d_idx, device=d)
|
||||
api.jit(lambda x: id_tap(barrier_tap_received, x), device=d)(x_on_dev)
|
||||
if logging.vlog_is_on(2):
|
||||
logging.vlog(2,
|
||||
f"barrier_wait[{logging_name}]: waiting for callbacks")
|
||||
|
||||
logger.debug("barrier_wait[%s]: waiting for callbacks", logging_name)
|
||||
|
||||
with lock:
|
||||
cv.wait_for(lambda: len(devices_at_barrier) == len(_callback_handler_data.devices))
|
||||
if logging.vlog_is_on(2):
|
||||
logging.vlog(2, f"barrier_wait[{logging_name}]: done")
|
||||
|
||||
logger.debug("barrier_wait[%s]: done", logging_name)
|
||||
|
||||
if _callback_handler_data.last_callback_exception is not None:
|
||||
last_exception, formatted_last_exception = _callback_handler_data.last_callback_exception
|
||||
_callback_handler_data.last_callback_exception = None
|
||||
|
@ -14,13 +14,12 @@
|
||||
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
|
||||
from functools import partial
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, cast
|
||||
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import config
|
||||
@ -33,7 +32,6 @@ from jax.experimental import pjit
|
||||
from jax._src import sharding
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
|
||||
@ -52,7 +50,6 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.lax import slicing as lax_slicing
|
||||
from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||||
from jax._src import lib as jaxlib
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
@ -94,6 +91,8 @@ _INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\/-]")
|
||||
map = util.safe_map
|
||||
zip = util.safe_zip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _sanitize_scope_name(name):
|
||||
scope_name = _INVALID_SCOPE_CHAR.sub("_", name)
|
||||
@ -636,7 +635,7 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
lowered = fun_jax_lower(*arg_specs_jax)._lowering
|
||||
mhlo_module = lowered.mhlo()
|
||||
mhlo_module_text = mlir.module_to_string(mhlo_module)
|
||||
logging.vlog(2, f"XlaCallModule {mhlo_module_text}")
|
||||
logger.debug("XlaCallModule %s", mhlo_module_text)
|
||||
# We do not support custom_call, try to give an error for now
|
||||
if "mhlo.custom_call" in mhlo_module_text:
|
||||
# Try to give a nice error message. We could just dump the module...
|
||||
|
@ -15,12 +15,14 @@
|
||||
"""Utils for building a device mesh."""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from absl import logging
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TPU_V2 = 'TPU v2'
|
||||
_TPU_V3 = 'TPU v3'
|
||||
_TPU_V4 = 'TPU v4'
|
||||
@ -248,14 +250,14 @@ def create_device_mesh(
|
||||
device_kind = devices[-1].device_kind
|
||||
if device_kind in (_TPU_V2, _TPU_V3):
|
||||
if len(devices) == 8:
|
||||
logging.info('Reordering mesh to physical ring order on single-tray TPU v2/v3.')
|
||||
logger.info('Reordering mesh to physical ring order on single-tray TPU v2/v3.')
|
||||
device_mesh = np.asarray(devices)
|
||||
device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)]
|
||||
device_mesh = device_mesh.reshape(mesh_shape)
|
||||
return device_mesh
|
||||
elif mesh_shape[-1] == 8:
|
||||
device_mesh = np.asarray(devices).reshape(mesh_shape)
|
||||
logging.info('Reordering mesh to physical ring order on each TPU v2/v3 tray.')
|
||||
logger.info('Reordering mesh to physical ring order on each TPU v2/v3 tray.')
|
||||
perm = np.array(_TRAY_RING_ORDER)
|
||||
device_mesh = device_mesh[..., perm]
|
||||
return device_mesh
|
||||
@ -270,7 +272,7 @@ def create_device_mesh(
|
||||
physical_mesh = _transpose_trick(physical_mesh, mesh_shape)
|
||||
device_mesh, assignment = _create_device_mesh_for_nd_torus(
|
||||
physical_mesh, mesh_shape)
|
||||
logging.info('_create_device_mesh_for_nd_torus assignment: %s', assignment)
|
||||
logger.info('_create_device_mesh_for_nd_torus assignment: %s', assignment)
|
||||
return device_mesh
|
||||
else:
|
||||
device_mesh = np.asarray(devices).reshape(mesh_shape)
|
||||
|
@ -36,6 +36,7 @@ from collections import defaultdict, OrderedDict
|
||||
import dataclasses
|
||||
from functools import partial, lru_cache
|
||||
import itertools as it
|
||||
import logging
|
||||
import operator as op
|
||||
import sys
|
||||
import warnings
|
||||
@ -45,7 +46,6 @@ from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
|
||||
TYPE_CHECKING)
|
||||
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -106,6 +106,8 @@ xe = xc._xla
|
||||
|
||||
unsafe_map, map = map, safe_map # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
|
||||
NoSharding = pmap_lib.NoSharding
|
||||
@ -1385,19 +1387,19 @@ def lower_parallel_callable(
|
||||
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
|
||||
pci, fun, global_arg_shapes)
|
||||
|
||||
if logging.vlog_is_on(2):
|
||||
logging.vlog(2, "sharded_avals: %s", shards.sharded_avals)
|
||||
logging.vlog(2, "global_sharded_avals: %s", shards.global_sharded_avals)
|
||||
logging.vlog(2, "num_replicas: %d num_local_replicas: %d",
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("sharded_avals: %s", shards.sharded_avals)
|
||||
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
|
||||
logger.debug("num_replicas: %d num_local_replicas: %d",
|
||||
replicas.num_global_replicas, replicas.num_local_replicas)
|
||||
logging.vlog(2, "num_partitions: %d local_num_partitions: %d",
|
||||
logger.debug("num_partitions: %d local_num_partitions: %d",
|
||||
parts.num_partitions, parts.local_num_partitions)
|
||||
logging.vlog(2, "arg_parts: %s", parts.arg_parts)
|
||||
logging.vlog(2, "local_arg_parts: %s", parts.local_arg_parts)
|
||||
logging.vlog(2, "out_parts: %s", parts.out_parts)
|
||||
logging.vlog(2, "local_out_parts: %s", parts.local_out_parts)
|
||||
logging.vlog(2, "devices: %s", devices)
|
||||
logging.vlog(2, "local_devices: %s", pci.local_devices)
|
||||
logger.debug("arg_parts: %s", parts.arg_parts)
|
||||
logger.debug("local_arg_parts: %s", parts.local_arg_parts)
|
||||
logger.debug("out_parts: %s", parts.out_parts)
|
||||
logger.debug("local_out_parts: %s", parts.local_out_parts)
|
||||
logger.debug("devices: %s", devices)
|
||||
logger.debug("local_devices: %s", pci.local_devices)
|
||||
|
||||
if (xb.process_count(backend) > 1 and must_run_on_all_devices and
|
||||
shards.num_local_shards != xb.local_device_count(backend)):
|
||||
@ -1425,7 +1427,7 @@ def lower_parallel_callable(
|
||||
f"{replicas.jaxpr_replicas} and nested_partitions={parts.num_partitions}")
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority,
|
||||
logger.log(log_priority,
|
||||
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d"
|
||||
" num_partitions=%d)", fun.__name__, id(fun),
|
||||
shards.num_global_shards, avals, replicas.num_global_replicas,
|
||||
@ -2785,7 +2787,7 @@ def lower_sharding_computation(
|
||||
if _is_unspecified(i) else i for i in in_shardings)
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority,
|
||||
logger.log(log_priority,
|
||||
"Compiling %s (%d) for with global shapes and types %s. "
|
||||
"Argument mapping: %s.",
|
||||
getattr(fun, '__name__', '<unnamed function>'), id(fun),
|
||||
@ -2964,7 +2966,7 @@ def lower_mesh_computation(
|
||||
global_axis_sizes = mesh.shape
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority,
|
||||
logger.log(log_priority,
|
||||
"Compiling %s (%d) for %s mesh with global shapes and types %s. "
|
||||
"Argument mapping: %s.",
|
||||
getattr(fun, '__name__', '<unnamed function>'), id(fun),
|
||||
|
Loading…
x
Reference in New Issue
Block a user